This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s3 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 0ddf789068eae441a2cdee9abec35be641d17d76 Author: tqchen <[email protected]> AuthorDate: Sun May 4 20:49:32 2025 -0400 [FFI] Simplify unpack traits --- ffi/include/tvm/ffi/container/variant.h | 4 +- ffi/include/tvm/ffi/function.h | 27 +++++---- ffi/include/tvm/ffi/function_details.h | 93 ++++------------------------- ffi/include/tvm/ffi/reflection/reflection.h | 9 +-- include/tvm/ir/env_func.h | 11 +++- include/tvm/runtime/packed_func.h | 19 +++--- src/runtime/metal/metal_module.mm | 2 +- src/tir/ir/expr.cc | 55 ++++++++--------- src/tir/schedule/instruction_traits.h | 9 +-- 9 files changed, 90 insertions(+), 139 deletions(-) diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index 1455a5b34a..c8b58ba49e 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -94,12 +94,12 @@ class Variant { template <typename T, typename = enable_if_variant_contains_t<T>> TVM_FFI_INLINE T get() const& { - return data_.cast<T>(); + return data_.template cast<T>(); } template <typename T, typename = enable_if_variant_contains_t<T>> TVM_FFI_INLINE T get() && { - return std::move(data_).cast<T>(); + return std::move(data_).template cast<T>(); } TVM_FFI_INLINE std::string GetTypeKey() const { return data_.GetTypeKey(); } diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 913237ec04..e4b67e4a76 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -453,8 +453,8 @@ class Function : public ObjectRef { static Function FromUnpacked(TCallable callable) { using FuncInfo = details::FunctionInfo<TCallable>; auto call_packed = [callable](const AnyView* args, int32_t num_args, Any* rv) mutable -> void { - details::unpack_call<typename FuncInfo::RetType, FuncInfo::num_args>(nullptr, callable, args, - num_args, rv); + details::unpack_call<typename FuncInfo::RetType>( + std::make_index_sequence<FuncInfo::num_args>{}, nullptr, callable, args, num_args, rv); }; return FromPackedInternal(call_packed); } @@ -469,8 +469,8 @@ class Function : public ObjectRef { using FuncInfo = details::FunctionInfo<TCallable>; auto call_packed = [callable, name](const AnyView* args, int32_t num_args, Any* rv) mutable -> void { - details::unpack_call<typename FuncInfo::RetType, FuncInfo::num_args>(&name, callable, args, - num_args, rv); + details::unpack_call<typename FuncInfo::RetType>( + std::make_index_sequence<FuncInfo::num_args>{}, &name, callable, args, num_args, rv); }; return FromPackedInternal(call_packed); } @@ -674,7 +674,16 @@ class TypedFunction<R(Args...)> { * \returns The return value. */ TVM_FFI_INLINE R operator()(Args... args) const { - return details::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...); + if constexpr (std::is_same_v<R, void>) { + packed_(std::forward<Args>(args)...); + } else { + Any res = packed_(std::forward<Args>(args)...); + if constexpr (std::is_same_v<R, Any>) { + return res; + } else { + return std::move(res).cast<R>(); + } + } } /*! * \brief convert to PackedFunc @@ -850,8 +859,6 @@ class Function::Registry { return *this; } - operator details::EmptyStruct() const { return details::EmptyStruct(); } - protected: /*! * \brief set the body of the function to be f @@ -875,8 +882,8 @@ inline int32_t TypeKeyToIndex(const char* type_key) { return type_index; } -#define TVM_FFI_REG_VAR_DEF \ - static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::details::EmptyStruct __mk_##TVMFFI +#define TVM_FFI_FUNC_REG_VAR_DEF \ + static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::Function::Registry& __##TVMFFIFuncReg /*! * \brief Register a function globally. @@ -888,7 +895,7 @@ inline int32_t TypeKeyToIndex(const char* type_key) { * \endcode */ #define TVM_FFI_REGISTER_GLOBAL(OpName) \ - TVM_FFI_STR_CONCAT(TVM_FFI_REG_VAR_DEF, __COUNTER__) = ::tvm::ffi::Function::Registry(OpName) + TVM_FFI_STR_CONCAT(TVM_FFI_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::ffi::Function::Registry(OpName) } // namespace ffi } // namespace tvm #endif // TVM_FFI_FUNCTION_H_ diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h index 34e166428e..f47a253a58 100644 --- a/ffi/include/tvm/ffi/function_details.h +++ b/ffi/include/tvm/ffi/function_details.h @@ -36,14 +36,6 @@ namespace tvm { namespace ffi { namespace details { -/*! - * \brief Empty struct, used to reduce unused variable in global static initialization. - */ -struct EmptyStruct { - /*! \brief at least one byte to ensure address */ - uint8_t pad; -}; - template <typename ArgType> struct Arg2Str { template <size_t i> @@ -167,46 +159,10 @@ class ArgValueWithContext { FGetFuncSignature f_sig_; }; -template <typename R, int nleft, int index, typename F> -struct unpack_call_dispatcher { - template <typename... Args> - TVM_FFI_INLINE static void run(const std::string* optional_name, FGetFuncSignature f_sig, - const F& f, const AnyView* args, int32_t num_args, Any* rv, - Args&&... unpacked_args) { - // construct a movable argument value - // which allows potential move of argument to the input of F. - unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run( - optional_name, f_sig, f, args, num_args, rv, std::forward<Args>(unpacked_args)..., - ArgValueWithContext(args, index, optional_name, f_sig)); - } -}; - -template <typename R, int index, typename F> -struct unpack_call_dispatcher<R, 0, index, F> { - template <typename... Args> - TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const F& f, const AnyView*, - int32_t, Any* rv, Args&&... unpacked_args) { - using RetType = decltype(f(std::forward<Args>(unpacked_args)...)); - if constexpr (std::is_same_v<RetType, R>) { - *rv = f(std::forward<Args>(unpacked_args)...); - } else { - *rv = R(f(std::forward<Args>(unpacked_args)...)); - } - } -}; - -template <int index, typename F> -struct unpack_call_dispatcher<void, 0, index, F> { - template <typename... Args> - TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const F& f, const AnyView*, - int32_t, Any*, Args&&... unpacked_args) { - f(std::forward<Args>(unpacked_args)...); - } -}; - -template <typename R, int nargs, typename F> -TVM_FFI_INLINE void unpack_call(const std::string* optional_name, const F& f, const AnyView* args, - int32_t num_args, Any* rv) { +template <typename R, std::size_t... Is, typename F> +TVM_FFI_INLINE void unpack_call(std::index_sequence<Is...>, const std::string* optional_name, + const F& f, [[maybe_unused]] const AnyView* args, + [[maybe_unused]] int32_t num_args, [[maybe_unused]] Any* rv) { using FuncInfo = FunctionInfo<F>; FGetFuncSignature f_sig = FuncInfo::Sig; @@ -214,47 +170,20 @@ TVM_FFI_INLINE void unpack_call(const std::string* optional_name, const F& f, co #ifndef _MSC_VER static_assert(FuncInfo::unpacked_supported, "The function signature do not support unpacked"); #endif - + constexpr size_t nargs = sizeof...(Is); if (nargs != num_args) { TVM_FFI_THROW(TypeError) << "Mismatched number of arguments when calling: `" << (optional_name == nullptr ? "" : *optional_name) << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << nargs << " but got " << num_args << " arguments"; } - unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, args, num_args, rv); -} - -template <typename FType> -struct unpack_call_by_signature {}; - -template <typename R, typename... Args> -struct unpack_call_by_signature<R(Args...)> { - template <typename F> - TVM_FFI_INLINE static void run(const F& f, const AnyView* args, int32_t num_args, Any* rv) { - unpack_call<R, sizeof...(Args)>(nullptr, f, args, num_args, rv); + // use index sequence to do recursive-less unpacking + if constexpr (std::is_same_v<R, void>) { + f(ArgValueWithContext(args, Is, optional_name, f_sig)...); + } else { + *rv = R(f(ArgValueWithContext(args, Is, optional_name, f_sig)...)); } -}; - -template <typename R> -struct typed_packed_call_dispatcher { - template <typename F, typename... Args> - TVM_FFI_INLINE static R run(const F& f, Args&&... args) { - Any res = f(std::forward<Args>(args)...); - if constexpr (std::is_same_v<R, Any>) { - return res; - } else { - return std::move(res).cast<R>(); - } - } -}; - -template <> -struct typed_packed_call_dispatcher<void> { - template <typename F, typename... Args> - TVM_FFI_INLINE static void run(const F& f, Args&&... args) { - f(std::forward<Args>(args)...); - } -}; +} /*! * \brief Move the safe call raised error to the caller diff --git a/ffi/include/tvm/ffi/reflection/reflection.h b/ffi/include/tvm/ffi/reflection/reflection.h index b1088c1eda..8ce2d22ddd 100644 --- a/ffi/include/tvm/ffi/reflection/reflection.h +++ b/ffi/include/tvm/ffi/reflection/reflection.h @@ -66,8 +66,6 @@ class ReflectionDef { return *this; } - operator details::EmptyStruct() const { return details::EmptyStruct(); } - private: template <typename Class, typename T> void RegisterField(const char* name, T Class::*field_ptr, bool readonly) { @@ -138,11 +136,14 @@ class ReflectionFieldGetter { const TVMFFIFieldInfo* field_info_; }; +#define TVM_FFI_REFLECTION_REG_VAR_DEF \ + static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::details::ReflectionDef& __TVMFFIReflectionReg + /*! * helper macro to define a reflection definition for an object */ -#define TVM_FFI_REFLECTION_DEF(TypeName) \ - TVM_FFI_STR_CONCAT(TVM_FFI_REG_VAR_DEF, __COUNTER__) = \ +#define TVM_FFI_REFLECTION_DEF(TypeName) \ + TVM_FFI_STR_CONCAT(TVM_FFI_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ ::tvm::ffi::details::ReflectionDef(TypeName::_GetOrAllocRuntimeTypeIndex()) } // namespace details } // namespace ffi diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index c83711756b..c44e102cca 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -139,7 +139,16 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef { R operator()(Args... args) const { const EnvFuncNode* n = operator->(); ICHECK(n != nullptr); - return ffi::details::typed_packed_call_dispatcher<R>::run(n->func, std::forward<Args>(args)...); + if constexpr (std::is_same_v<R, void>) { + n->func(std::forward<Args>(args)...); + } else { + Any res = n->func(std::forward<Args>(args)...); + if constexpr (std::is_same_v<R, Any>) { + return res; + } else { + return std::move(res).cast<R>(); + } + } } /*! \brief specify container node */ using ContainerType = EnvFuncNode; diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 228ce731e0..3609987d55 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -343,7 +343,8 @@ struct ModuleVTableEntryHelper<R (T::*)(Args...) const> { using MemFnType = R (T::*)(Args...) const; static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward<Args>(args)...); }; - ffi::details::unpack_call<R, sizeof...(Args)>(nullptr, wrapped, args.data(), args.size(), rv); + ffi::details::unpack_call<R>(std::make_index_sequence<sizeof...(Args)>{}, nullptr, wrapped, + args.data(), args.size(), rv); } }; @@ -352,7 +353,8 @@ struct ModuleVTableEntryHelper<R (T::*)(Args...)> { using MemFnType = R (T::*)(Args...); static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward<Args>(args)...); }; - ffi::details::unpack_call<R, sizeof...(Args)>(nullptr, wrapped, args.data(), args.size(), rv); + ffi::details::unpack_call<R>(std::make_index_sequence<sizeof...(Args)>{}, nullptr, wrapped, + args.data(), args.size(), rv); } }; @@ -361,8 +363,8 @@ struct ModuleVTableEntryHelper<void (T::*)(Args...) const> { using MemFnType = void (T::*)(Args...) const; static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward<Args>(args)...); }; - ffi::details::unpack_call<void, sizeof...(Args)>(nullptr, wrapped, args.data(), args.size(), - rv); + ffi::details::unpack_call<void>(std::make_index_sequence<sizeof...(Args)>{}, nullptr, wrapped, + args.data(), args.size(), rv); } }; @@ -371,8 +373,8 @@ struct ModuleVTableEntryHelper<void (T::*)(Args...)> { using MemFnType = void (T::*)(Args...); static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward<Args>(args)...); }; - ffi::details::unpack_call<void, sizeof...(Args)>(nullptr, wrapped, args.data(), args.size(), - rv); + ffi::details::unpack_call<void>(std::make_index_sequence<sizeof...(Args)>{}, nullptr, wrapped, + args.data(), args.size(), rv); } }; } // namespace details @@ -446,8 +448,9 @@ struct ModuleVTableEntryHelper<void (T::*)(Args...)> { TVM_FFI_SAFE_CALL_BEGIN(); \ using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \ static std::string name = #ExportName; \ - ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType, FuncInfo::num_args>( \ - &name, Function, reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args, \ + ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>( \ + std::make_index_sequence<FuncInfo::num_args>{}, &name, Function, \ + reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args, \ reinterpret_cast<::tvm::ffi::Any*>(result)); \ TVM_FFI_SAFE_CALL_END(); \ } \ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index b338274231..a9a7faaefa 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -211,7 +211,7 @@ class MetalWrappedFunc { id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder]; [encoder setComputePipelineState:scache_[device_id]]; for (size_t i = 0; i < num_buffer_args_; ++i) { - void* buf = args[static_cast<int>(i)]; + void* buf = args[static_cast<int>(i)].cast<void*>(); [encoder setBuffer:(id<MTLBuffer>)(buf) offset:0 atIndex:i]; } if (num_pack_args_ != 0) { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 491488ba30..bec6c04085 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -558,36 +558,37 @@ Call::Call(DataType dtype, RelaxExpr op, Array<PrimExpr> args, Span span) { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](Optional<DataType> dtype, RelaxExpr op, - Array<Variant<runtime::String, DLDataType, IterVar, BufferRegion, PrimExpr>> args, - Span span) { - Array<PrimExpr> prim_expr_args; - for (const auto& it : args) { - if (auto opt_str = it.as<String>()) { - prim_expr_args.push_back(StringImm(opt_str.value())); - } else if (auto opt_dtype = it.as<DLDataType>()) { - prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value()))); - } else if (const auto* iter_var = it.as<IterVarNode>()) { - prim_expr_args.push_back(iter_var->var); - } else if (const auto* br = it.as<BufferRegionNode>()) { - Array<PrimExpr> indices; - for (Range r : br->region) { - if (is_one(r->extent)) { - indices.push_back(r->min); - } else if (r->extent.as<IntImmNode>()) { - indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); + .set_body_typed( + [](Optional<DataType> dtype, RelaxExpr op, + Array<Variant<runtime::String, DLDataType, IterVar, BufferRegion, PrimExpr>> args, + Span span) { + Array<PrimExpr> prim_expr_args; + for (const auto& it : args) { + if (auto opt_str = it.as<String>()) { + prim_expr_args.push_back(StringImm(opt_str.value())); + } else if (auto opt_dtype = it.as<DLDataType>()) { + prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value()))); + } else if (const auto* iter_var = it.as<IterVarNode>()) { + prim_expr_args.push_back(iter_var->var); + } else if (const auto* br = it.as<BufferRegionNode>()) { + Array<PrimExpr> indices; + for (Range r : br->region) { + if (is_one(r->extent)) { + indices.push_back(r->min); + } else if (r->extent.as<IntImmNode>()) { + indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); + } else { + LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " + << GetRef<BufferRegion>(br); + } + } + prim_expr_args.push_back(BufferLoad(br->buffer, indices)); } else { - LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " - << GetRef<BufferRegion>(br); + prim_expr_args.push_back(Downcast<PrimExpr>(it)); } } - prim_expr_args.push_back(BufferLoad(br->buffer, indices)); - } else { - prim_expr_args.push_back(Downcast<PrimExpr>(it)); - } - } - return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, span); - }); + return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, span); + }); TVM_REGISTER_NODE_TYPE(CallNode); diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index d292bd3dc7..29165750e8 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -319,8 +319,9 @@ Array<Any> UnpackedInstTraits<TTraits>::ApplyToSchedule(const Schedule& sch, PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { constexpr size_t kNumArgs = details::NumArgs<method_type>; ICHECK_EQ(args.size(), kNumArgs); - ffi::details::unpack_call<return_type, kNumArgs>(nullptr, TTraits::UnpackedApplyToSchedule, - args.data(), args.size(), rv); + ffi::details::unpack_call<return_type>(std::make_index_sequence<kNumArgs>{}, nullptr, + TTraits::UnpackedApplyToSchedule, args.data(), + args.size(), rv); }); ffi::Any rv; pf.CallPacked(ffi::PackedArgs(packed_args, kNumArgs), &rv); @@ -349,8 +350,8 @@ String UnpackedInstTraits<TTraits>::AsPython(const Array<Any>& inputs, const Arr PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { constexpr size_t kNumArgs = details::NumArgs<method_type>; ICHECK_EQ(args.size(), kNumArgs); - ffi::details::unpack_call<return_type, kNumArgs>(nullptr, TTraits::UnpackedAsPython, - args.data(), args.size(), rv); + ffi::details::unpack_call<return_type>(std::make_index_sequence<kNumArgs>{}, nullptr, + TTraits::UnpackedAsPython, args.data(), args.size(), rv); }); ffi::Any rv; pf.CallPacked(ffi::PackedArgs(packed_args, kNumArgs), &rv);
