This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit d7987f4217d69810e946baeea681f678c3938348
Author: tqchen <[email protected]>
AuthorDate: Tue Apr 22 10:42:39 2025 -0400

    Upgrade FFI to accept ndarray as handle
---
 src/tir/transforms/arg_binder.cc      | 17 +++--------------
 src/tir/transforms/arg_binder.h       |  3 +--
 src/tir/transforms/make_packed_api.cc | 13 +++++++++----
 3 files changed, 13 insertions(+), 20 deletions(-)

diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index 4eec34dbac..9270a14df9 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -150,27 +150,16 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr, 
builtin::TVMStructFieldKind kin
 }
 
 void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
-                             const PrimExpr& device_id, const Var& ptr_handle,
-                             const PrimExpr& type_index, const std::string& 
arg_name) {
+                             const PrimExpr& device_id, const Var& handle,
+                            const std::string& arg_name) {
   const DataType tvm_shape_type = DataType::ShapeIndex();
   const DataType tvm_ndim_type = DataType::Int(32);
   const Stmt nop = Evaluate(0);
 
   init_nest_.emplace_back(AssertStmt(
-      !Call(DataType::Bool(), builtin::isnullptr(), {ptr_handle}),
+      !Call(DataType::Bool(), builtin::isnullptr(), {handle}),
       tvm::tir::StringImm(arg_name + " is expected to have non-NULL DLTensor* 
pointer"), nop));
 
-  Var handle = Var(ptr_handle->name_hint + "_as_dltensor", DataType::Handle());
-
-  PrimExpr handle_from_ndarray = Call(DataType::Handle(), 
tir::builtin::handle_add_byte_offset(),
-                                      {ptr_handle, IntImm(DataType::Int(32), 
16)});
-  // if the type index is not DLTensorPtr, we need to add the offset of the 
DLTensor header
-  // which always equals 16 bytes;
-  init_nest_.emplace_back(LetStmt(
-      handle,
-      Select(type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr, ptr_handle, 
handle_from_ndarray),
-      nop));
-
   // dimension checks
   PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
 
diff --git a/src/tir/transforms/arg_binder.h b/src/tir/transforms/arg_binder.h
index b831e797ea..68cbbb6773 100644
--- a/src/tir/transforms/arg_binder.h
+++ b/src/tir/transforms/arg_binder.h
@@ -97,11 +97,10 @@ class ArgBinder {
    * \param device_type The device id to be binded.
    * \param device_id The device id to be binded.
    * \param handle The DLTensor handle.
-   * \param type_index The type index of the DLTensor handle.
    * \param arg_name argument name.
    */
   void BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const 
PrimExpr& device_id,
-                    const Var& handle, const PrimExpr& type_index, const 
std::string& arg_name);
+                    const Var& handle, const std::string& arg_name);
 
   /*! \return The defs generated in binding. */
   const std::vector<Var>& defs() const { return defs_; }
diff --git a/src/tir/transforms/make_packed_api.cc 
b/src/tir/transforms/make_packed_api.cc
index 13cff5276b..d1931ebced 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -269,7 +269,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
   // Need to delay binding of the buffers, in case some arguments also
   // appear in the buffer.
   std::vector<std::pair<PrimExpr, Var>> var_def;
-  std::vector<std::tuple<Var, Buffer, Var>> buffer_def;
+  std::vector<std::pair<Var, Buffer>> buffer_def;
 
   for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
     Var param = func_ptr->params[i];
@@ -290,7 +290,12 @@ PrimFunc MakePackedAPI(PrimFunc func) {
                                            type_index == 
ffi::TypeIndex::kTVMFFIDLTensorPtr ||
                                            type_index >= 
ffi::TypeIndex::kTVMFFIStaticObjectBegin,
                                        tvm::tir::StringImm(msg.str()), nop));
+      // if type_index is NDArray, we need to add the offset of the DLTensor 
header
+      // which always equals 16 bytes, this ensures that T.handle always shows 
up as a DLTensor*
       arg_value = f_load_arg_value(param.dtype(), i);
+      PrimExpr handle_from_ndarray = Call(DataType::Handle(), 
tir::builtin::handle_add_byte_offset(),
+                                      {arg_value, IntImm(DataType::Int(32), 
16)});
+      arg_value = Select(type_index == ffi::TypeIndex::kTVMFFINDArray, 
handle_from_ndarray, arg_value);
     } else if (dtype.is_bool()) {
       std::ostringstream msg;
       msg << name_hint << ": Expect arg[" << i << "] to be boolean";
@@ -324,7 +329,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
     if (func_ptr->buffer_map.count(param)) {
       // buffer binding now depends on type index
       // if the index is NDArray handle, we need to offset to get the DLTensor*
-      buffer_def.emplace_back(param, func_ptr->buffer_map[param], type_index);
+      buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
     }
   }
 
@@ -342,8 +347,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
     binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
   }
 
-  for (const auto& [var, buffer, type_index] : buffer_def) {
-    binder.BindDLTensor(buffer, device_type, device_id, var, type_index,
+  for (const auto& [var, buffer] : buffer_def) {
+    binder.BindDLTensor(buffer, device_type, device_id, var,
                         name_hint + "." + var->name_hint);
     arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
   }

Reply via email to