This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ed4c92c5ce [FFI] Introduce GlobalDef for function registration (#18111)
ed4c92c5ce is described below

commit ed4c92c5ce51b681cb52f6d88f34ee3f39f94b8e
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jul 3 21:56:43 2025 -0400

    [FFI] Introduce GlobalDef for function registration (#18111)
    
    This PR introduces reflection::GlobalDef for function
    registration, which makes the global function registration API
    more closely aligned with the new reflection style.
    
    We will send followup PRs to transition some of the existing
    mechanisms to the new one.
---
 ffi/include/tvm/ffi/reflection/reflection.h | 133 ++++++++++++++++++++++++++--
 ffi/src/ffi/container.cc                    |  64 ++++++-------
 ffi/src/ffi/function.cc                     |  54 +++++------
 ffi/src/ffi/ndarray.cc                      |  26 +++---
 ffi/src/ffi/object.cc                       |   6 +-
 ffi/src/ffi/testing.cc                      |  60 ++++++-------
 ffi/tests/cpp/test_function.cc              |   9 +-
 ffi/tests/cpp/test_reflection.cc            |  11 +++
 8 files changed, 239 insertions(+), 124 deletions(-)

diff --git a/ffi/include/tvm/ffi/reflection/reflection.h 
b/ffi/include/tvm/ffi/reflection/reflection.h
index 0a5e836e1a..ea079183b8 100644
--- a/ffi/include/tvm/ffi/reflection/reflection.h
+++ b/ffi/include/tvm/ffi/reflection/reflection.h
@@ -118,20 +118,47 @@ class ReflectionDefBase {
       info->doc = TVMFFIByteArray{value, 
std::char_traits<char>::length(value)};
     }
   }
+
   template <typename Class, typename R, typename... Args>
   static TVM_FFI_INLINE Function GetMethod(std::string name, R 
(Class::*func)(Args...)) {
-    auto fwrap = [func](const Class* target, Args... params) -> R {
-      return 
(const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
-    };
-    return ffi::Function::FromTyped(fwrap, name);
+    static_assert(std::is_base_of_v<ObjectRef, Class> || 
std::is_base_of_v<Object, Class>,
+                  "Class must be derived from ObjectRef or Object");
+    if constexpr (std::is_base_of_v<ObjectRef, Class>) {
+      auto fwrap = [func](Class target, Args... params) -> R {
+        // call method pointer
+        return (target.*func)(std::forward<Args>(params)...);
+      };
+      return ffi::Function::FromTyped(fwrap, name);
+    }
+
+    if constexpr (std::is_base_of_v<Object, Class>) {
+      auto fwrap = [func](const Class* target, Args... params) -> R {
+        // call method pointer
+        return 
(const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
+      };
+      return ffi::Function::FromTyped(fwrap, name);
+    }
   }
 
   template <typename Class, typename R, typename... Args>
   static TVM_FFI_INLINE Function GetMethod(std::string name, R 
(Class::*func)(Args...) const) {
-    auto fwrap = [func](const Class* target, Args... params) -> R {
-      return (target->*func)(std::forward<Args>(params)...);
-    };
-    return ffi::Function::FromTyped(fwrap, name);
+    static_assert(std::is_base_of_v<ObjectRef, Class> || 
std::is_base_of_v<Object, Class>,
+                  "Class must be derived from ObjectRef or Object");
+    if constexpr (std::is_base_of_v<ObjectRef, Class>) {
+      auto fwrap = [func](const Class target, Args... params) -> R {
+        // call method pointer
+        return (target.*func)(std::forward<Args>(params)...);
+      };
+      return ffi::Function::FromTyped(fwrap, name);
+    }
+
+    if constexpr (std::is_base_of_v<Object, Class>) {
+      auto fwrap = [func](const Class* target, Args... params) -> R {
+        // call method pointer
+        return (target->*func)(std::forward<Args>(params)...);
+      };
+      return ffi::Function::FromTyped(fwrap, name);
+    }
   }
 
   template <typename Class, typename Func>
@@ -140,6 +167,96 @@ class ReflectionDefBase {
   }
 };
 
+class GlobalDef : public ReflectionDefBase {
+ public:
+  /*
+   * \brief Define a global function.
+   *
+   * \tparam Func The function type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the function.
+   * \param func The function to be registered.
+   * \param extra The extra arguments that can be docstring.
+   *
+   * \return The reflection definition.
+   */
+  template <typename Func, typename... Extra>
+  GlobalDef& def(const char* name, Func&& func, Extra&&... extra) {
+    RegisterFunc(name, ffi::Function::FromTyped(std::forward<Func>(func), 
std::string(name)),
+                 std::forward<Extra>(extra)...);
+    return *this;
+  }
+
+  /*
+   * \brief Define a global function in ffi::PackedArgs format.
+   *
+   * \tparam Func The function type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the function.
+   * \param func The function to be registered.
+   * \param extra The extra arguments that can be docstring.
+   *
+   * \return The reflection definition.
+   */
+  template <typename Func, typename... Extra>
+  GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) {
+    RegisterFunc(name, ffi::Function::FromPacked(func), 
std::forward<Extra>(extra)...);
+    return *this;
+  }
+
+  /*
+   * \brief Expose a class method as a global function.
+   *
+   * An argument will be added to the first position if the function is not 
static.
+   *
+   * \tparam Class The class type.
+   * \tparam Func The function type.
+   *
+   * \param name The name of the method.
+   * \param func The function to be registered.
+   *
+   * \return The reflection definition.
+   */
+  template <typename Func, typename... Extra>
+  GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
+    RegisterFunc(name, GetMethod_(std::string(name), std::forward<Func>(func)),
+                 std::forward<Extra>(extra)...);
+    return *this;
+  }
+
+ private:
+  template <typename Func>
+  static TVM_FFI_INLINE Function GetMethod_(std::string name, Func&& func) {
+    return ffi::Function::FromTyped(std::forward<Func>(func), name);
+  }
+
+  template <typename Class, typename R, typename... Args>
+  static TVM_FFI_INLINE Function GetMethod_(std::string name, R 
(Class::*func)(Args...) const) {
+    return GetMethod<Class>(std::string(name), func);
+  }
+
+  template <typename Class, typename R, typename... Args>
+  static TVM_FFI_INLINE Function GetMethod_(std::string name, R 
(Class::*func)(Args...)) {
+    return GetMethod<Class>(std::string(name), func);
+  }
+
+  template <typename... Extra>
+  void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) {
+    TVMFFIMethodInfo info;
+    info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
+    info.doc = TVMFFIByteArray{nullptr, 0};
+    info.type_schema = TVMFFIByteArray{nullptr, 0};
+    info.flags = 0;
+    // obtain the method function
+    info.method = AnyView(func).CopyToTVMFFIAny();
+    // apply method info traits
+    ((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0));
+  }
+};
+
 template <typename Class>
 class ObjectDef : public ReflectionDefBase {
  public:
diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc
index 0ca2034aa2..a0dc660b42 100644
--- a/ffi/src/ffi/container.cc
+++ b/ffi/src/ffi/container.cc
@@ -25,40 +25,11 @@
 #include <tvm/ffi/container/map.h>
 #include <tvm/ffi/container/shape.h>
 #include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/reflection.h>
 
 namespace tvm {
 namespace ffi {
 
-TVM_FFI_REGISTER_GLOBAL("ffi.Array").set_body_packed([](ffi::PackedArgs args, 
Any* ret) {
-  *ret = Array<Any>(args.data(), args.data() + args.size());
-});
-
-TVM_FFI_REGISTER_GLOBAL("ffi.ArrayGetItem")
-    .set_body_typed([](const ffi::ArrayObj* n, int64_t i) -> Any { return 
n->at(i); });
-
-TVM_FFI_REGISTER_GLOBAL("ffi.ArraySize").set_body_typed([](const 
ffi::ArrayObj* n) -> int64_t {
-  return static_cast<int64_t>(n->size());
-});
-// Map
-TVM_FFI_REGISTER_GLOBAL("ffi.Map").set_body_packed([](ffi::PackedArgs args, 
Any* ret) {
-  TVM_FFI_ICHECK_EQ(args.size() % 2, 0);
-  Map<Any, Any> data;
-  for (int i = 0; i < args.size(); i += 2) {
-    data.Set(args[i], args[i + 1]);
-  }
-  *ret = data;
-});
-
-TVM_FFI_REGISTER_GLOBAL("ffi.MapSize").set_body_typed([](const ffi::MapObj* n) 
-> int64_t {
-  return static_cast<int64_t>(n->size());
-});
-
-TVM_FFI_REGISTER_GLOBAL("ffi.MapGetItem")
-    .set_body_typed([](const ffi::MapObj* n, const Any& k) -> Any { return 
n->at(k); });
-
-TVM_FFI_REGISTER_GLOBAL("ffi.MapCount")
-    .set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return 
n->count(k); });
-
 // Favor struct outside function scope as MSVC may have bug for in fn scope 
struct.
 class MapForwardIterFunctor {
  public:
@@ -86,10 +57,33 @@ class MapForwardIterFunctor {
   ffi::MapObj::iterator end_;
 };
 
-TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor")
-    .set_body_typed([](const ffi::MapObj* n) -> ffi::Function {
-      return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), 
n->end()));
-    });
-
+TVM_FFI_STATIC_INIT_BLOCK({
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef()
+      .def_packed("ffi.Array",
+                  [](ffi::PackedArgs args, Any* ret) {
+                    *ret = Array<Any>(args.data(), args.data() + args.size());
+                  })
+      .def("ffi.ArrayGetItem", [](const ffi::ArrayObj* n, int64_t i) -> Any { 
return n->at(i); })
+      .def("ffi.ArraySize",
+           [](const ffi::ArrayObj* n) -> int64_t { return 
static_cast<int64_t>(n->size()); })
+      .def_packed("ffi.Map",
+                  [](ffi::PackedArgs args, Any* ret) {
+                    TVM_FFI_ICHECK_EQ(args.size() % 2, 0);
+                    Map<Any, Any> data;
+                    for (int i = 0; i < args.size(); i += 2) {
+                      data.Set(args[i], args[i + 1]);
+                    }
+                    *ret = data;
+                  })
+      .def("ffi.MapSize",
+           [](const ffi::MapObj* n) -> int64_t { return 
static_cast<int64_t>(n->size()); })
+      .def("ffi.MapGetItem", [](const ffi::MapObj* n, const Any& k) -> Any { 
return n->at(k); })
+      .def("ffi.MapCount",
+           [](const ffi::MapObj* n, const Any& k) -> int64_t { return 
n->count(k); })
+      .def("ffi.MapForwardIterFunctor", [](const ffi::MapObj* n) -> 
ffi::Function {
+        return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), 
n->end()));
+      });
+});
 }  // namespace ffi
 }  // namespace tvm
diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc
index b6bfe73af3..4df86a213a 100644
--- a/ffi/src/ffi/function.cc
+++ b/ffi/src/ffi/function.cc
@@ -28,6 +28,7 @@
 #include <tvm/ffi/error.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/memory.h>
+#include <tvm/ffi/reflection/reflection.h>
 #include <tvm/ffi/string.h>
 
 namespace tvm {
@@ -307,31 +308,30 @@ int TVMFFIEnvRegisterCAPI(const TVMFFIByteArray* name, 
void* symbol) {
   TVM_FFI_SAFE_CALL_END();
 }
 
-TVM_FFI_REGISTER_GLOBAL("ffi.FunctionRemoveGlobal")
-    .set_body_typed([](const tvm::ffi::String& name) -> bool {
-      return tvm::ffi::GlobalFunctionTable::Global()->Remove(name);
-    });
-
-TVM_FFI_REGISTER_GLOBAL("ffi.FunctionListGlobalNamesFunctor").set_body_typed([]()
 {
-  // NOTE: we return functor instead of array
-  // so list global function names do not need to depend on array
-  // this is because list global function names usually is a core api that 
happens
-  // before array ffi functions are available.
-  tvm::ffi::Array<tvm::ffi::String> names = 
tvm::ffi::GlobalFunctionTable::Global()->ListNames();
-  auto return_functor = [names](int64_t i) -> tvm::ffi::Any {
-    if (i < 0) {
-      return names.size();
-    } else {
-      return names[i];
-    }
-  };
-  return tvm::ffi::Function::FromTyped(return_functor);
-});
-
-TVM_FFI_REGISTER_GLOBAL("ffi.String").set_body_typed([](tvm::ffi::String val) 
-> tvm::ffi::String {
-  return val;
-});
-
-TVM_FFI_REGISTER_GLOBAL("ffi.Bytes").set_body_typed([](tvm::ffi::Bytes val) -> 
tvm::ffi::Bytes {
-  return val;
+TVM_FFI_STATIC_INIT_BLOCK({
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef()
+      .def("ffi.FunctionRemoveGlobal",
+           [](const tvm::ffi::String& name) -> bool {
+             return tvm::ffi::GlobalFunctionTable::Global()->Remove(name);
+           })
+      .def("ffi.FunctionListGlobalNamesFunctor",
+           []() {
+             // NOTE: we return functor instead of array
+             // so list global function names do not need to depend on array
+             // this is because list global function names usually is a core 
api that happens
+             // before array ffi functions are available.
+             tvm::ffi::Array<tvm::ffi::String> names =
+                 tvm::ffi::GlobalFunctionTable::Global()->ListNames();
+             auto return_functor = [names](int64_t i) -> tvm::ffi::Any {
+               if (i < 0) {
+                 return names.size();
+               } else {
+                 return names[i];
+               }
+             };
+             return tvm::ffi::Function::FromTyped(return_functor);
+           })
+      .def("ffi.String", [](tvm::ffi::String val) -> tvm::ffi::String { return 
val; })
+      .def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return 
val; });
 });
diff --git a/ffi/src/ffi/ndarray.cc b/ffi/src/ffi/ndarray.cc
index f3c48c8ad5..dc3a5fb1ec 100644
--- a/ffi/src/ffi/ndarray.cc
+++ b/ffi/src/ffi/ndarray.cc
@@ -23,23 +23,27 @@
 #include <tvm/ffi/c_api.h>
 #include <tvm/ffi/container/ndarray.h>
 #include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/reflection.h>
 
 namespace tvm {
 namespace ffi {
 
-// Shape
-TVM_FFI_REGISTER_GLOBAL("ffi.Shape").set_body_packed([](ffi::PackedArgs args, 
Any* ret) {
-  int64_t* mutable_data;
-  ObjectPtr<ShapeObj> shape = details::MakeEmptyShape(args.size(), 
&mutable_data);
-  for (int i = 0; i < args.size(); ++i) {
-    if (auto opt_int = args[i].try_cast<int64_t>()) {
-      mutable_data[i] = *opt_int;
-    } else {
-      TVM_FFI_THROW(ValueError) << "Expect shape to take list of int 
arguments";
+TVM_FFI_STATIC_INIT_BLOCK({
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def_packed("ffi.Shape", [](ffi::PackedArgs args, Any* ret) 
{
+    int64_t* mutable_data;
+    ObjectPtr<ShapeObj> shape = details::MakeEmptyShape(args.size(), 
&mutable_data);
+    for (int i = 0; i < args.size(); ++i) {
+      if (auto opt_int = args[i].try_cast<int64_t>()) {
+        mutable_data[i] = *opt_int;
+      } else {
+        TVM_FFI_THROW(ValueError) << "Expect shape to take list of int 
arguments";
+      }
     }
-  }
-  *ret = Shape(shape);
+    *ret = Shape(shape);
+  });
 });
+
 }  // namespace ffi
 }  // namespace tvm
 
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 9b193b757f..84a83e4b73 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -24,6 +24,7 @@
 #include <tvm/ffi/container/map.h>
 #include <tvm/ffi/error.h>
 #include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/reflection.h>
 #include <tvm/ffi/string.h>
 
 #include <memory>
@@ -404,7 +405,10 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* 
ret) {
   *ret = ObjectRef(ptr);
 }
 
-TVM_FFI_REGISTER_GLOBAL("ffi.MakeObjectFromPackedArgs").set_body_packed(MakeObjectFromPackedArgs);
+TVM_FFI_STATIC_INIT_BLOCK({
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs", 
MakeObjectFromPackedArgs);
+});
 
 }  // namespace ffi
 }  // namespace tvm
diff --git a/ffi/src/ffi/testing.cc b/ffi/src/ffi/testing.cc
index 6bc7968eab..a1747a279d 100644
--- a/ffi/src/ffi/testing.cc
+++ b/ffi/src/ffi/testing.cc
@@ -54,6 +54,12 @@ class TestObjectDerived : public TestObjectBase {
   TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjectDerived, TestObjectBase);
 };
 
+void TestRaiseError(String kind, String msg) {
+  throw ffi::Error(kind, msg, TVM_FFI_TRACEBACK_HERE);
+}
+
+void TestApply(Function f, PackedArgs args, Any* ret) { f.CallPacked(args, 
ret); }
+
 TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
 
@@ -66,41 +72,27 @@ TVM_FFI_STATIC_INIT_BLOCK({
   refl::ObjectDef<TestObjectDerived>()
       .def_ro("v_map", &TestObjectDerived::v_map)
       .def_ro("v_array", &TestObjectDerived::v_array);
-});
-
-void TestRaiseError(String kind, String msg) {
-  throw ffi::Error(kind, msg, TVM_FFI_TRACEBACK_HERE);
-}
-
-TVM_FFI_REGISTER_GLOBAL("testing.test_raise_error").set_body_typed(TestRaiseError);
-
-TVM_FFI_REGISTER_GLOBAL("testing.nop").set_body_packed([](PackedArgs args, 
Any* ret) {
-  *ret = args[0];
-});
-
-TVM_FFI_REGISTER_GLOBAL("testing.echo").set_body_packed([](PackedArgs args, 
Any* ret) {
-  *ret = args[0];
-});
-
-void TestApply(Function f, PackedArgs args, Any* ret) { f.CallPacked(args, 
ret); }
-
-TVM_FFI_REGISTER_GLOBAL("testing.apply").set_body_packed([](PackedArgs args, 
Any* ret) {
-  auto f = args[0].cast<Function>();
-  TestApply(f, args.Slice(1), ret);
-});
-
-TVM_FFI_REGISTER_GLOBAL("testing.run_check_signal").set_body_typed([](int 
nsec) {
-  for (int i = 0; i < nsec; ++i) {
-    if (TVMFFIEnvCheckSignals() != 0) {
-      throw ffi::EnvErrorAlreadySet();
-    }
-    std::this_thread::sleep_for(std::chrono::seconds(1));
-  }
-  std::cout << "Function finished without catching signal" << std::endl;
-});
 
-TVM_FFI_REGISTER_GLOBAL("testing.object_use_count").set_body_typed([](const 
Object* obj) {
-  return obj->use_count();
+  refl::GlobalDef()
+      .def("testing.test_raise_error", TestRaiseError)
+      .def_packed("testing.nop", [](PackedArgs args, Any* ret) { *ret = 
args[0]; })
+      .def_packed("testing.echo", [](PackedArgs args, Any* ret) { *ret = 
args[0]; })
+      .def_packed("testing.apply",
+                  [](PackedArgs args, Any* ret) {
+                    auto f = args[0].cast<Function>();
+                    TestApply(f, args.Slice(1), ret);
+                  })
+      .def("testing.run_check_signal",
+           [](int nsec) {
+             for (int i = 0; i < nsec; ++i) {
+               if (TVMFFIEnvCheckSignals() != 0) {
+                 throw ffi::EnvErrorAlreadySet();
+               }
+               std::this_thread::sleep_for(std::chrono::seconds(1));
+             }
+             std::cout << "Function finished without catching signal" << 
std::endl;
+           })
+      .def("testing.object_use_count", [](const Object* obj) { return 
obj->use_count(); });
 });
 
 }  // namespace ffi
diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc
index 526e1ad03e..c3c484f333 100644
--- a/ffi/tests/cpp/test_function.cc
+++ b/ffi/tests/cpp/test_function.cc
@@ -131,7 +131,7 @@ TEST(Func, FromTyped) {
           EXPECT_EQ(error.kind(), "TypeError");
           EXPECT_EQ(error.message(),
                     "Mismatched number of arguments when calling: "
-                    "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> 
object.Function`. "
+                    "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> 
ffi.Function`. "
                     "Expected 3 but got 0 arguments");
           throw;
         }
@@ -236,11 +236,4 @@ TEST(Func, ObjectRefWithFallbackTraits) {
       ::tvm::ffi::Error);
 }
 
-TVM_FFI_REGISTER_GLOBAL("testing.Int_GetValue").set_body_method(&TIntObj::GetValue);
-
-TEST(Func, Register) {
-  Function fget_value = Function::GetGlobalRequired("testing.Int_GetValue");
-  TInt a(12);
-  EXPECT_EQ(fget_value(a).cast<int>(), 12);
-}
 }  // namespace
diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc
index 17494744ef..ce15fc14c4 100644
--- a/ffi/tests/cpp/test_reflection.cc
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -153,4 +153,15 @@ TEST(Reflection, ForEachFieldInfo) {
   EXPECT_EQ(field_name_to_offset["z"], 16 + sizeof(TVMFFIObject));
 }
 
+TVM_FFI_STATIC_INIT_BLOCK({
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def_method("testing.Int_GetValue", &TIntObj::GetValue);
+});
+
+TEST(Reflection, FuncRegister) {
+  Function fget_value = Function::GetGlobalRequired("testing.Int_GetValue");
+  TInt a(12);
+  EXPECT_EQ(fget_value(a).cast<int>(), 12);
+}
+
 }  // namespace

Reply via email to