This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit d28b90db6aa189ae991eda5b71fd48d4df10e7fd Author: tqchen <[email protected]> AuthorDate: Sat May 3 08:08:42 2025 -0400 [FFI] Clarify and upgrade the function ffi --- ffi/include/tvm/ffi/c_api.h | 44 +++++++++++++++++------ ffi/include/tvm/ffi/container/container_details.h | 2 +- ffi/include/tvm/ffi/container/tuple.h | 4 +-- ffi/include/tvm/ffi/error.h | 2 +- ffi/include/tvm/ffi/function.h | 22 ++++++------ ffi/src/ffi/function.cc | 7 ++-- ffi/src/ffi/traceback.h | 5 +-- ffi/tests/cpp/test_array.cc | 3 +- ffi/tests/cpp/testing_object.h | 1 - include/tvm/ir/name_supply.h | 2 +- include/tvm/runtime/packed_func.h | 4 +-- python/tvm/ffi/cython/base.pxi | 12 +++---- python/tvm/ffi/cython/function.pxi | 18 ++++++---- src/arith/analyzer.cc | 21 +++++------ src/meta_schedule/utils.h | 2 +- src/relax/transform/tuning_api/database.cc | 4 +-- src/runtime/dso_library.cc | 9 ++--- src/runtime/library_module.cc | 4 +-- src/runtime/rpc/rpc_local_session.cc | 2 +- src/runtime/rpc/rpc_module.cc | 2 +- src/target/llvm/codegen_cpu.cc | 13 +++---- src/target/llvm/codegen_cpu.h | 2 +- src/target/source/codegen_c_host.cc | 2 +- 23 files changed, 106 insertions(+), 81 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 5d7e2ca679..6569050b9e 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -106,7 +106,7 @@ typedef enum { /*! \brief Error object. */ kTVMFFIError = 67, /*! \brief Function object. */ - kTVMFFIFunc = 68, + kTVMFFIFunction = 68, /*! \brief Array object. */ kTVMFFIArray = 69, /*! \brief Map object. */ @@ -232,11 +232,14 @@ typedef struct { * Safe call explicitly catches exception on function boundary. * * \param self The function handle - * \param num_args Number if input arguments + * \param num_args Number of input arguments * \param args The input arguments to the call. - * \param result Store output result, the result must not initially contain object value. + * \param result Store output result. * - * \return The call return 0 if call is successful. + * IMPORTANT: caller must initialize result->type_index to be kTVMFFINone, + * or any other value smaller than kTVMFFIStaticObjectBegin. + * + * \return The call returns 0 if call is successful. * It returns non-zero value if there is an error. * * Possible return error of the API functions: @@ -244,16 +247,28 @@ typedef struct { * * -1: error happens, can be retrieved by TVMFFIErrorMoveFromRaised * * -2: a frontend error occurred and recorded in the frontend. * - * \note We decided to leverage TVMFFIErrorMoveFromRaised and TVMFFIErrorSetRaisedByConsume + * \note We decided to leverage TVMFFIErrorMoveFromRaised and TVMFFIErrorSetRaised * for C function error propagation. This design choice, while * introducing a dependency for TLS runtime, simplifies error * propgation in chains of calls in compiler codegen. * As we do not need to propagate error through argument but simply * set them in the runtime environment. + * + * \sa TVMFFIErrorMoveFromRaised + * \sa TVMFFIErrorSetRaised + * \sa TVMFFIErrorSetRaisedByCStr */ typedef int (*TVMFFISafeCallType)(void* self, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result); +/*! + * \brief Object cell for function object. + */ +typedef struct { + /*! \brief A C API compatible call with exception catching. */ + TVMFFISafeCallType safe_call; +} TVMFFIFunctionCell; + /*! * \brief Getter that can take address of a field and set the result. * \param field The raw address of the field. @@ -401,11 +416,11 @@ TVM_FFI_DLL int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* ou * \param func The resource handle of the C callback. * \param args The input arguments to the call. * \param num_args The number of input arguments. - * \param result The output result. + * \param result The output result, caller must ensure result->type_index is set to kTVMFFINone. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIFuncCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result); +TVM_FFI_DLL int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, + TVMFFIAny* result); /*! * \brief Register the function to runtime's global table. @@ -417,7 +432,7 @@ TVM_FFI_DLL int TVMFFIFuncCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t * \param override Whether allow override already registered function. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, int override); +TVM_FFI_DLL int TVMFFIFunctionSetGlobal(const char* name, TVMFFIObjectHandle f, int override); /*! * \brief Get a global function. @@ -426,7 +441,7 @@ TVM_FFI_DLL int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, int * \param out the result function pointer, NULL if it does not exist. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out); +TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const char* name, TVMFFIObjectHandle* out); /*! * \brief Move the last error from the environment to result. @@ -659,6 +674,15 @@ inline TVMFFIErrorInfo* TVMFFIErrorGetErrorInfoPtr(TVMFFIObjectHandle obj) { return reinterpret_cast<TVMFFIErrorInfo*>(reinterpret_cast<char*>(obj) + sizeof(TVMFFIObject)); } +/*! + * \brief Get the data pointer of a function cell from a function object. + * \param obj The object handle. + * \return The data pointer. + */ +inline TVMFFIFunctionCell* TVMFFIFunctionGetFunctionCellPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast<TVMFFIFunctionCell*>(reinterpret_cast<char*>(obj) + sizeof(TVMFFIObject)); +} + /*! * \brief Get the data pointer of a shape array from a shape object. * \param obj The object handle. diff --git a/ffi/include/tvm/ffi/container/container_details.h b/ffi/include/tvm/ffi/container/container_details.h index c842218815..51e130f373 100644 --- a/ffi/include/tvm/ffi/container/container_details.h +++ b/ffi/include/tvm/ffi/container/container_details.h @@ -281,7 +281,7 @@ inline constexpr bool storage_enabled_v = std::is_same_v<T, Any> || TypeTraits<T * \tparam T The type to check. * \return True if T is compatible with Any, false otherwise. */ -template <typename ...T> +template <typename... T> inline constexpr bool all_storage_enabled_v = (storage_enabled_v<T> && ...); /** diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h index 1fff225aed..260237a08c 100644 --- a/ffi/include/tvm/ffi/container/tuple.h +++ b/ffi/include/tvm/ffi/container/tuple.h @@ -33,7 +33,6 @@ namespace tvm { namespace ffi { - /*! * \brief Typed tuple like std::tuple backed by ArrayObj container. * @@ -44,7 +43,8 @@ namespace ffi { template <typename... Types> class Tuple : public ObjectRef { public: - static_assert(details::all_storage_enabled_v<Types...>, "All types used in Tuple<...> must be compatible with Any"); + static_assert(details::all_storage_enabled_v<Types...>, + "All types used in Tuple<...> must be compatible with Any"); Tuple() : ObjectRef(MakeDefaultTupleNode()) {} Tuple(const Tuple<Types...>& other) : ObjectRef(other) {} diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h index c39b0ff64e..4e8603a5e3 100644 --- a/ffi/include/tvm/ffi/error.h +++ b/ffi/include/tvm/ffi/error.h @@ -134,7 +134,7 @@ class ErrorBuilder { public: explicit ErrorBuilder(std::string kind, std::string traceback, bool log_before_throw) : kind_(kind), traceback_(traceback), log_before_throw_(log_before_throw) {} - + // MSVC disable warning in error builder as it is exepected #ifdef _MSC_VER #pragma disagnostic push diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index f6baca5999..913237ec04 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -81,19 +81,18 @@ namespace ffi { * \brief Object container class that backs ffi::Function * \note Do not use this function directly, use ffi::Function */ -class FunctionObj : public Object { +class FunctionObj : public Object, public TVMFFIFunctionCell { public: typedef void (*FCall)(const FunctionObj*, const AnyView*, int32_t, Any*); - /*! \brief A C++ style call implementation */ + using TVMFFIFunctionCell::safe_call; + /*! \brief A C++ style call implementation, with exception propagation in c++ style. */ FCall call; - /*! \brief A C API compatible call with exception catching. */ - TVMFFISafeCallType safe_call; TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const { this->call(this, args, num_args, result); } - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunc; + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction; static constexpr const char* _type_key = "object.Function"; TVM_FFI_DECLARE_STATIC_OBJECT_INFO(FunctionObj, Object); @@ -105,8 +104,7 @@ class FunctionObj : public Object { // Implementing safe call style static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) { TVM_FFI_SAFE_CALL_BEGIN(); - result->type_index = kTVMFFINone; - result->v_int64 = 0; + TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin); FunctionObj* self = static_cast<FunctionObj*>(func); self->call(self, reinterpret_cast<const AnyView*>(args), num_args, reinterpret_cast<Any*>(result)); @@ -133,8 +131,8 @@ class FunctionObjImpl : public FunctionObj { * \param callable The type-erased callable object. */ explicit FunctionObjImpl(TCallable callable) : callable_(callable) { - this->call = Call; this->safe_call = SafeCall; + this->call = Call; } private: @@ -378,7 +376,7 @@ class Function : public ObjectRef { */ static std::optional<Function> GetGlobal(const char* name) { TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIFuncGetGlobal(name, &handle)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(name, &handle)); if (handle != nullptr) { return Function( details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<Object*>(handle))); @@ -421,7 +419,7 @@ class Function : public ObjectRef { */ static void SetGlobal(const char* name, Function func, bool override = false) { TVM_FFI_CHECK_SAFE_CALL( - TVMFFIFuncSetGlobal(name, details::ObjectUnsafe::GetHeader(func.get()), override)); + TVMFFIFunctionSetGlobal(name, details::ObjectUnsafe::GetHeader(func.get()), override)); } /*! * \brief List all global names @@ -706,7 +704,7 @@ inline constexpr bool use_default_type_traits_v<TypedFunction<FType>> = false; template <typename FType> struct TypeTraits<TypedFunction<FType>> : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunc; + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction; static TVM_FFI_INLINE void CopyToAnyView(const TypedFunction<FType>& src, TVMFFIAny* result) { TypeTraits<Function>::CopyToAnyView(src.packed(), result); @@ -717,7 +715,7 @@ struct TypeTraits<TypedFunction<FType>> : public TypeTraitsBase { } static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIFunc; + return src->type_index == TypeIndex::kTVMFFIFunction; } static TVM_FFI_INLINE TypedFunction<FType> CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc index 6fc959ff60..483c022566 100644 --- a/ffi/src/ffi/function.cc +++ b/ffi/src/ffi/function.cc @@ -241,7 +241,7 @@ int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out) { TVM_FFI_SAFE_CALL_END(); } -int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, int override) { +int TVMFFIFunctionSetGlobal(const char* name, TVMFFIObjectHandle f, int override) { using namespace tvm::ffi; TVM_FFI_SAFE_CALL_BEGIN(); GlobalFunctionTable::Global()->Update(name, GetRef<Function>(static_cast<FunctionObj*>(f)), @@ -249,7 +249,7 @@ int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, int override) { TVM_FFI_SAFE_CALL_END(); } -int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out) { +int TVMFFIFunctionGetGlobal(const char* name, TVMFFIObjectHandle* out) { using namespace tvm::ffi; TVM_FFI_SAFE_CALL_BEGIN(); const Function* fp = GlobalFunctionTable::Global()->Get(name); @@ -262,7 +262,8 @@ int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out) { TVM_FFI_SAFE_CALL_END(); } -int TVMFFIFuncCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) { +int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, + TVMFFIAny* result) { using namespace tvm::ffi; // NOTE: this is a tail call return reinterpret_cast<FunctionObj*>(func)->safe_call(func, args, num_args, result); diff --git a/ffi/src/ffi/traceback.h b/ffi/src/ffi/traceback.h index 1314e00fb8..0c07361fb5 100644 --- a/ffi/src/ffi/traceback.h +++ b/ffi/src/ffi/traceback.h @@ -106,10 +106,7 @@ inline bool ShouldExcludeFrame(const char* filename, const char* symbol) { */ inline bool ShouldStopTraceback(const char* filename, const char* symbol) { if (symbol != nullptr) { - if (strncmp(symbol, "TVMFFIFuncCall", 14) == 0) { - return true; - } - if (strncmp(symbol, "TVMFuncCall", 11) == 0) { + if (strncmp(symbol, "TVMFFIFunctionCall", 14) == 0) { return true; } // Python interpreter stack frames diff --git a/ffi/tests/cpp/test_array.cc b/ffi/tests/cpp/test_array.cc index 5062f6dd2d..b268cf1039 100644 --- a/ffi/tests/cpp/test_array.cc +++ b/ffi/tests/cpp/test_array.cc @@ -72,7 +72,8 @@ TEST(Array, Map) { // Basic functionality TInt x(1), y(1); Array<TInt> var_arr{x, y}; - Array<TNumber> expr_arr = var_arr.Map([](TInt var) -> TNumber { return TFloat(static_cast<double>(var->value + 1)); }); + Array<TNumber> expr_arr = + var_arr.Map([](TInt var) -> TNumber { return TFloat(static_cast<double>(var->value + 1)); }); EXPECT_NE(var_arr.get(), expr_arr.get()); EXPECT_TRUE(expr_arr[0]->IsInstance<TFloatObj>()); diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index d0db5ca094..69a91efc46 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -67,7 +67,6 @@ class TIntObj : public TNumberObj { TVM_FFI_REFLECTION_DEF(TIntObj).def_readonly("value", &TIntObj::value); - class TInt : public TNumber { public: explicit TInt(int64_t value) { data_ = make_object<TIntObj>(value); } diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index 136c95741e..df0fba16cb 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -25,10 +25,10 @@ #define TVM_IR_NAME_SUPPLY_H_ #include <algorithm> +#include <cctype> #include <string> #include <unordered_map> #include <utility> -#include <cctype> #include "tvm/ir/expr.h" diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 192c9a883c..228ce731e0 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -95,7 +95,7 @@ inline TVMFFIAny LegacyTVMArgValueToFFIAny(TVMValue value, int type_code) { return res; } case kTVMPackedFuncHandle: { - res.type_index = ffi::TypeIndex::kTVMFFIFunc; + res.type_index = ffi::TypeIndex::kTVMFFIFunction; res.v_obj = static_cast<TVMFFIObject*>(value.v_handle); return res; } @@ -215,7 +215,7 @@ inline void AnyViewToLegacyTVMArgValue(TVMFFIAny src, TVMValue* value, int* type value[0].v_handle = src.v_obj; break; } - case ffi::TypeIndex::kTVMFFIFunc: { + case ffi::TypeIndex::kTVMFFIFunction: { type_code[0] = kTVMPackedFuncHandle; value[0].v_handle = src.v_obj; break; diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 7fc73aa112..21be77e86f 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -57,7 +57,7 @@ cdef extern from "tvm/ffi/c_api.h": kTVMFFIStr = 65 kTVMFFIBytes = 66 kTVMFFIError = 67 - kTVMFFIFunc = 68 + kTVMFFIFunction = 68 kTVMFFIArray = 69 kTVMFFIMap = 70 kTVMFFIShape = 71 @@ -135,13 +135,13 @@ cdef extern from "tvm/ffi/c_api.h": int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil - int TVMFFIFuncCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result) nogil - int TVMFFIFuncCreate(void* self, TVMFFISafeCallType safe_call, + int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, + TVMFFIAny* result) nogil + int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void*), TVMFFIObjectHandle* out) nogil int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out) nogil - int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, int override) nogil - int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out) nogil + int TVMFFIFunctionSetGlobal(const char* name, TVMFFIObjectHandle f, int override) nogil + int TVMFFIFunctionGetGlobal(const char* name, TVMFFIObjectHandle* out) nogil void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) nogil void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) nogil void TVMFFIErrorSetRaisedCStr(const char* kind, const char* message) nogil diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index 4a9a26a637..ea96d700c6 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -145,7 +145,7 @@ cdef inline int FuncCall3(void* chandle, temp_args = [] make_args(args, &packed_args[0], temp_args) with nogil: - c_api_ret_code[0] = TVMFFIFuncCall( + c_api_ret_code[0] = TVMFFIFunctionCall( chandle, &packed_args[0], nargs, result ) return 0 @@ -168,7 +168,7 @@ cdef inline int FuncCall(void* chandle, make_args(args, &packed_args[0], temp_args) with nogil: - c_api_ret_code[0] = TVMFFIFuncCall(chandle, &packed_args[0], nargs, result) + c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], nargs, result) return 0 @@ -179,6 +179,9 @@ cdef inline int ConstructorCall(void* constructor_handle, """Call contructor of a handle function""" cdef TVMFFIAny result cdef int c_api_ret_code + # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone + result.type_index = kTVMFFINone + result.v_int64 = 0 FuncCall(constructor_handle, args, &result, &c_api_ret_code) CHECK_CALL(c_api_ret_code) handle[0] = result.v_ptr @@ -196,6 +199,9 @@ class Function(Object): def __call__(self, *args): cdef TVMFFIAny result cdef int c_api_ret_code + # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone + result.type_index = kTVMFFINone + result.v_int64 = 0 FuncCall((<Object>self).chandle, args, &result, &c_api_ret_code) # NOTE: logic is same as check_call # directly inline here to simplify traceback @@ -205,7 +211,7 @@ class Function(Object): raise_existing_error() raise move_from_last_error().py_error() -_register_object_by_index(kTVMFFIFunc, Function) +_register_object_by_index(kTVMFFIFunction, Function) def _register_global_func(name, pyfunc, override): @@ -216,14 +222,14 @@ def _register_global_func(name, pyfunc, override): if not isinstance(pyfunc, Function): pyfunc = _convert_to_ffi_func(pyfunc) - CHECK_CALL(TVMFFIFuncSetGlobal(c_str(name), (<Object>pyfunc).chandle, ioverride)) + CHECK_CALL(TVMFFIFunctionSetGlobal(c_str(name), (<Object>pyfunc).chandle, ioverride)) return pyfunc def _get_global_func(name, allow_missing): cdef TVMFFIObjectHandle chandle - CHECK_CALL(TVMFFIFuncGetGlobal(c_str(name), &chandle)) + CHECK_CALL(TVMFFIFunctionGetGlobal(c_str(name), &chandle)) if chandle != NULL: ret = Function.__new__(Function) (<Object>ret).chandle = chandle @@ -273,7 +279,7 @@ def _convert_to_ffi_func(object pyfunc): """Convert a python function to TVM FFI function""" cdef TVMFFIObjectHandle chandle Py_INCREF(pyfunc) - CHECK_CALL(TVMFFIFuncCreate( + CHECK_CALL(TVMFFIFunctionCreate( <void*>(pyfunc), tvm_ffi_callback, tvm_ffi_callback_deleter, diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index fa1f100a5f..3304ce7959 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -297,10 +297,9 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body_packed([](TVMArgs args, TVM } }); } else if (name == "rewrite_simplify") { - return PackedFunc( - [self](TVMArgs args, TVMRetValue* ret) { - *ret = self->rewrite_simplify(args[0].cast<PrimExpr>()); - }); + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + *ret = self->rewrite_simplify(args[0].cast<PrimExpr>()); + }); } else if (name == "get_rewrite_simplify_stats") { return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify.GetStatsCounters(); @@ -309,10 +308,9 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body_packed([](TVMArgs args, TVM return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { self->rewrite_simplify.ResetStatsCounters(); }); } else if (name == "canonical_simplify") { - return PackedFunc( - [self](TVMArgs args, TVMRetValue* ret) { - *ret = self->canonical_simplify(args[0].cast<PrimExpr>()); - }); + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + *ret = self->canonical_simplify(args[0].cast<PrimExpr>()); + }); } else if (name == "int_set") { return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { *ret = self->int_set(args[0].cast<Var>()); }); @@ -339,10 +337,9 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body_packed([](TVMArgs args, TVM *ret = ffi::Function::FromPacked(fexit); }); } else if (name == "can_prove_equal") { - return PackedFunc( - [self](TVMArgs args, TVMRetValue* ret) { - *ret = self->CanProveEqual(args[0].cast<PrimExpr>(), args[1].cast<PrimExpr>()); - }); + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + *ret = self->CanProveEqual(args[0].cast<PrimExpr>(), args[1].cast<PrimExpr>()); + }); } else if (name == "get_enabled_extensions") { return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { *ret = static_cast<std::int64_t>(self->rewrite_simplify.GetEnabledExtensions()); diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index aaebf3db7f..adf1334385 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -43,9 +43,9 @@ #include <tvm/tir/transform.h> #include <algorithm> +#include <sstream> #include <string> #include <unordered_set> -#include <sstream> #include <utility> #include <vector> diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc index 55c97d0b85..87d9a76cfb 100644 --- a/src/relax/transform/tuning_api/database.cc +++ b/src/relax/transform/tuning_api/database.cc @@ -29,7 +29,6 @@ #include "../../../meta_schedule/utils.h" - namespace tvm { namespace relax { @@ -235,7 +234,8 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, n->workloads2idx_.reserve(n_objs); workloads.reserve(n_objs); for (int i = 0; i < n_objs; ++i) { - meta_schedule::Workload workload = meta_schedule::Workload::FromJSON(json_objs[i].cast<ObjectRef>()); + meta_schedule::Workload workload = + meta_schedule::Workload::FromJSON(json_objs[i].cast<ObjectRef>()); n->workloads2idx_.emplace(workload, i); workloads.push_back(workload); } diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 2600f3fdb9..185b066e72 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -149,9 +149,10 @@ ObjectPtr<Library> CreateDSOLibraryObject(std::string library_path) { return n; } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body_typed([](std::string library_path) { - ObjectPtr<Library> n = CreateDSOLibraryObject(library_path); - return CreateModuleFromLibrary(n); -}); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_so") + .set_body_typed([](std::string library_path, std::string) { + ObjectPtr<Library> n = CreateDSOLibraryObject(library_path); + return CreateModuleFromLibrary(n); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 13aca3df45..00c5d71996 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -70,7 +70,7 @@ class LibraryModuleNode final : public ModuleNode { PackedFunc WrapPackedFunc(TVMFFISafeCallType faddr, const ObjectPtr<Object>& sptr_to_self) { return ffi::Function::FromPacked([faddr, sptr_to_self](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK_EQ(rv->type_index(), ffi::TypeIndex::kTVMFFINone); + ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast<const TVMFFIAny*>(args.data()), args.size(), reinterpret_cast<TVMFFIAny*>(rv))); }); @@ -82,7 +82,7 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) { *fp = FuncName; \ } // Initialize the functions - TVM_INIT_CONTEXT_FUNC(TVMFFIFuncCall); + TVM_INIT_CONTEXT_FUNC(TVMFFIFunctionCall); TVM_INIT_CONTEXT_FUNC(TVMFFIErrorSetRaisedByCStr); TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 45b1ad2a7b..d924be327d 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -83,7 +83,7 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_retu packed_args[1] = opaque_handle; if (ret_any.type_index == ffi::TypeIndex::kTVMFFIModule) { packed_args[0] = static_cast<int32_t>(kTVMModuleHandle); - } else if (ret_any.type_index == ffi::TypeIndex::kTVMFFIFunc) { + } else if (ret_any.type_index == ffi::TypeIndex::kTVMFFIFunction) { packed_args[0] = static_cast<int32_t>(kTVMPackedFuncHandle); } else { packed_args[0] = static_cast<int32_t>(kTVMObjectHandle); diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 23b982127c..8f96061004 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -117,7 +117,7 @@ class RPCWrappedFunc : public Object { packed_args[i] = RemoveSessMask(args[i].cast<DLDevice>()); break; } - case ffi::TypeIndex::kTVMFFIFunc: + case ffi::TypeIndex::kTVMFFIFunction: case ffi::TypeIndex::kTVMFFIModule: { packed_args[i] = UnwrapRemoteValueToHandle(args[i]); // hack, need to force set the type index to the correct one diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index d16d139f7e..332650b7fe 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -117,7 +117,7 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, // Runtime functions. // Defined in include/tvm/ffi/c_api.h: - // int TVMFFIFuncCall(TVMFunctionHandle func, TVMFFIAny* args, int32_t num_args, + // int TVMFFIFunctionCall(TVMFunctionHandle func, TVMFFIAny* args, int32_t num_args, // TVMFFIAny* result); ftype_tvm_ffi_func_call_ = ftype_tvm_ffi_c_func_; // Defined in include/tvm/ffi/c_api.h: @@ -155,8 +155,9 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, f_tvm_register_system_symbol_ = nullptr; } if (dynamic_lookup || system_lib_prefix_.defined()) { - f_tvm_ffi_func_call_ = llvm::Function::Create( - ftype_tvm_ffi_func_call_, llvm::Function::ExternalLinkage, "TVMFFIFuncCall", module_.get()); + f_tvm_ffi_func_call_ = + llvm::Function::Create(ftype_tvm_ffi_func_call_, llvm::Function::ExternalLinkage, + "TVMFFIFunctionCall", module_.get()); f_tvm_ffi_set_raised_by_c_str_ = llvm::Function::Create( ftype_tvm_ffi_error_set_raised_by_c_str_, llvm::Function::ExternalLinkage, "TVMFFIErrorSetRaisedByCStr", module_.get()); @@ -444,7 +445,7 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { } else { if (!dynamic_lookup) { gv_tvm_ffi_func_call_ = - InitContextPtr(llvmGetPointerTo(ftype_tvm_ffi_func_call_, 0), "__TVMFFIFuncCall"); + InitContextPtr(llvmGetPointerTo(ftype_tvm_ffi_func_call_, 0), "__TVMFFIFunctionCall"); gv_tvm_get_func_from_env_ = InitContextPtr(llvmGetPointerTo(ftype_tvm_get_func_from_env_, 0), "__TVMBackendGetFuncFromEnv"); gv_tvm_ffi_set_last_error_c_str_ = @@ -834,7 +835,7 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& if (use_env_lookup) { callee_ftype = ftype_tvm_ffi_func_call_; - callee_value = RuntimeTVMFFIFuncCall(); + callee_value = RuntimeTVMFFIFunctionCall(); call_args.push_back(GetPackedFuncHandle(func_name)); call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result}); } else { @@ -928,7 +929,7 @@ llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { return phi_rvalue; } -llvm::Value* CodeGenCPU::RuntimeTVMFFIFuncCall() { +llvm::Value* CodeGenCPU::RuntimeTVMFFIFunctionCall() { if (f_tvm_ffi_func_call_ != nullptr) return f_tvm_ffi_func_call_; return GetContextPtr(gv_tvm_ffi_func_call_); } diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index e7a87599c0..2bc4bb320a 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -115,7 +115,7 @@ class CodeGenCPU : public CodeGenLLVM { void InitGlobalContext(bool dynamic_lookup); llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name); llvm::Value* GetContextPtr(llvm::GlobalVariable* gv); - llvm::Value* RuntimeTVMFFIFuncCall(); + llvm::Value* RuntimeTVMFFIFunctionCall(); llvm::Value* RuntimeTVMGetFuncFromEnv(); llvm::Value* RuntimeTVMFFIErrorSetRaisedByCStr(); llvm::Value* RuntimeTVMParallelLaunch(); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index c34d1206ab..815904c946 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -244,7 +244,7 @@ void CodeGenCHost::PrintCallPacked(const CallNode* op) { this->PrintIndent(); if (op->op.same_as(builtin::tvm_call_packed_lowered())) { - this->stream << "if (TVMFFIFuncCall(" << packed_func_name << ", "; + this->stream << "if (TVMFFIFunctionCall(" << packed_func_name << ", "; } else { this->stream << "if (" << packed_func_name << "(NULL, "; }
