This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit d7987f4217d69810e946baeea681f678c3938348 Author: tqchen <[email protected]> AuthorDate: Tue Apr 22 10:42:39 2025 -0400 Upgrade FFI to accept ndarray as handle --- src/tir/transforms/arg_binder.cc | 17 +++-------------- src/tir/transforms/arg_binder.h | 3 +-- src/tir/transforms/make_packed_api.cc | 13 +++++++++---- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 4eec34dbac..9270a14df9 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -150,27 +150,16 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr, builtin::TVMStructFieldKind kin } void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, - const PrimExpr& device_id, const Var& ptr_handle, - const PrimExpr& type_index, const std::string& arg_name) { + const PrimExpr& device_id, const Var& handle, + const std::string& arg_name) { const DataType tvm_shape_type = DataType::ShapeIndex(); const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = Evaluate(0); init_nest_.emplace_back(AssertStmt( - !Call(DataType::Bool(), builtin::isnullptr(), {ptr_handle}), + !Call(DataType::Bool(), builtin::isnullptr(), {handle}), tvm::tir::StringImm(arg_name + " is expected to have non-NULL DLTensor* pointer"), nop)); - Var handle = Var(ptr_handle->name_hint + "_as_dltensor", DataType::Handle()); - - PrimExpr handle_from_ndarray = Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), - {ptr_handle, IntImm(DataType::Int(32), 16)}); - // if the type index is not DLTensorPtr, we need to add the offset of the DLTensor header - // which always equals 16 bytes; - init_nest_.emplace_back(LetStmt( - handle, - Select(type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr, ptr_handle, handle_from_ndarray), - nop)); - // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); diff --git a/src/tir/transforms/arg_binder.h b/src/tir/transforms/arg_binder.h index b831e797ea..68cbbb6773 100644 --- a/src/tir/transforms/arg_binder.h +++ b/src/tir/transforms/arg_binder.h @@ -97,11 +97,10 @@ class ArgBinder { * \param device_type The device id to be binded. * \param device_id The device id to be binded. * \param handle The DLTensor handle. - * \param type_index The type index of the DLTensor handle. * \param arg_name argument name. */ void BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id, - const Var& handle, const PrimExpr& type_index, const std::string& arg_name); + const Var& handle, const std::string& arg_name); /*! \return The defs generated in binding. */ const std::vector<Var>& defs() const { return defs_; } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 13cff5276b..d1931ebced 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -269,7 +269,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { // Need to delay binding of the buffers, in case some arguments also // appear in the buffer. std::vector<std::pair<PrimExpr, Var>> var_def; - std::vector<std::tuple<Var, Buffer, Var>> buffer_def; + std::vector<std::pair<Var, Buffer>> buffer_def; for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; @@ -290,7 +290,12 @@ PrimFunc MakePackedAPI(PrimFunc func) { type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr || type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin, tvm::tir::StringImm(msg.str()), nop)); + // if type_index is NDArray, we need to add the offset of the DLTensor header + // which always equals 16 bytes, this ensures that T.handle always shows up as a DLTensor* arg_value = f_load_arg_value(param.dtype(), i); + PrimExpr handle_from_ndarray = Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), + {arg_value, IntImm(DataType::Int(32), 16)}); + arg_value = Select(type_index == ffi::TypeIndex::kTVMFFINDArray, handle_from_ndarray, arg_value); } else if (dtype.is_bool()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be boolean"; @@ -324,7 +329,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { if (func_ptr->buffer_map.count(param)) { // buffer binding now depends on type index // if the index is NDArray handle, we need to offset to get the DLTensor* - buffer_def.emplace_back(param, func_ptr->buffer_map[param], type_index); + buffer_def.emplace_back(param, func_ptr->buffer_map[param]); } } @@ -342,8 +347,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { binder.Bind(param, expr, name_hint + "." + param->name_hint, true); } - for (const auto& [var, buffer, type_index] : buffer_def) { - binder.BindDLTensor(buffer, device_type, device_id, var, type_index, + for (const auto& [var, buffer] : buffer_def) { + binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + var->name_hint); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); }
