This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 8ac4c9b587db73fd281c2c8a721478918bcd54e5 Author: tqchen <[email protected]> AuthorDate: Sun Apr 20 19:56:10 2025 -0400 Upgrade the caller side of ABI to make use of the latest FII --- include/tvm/runtime/packed_func.h | 30 ++---- include/tvm/tir/builtin.h | 44 +++----- include/tvm/tir/transform.h | 5 - python/tvm/script/ir_builder/tir/ir.py | 2 - python/tvm/tir/__init__.py | 1 - python/tvm/tir/op.py | 18 ---- python/tvm/tir/transform/transform.py | 11 -- src/relax/backend/vm/codegen_vm_tir.cc | 3 - src/runtime/library_module.cc | 2 +- src/runtime/relax_vm/builtin.cc | 21 ++-- src/target/llvm/codegen_cpu.cc | 159 +++++++++------------------- src/target/llvm/codegen_cpu.h | 13 +-- src/target/llvm/codegen_hexagon.cc | 88 --------------- src/target/source/codegen_c.cc | 23 ++-- src/target/source/codegen_c_host.cc | 117 ++++++-------------- src/target/source/codegen_c_host.h | 15 +-- src/tir/op/builtin.cc | 4 - src/tir/transforms/legalize_packed_calls.cc | 138 ------------------------ src/tir/transforms/lower_tvm_builtin.cc | 129 ++++++++++------------ 19 files changed, 196 insertions(+), 627 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index d90017288b..0746edc195 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -458,25 +458,17 @@ struct ModuleVTableEntryHelper<void (T::*)(Args...)> { * * \endcode */ -#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ - int* out_type_code, void* resource_handle) { \ - try { \ - using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \ - std::vector<::tvm::ffi::AnyView> packed_args(num_args); \ - ::tvm::runtime::LegacyTVMArgsToPackedArgs(args, type_code, num_args, packed_args.data()); \ - ::tvm::ffi::Any rv; \ - static std::string name = #ExportName; \ - ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType, FuncInfo::num_args>( \ - &name, Function, packed_args.data(), num_args, &rv); \ - ::tvm::runtime::MoveAnyToLegacyTVMValue(std::move(rv), out_value, out_type_code); \ - return 0; \ - } catch (const ::std::exception& _except_) { \ - TVMAPISetLastError(_except_.what()); \ - return -1; \ - } \ - } \ +#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_DLL int ExportName(void* self, TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) { \ + TVM_FFI_SAFE_CALL_BEGIN(); \ + using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \ + static std::string name = #ExportName; \ + ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType, FuncInfo::num_args>( \ + &name, Function, reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args, \ + reinterpret_cast<::tvm::ffi::Any*>(result)); \ + TVM_FFI_SAFE_CALL_END(); \ + } \ } } // namespace runtime // NOLINT(*) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 822763d0b2..6f7ce9de2b 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -356,14 +356,13 @@ TVM_DLL const Op& tvm_stack_make_array(); /*! * \brief See pesudo code * - * return_type tvm_call_packed(name, TVMValue* args) { - * TVMValue ret_value; - * int ret_code; + * return_type tvm_call_packed(name, TVMFFIAny* args) { + * TVMFFIAny result; * ModuleNode* env = GetCurrentEnv(); * const PackedFunc* f = env->GetFuncFromEnv(name); - * (*f)(args, type_code_of(args), len(args), &ret_value, &ret_code); + * (*f)(args, args, len(args), &result); * // return type can be int, float, handle. - * return cast(return_type, ret_value.v_return_type); + * return cast(return_type, result); * } */ TVM_DLL const Op& tvm_call_packed(); @@ -371,11 +370,10 @@ TVM_DLL const Op& tvm_call_packed(); /*! * \brief See pesudo code * - * return_type tvm_call_packed(fname, TVMValue* args) { - * int ret_code; - * TVMValue ret_value; - * (*fname)(args, type_code_of(args), len(args), &ret_value, &ret_code); - * return cast(return_type, ret_value.v_return_type); + * return_type tvm_call_packed(fname, TVMFFIAny* args) { + * TVMFFIAny result; + * (*fname)(args, args, len(args), &result); + * return cast(return_type, result); * } */ TVM_DLL const Op& tvm_call_cpacked(); @@ -383,30 +381,16 @@ TVM_DLL const Op& tvm_call_cpacked(); /*! * \brief See pesudo code * - * return_type tvm_call_trace_packed(name, TVMValue* args) { + * return_type tvm_call_trace_packed(name, TVMFFIAny* args) { * ModuleNode* env = GetCurrentEnv(); * const PackedFunc* f = env->GetFuncFromEnv(name); - * (*f)(args, type_code_of(args), len(args)); + * (*f)(args, args, len(args)); * // return type can be int, float, handle. - * return cast(return_type, ret_value.v_return_type); + * return cast(return_type, result); * } */ TVM_DLL const Op& tvm_call_trace_packed(); -/*! - * \brief Checks the return value of another call is correct or returns a given value. - * - * \note This is meant to serve a specific case for AOT code generator whilst this - * cannot be fully represented in TIR. - * - * Type tvm_check_return(expected, return_unexpected, nested_call) { - * if (nested_call() != expected) { - * return return_unexpected; - * } - * } - */ -TVM_DLL const Op& tvm_check_return(); - /*! * \brief See pesudo code * Mark the content as thread local context, can get optimized @@ -451,10 +435,10 @@ TVM_DLL const Op& tvm_call_packed_lowered(); * type codes are explicitly allocated. * * int tvm_call_packed_lowered(fname, - * TVMValue* value_stack, - * int* tcode_stack, + * TVMFFIAny* args_stack, * int begin, - * int end) { + * int end, + * void* self) { * fname(TVMArgs(value_stack[begin:end], tcode_stack[begin:end]), * TVMRetValue(value_stack + end, tcode_stack + end)); * } diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 21f7278b3a..17b85585b0 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -489,11 +489,6 @@ TVM_DLL Pass LiftThreadBinding(); */ TVM_DLL Pass CompactBufferAllocation(bool is_strict = true); -/*! - * This pass legalizes packed calls by wrapping their arguments into TVMValues - */ -TVM_DLL Pass LegalizePackedCalls(); - /*! * \brief Remove match buffers inside the block. Also, it will validate the binding. * \return The pass. diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 3814f2df88..e270b91526 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1875,7 +1875,6 @@ tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error) tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca) tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape) tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array) -tvm_check_return = _op_wrapper(_tir_op.tvm_check_return) call_packed = _op_wrapper(_tir_op.call_packed) call_cpacked = _op_wrapper(_tir_op.call_cpacked) call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered) @@ -2168,7 +2167,6 @@ __all__ = [ "tvm_stack_alloca", "tvm_stack_make_shape", "tvm_stack_make_array", - "tvm_check_return", "call_packed", "call_cpacked", "call_packed_lowered", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 63d3cb8f31..2f8b1f45cb 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -48,7 +48,6 @@ from .function import PrimFunc, TensorIntrin, IndexMap from .op import call_packed_lowered, call_cpacked_lowered, call_tir from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace -from .op import tvm_check_return from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array from .op import tvm_tuple, handle_add_byte_offset, tvm_struct_get, tvm_struct_set from .op import address_of, lookup_param, assume, undef diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 46c634eeb4..23c10fc7be 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -319,24 +319,6 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): ) -def tvm_check_return(expected, return_unexpected, nested_call): - """Return new on stack dtype[num] - Parameters - ---------- - expected : int - The expected return code. - return_unexpected : int - The unexpected return code. - nested_call : PrimExpr - The call expression to check return. - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("int32", "tir.tvm_check_return", expected, return_unexpected, nested_call) - - def tvm_stack_alloca(dtype_str, num): """Return new on stack dtype[num] diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index eb38c5f775..38d3e384c4 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -513,17 +513,6 @@ def LowerTVMBuiltin(): return _ffi_api.LowerTVMBuiltin() # type: ignore -def LegalizePackedCalls(): - """Legalize packed calls to have its arguments wrapped in TVMValues - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LegalizePackedCalls() # type: ignore - - def LowerIntrin(): """Lower target specific intrinsic calls. diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index e3812ea8c1..0fead47194 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -136,9 +136,6 @@ class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> { for (PrimExpr arg : args) { all_args.push_back(arg); } - // push an empty handle to be compatible with current cpacked convention - // TODO(tqchen): revisit C Packed convention - all_args.push_back(tir::make_zero(DataType::Handle())); if (dst_anylist_slot >= 0) { this->EmitStmt(tir::Evaluate( tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_cpacked(), all_args))); diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index e718255bcc..a8a66b21cb 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -82,7 +82,7 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) { *fp = FuncName; \ } // Initialize the functions - TVM_INIT_CONTEXT_FUNC(TVMFuncCall); + TVM_INIT_CONTEXT_FUNC(TVMFFIFuncCall); TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError); TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index c8c8dd4b7e..f7b08c097a 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -19,6 +19,7 @@ /*! * \file src/runtime/relax_vm/builtin.cc */ +#include <tvm/ffi/any.h> #include <tvm/runtime/container/array.h> #include <tvm/runtime/container/shape_tuple.h> #include <tvm/runtime/data_type.h> @@ -573,11 +574,10 @@ extern "C" { * \param anylist The handle to the anylist, backed by TVMRetValue* * \param int The index. * \param args The args stack. - * \param type_codes The type codes stack. * \param arg_offset The offset of argument. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMValue* args, int* type_codes, +TVM_DLL int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMFFIAny* args, int arg_offset); /*! * \brief Backend function to get anylist item and set into Packed Func call arg stack. @@ -597,15 +597,14 @@ TVM_DLL int TVMBackendAnyListResetItem(void* anylist, int index); * \param arg_offset The offset of argument. * \return 0 when no error is thrown, -1 when failure happens. */ -TVM_DLL int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMValue* args, - int* type_codes, int ret_offset); +TVM_DLL int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMFFIAny* args, + int ret_offset); -int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMValue* args, int* type_codes, - int arg_offset) { +int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMFFIAny* args, int arg_offset) { using namespace tvm::runtime; API_BEGIN(); auto* list = static_cast<TVMFFIAny*>(anylist); - AnyViewToLegacyTVMArgValue(list[index], args + arg_offset, type_codes + arg_offset); + args[arg_offset] = list[index]; API_END(); } @@ -617,16 +616,12 @@ int TVMBackendAnyListResetItem(void* anylist, int index) { API_END(); } -int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMValue* args, int* type_codes, +int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMFFIAny* args, int ret_offset) { using namespace tvm::runtime; API_BEGIN(); auto* list = static_cast<Any*>(anylist); - if (type_codes[ret_offset] == kTVMStr || type_codes[ret_offset] == kTVMBytes) { - list[index] = LegacyTVMArgValueToAnyView(args[ret_offset], type_codes[ret_offset]); - } else { - list[index] = MoveLegacyTVMArgValueToAny(args[ret_offset], type_codes[ret_offset]); - } + list[index] = tvm::ffi::details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(args[ret_offset])); API_END(); } } // extern "C" diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 6c22ac198e..83b07b14e9 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -97,26 +97,18 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, t_tvm_array_ = llvm::StructType::create({t_void_p_, t_tvm_device_, t_int_, t_tvm_type_, llvmGetPointerTo(t_tvm_shape_index_, 0), llvmGetPointerTo(t_tvm_shape_index_, 0), t_int64_}); - // Defined in include/tvm/runtime/c_runtime_api.h: - // typedef union { ... } TVMValue; - t_tvm_value_ = llvm::StructType::create({t_float64_}); // Defined in include/tvm/ffi/c_api.h: t_tvm_ffi_any_ = llvm::StructType::create({t_int32_, t_int32_, t_float64_}); // Defined in include/tvm/runtime/c_backend_api.h: // typedef struct { void* sync_handle; int32_t num_task; } TVMParallelGroupEnv; t_tvm_parallel_group_env_ = llvm::StructType::create({llvmGetPointerTo(t_int32_, 0), t_int32_}); - // Defined in include/tvm/runtime/c_backend_api.h: - // typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args, - // TVMValue* out_ret_value, int* out_ret_tcode, - // void* resource_handle); - ftype_tvm_backend_packed_c_func_ = - llvm::FunctionType::get(t_int_, - {t_void_p_, llvmGetPointerTo(t_int_, 0), t_int_, t_void_p_, - llvmGetPointerTo(t_int_, 0), t_void_p_}, - false); - t_tvm_crt_func_registry_ = llvm::StructType::create( - {llvmGetPointerTo(t_char_, 0), llvmGetPointerTo(ftype_tvm_backend_packed_c_func_, 0)}); - t_tvm_crt_module_ = llvm::StructType::create({llvmGetPointerTo(t_tvm_crt_func_registry_, 0)}); + // Defined in include/tvm/ffi/c_api.h: + // typedef int (*)(void* self, const TVMFFIAny* args, int32_t num_args, + // TVMFFIAny* result); + ftype_tvm_ffi_c_func_ = llvm::FunctionType::get( + t_int_, + {t_void_p_, llvmGetPointerTo(t_tvm_ffi_any_, 0), t_int_, llvmGetPointerTo(t_tvm_ffi_any_, 0)}, + false); // Defined in include/tvm/runtime/c_backend_api.h: // typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata); ftype_tvm_parallel_lambda_ = llvm::FunctionType::get( @@ -124,15 +116,10 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_); // Runtime functions. - - // Defined in include/tvm/runtime/c_runtime_api.h: - // int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args, - // TVMValue* ret_val, int* ret_type_code); - ftype_tvm_func_call_ = llvm::FunctionType::get( - t_int_, - {t_tvm_func_handle_, llvmGetPointerTo(t_tvm_value_, 0), llvmGetPointerTo(t_int_, 0), t_int_, - llvmGetPointerTo(t_tvm_value_, 0), llvmGetPointerTo(t_int_, 0)}, - false); + // Defined in include/tvm/ffi/c_api.h: + // int TVMFFIFuncCall(TVMFunctionHandle func, TVMFFIAny* args, int32_t num_args, + // TVMFFIAny* result); + ftype_tvm_ffi_func_call_ = ftype_tvm_ffi_c_func_; // Defined in include/tvm/runtime/c_backend_api.h: // int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out); ftype_tvm_get_func_from_env_ = llvm::FunctionType::get( @@ -168,14 +155,14 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, f_tvm_register_system_symbol_ = nullptr; } if (dynamic_lookup || system_lib_prefix_.defined()) { - f_tvm_func_call_ = llvm::Function::Create(ftype_tvm_func_call_, llvm::Function::ExternalLinkage, - "TVMFuncCall", module_.get()); - f_tvm_get_func_from_env_ = - llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage, - "TVMBackendGetFuncFromEnv", module_.get()); + f_tvm_ffi_func_call_ = llvm::Function::Create( + ftype_tvm_ffi_func_call_, llvm::Function::ExternalLinkage, "TVMFFIFuncCall", module_.get()); f_tvm_api_set_last_error_ = llvm::Function::Create(ftype_tvm_api_set_last_error_, llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); + f_tvm_get_func_from_env_ = + llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage, + "TVMBackendGetFuncFromEnv", module_.get()); f_tvm_parallel_launch_ = llvm::Function::Create(ftype_tvm_parallel_launch_, llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get()); @@ -374,31 +361,6 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value LOG(DEBUG) << "DataType " << t << " cannot be stored into a TVMFFIAny's value field"; } } - case builtin::kTVMValueContent: { - ICHECK_EQ(t.lanes(), 1); - if (t.is_bool()) { - // The stride between adjacent entries is still - // `sizeof(TVMValue)==64`, even if the enum currently holds a - // boolean. - buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_int64_, 0)); - buf = builder_->CreateInBoundsGEP(t_int64_, buf, index); - buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(DTypeToLLVMType(t), 0)); - return TypedPointer(t_int8_, buf); - } else if (t.is_int() && t.bits() == 64) { - buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_int64_, 0)); - return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); - } else if (t.is_float() && t.bits() == 64) { - buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_float64_, 0)); - return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); - } else if (t.is_handle()) { - buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_value_, 0)); - buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); - return TypedPointer(t_void_p_, - builder_->CreatePointerCast(buf, llvmGetPointerTo(t_void_p_, 0))); - } else { - LOG(DEBUG) << "DataType " << t << " cannot be stored into a TVMValue"; - } - } default: LOG(FATAL) << "unknown field code"; } @@ -482,7 +444,7 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { } else { if (!dynamic_lookup) { gv_tvm_func_call_ = - InitContextPtr(llvmGetPointerTo(ftype_tvm_func_call_, 0), "__TVMFuncCall"); + InitContextPtr(llvmGetPointerTo(ftype_tvm_ffi_func_call_, 0), "__TVMFFIFuncCall"); gv_tvm_get_func_from_env_ = InitContextPtr(llvmGetPointerTo(ftype_tvm_get_func_from_env_, 0), "__TVMBackendGetFuncFromEnv"); gv_tvm_api_set_last_error_ = InitContextPtr( @@ -846,7 +808,7 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& args, const DataType& r_type, const int64_t begin, const int64_t end, - bool use_string_lookup) { + bool use_env_lookup) { std::string func_name = [&]() { auto ptr = args[0].as<StringImmNode>(); ICHECK(ptr) << "Expected first argument of tir::Call to be " @@ -857,49 +819,32 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& // call the function int64_t nargs = end - begin; ICHECK_GE(nargs, 0); - llvm::Value* stack_value = MakeValue(args[1]); - llvm::Value* stack_tcode = MakeValue(args[2]); - llvm::Value* arg_value = builder_->CreateInBoundsGEP( - t_tvm_value_, builder_->CreatePointerCast(stack_value, llvmGetPointerTo(t_tvm_value_, 0)), + llvm::Value* stack_args = MakeValue(args[1]); + llvm::Value* packed_args = builder_->CreateInBoundsGEP( + t_tvm_ffi_any_, builder_->CreatePointerCast(stack_args, llvmGetPointerTo(t_tvm_ffi_any_, 0)), ConstInt32(begin)); - TypedPointer arg_tcode = - CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(begin)}, DataType::Int(32)); - llvm::Value* ret_value = builder_->CreateInBoundsGEP( - t_tvm_value_, builder_->CreatePointerCast(stack_value, llvmGetPointerTo(t_tvm_value_, 0)), + llvm::Value* result = builder_->CreateInBoundsGEP( + t_tvm_ffi_any_, builder_->CreatePointerCast(stack_args, llvmGetPointerTo(t_tvm_ffi_any_, 0)), ConstInt32(end)); - TypedPointer ret_tcode = - CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32)); llvm::FunctionType* callee_ftype = nullptr; llvm::Value* callee_value = nullptr; std::vector<llvm::Value*> call_args; - if (use_string_lookup) { - callee_ftype = ftype_tvm_func_call_; + if (use_env_lookup) { + callee_ftype = ftype_tvm_ffi_func_call_; callee_value = RuntimeTVMFuncCall(); call_args.push_back(GetPackedFuncHandle(func_name)); - call_args.insert(call_args.end(), - {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result}); } else { - callee_ftype = ftype_tvm_backend_packed_c_func_; + callee_ftype = ftype_tvm_ffi_c_func_; callee_value = module_->getFunction(func_name); if (callee_value == nullptr) { - callee_value = - llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, - func_name, module_.get()); + callee_value = llvm::Function::Create(ftype_tvm_ffi_c_func_, llvm::Function::ExternalLinkage, + func_name, module_.get()); } - // NOTE: This is a bugfix to a previous coupled convention(in lower_tvm_builtin) - // The begin, end should correspond to the right location in cpacked excluding resource handle. - // TODO(tqchen): upstream the fix. - // nargs -= 1; - call_args.insert(call_args.end(), { - builder_->CreateBitCast(arg_value, t_void_p_), - arg_tcode.addr, - ConstInt32(nargs), - builder_->CreateBitCast(ret_value, t_void_p_), - ret_tcode.addr, - }); call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_)); + call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result}); } #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value); @@ -917,8 +862,10 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& // Load the return value and cast it to the designated type (r_type). DataType r_api_type = tir::APIType(r_type); llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); + llvm::Value* result_value = + builder_->CreateInBoundsGEP(t_tvm_ffi_any_, result, {ConstInt32(0), ConstInt32(2)}); llvm::Value* load_ptr = - builder_->CreatePointerCast(ret_value, llvmGetPointerTo(llvm_r_api_type, 0)); + builder_->CreatePointerCast(result_value, llvmGetPointerTo(llvm_r_api_type, 0)); #if TVM_LLVM_VERSION >= 110 llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); #elif TVM_LLVM_VERSION >= 80 @@ -931,11 +878,13 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& // Load the return type code. #if TVM_LLVM_VERSION >= 110 - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); + llvm::Value* result_type_index = + builder_->CreateInBoundsGEP(t_tvm_ffi_any_, result, {ConstInt32(0), ConstInt32(0)}); + pc.ret_type_index = builder_->CreateAlignedLoad(t_int32_, result_type_index, llvm::Align(4)); #elif TVM_LLVM_VERSION >= 80 - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); + pc.ret_type_index = builder_->CreateAlignedLoad(t_int32_, result_type_index, 8); #else - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); + pc.ret_type_index = builder_->CreateAlignedLoad(result_type_index, 8); #endif } @@ -943,21 +892,21 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& return pc; } -llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lookup) { - auto expected_num_args = use_string_lookup ? 5U : 6U; - ICHECK_EQ(op->args.size(), expected_num_args); - PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value, - op->args[4].as<IntImmNode>()->value, use_string_lookup); +llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { + ICHECK_EQ(op->args.size(), 4U); + bool use_string_lookup = op->op.same_as(builtin::tvm_call_packed_lowered()); + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[2].as<IntImmNode>()->value, + op->args[3].as<IntImmNode>()->value, use_string_lookup); return pc.ret_value; } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { - ICHECK_EQ(op->args.size(), 6U); - PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value, - op->args[4].as<IntImmNode>()->value, true); + ICHECK_EQ(op->args.size(), 5U); + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[2].as<IntImmNode>()->value, + op->args[3].as<IntImmNode>()->value, true); llvm::LLVMContext* ctx = llvm_target_->GetContext(); // Get traced value. - llvm::Value* traced_value = MakeValue(op->args[5]); + llvm::Value* traced_value = MakeValue(op->args[4]); // The update_block handles case when we need to update the return value. llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx, "update_block", function_); // The continue_block handles case when we need to return original @@ -965,8 +914,8 @@ llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx, "continue_block", function_); // Check the ret_type_code and create cmp instruction. - llvm::Value* cmp = - builder_->CreateICmpNE(pc.ret_tcode, llvm::ConstantInt::get(t_int_, kTVMNullptr)); + llvm::Value* cmp = builder_->CreateICmpNE( + pc.ret_type_index, llvm::ConstantInt::get(t_int_, ffi::TypeIndex::kTVMFFINone)); builder_->CreateCondBr(cmp, update_block, continue_block); builder_->SetInsertPoint(update_block); builder_->CreateBr(continue_block); @@ -979,7 +928,7 @@ llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { } llvm::Value* CodeGenCPU::RuntimeTVMFuncCall() { - if (f_tvm_func_call_ != nullptr) return f_tvm_func_call_; + if (f_tvm_ffi_func_call_ != nullptr) return f_tvm_ffi_func_call_; return GetContextPtr(gv_tvm_func_call_); } @@ -1022,11 +971,11 @@ void CodeGenCPU::AddStartupFunction() { llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin::tvm_call_packed_lowered())) { - return CreateCallPacked(op, true /* use_string_lookup */); + return CreateCallPacked(op); } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { return CreateCallTracePacked(op); } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { - return CreateCallPacked(op, false /* use_string_lookup */); + return CreateCallPacked(op); } else if (op->op.same_as(builtin::tvm_static_handle())) { return CreateStaticHandle(); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { @@ -1083,10 +1032,6 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { llvm::Value* num = ConstInt32(pval[0]); if (type == "shape") { return builder_->CreateAlloca(t_tvm_shape_index_, num); - } else if (type == "arg_value") { - return builder_->CreateAlloca(t_tvm_value_, num); - } else if (type == "arg_tcode") { - return builder_->CreateAlloca(t_int_, num); } else if (type == "tvm_ffi_any") { return builder_->CreateAlloca(t_tvm_ffi_any_, num); } else if (type == "array") { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 5187b3e3f2..03a0ad966e 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -87,14 +87,11 @@ class CodeGenCPU : public CodeGenLLVM { llvm::StructType* t_tvm_device_{nullptr}; llvm::StructType* t_tvm_type_{nullptr}; llvm::StructType* t_tvm_array_{nullptr}; - llvm::StructType* t_tvm_value_{nullptr}; llvm::StructType* t_tvm_ffi_any_{nullptr}; llvm::StructType* t_tvm_parallel_group_env_{nullptr}; - llvm::FunctionType* ftype_tvm_backend_packed_c_func_{nullptr}; - llvm::StructType* t_tvm_crt_func_registry_{nullptr}; - llvm::StructType* t_tvm_crt_module_{nullptr}; + llvm::FunctionType* ftype_tvm_ffi_c_func_{nullptr}; llvm::FunctionType* ftype_tvm_parallel_lambda_{nullptr}; - llvm::FunctionType* ftype_tvm_func_call_{nullptr}; + llvm::FunctionType* ftype_tvm_ffi_func_call_{nullptr}; llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr}; llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr}; llvm::FunctionType* ftype_tvm_parallel_launch_{nullptr}; @@ -133,13 +130,13 @@ class CodeGenCPU : public CodeGenLLVM { // Make packed call. struct PackedCall { llvm::Value* ret_value; - llvm::Value* ret_tcode; + llvm::Value* ret_type_index; llvm::BasicBlock* end_block; }; PackedCall MakeCallPackedLowered(const Array<PrimExpr>& args, const DataType& r_type, const int64_t begin, const int64_t end, bool use_string_lookup); // create call into tvm packed function. - llvm::Value* CreateCallPacked(const CallNode* op, bool use_string_lookup); + llvm::Value* CreateCallPacked(const CallNode* op); // Create trace call into tvm packed function. llvm::Value* CreateCallTracePacked(const CallNode* op); // Create static initialization @@ -166,7 +163,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr}; std::unordered_map<String, llvm::GlobalVariable*> gv_func_map_; // context for direct dynamic lookup - llvm::Function* f_tvm_func_call_{nullptr}; + llvm::Function* f_tvm_ffi_func_call_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr}; llvm::Function* f_tvm_api_set_last_error_{nullptr}; llvm::Function* f_tvm_parallel_launch_{nullptr}; diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 1bfdcb6ac3..7c35b2ead6 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -101,7 +101,6 @@ class CodeGenHexagon final : public CodeGenCPU { private: TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, llvm::ArrayRef<llvm::Value*> indices, DataType value_dtype) final; - TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind); bool IsQHLFunction(const std::string& func); @@ -292,93 +291,6 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateBufferPtr(llvm::Value* buffer_pt value_dtype); } -CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, - llvm::Value* index, int kind) { - static const std::map<int, int> field_index = { - {builtin::kArrData, 0}, {builtin::kArrDeviceType, 1}, {builtin::kArrDeviceId, 1}, - {builtin::kArrNDim, 2}, {builtin::kArrTypeCode, 3}, {builtin::kArrTypeBits, 3}, - {builtin::kArrTypeLanes, 3}, {builtin::kArrShape, 4}, {builtin::kArrStrides, 5}, - {builtin::kArrByteOffset, 6}}; - static const std::map<int, int> subfield_index = { - {builtin::kArrDeviceType, 0}, {builtin::kArrDeviceId, 1}, {builtin::kArrTypeCode, 0}, - {builtin::kArrTypeBits, 1}, {builtin::kArrTypeLanes, 2}, - }; - - if (kind < builtin::kArrKindBound_) { - if (buf->getType() == t_void_p_) { - buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_array_, 0)); - } else { - ICHECK_EQ(buf->getType(), llvmGetPointerTo(t_tvm_array_, 0)); - } - /* The following "kinds" are accessing the members of DLTensor: - typedef struct { - void* data; kArrData - DLDevice device; kArrDeviceType (device.device_type) - kArrDeviceId (device.device_id) - int ndim; kArrNDim - DLDataType dtype; kArrTypeCode (dtype.code) - kArrTypeBits (dtype.bits) - kArrTypeLanes (dtype.lanes) - int64_t* shape; kArrShape - int64_t* strides; kArrStrides - uint64_t byte_offset; kArrByteOffset - } DLTensor; - */ - llvm::Value* base_gep = builder_->CreateInBoundsGEP(t_tvm_array_, buf, index, "base_gep"); - if (kind == builtin::kArrAddr) { - return TypedPointer(t_void_p_, base_gep); - } - llvm::Value* field_gep = builder_->CreateInBoundsGEP( - t_tvm_array_, base_gep, {ConstInt32(0), ConstInt32(field_index.at(kind))}, "field_gep"); - llvm::Type* field_type = t_tvm_array_->getStructElementType(field_index.at(kind)); - switch (kind) { - // These fields have no sub-fields. - case builtin::kArrData: - case builtin::kArrNDim: - case builtin::kArrShape: - case builtin::kArrStrides: - case builtin::kArrByteOffset: - return TypedPointer(field_type, field_gep); - } - llvm::Value* subfield_gep = builder_->CreateInBoundsGEP( - field_type, field_gep, {ConstInt32(0), ConstInt32(subfield_index.at(kind))}, - "subfield_gep"); - llvm::Type* subfield_type = field_type->getStructElementType(subfield_index.at(kind)); - return TypedPointer(subfield_type, subfield_gep); - } - - if (kind == builtin::kTVMValueContent) { - /* TVMValue is a union: - typedef union { - int64_t v_int64; - double v_float64; - void* v_handle; - const char* v_str; - TVMType v_type; - DLDevice v_device; - } TVMValue; - */ - ICHECK_EQ(t.lanes(), 1); - ICHECK(t.is_handle() || t.bits() == 64); - if (t.is_int()) { - buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_int64_, 0)); - return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); - } else if (t.is_float()) { - buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_float64_, 0)); - return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); - } else { - ICHECK(t.is_handle()); - buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_value_, 0)); - buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); - return TypedPointer(t_void_p_, - builder_->CreatePointerCast(buf, llvmGetPointerTo(t_void_p_, 0))); - } - } - - assert(!"Unknown kind"); - return TypedPointer(); -} - llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID, llvm::ArrayRef<llvm::Value*> args) { #if TVM_LLVM_VERSION >= 200 diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index db5eaac50b..ad9456a1e9 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -601,15 +601,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (auto opt_call_op = op->op.as<Op>()) { auto call_op = opt_call_op.value(); - if (op->op.same_as(builtin::tvm_check_return())) { - const CallNode* call = op->args[2].as<CallNode>(); - os << "if ("; - VisitExpr_(call, os); - os << " != "; - PrintExpr(op->args[0], os); - os << " ) return "; - PrintExpr(op->args[1], os); - } else if (op->op.same_as(builtin::ret())) { + if (op->op.same_as(builtin::ret())) { os << "return "; PrintExpr(op->args[0], os); } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { @@ -1182,9 +1174,20 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { } else if (call->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(call->args.size(), 4); int kind = call->args[2].as<IntImmNode>()->value; - std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], kind); + DataType store_dtype = call->args[3].dtype(); + std::string ref = GetStructRef(store_dtype, call->args[0], call->args[1], kind); std::string value = PrintExpr(call->args[3]); std::string cast; + + if (kind == builtin::kTVMFFIAnyUnionValue && + (store_dtype.bits() < 64 || store_dtype.is_handle())) { + this->PrintIndent(); + // when we set any union value, we need to be careful to + // clear off the union value to zero if the set size is less than 64 bits + this->stream << GetStructRef(DataType::Int(64), call->args[0], call->args[1], kind) + << " = 0;\n"; + } + if (kind == builtin::kArrStrides) { // cast void* to int64_t* cast = call->args[3]->dtype.is_handle() ? "(int64_t*)" : ""; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 4810fdf003..c1859fa436 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -82,8 +82,7 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, PrintType(func->ret_type, stream); stream << " " << tvm::runtime::symbol::tvm_module_main << "(void* self, void* args,int num_args, void* result) {\n"; - stream << " return " << global_symbol.value() - << "(self, args, num_args, result);\n"; + stream << " return " << global_symbol.value() << "(self, args, num_args, result);\n"; stream << "}\n"; } } @@ -219,47 +218,38 @@ void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name, this->stream << "}\n"; } -void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_args) { - this->PrintIndent(); - std::string ret_val = name_supply_->FreshName("ret_val"); - std::string ret_type_code = name_supply_->FreshName("ret_type_code"); - this->stream << "TVMValue " << ret_val << ";\n"; - this->PrintIndent(); - this->stream << "int " << ret_type_code << ";\n"; - this->PrintIndent(); - this->stream << "if (TVMFuncCall(" << packed_func_name << ", " - << "(TVMValue*) stack_value" - << ", " - << "(int*) stack_tcode" - << ", " << num_args << ", " - << "&" << ret_val << ", " - << "&" << ret_type_code << ") != 0) {\n"; - int func_call_scope = this->BeginScope(); - this->PrintIndent(); - this->stream << "return -1;\n"; - this->EndScope(func_call_scope); - this->PrintIndent(); - this->stream << "}\n"; -} +void CodeGenCHost::PrintCallPacked(const CallNode* op) { + const StringImmNode* func_name = op->args[0].as<StringImmNode>(); + ICHECK(func_name != nullptr) + << "tvm_call_[c]packed_lowered expects first argument as function name"; + int64_t begin = op->args[3].as<IntImmNode>()->value; + int64_t end = op->args[4].as<IntImmNode>()->value; + int64_t num_args = end - begin; + ICHECK_GE(num_args, 0); -void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args, - const std::string& resource_handle_name) { - this->PrintIndent(); - std::string ret_val = name_supply_->FreshName("ret_val"); - std::string ret_type_code = name_supply_->FreshName("ret_type_code"); - this->stream << "TVMValue " << ret_val << ";\n"; + std::string packed_func_name; + if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + packed_func_name = GetPackedName(op); + this->PrintGetFuncFromBackend(func_name->value, packed_func_name); + } else { + // directly use the original symbol + ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered())); + packed_func_name = func_name->value; + } + + std::string args_stack = PrintExpr(op->args[1]); this->PrintIndent(); - this->stream << "int " << ret_type_code << ";\n"; + std::string result = name_supply_->FreshName("result"); + this->stream << "TVMFFIAny " << result << ";\n"; this->PrintIndent(); - this->stream << "if (" << packed_func_name << "( " - << "(TVMValue*) stack_value " - << ", " - << "(int*) stack_tcode" - << ", " << num_args << ", " - << "&" << ret_val << ", " - << "&" << ret_type_code << ", " << resource_handle_name << ") != 0){\n"; - + if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + this->stream << "if (TVMFFIFuncCall(" << packed_func_name << ", "; + } else { + this->stream << "if (" << packed_func_name << "(NULL, "; + } + this->stream << "(TVMFFIAny*) " << args_stack << ", " << num_args << ", " + << "&" << result << ") != 0) {\n"; int func_call_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; @@ -285,43 +275,6 @@ std::string CodeGenCHost::GetPackedName(const CallNode* op) { return unique_name; } -CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op, - bool has_resource_handle) { - const StringImmNode* s = op->args[0].as<StringImmNode>(); - ICHECK(s != nullptr) << "tvm_call_[c]packed_lowered expects first argument as function name"; - int64_t begin = op->args[3].as<IntImmNode>()->value; - int64_t end = op->args[4].as<IntImmNode>()->value; - int64_t num_args = end - begin; - ICHECK_GE(num_args, 0); - std::string func_name = s->value; - - if (has_resource_handle) { - const StringImmNode* resource_handle_var = op->args[5].as<StringImmNode>(); - if (resource_handle_var != nullptr) { - std::string resource_handle_name = resource_handle_var->value; - return {func_name, num_args - 1, resource_handle_name}; - } else { - // The final arg should be "(void*) NULL" to indicate the empty resource_handle. - num_args--; - - const CallNode* reinterpret_call = op->args[5].as<CallNode>(); - ICHECK_NE(reinterpret_call, (void*)nullptr) - << "At CallNode to " << s - << "arg 5: Expect either StringImm naming the resource_handle var from interface API or " - << "reinterpret(0); got: " << op->args[5]; - ICHECK_EQ(reinterpret_call->op, builtin::reinterpret()) - << "At CallNode to " << s - << "arg 5: Expect either StringImm naming the resource_handle var from interface API or " - << "reinterpret(0); got: " << op->args[5]; - ICHECK(is_zero(reinterpret_call->args[0])) << "At CallNode to " << s - << " arg 5: Expect either StringImm naming the " - "resource_handle var from interface API, or " - << "zero; got " << op->args[5]; - } - } - return {func_name, num_args, "NULL"}; -} - void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::tvm_stack_alloca())) { std::string stack_name = name_supply_->FreshName("stack"); @@ -348,14 +301,9 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT this->stream << "TVMValue " << stack_name << "[" << size << "];\n"; os << stack_name; } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { - auto function_info = GetFunctionInfo(op, false /* has_resource_handle */); - std::string func_name_packed = GetPackedName(op); - this->PrintGetFuncFromBackend(function_info.func_name, func_name_packed); - this->PrintFuncCall(func_name_packed, function_info.num_args); + this->PrintCallPacked(op); } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { - auto function_info = GetFunctionInfo(op, true /* has_resource_handle */); - this->PrintFuncCallC(function_info.func_name, function_info.num_args, - function_info.resource_handle_name); + this->PrintCallPacked(op); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { this->PrintIndent(); this->stream << "return -1;\n"; @@ -371,7 +319,8 @@ void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) { // NOLINT(*) stream << "if (!(" << cond << ")) {\n"; int assert_if_scope = this->BeginScope(); PrintIndent(); - stream << "TVMAPISetLastError(\"" << op->message.as<StringImmNode>()->value << "\");\n"; + stream << "TVMFFISetLastErrorCStr(\"RuntimeError\", \"" + << op->message.as<StringImmNode>()->value << "\");\n"; PrintIndent(); stream << "return -1;\n"; this->EndScope(assert_if_scope); diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 3e013492ef..4a2f530e2f 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -74,15 +74,6 @@ class CodeGenCHost : public CodeGenC { Array<String> GetFunctionNames() { return function_names_; } private: - /* \brief Internal structure to store information about function calls */ - struct FunctionInfo { - /* \brief function name */ - std::string func_name; - /* number of arguments required by the function */ - int64_t num_args; - /* \brief name of resource_handle to pass */ - std::string resource_handle_name; - }; std::string module_name_; /* \brief mapping global packed func to the unique name */ std::unordered_map<std::string, std::string> declared_globals_; @@ -93,13 +84,9 @@ class CodeGenCHost : public CodeGenC { /*! \brief whether to emit forwared function declarations in the resulting C code */ bool emit_fwd_func_decl_; - FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle); std::string GetPackedName(const CallNode* op); void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name); - void PrintFuncCall(const std::string& packed_func_name, int num_args); - void PrintFuncCallC(const std::string& packed_func_name, int num_args, - const std::string& resource_handle_name); - + void PrintCallPacked(const CallNode* op); /*! * \brief Print ternary conditional operator implementing binary `op` * Forces the operands to be in SSA form. diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index f0894acecf..a336688622 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -208,10 +208,6 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked) TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed) .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(tvm_check_return) - .set_num_inputs(3) - .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)); - TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context) .set_num_inputs(1) .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc deleted file mode 100644 index fed76876f6..0000000000 --- a/src/tir/transforms/legalize_packed_calls.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file make_packed_call.cc - * \brief Rewrite packed calls in AOT so that the arguments are packed - */ -#include <tvm/tir/builtin.h> -#include <tvm/tir/expr.h> -#include <tvm/tir/function.h> -#include <tvm/tir/op.h> -#include <tvm/tir/stmt_functor.h> -#include <tvm/tir/transform.h> - -#include <unordered_map> - -#include "ir_utils.h" - -namespace tvm { -namespace tir { - -using InputMap = - std::unordered_map<PrimExpr, bool, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>; -/** - * This is a legalization pass only used in AOT. Traverse the TIR graph to legalize - * packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in) - */ -class PackedCallLegalizer : public StmtExprMutator { - public: - PackedCallLegalizer(IRModule m, const InputMap& inputs) : mod_{m}, inputs_{inputs} {} - - Stmt Legalize(tir::Stmt body) { return StmtExprMutator::VisitStmt(body); } - - Stmt VisitStmt_(const EvaluateNode* op) final { - if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op); - const CallNode* call = op->value.as<CallNode>(); - // Given a packed call f(A,B,C), we need a set of new statements - // let A_packed = set_struct(tvm_value1, A) - // let B_packed = set_struct(tvm_value2, B) - // let C_packed = set_struct(tvm_value3, C) - // call_packed(f, A_packed, B_packed, C_packed) - if (call) { - if (call->op.same_as(builtin::tvm_call_cpacked())) { - Array<PrimExpr> packed_args{call->args[0]}; - VLOG(2) << "Legalize call:" << call; - BaseFunc base_func = mod_->Lookup(Downcast<StringImm>(call->args[0])->value); - const PrimFuncNode* prim_func = base_func.as<PrimFuncNode>(); - VLOG(2) << " to func " << base_func; - for (unsigned i = 1; i < call->args.size() - 1; i++) { - // No need to pack inputs of the prim_func - if (inputs_[call->args[i]] == true) { - packed_args.push_back(call->args[i]); - } else { - // Stack-allocate a DLTensor for this parameter. Note that LowerTVMBuiltin will collect - // all such stack-allocated tensors and minimize the storage needed by reusing - // DLTensors. - Array<PrimExpr> call_args{call->args[i]}; - tvm::runtime::Map<tvm::tir::Var, tvm::tir::Buffer>::iterator param_buf_it; - if (prim_func != nullptr) { - auto param_var = prim_func->params[i - 1]; - param_buf_it = prim_func->buffer_map.find(param_var); - } - if (prim_func != nullptr && param_buf_it != prim_func->buffer_map.end()) { - Buffer param = (*param_buf_it).second; - PrimExpr shape = tvm::tir::Call( - DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape); - Cast var_type(param->dtype, IntImm(DataType::Int(32), 0)); - call_args.push_back(shape /* shape */); - call_args.push_back(make_zero(DataType::Handle()) /* strides */); - call_args.push_back(tvm::IntImm(DataType::UInt(32), param->shape.size()) /* ndim */); - call_args.push_back(var_type /* carries dtype */); - call_args.push_back(param->elem_offset /* elem_offset */); - } else { - // When the PrimFunc cannot be found, most DLTensor information cannot be populated. - PrimExpr shape = tvm::tir::Call( - DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), Array<PrimExpr>()); - Cast var_type(DataType::Handle(), IntImm(DataType::Int(32), 0)); - call_args.push_back(shape /* shape */); - call_args.push_back(make_zero(DataType::Handle()) /* strides */); - call_args.push_back(tvm::IntImm(DataType::UInt(32), 0) /* ndim */); - call_args.push_back(var_type /* carries dtype */); - call_args.push_back(tvm::IntImm(DataType::UInt(64), 0) /* elem_offset */); - } - packed_args.push_back(tvm::tir::Call( - DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), call_args)); - } - } - packed_args.push_back(call->args[call->args.size() - 1]); // push device_context - // Evaluate the packed call - return tir::Evaluate(tir::Call(call->dtype, call->op, packed_args)); - } - } - return StmtExprMutator::VisitStmt_(op); - } - - private: - IRModule mod_; - InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed. -}; - -namespace transform { - -Pass LegalizePackedCalls() { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - - // Note which Var are inputs and exclude them from packing. - InputMap inputs; - for (auto i : f->params) { - inputs[i] = true; - } - n->body = PackedCallLegalizer(m, inputs).Legalize(std::move(n->body)); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.LegalizePackedCalls").set_body_typed(LegalizePackedCalls); -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 1cde4f2ebe..0931edcc2e 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -94,8 +94,7 @@ class BuiltinLower : public StmtExprMutator { struct AllocaScope { Buffer stack_shape; Var stack_array = Var("stack_array", DataType::Handle()); - Var stack_value = Var("stack_value", DataType::Handle()); - Buffer stack_tcode; + Var stack_ffi_any = Var("stack_ffi_any", DataType::Handle()); StackSizes max_sizes; StackSizes run_sizes; @@ -127,8 +126,6 @@ class BuiltinLower : public StmtExprMutator { auto& scope = precheck.alloca_scope_.back(); scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape"); - scope.stack_tcode = - decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode"); } precheck.VisitStmt(stmt); @@ -168,13 +165,7 @@ class BuiltinLower : public StmtExprMutator { } if (scope.max_sizes.arg_stack != 0) { - scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)}, - DataType::Int(32), "stack_tcode"); - stmt = - LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt); - - stmt = DeclBuffer(scope.stack_tcode, stmt); - stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack), + stmt = LetStmt(scope.stack_ffi_any, StackAlloca("tvm_ffi_any", scope.max_sizes.arg_stack), stmt); } } @@ -337,20 +328,17 @@ class BuiltinLower : public StmtExprMutator { PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_call_packed())) { return MakeCallPackedGeneric(op, 0, builtin::tvm_call_packed_lowered(), - /* use_string_lookup */ true, /* use_last_value_as_traced_value*/ false); } else if (op->op.same_as(builtin::tvm_call_cpacked())) { return MakeCallPackedGeneric(op, 0, builtin::tvm_call_cpacked_lowered(), - /* use_string_lookup */ false, /* use_last_value_as_traced_value*/ false); } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { return MakeCallPackedGeneric(op, 0, builtin::tvm_call_trace_packed_lowered(), - /* use_string_lookup */ true, /* use_last_value_as_traced_value*/ true); } else if (op->op.same_as(builtin::anylist_setitem_call_packed())) { - return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_packed_lowered(), true); + return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_packed_lowered()); } else if (op->op.same_as(builtin::anylist_setitem_call_cpacked())) { - return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_cpacked_lowered(), false); + return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_cpacked_lowered()); } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { return MakeShape(op); } else if (op->op.same_as(builtin::tvm_stack_make_array())) { @@ -490,57 +478,67 @@ class BuiltinLower : public StmtExprMutator { return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr); } - void SetPackedArg(PrimExpr arg, const Var& value_stack, const Buffer& tcode_stack, - size_t stack_offset, std::vector<tir::Stmt>* prep_seq) { + void SetPackedArg(PrimExpr arg, const Var& args_stack, size_t stack_offset, + std::vector<tir::Stmt>* prep_seq) { auto* call_pattern = arg.as<CallNode>(); if (call_pattern && call_pattern->op.same_as(builtin::anylist_getitem())) { // call runtime function to set anylist - prep_seq->emplace_back( - Evaluate(Call(DataType::Int(32), Op::Get("tir.TVMBackendAnyListSetPackedArg"), - {call_pattern->args[0], call_pattern->args[1], value_stack, - tcode_stack->data, ConstInt32(stack_offset)}))); + prep_seq->emplace_back(Evaluate(Call( + DataType::Int(32), Op::Get("tir.TVMBackendAnyListSetPackedArg"), + {call_pattern->args[0], call_pattern->args[1], args_stack, ConstInt32(stack_offset)}))); } else { - DataType api_type = APIType(arg.dtype()); - if (arg.dtype() != api_type) { - arg = Cast(api_type, arg); - } - prep_seq->emplace_back( - TVMStructSet(value_stack, stack_offset, builtin::kTVMValueContent, arg)); - int arg_tcode = api_type.code(); - if (api_type.is_handle() && arg.as<StringImmNode>()) { - arg_tcode = kTVMStr; - } else if (IsArrayHandle(arg)) { - arg_tcode = kTVMDLTensorHandle; - } else if (arg.dtype().is_bool()) { - arg_tcode = kTVMArgBool; + DataType api_dtype = APIType(arg.dtype()); + if (arg.dtype() != api_dtype) { + arg = Cast(api_dtype, arg); } + + int arg_type_index = [&]() { + if (api_dtype.is_bool()) return ffi::TypeIndex::kTVMFFIBool; + if (api_dtype.is_int() || api_dtype.is_uint()) return ffi::TypeIndex::kTVMFFIInt; + if (api_dtype.is_float()) return ffi::TypeIndex::kTVMFFIFloat; + if (api_dtype.is_handle() && arg.as<StringImmNode>()) { + return ffi::TypeIndex::kTVMFFIRawStr; + } else if (IsArrayHandle(arg)) { + return ffi::TypeIndex::kTVMFFIDLTensorPtr; + } else if (api_dtype.is_handle()) { + return ffi::TypeIndex::kTVMFFIOpaquePtr; + } else { + LOG(FATAL) << "Unsupported type: " << api_dtype; + } + }(); + // opaque handle need to set the kind properly - if (arg_tcode == kTVMOpaqueHandle) { - prep_seq->emplace_back(IfThenElse( - Call(DataType::Bool(), builtin::isnullptr(), {arg}), - BufferStore(tcode_stack, ConstInt32(kTVMNullptr), {ConstInt32(stack_offset)}), - BufferStore(tcode_stack, ConstInt32(arg_tcode), {ConstInt32(stack_offset)}))); - } else { + if (arg_type_index == ffi::TypeIndex::kTVMFFIOpaquePtr) { prep_seq->emplace_back( - BufferStore(tcode_stack, ConstInt32(arg_tcode), {ConstInt32(stack_offset)})); + IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {arg}), + TVMStructSet(args_stack, stack_offset, builtin::kTVMFFIAnyTypeIndex, + ConstInt32(ffi::TypeIndex::kTVMFFINone)), + TVMStructSet(args_stack, stack_offset, builtin::kTVMFFIAnyTypeIndex, + ConstInt32(ffi::TypeIndex::kTVMFFIOpaquePtr)))); + } else { + prep_seq->emplace_back(TVMStructSet(args_stack, stack_offset, builtin::kTVMFFIAnyTypeIndex, + ConstInt32(arg_type_index))); } + // handle arg value + // NOTE: the intrinsic codegen will handle padding value clear for 32bit + // types or types that are smaller than 64 bits. + prep_seq->emplace_back( + TVMStructSet(args_stack, stack_offset, builtin::kTVMFFIAnyUnionValue, arg)); } } - PrimExpr MakeAnyListSetItemCallPacked(const CallNode* op, const Op& lowered_op, - bool use_string_lookup) { + PrimExpr MakeAnyListSetItemCallPacked(const CallNode* op, const Op& lowered_op) { PrimExpr list_handle = op->args[0]; PrimExpr list_index = op->args[1]; - Call call = MakeCallPackedGeneric(op, 2, lowered_op, use_string_lookup, false); - PrimExpr value_stack = call->args[1]; - PrimExpr tcode_stack = call->args[2]; + Call call = MakeCallPackedGeneric(op, 2, lowered_op, false); + PrimExpr args_stack = call->args[1]; // The stack offset of return value stack_end - PrimExpr ret_offset = call->args[4]; + PrimExpr ret_offset = call->args[3]; auto& prep_seq = prep_seq_stack_.back(); prep_seq.emplace_back(Evaluate(call)); return Call(DataType::Int(32), Op::Get("tir.TVMBackendAnyListMoveFromPackedReturn"), - {list_handle, list_index, value_stack, tcode_stack, ret_offset}); + {list_handle, list_index, args_stack, ret_offset}); } /*! * \brief Generic tool to make low-level @@ -549,11 +547,10 @@ class BuiltinLower : public StmtExprMutator { * \param op The call * \param name_offset The beginning of function name and call packed section. * \param lowered_packed_op The target lowered op. - * \param use_string_lookup Whether to lookup function by string. * \param pass_last_arg_as_traced_value Whether to pass last argument as traced value */ Call MakeCallPackedGeneric(const CallNode* op, size_t name_offset, const Op& lowered_packed_op, - bool use_string_lookup, bool pass_last_arg_as_traced_value) { + bool pass_last_arg_as_traced_value) { auto& scope = alloca_scope_.back(); auto& prep_seq = prep_seq_stack_.back(); @@ -564,10 +561,6 @@ class BuiltinLower : public StmtExprMutator { size_t args_begin = name_offset + 1; size_t args_end = op->args.size(); - // cpacked expects a resource_handle parameter - if (!use_string_lookup) { - --args_end; - } size_t num_args = args_end - args_begin; // The extra one slot is for return value. @@ -577,9 +570,14 @@ class BuiltinLower : public StmtExprMutator { op = expr.as<CallNode>(); for (size_t i = 0; i < num_args; ++i) { - this->SetPackedArg(op->args[args_begin + i], scope.stack_value, scope.stack_tcode, - arg_stack_begin + i, &prep_seq); + this->SetPackedArg(op->args[args_begin + i], scope.stack_ffi_any, arg_stack_begin + i, + &prep_seq); } + // explicitly set return value to None to avoid bad state interpretation + prep_seq.emplace_back(TVMStructSet(scope.stack_ffi_any, num_args, builtin::kTVMFFIAnyTypeIndex, + ConstInt32(ffi::TypeIndex::kTVMFFINone))); + prep_seq.emplace_back(TVMStructSet(scope.stack_ffi_any, num_args, builtin::kTVMFFIAnyUnionValue, + make_zero(DataType::Int(64)))); // Verify stack size matches earlier value. if (is_precheck_) { scope.UpdateMax(); @@ -589,21 +587,10 @@ class BuiltinLower : public StmtExprMutator { scope.run_sizes.shape_stack = restore_shape_stack; scope.run_sizes.array_stack = restore_array_stack; scope.run_sizes.arg_stack = arg_stack_begin; - Array<PrimExpr> packed_args = {op->args[name_offset], scope.stack_value, - scope.stack_tcode->data, ConstInt32(arg_stack_begin), + Array<PrimExpr> packed_args = {op->args[name_offset], scope.stack_ffi_any, + ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + num_args)}; - // cpacked call resource_handle - if (!use_string_lookup) { - ICHECK(!pass_last_arg_as_traced_value); - PrimExpr last_arg = op->args[args_end]; - const VarNode* var_node = last_arg.as<VarNode>(); - if (var_node != nullptr) { - tir::Var resource_handle = GetRef<Var>(var_node); - packed_args.push_back(StringImm(resource_handle->name_hint)); - } else { - packed_args.push_back(last_arg); - } - } else if (pass_last_arg_as_traced_value) { + if (pass_last_arg_as_traced_value) { // pass in last element as traced value // used by call_packed_traced packed_args.push_back(op->args[op->args.size() - 1]);
