This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s3
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 0ddf789068eae441a2cdee9abec35be641d17d76
Author: tqchen <[email protected]>
AuthorDate: Sun May 4 20:49:32 2025 -0400

    [FFI] Simplify unpack traits
---
 ffi/include/tvm/ffi/container/variant.h     |  4 +-
 ffi/include/tvm/ffi/function.h              | 27 +++++----
 ffi/include/tvm/ffi/function_details.h      | 93 ++++-------------------------
 ffi/include/tvm/ffi/reflection/reflection.h |  9 +--
 include/tvm/ir/env_func.h                   | 11 +++-
 include/tvm/runtime/packed_func.h           | 19 +++---
 src/runtime/metal/metal_module.mm           |  2 +-
 src/tir/ir/expr.cc                          | 55 ++++++++---------
 src/tir/schedule/instruction_traits.h       |  9 +--
 9 files changed, 90 insertions(+), 139 deletions(-)

diff --git a/ffi/include/tvm/ffi/container/variant.h 
b/ffi/include/tvm/ffi/container/variant.h
index 1455a5b34a..c8b58ba49e 100644
--- a/ffi/include/tvm/ffi/container/variant.h
+++ b/ffi/include/tvm/ffi/container/variant.h
@@ -94,12 +94,12 @@ class Variant {
 
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE T get() const& {
-    return data_.cast<T>();
+    return data_.template cast<T>();
   }
 
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE T get() && {
-    return std::move(data_).cast<T>();
+    return std::move(data_).template cast<T>();
   }
 
   TVM_FFI_INLINE std::string GetTypeKey() const { return data_.GetTypeKey(); }
diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h
index 913237ec04..e4b67e4a76 100644
--- a/ffi/include/tvm/ffi/function.h
+++ b/ffi/include/tvm/ffi/function.h
@@ -453,8 +453,8 @@ class Function : public ObjectRef {
   static Function FromUnpacked(TCallable callable) {
     using FuncInfo = details::FunctionInfo<TCallable>;
     auto call_packed = [callable](const AnyView* args, int32_t num_args, Any* 
rv) mutable -> void {
-      details::unpack_call<typename FuncInfo::RetType, 
FuncInfo::num_args>(nullptr, callable, args,
-                                                                           
num_args, rv);
+      details::unpack_call<typename FuncInfo::RetType>(
+          std::make_index_sequence<FuncInfo::num_args>{}, nullptr, callable, 
args, num_args, rv);
     };
     return FromPackedInternal(call_packed);
   }
@@ -469,8 +469,8 @@ class Function : public ObjectRef {
     using FuncInfo = details::FunctionInfo<TCallable>;
     auto call_packed = [callable, name](const AnyView* args, int32_t num_args,
                                         Any* rv) mutable -> void {
-      details::unpack_call<typename FuncInfo::RetType, 
FuncInfo::num_args>(&name, callable, args,
-                                                                           
num_args, rv);
+      details::unpack_call<typename FuncInfo::RetType>(
+          std::make_index_sequence<FuncInfo::num_args>{}, &name, callable, 
args, num_args, rv);
     };
     return FromPackedInternal(call_packed);
   }
@@ -674,7 +674,16 @@ class TypedFunction<R(Args...)> {
    * \returns The return value.
    */
   TVM_FFI_INLINE R operator()(Args... args) const {
-    return details::typed_packed_call_dispatcher<R>::run(packed_, 
std::forward<Args>(args)...);
+    if constexpr (std::is_same_v<R, void>) {
+      packed_(std::forward<Args>(args)...);
+    } else {
+      Any res = packed_(std::forward<Args>(args)...);
+      if constexpr (std::is_same_v<R, Any>) {
+        return res;
+      } else {
+        return std::move(res).cast<R>();
+      }
+    }
   }
   /*!
    * \brief convert to PackedFunc
@@ -850,8 +859,6 @@ class Function::Registry {
     return *this;
   }
 
-  operator details::EmptyStruct() const { return details::EmptyStruct(); }
-
  protected:
   /*!
    * \brief set the body of the function to be f
@@ -875,8 +882,8 @@ inline int32_t TypeKeyToIndex(const char* type_key) {
   return type_index;
 }
 
-#define TVM_FFI_REG_VAR_DEF \
-  static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::details::EmptyStruct 
__mk_##TVMFFI
+#define TVM_FFI_FUNC_REG_VAR_DEF \
+  static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::Function::Registry& 
__##TVMFFIFuncReg
 
 /*!
  * \brief Register a function globally.
@@ -888,7 +895,7 @@ inline int32_t TypeKeyToIndex(const char* type_key) {
  * \endcode
  */
 #define TVM_FFI_REGISTER_GLOBAL(OpName) \
-  TVM_FFI_STR_CONCAT(TVM_FFI_REG_VAR_DEF, __COUNTER__) = 
::tvm::ffi::Function::Registry(OpName)
+  TVM_FFI_STR_CONCAT(TVM_FFI_FUNC_REG_VAR_DEF, __COUNTER__) = 
::tvm::ffi::Function::Registry(OpName)
 }  // namespace ffi
 }  // namespace tvm
 #endif  // TVM_FFI_FUNCTION_H_
diff --git a/ffi/include/tvm/ffi/function_details.h 
b/ffi/include/tvm/ffi/function_details.h
index 34e166428e..f47a253a58 100644
--- a/ffi/include/tvm/ffi/function_details.h
+++ b/ffi/include/tvm/ffi/function_details.h
@@ -36,14 +36,6 @@ namespace tvm {
 namespace ffi {
 namespace details {
 
-/*!
- * \brief Empty struct, used to reduce unused variable in global static 
initialization.
- */
-struct EmptyStruct {
-  /*! \brief at least one byte to ensure address */
-  uint8_t pad;
-};
-
 template <typename ArgType>
 struct Arg2Str {
   template <size_t i>
@@ -167,46 +159,10 @@ class ArgValueWithContext {
   FGetFuncSignature f_sig_;
 };
 
-template <typename R, int nleft, int index, typename F>
-struct unpack_call_dispatcher {
-  template <typename... Args>
-  TVM_FFI_INLINE static void run(const std::string* optional_name, 
FGetFuncSignature f_sig,
-                                 const F& f, const AnyView* args, int32_t 
num_args, Any* rv,
-                                 Args&&... unpacked_args) {
-    // construct a movable argument value
-    // which allows potential move of argument to the input of F.
-    unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
-        optional_name, f_sig, f, args, num_args, rv, 
std::forward<Args>(unpacked_args)...,
-        ArgValueWithContext(args, index, optional_name, f_sig));
-  }
-};
-
-template <typename R, int index, typename F>
-struct unpack_call_dispatcher<R, 0, index, F> {
-  template <typename... Args>
-  TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const 
F& f, const AnyView*,
-                                 int32_t, Any* rv, Args&&... unpacked_args) {
-    using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
-    if constexpr (std::is_same_v<RetType, R>) {
-      *rv = f(std::forward<Args>(unpacked_args)...);
-    } else {
-      *rv = R(f(std::forward<Args>(unpacked_args)...));
-    }
-  }
-};
-
-template <int index, typename F>
-struct unpack_call_dispatcher<void, 0, index, F> {
-  template <typename... Args>
-  TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const 
F& f, const AnyView*,
-                                 int32_t, Any*, Args&&... unpacked_args) {
-    f(std::forward<Args>(unpacked_args)...);
-  }
-};
-
-template <typename R, int nargs, typename F>
-TVM_FFI_INLINE void unpack_call(const std::string* optional_name, const F& f, 
const AnyView* args,
-                                int32_t num_args, Any* rv) {
+template <typename R, std::size_t... Is, typename F>
+TVM_FFI_INLINE void unpack_call(std::index_sequence<Is...>, const std::string* 
optional_name,
+                                const F& f, [[maybe_unused]] const AnyView* 
args,
+                                [[maybe_unused]] int32_t num_args, 
[[maybe_unused]] Any* rv) {
   using FuncInfo = FunctionInfo<F>;
   FGetFuncSignature f_sig = FuncInfo::Sig;
 
@@ -214,47 +170,20 @@ TVM_FFI_INLINE void unpack_call(const std::string* 
optional_name, const F& f, co
 #ifndef _MSC_VER
   static_assert(FuncInfo::unpacked_supported, "The function signature do not 
support unpacked");
 #endif
-
+  constexpr size_t nargs = sizeof...(Is);
   if (nargs != num_args) {
     TVM_FFI_THROW(TypeError) << "Mismatched number of arguments when calling: 
`"
                              << (optional_name == nullptr ? "" : 
*optional_name)
                              << (f_sig == nullptr ? "" : (*f_sig)()) << "`. 
Expected " << nargs
                              << " but got " << num_args << " arguments";
   }
-  unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, args, 
num_args, rv);
-}
-
-template <typename FType>
-struct unpack_call_by_signature {};
-
-template <typename R, typename... Args>
-struct unpack_call_by_signature<R(Args...)> {
-  template <typename F>
-  TVM_FFI_INLINE static void run(const F& f, const AnyView* args, int32_t 
num_args, Any* rv) {
-    unpack_call<R, sizeof...(Args)>(nullptr, f, args, num_args, rv);
+  // use index sequence to do recursive-less unpacking
+  if constexpr (std::is_same_v<R, void>) {
+    f(ArgValueWithContext(args, Is, optional_name, f_sig)...);
+  } else {
+    *rv = R(f(ArgValueWithContext(args, Is, optional_name, f_sig)...));
   }
-};
-
-template <typename R>
-struct typed_packed_call_dispatcher {
-  template <typename F, typename... Args>
-  TVM_FFI_INLINE static R run(const F& f, Args&&... args) {
-    Any res = f(std::forward<Args>(args)...);
-    if constexpr (std::is_same_v<R, Any>) {
-      return res;
-    } else {
-      return std::move(res).cast<R>();
-    }
-  }
-};
-
-template <>
-struct typed_packed_call_dispatcher<void> {
-  template <typename F, typename... Args>
-  TVM_FFI_INLINE static void run(const F& f, Args&&... args) {
-    f(std::forward<Args>(args)...);
-  }
-};
+}
 
 /*!
  * \brief Move the safe call raised error to the caller
diff --git a/ffi/include/tvm/ffi/reflection/reflection.h 
b/ffi/include/tvm/ffi/reflection/reflection.h
index b1088c1eda..8ce2d22ddd 100644
--- a/ffi/include/tvm/ffi/reflection/reflection.h
+++ b/ffi/include/tvm/ffi/reflection/reflection.h
@@ -66,8 +66,6 @@ class ReflectionDef {
     return *this;
   }
 
-  operator details::EmptyStruct() const { return details::EmptyStruct(); }
-
  private:
   template <typename Class, typename T>
   void RegisterField(const char* name, T Class::*field_ptr, bool readonly) {
@@ -138,11 +136,14 @@ class ReflectionFieldGetter {
   const TVMFFIFieldInfo* field_info_;
 };
 
+#define TVM_FFI_REFLECTION_REG_VAR_DEF \
+  static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::details::ReflectionDef& 
__TVMFFIReflectionReg
+
 /*!
  * helper macro to define a reflection definition for an object
  */
-#define TVM_FFI_REFLECTION_DEF(TypeName)                 \
-  TVM_FFI_STR_CONCAT(TVM_FFI_REG_VAR_DEF, __COUNTER__) = \
+#define TVM_FFI_REFLECTION_DEF(TypeName)                            \
+  TVM_FFI_STR_CONCAT(TVM_FFI_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
       
::tvm::ffi::details::ReflectionDef(TypeName::_GetOrAllocRuntimeTypeIndex())
 }  // namespace details
 }  // namespace ffi
diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h
index c83711756b..c44e102cca 100644
--- a/include/tvm/ir/env_func.h
+++ b/include/tvm/ir/env_func.h
@@ -139,7 +139,16 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef {
   R operator()(Args... args) const {
     const EnvFuncNode* n = operator->();
     ICHECK(n != nullptr);
-    return ffi::details::typed_packed_call_dispatcher<R>::run(n->func, 
std::forward<Args>(args)...);
+    if constexpr (std::is_same_v<R, void>) {
+      n->func(std::forward<Args>(args)...);
+    } else {
+      Any res = n->func(std::forward<Args>(args)...);
+      if constexpr (std::is_same_v<R, Any>) {
+        return res;
+      } else {
+        return std::move(res).cast<R>();
+      }
+    }
   }
   /*! \brief specify container node */
   using ContainerType = EnvFuncNode;
diff --git a/include/tvm/runtime/packed_func.h 
b/include/tvm/runtime/packed_func.h
index 228ce731e0..3609987d55 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -343,7 +343,8 @@ struct ModuleVTableEntryHelper<R (T::*)(Args...) const> {
   using MemFnType = R (T::*)(Args...) const;
   static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, 
TVMArgs args) {
     auto wrapped = [self, f](Args... args) -> R { return 
(self->*f)(std::forward<Args>(args)...); };
-    ffi::details::unpack_call<R, sizeof...(Args)>(nullptr, wrapped, 
args.data(), args.size(), rv);
+    ffi::details::unpack_call<R>(std::make_index_sequence<sizeof...(Args)>{}, 
nullptr, wrapped,
+                                 args.data(), args.size(), rv);
   }
 };
 
@@ -352,7 +353,8 @@ struct ModuleVTableEntryHelper<R (T::*)(Args...)> {
   using MemFnType = R (T::*)(Args...);
   static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, 
TVMArgs args) {
     auto wrapped = [self, f](Args... args) -> R { return 
(self->*f)(std::forward<Args>(args)...); };
-    ffi::details::unpack_call<R, sizeof...(Args)>(nullptr, wrapped, 
args.data(), args.size(), rv);
+    ffi::details::unpack_call<R>(std::make_index_sequence<sizeof...(Args)>{}, 
nullptr, wrapped,
+                                 args.data(), args.size(), rv);
   }
 };
 
@@ -361,8 +363,8 @@ struct ModuleVTableEntryHelper<void (T::*)(Args...) const> {
   using MemFnType = void (T::*)(Args...) const;
   static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, 
TVMArgs args) {
     auto wrapped = [self, f](Args... args) -> void { 
(self->*f)(std::forward<Args>(args)...); };
-    ffi::details::unpack_call<void, sizeof...(Args)>(nullptr, wrapped, 
args.data(), args.size(),
-                                                     rv);
+    
ffi::details::unpack_call<void>(std::make_index_sequence<sizeof...(Args)>{}, 
nullptr, wrapped,
+                                    args.data(), args.size(), rv);
   }
 };
 
@@ -371,8 +373,8 @@ struct ModuleVTableEntryHelper<void (T::*)(Args...)> {
   using MemFnType = void (T::*)(Args...);
   static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, 
TVMArgs args) {
     auto wrapped = [self, f](Args... args) -> void { 
(self->*f)(std::forward<Args>(args)...); };
-    ffi::details::unpack_call<void, sizeof...(Args)>(nullptr, wrapped, 
args.data(), args.size(),
-                                                     rv);
+    
ffi::details::unpack_call<void>(std::make_index_sequence<sizeof...(Args)>{}, 
nullptr, wrapped,
+                                    args.data(), args.size(), rv);
   }
 };
 }  // namespace details
@@ -446,8 +448,9 @@ struct ModuleVTableEntryHelper<void (T::*)(Args...)> {
     TVM_FFI_SAFE_CALL_BEGIN();                                                 
              \
     using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>;    
              \
     static std::string name = #ExportName;                                     
              \
-    ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType, 
FuncInfo::num_args>(        \
-        &name, Function, reinterpret_cast<const ::tvm::ffi::AnyView*>(args), 
num_args,       \
+    ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>(              
              \
+        std::make_index_sequence<FuncInfo::num_args>{}, &name, Function,       
              \
+        reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args,          
              \
         reinterpret_cast<::tvm::ffi::Any*>(result));                           
              \
     TVM_FFI_SAFE_CALL_END();                                                   
              \
   }                                                                            
              \
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index b338274231..a9a7faaefa 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -211,7 +211,7 @@ class MetalWrappedFunc {
       id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
       [encoder setComputePipelineState:scache_[device_id]];
       for (size_t i = 0; i < num_buffer_args_; ++i) {
-        void* buf = args[static_cast<int>(i)];
+        void* buf = args[static_cast<int>(i)].cast<void*>();
         [encoder setBuffer:(id<MTLBuffer>)(buf) offset:0 atIndex:i];
       }
       if (num_pack_args_ != 0) {
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 491488ba30..bec6c04085 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -558,36 +558,37 @@ Call::Call(DataType dtype, RelaxExpr op, Array<PrimExpr> 
args, Span span) {
 }
 
 TVM_REGISTER_GLOBAL("tir.Call")
-    .set_body_typed([](Optional<DataType> dtype, RelaxExpr op,
-                       Array<Variant<runtime::String, DLDataType, IterVar, 
BufferRegion, PrimExpr>> args,
-                       Span span) {
-      Array<PrimExpr> prim_expr_args;
-      for (const auto& it : args) {
-        if (auto opt_str = it.as<String>()) {
-          prim_expr_args.push_back(StringImm(opt_str.value()));
-        } else if (auto opt_dtype = it.as<DLDataType>()) {
-          
prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value())));
-        } else if (const auto* iter_var = it.as<IterVarNode>()) {
-          prim_expr_args.push_back(iter_var->var);
-        } else if (const auto* br = it.as<BufferRegionNode>()) {
-          Array<PrimExpr> indices;
-          for (Range r : br->region) {
-            if (is_one(r->extent)) {
-              indices.push_back(r->min);
-            } else if (r->extent.as<IntImmNode>()) {
-              indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 
1), r->extent));
+    .set_body_typed(
+        [](Optional<DataType> dtype, RelaxExpr op,
+           Array<Variant<runtime::String, DLDataType, IterVar, BufferRegion, 
PrimExpr>> args,
+           Span span) {
+          Array<PrimExpr> prim_expr_args;
+          for (const auto& it : args) {
+            if (auto opt_str = it.as<String>()) {
+              prim_expr_args.push_back(StringImm(opt_str.value()));
+            } else if (auto opt_dtype = it.as<DLDataType>()) {
+              
prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value())));
+            } else if (const auto* iter_var = it.as<IterVarNode>()) {
+              prim_expr_args.push_back(iter_var->var);
+            } else if (const auto* br = it.as<BufferRegionNode>()) {
+              Array<PrimExpr> indices;
+              for (Range r : br->region) {
+                if (is_one(r->extent)) {
+                  indices.push_back(r->min);
+                } else if (r->extent.as<IntImmNode>()) {
+                  indices.push_back(tir::Ramp(r->min, 
make_const(r->min->dtype, 1), r->extent));
+                } else {
+                  LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: "
+                             << GetRef<BufferRegion>(br);
+                }
+              }
+              prim_expr_args.push_back(BufferLoad(br->buffer, indices));
             } else {
-              LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: "
-                         << GetRef<BufferRegion>(br);
+              prim_expr_args.push_back(Downcast<PrimExpr>(it));
             }
           }
-          prim_expr_args.push_back(BufferLoad(br->buffer, indices));
-        } else {
-          prim_expr_args.push_back(Downcast<PrimExpr>(it));
-        }
-      }
-      return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, span);
-    });
+          return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, 
span);
+        });
 
 TVM_REGISTER_NODE_TYPE(CallNode);
 
diff --git a/src/tir/schedule/instruction_traits.h 
b/src/tir/schedule/instruction_traits.h
index d292bd3dc7..29165750e8 100644
--- a/src/tir/schedule/instruction_traits.h
+++ b/src/tir/schedule/instruction_traits.h
@@ -319,8 +319,9 @@ Array<Any> 
UnpackedInstTraits<TTraits>::ApplyToSchedule(const Schedule& sch,
   PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void {
     constexpr size_t kNumArgs = details::NumArgs<method_type>;
     ICHECK_EQ(args.size(), kNumArgs);
-    ffi::details::unpack_call<return_type, kNumArgs>(nullptr, 
TTraits::UnpackedApplyToSchedule,
-                                                     args.data(), args.size(), 
rv);
+    
ffi::details::unpack_call<return_type>(std::make_index_sequence<kNumArgs>{}, 
nullptr,
+                                           TTraits::UnpackedApplyToSchedule, 
args.data(),
+                                           args.size(), rv);
   });
   ffi::Any rv;
   pf.CallPacked(ffi::PackedArgs(packed_args, kNumArgs), &rv);
@@ -349,8 +350,8 @@ String UnpackedInstTraits<TTraits>::AsPython(const 
Array<Any>& inputs, const Arr
   PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void {
     constexpr size_t kNumArgs = details::NumArgs<method_type>;
     ICHECK_EQ(args.size(), kNumArgs);
-    ffi::details::unpack_call<return_type, kNumArgs>(nullptr, 
TTraits::UnpackedAsPython,
-                                                     args.data(), args.size(), 
rv);
+    
ffi::details::unpack_call<return_type>(std::make_index_sequence<kNumArgs>{}, 
nullptr,
+                                           TTraits::UnpackedAsPython, 
args.data(), args.size(), rv);
   });
   ffi::Any rv;
   pf.CallPacked(ffi::PackedArgs(packed_args, kNumArgs), &rv);

Reply via email to