This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new e268eb1d feat(ffi)!: centralize object creation with 
CreateEmptyObject/HasCreator (#501)
e268eb1d is described below

commit e268eb1d2d55ae97bdaab7ab52f381f436fd0c82
Author: Junru Shao <[email protected]>
AuthorDate: Tue Mar 10 13:48:04 2026 -0700

    feat(ffi)!: centralize object creation with CreateEmptyObject/HasCreator 
(#501)
    
    ## Summary
    
    - Add `CreateEmptyObject()` and `HasCreator()` inline helpers in
    `function.h` that unify the two-step "check creator, call creator"
    pattern. These try the native `metadata->creator` fast path first, then
    fall back to the `__ffi_new__` type attribute for Python-defined types.
    - Deduplicate four call sites in `creator.h`, `init.h`,
    `reflection_extra.cc`, and `serialization.cc` to use the new centralized
    helpers.
    - Qualify bare `details::` references to `::tvm::ffi::details::` in
    reflection headers (`overload.h`, `registry.h`, `init.h`) to prevent
    ADL/lookup ambiguity when included from other namespaces.
    
    ## Test plan
    
    - [x] Existing C++ tests pass (object creation via reflection,
    serialization roundtrip)
    - [x] Existing Python tests pass (c_class / py_class decorator,
    reflection-based construction)
    - [x] CI lint (clang-format, clang-tidy) passes on changed headers
---
 include/tvm/ffi/reflection/creator.h  | 75 +++++++++++++++++++++++++++++------
 include/tvm/ffi/reflection/init.h     | 20 ++++------
 include/tvm/ffi/reflection/overload.h | 13 +++---
 include/tvm/ffi/reflection/registry.h | 21 +++++-----
 src/ffi/extra/reflection_extra.cc     | 10 +----
 src/ffi/extra/serialization.cc        | 10 +----
 6 files changed, 90 insertions(+), 59 deletions(-)

diff --git a/include/tvm/ffi/reflection/creator.h 
b/include/tvm/ffi/reflection/creator.h
index 300ad512..977e6ded 100644
--- a/include/tvm/ffi/reflection/creator.h
+++ b/include/tvm/ffi/reflection/creator.h
@@ -23,13 +23,72 @@
 #ifndef TVM_FFI_REFLECTION_CREATOR_H_
 #define TVM_FFI_REFLECTION_CREATOR_H_
 
-#include <tvm/ffi/any.h>
 #include <tvm/ffi/container/map.h>
+#include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/accessor.h>
 #include <tvm/ffi/string.h>
 
 namespace tvm {
 namespace ffi {
+/*!
+ * \brief Create an empty object via the type's native creator or 
``__ffi_new__`` type attr.
+ *
+ * Falls back to the ``__ffi_new__`` type attribute (used by Python-defined 
types)
+ * when the native ``metadata->creator`` is NULL.
+ *
+ * \param type_info The type info for the object to create.
+ * \return An owned ObjectPtr to the newly allocated (zero-initialized) object.
+ * \throws RuntimeError if neither creator nor __ffi_new__ is available.
+ */
+inline ObjectPtr<Object> CreateEmptyObject(const TVMFFITypeInfo* type_info) {
+  // Fast path: native C++ creator
+  if (type_info->metadata != nullptr && type_info->metadata->creator != 
nullptr) {
+    TVMFFIObjectHandle handle;
+    TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
+    return 
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+  }
+  // Fallback: __ffi_new__ type attr (Python-defined types)
+  constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
+  const TVMFFITypeAttrColumn* column = 
TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
+  if (column != nullptr) {
+    int32_t offset = type_info->type_index - column->begin_index;
+    if (offset >= 0 && offset < column->size) {
+      AnyView attr_view = AnyView::CopyFromTVMFFIAny(column->data[offset]);
+      if (auto opt_func = attr_view.try_cast<Function>()) {
+        ObjectRef obj_ref = (*opt_func)().cast<ObjectRef>();
+        return 
details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(obj_ref));
+      }
+    }
+  }
+  TVM_FFI_THROW(RuntimeError) << "Type `" << 
TypeIndexToTypeKey(type_info->type_index)
+                              << "` does not support reflection creation"
+                              << " (no native creator or __ffi_new__ type 
attr)";
+}
+
+/*!
+ * \brief Check whether a type supports reflection creation.
+ *
+ * Returns true if the type has a native creator or a ``__ffi_new__`` type 
attr.
+ *
+ * \param type_info The type info to check.
+ * \return true if CreateEmptyObject would succeed.
+ */
+inline bool HasCreator(const TVMFFITypeInfo* type_info) {
+  if (type_info->metadata != nullptr && type_info->metadata->creator != 
nullptr) {
+    return true;
+  }
+  constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
+  const TVMFFITypeAttrColumn* column = 
TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
+  if (column != nullptr) {
+    int32_t offset = type_info->type_index - column->begin_index;
+    if (offset >= 0 && offset < column->size &&
+        column->data[offset].type_index == TypeIndex::kTVMFFIFunction) {
+      return true;
+    }
+  }
+  return false;
+}
+
 namespace reflection {
 /*!
  * \brief helper wrapper class of TVMFFITypeInfo to create object based on 
reflection.
@@ -48,13 +107,8 @@ class ObjectCreator {
    * \param type_info The type info.
    */
   explicit ObjectCreator(const TVMFFITypeInfo* type_info) : 
type_info_(type_info) {
-    int32_t type_index = type_info->type_index;
-    if (type_info->metadata == nullptr) {
-      TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
-                                  << "` does not have reflection registered";
-    }
-    if (type_info->metadata->creator == nullptr) {
-      TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
+    if (!HasCreator(type_info)) {
+      TVM_FFI_THROW(RuntimeError) << "Type `" << 
TypeIndexToTypeKey(type_info->type_index)
                                   << "` does not support default constructor, "
                                   << "as a result cannot be created via 
reflection";
     }
@@ -66,10 +120,7 @@ class ObjectCreator {
    * \return The created object.
    */
   Any operator()(const Map<String, Any>& fields) const {
-    TVMFFIObjectHandle handle;
-    TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle));
-    ObjectPtr<Object> ptr =
-        
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+    ObjectPtr<Object> ptr = CreateEmptyObject(type_info_);
     size_t match_field_count = 0;
     ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) {
       String field_name(field_info->name);
diff --git a/include/tvm/ffi/reflection/init.h 
b/include/tvm/ffi/reflection/init.h
index 337753ce..8853f054 100644
--- a/include/tvm/ffi/reflection/init.h
+++ b/include/tvm/ffi/reflection/init.h
@@ -29,6 +29,7 @@
 #include <tvm/ffi/function_details.h>
 #include <tvm/ffi/object.h>
 #include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/reflection/creator.h>
 #include <tvm/ffi/string.h>
 
 #include <algorithm>
@@ -69,10 +70,8 @@ inline Function MakeInit(int32_t type_index) {
   };
   // ---- Pre-compute field analysis (once per type) -------------------------
   const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
-  TVM_FFI_ICHECK(type_info->metadata != nullptr)
-      << "Type `" << TypeIndexToTypeKey(type_index) << "` has no reflection 
metadata";
-  TVM_FFI_ICHECK(type_info->metadata->creator != nullptr)
-      << "Type `" << TypeIndexToTypeKey(type_index) << "` has no creator";
+  TVM_FFI_ICHECK(HasCreator(type_info)) << "Type `" << 
TypeIndexToTypeKey(type_index)
+                                        << "` has no creator or __ffi_new__ 
for __ffi_init__";
 
   auto info = std::make_shared<AutoInitInfo>();
   info->type_key = std::string_view(type_info->type_key.data, 
type_info->type_key.size);
@@ -101,16 +100,11 @@ inline Function MakeInit(int32_t type_index) {
   // Eagerly resolve the KWARGS sentinel via global function registry.
   ObjectRef kwargs_sentinel =
       Function::GetGlobalRequired("ffi.GetKwargsObject")().cast<ObjectRef>();
-  // Cache pointers for the lambda (avoid repeated lookups).
-  TVMFFIObjectCreator creator = type_info->metadata->creator;
 
   return Function::FromPacked(
-      [info, kwargs_sentinel, creator](PackedArgs args, Any* rv) {
-        // ---- 1. Create object via creator 
------------------------------------
-        TVMFFIObjectHandle handle;
-        TVM_FFI_CHECK_SAFE_CALL(creator(&handle));
-        ObjectPtr<Object> obj_ptr =
-            
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+      [info, kwargs_sentinel, type_info](PackedArgs args, Any* rv) {
+        // ---- 1. Create object via CreateEmptyObject 
--------------------------
+        ObjectPtr<Object> obj_ptr = CreateEmptyObject(type_info);
 
         // ---- 2. Find KWARGS sentinel position 
--------------------------------
         int kwargs_pos = -1;
@@ -219,7 +213,7 @@ inline void RegisterAutoInit(int32_t type_index) {
   info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod;
   info.method = AnyView(auto_init_fn).CopyToTVMFFIAny();
   static const std::string kMetadata =
-      "{\"type_schema\":" + 
std::string(details::TypeSchemaImpl<Function>::v()) +
+      "{\"type_schema\":" + 
std::string(::tvm::ffi::details::TypeSchemaImpl<Function>::v()) +
       ",\"auto_init\":true}";
   info.metadata = TVMFFIByteArray{kMetadata.c_str(), kMetadata.size()};
   TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index, &info));
diff --git a/include/tvm/ffi/reflection/overload.h 
b/include/tvm/ffi/reflection/overload.h
index 8647c3c6..80e9bce9 100644
--- a/include/tvm/ffi/reflection/overload.h
+++ b/include/tvm/ffi/reflection/overload.h
@@ -404,14 +404,15 @@ class OverloadObjectDef : private ObjectDef<Class> {
   template <typename Func>
   static auto GetOverloadMethod(std::string name, Func&& func) {
     using WrapFn = decltype(WrapFunction(std::forward<Func>(func)));
-    using OverloadFn = details::OverloadedFunction<std::decay_t<WrapFn>>;
+    using OverloadFn = 
::tvm::ffi::details::OverloadedFunction<std::decay_t<WrapFn>>;
     return 
ffi::Function::FromPackedInplace<OverloadFn>(WrapFunction(std::forward<Func>(func)),
                                                         std::move(name));
   }
 
   template <typename Func>
   static auto NewOverload(std::string name, Func&& func) {
-    return details::CreateNewOverload(WrapFunction(std::forward<Func>(func)), 
std::move(name));
+    return 
::tvm::ffi::details::CreateNewOverload(WrapFunction(std::forward<Func>(func)),
+                                                  std::move(name));
   }
 
   template <typename... ExtraArgs>
@@ -452,7 +453,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
     // initialize default value to nullptr
     info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
     info.doc = TVMFFIByteArray{nullptr, 0};
-    info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
+    info.metadata_.emplace_back("type_schema", 
::tvm::ffi::details::TypeSchema<T>::v());
     // apply field info traits
     ((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
     // call register
@@ -464,7 +465,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
   // register a method
   template <typename Func, typename... Extra>
   void RegisterMethod(const char* name, bool is_static, Func&& func, 
Extra&&... extra) {
-    using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+    using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
     MethodInfoBuilder info;
     info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
     info.doc = TVMFFIByteArray{nullptr, 0};
@@ -478,7 +479,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
     // if an overload method exists, register to existing overload function
     if (const auto overload_it = registered_fields_.find(name);
         overload_it != registered_fields_.end()) {
-      details::OverloadBase* overload_ptr = overload_it->second;
+      ::tvm::ffi::details::OverloadBase* overload_ptr = overload_it->second;
       return overload_ptr->Register(NewOverload(std::move(method_name), 
std::forward<Func>(func)));
     }
 
@@ -496,7 +497,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
     TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
   }
 
-  std::unordered_map<std::string, details::OverloadBase*> registered_fields_;
+  std::unordered_map<std::string, ::tvm::ffi::details::OverloadBase*> 
registered_fields_;
 };
 
 }  // namespace reflection
diff --git a/include/tvm/ffi/reflection/registry.h 
b/include/tvm/ffi/reflection/registry.h
index 719d6121..17c0078b 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -362,7 +362,8 @@ template <typename Class, typename T>
 TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) {
   int64_t field_offset_to_class =
       reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
-  return field_offset_to_class - 
details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
+  return field_offset_to_class -
+         ::tvm::ffi::details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
 }
 
 /// \cond Doxygen_Suppress
@@ -371,7 +372,7 @@ class ReflectionDefBase {
   template <typename T>
   static int FieldGetter(void* field, TVMFFIAny* result) {
     TVM_FFI_SAFE_CALL_BEGIN();
-    *result = 
details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
+    *result = 
::tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
     TVM_FFI_SAFE_CALL_END();
   }
 
@@ -390,7 +391,7 @@ class ReflectionDefBase {
   static int ObjectCreatorDefault(TVMFFIObjectHandle* result) {
     TVM_FFI_SAFE_CALL_BEGIN();
     ObjectPtr<T> obj = make_object<T>();
-    *result = 
details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
+    *result = 
::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
     TVM_FFI_SAFE_CALL_END();
   }
 
@@ -398,7 +399,7 @@ class ReflectionDefBase {
   static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) {
     TVM_FFI_SAFE_CALL_BEGIN();
     ObjectPtr<T> obj = make_object<T>(UnsafeInit{});
-    *result = 
details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
+    *result = 
::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
     TVM_FFI_SAFE_CALL_END();
   }
 
@@ -499,7 +500,7 @@ class GlobalDef : public ReflectionDefBase {
    */
   template <typename Func, typename... Extra>
   GlobalDef& def(const char* name, Func&& func, Extra&&... extra) {
-    using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+    using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
     RegisterFunc(name, ffi::Function::FromTyped(std::forward<Func>(func), 
std::string(name)),
                  FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
     return *this;
@@ -519,8 +520,8 @@ class GlobalDef : public ReflectionDefBase {
    */
   template <typename Func, typename... Extra>
   GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) {
-    RegisterFunc(name, ffi::Function::FromPacked(func), 
details::TypeSchemaImpl<Function>::v(),
-                 std::forward<Extra>(extra)...);
+    RegisterFunc(name, ffi::Function::FromPacked(func),
+                 ::tvm::ffi::details::TypeSchemaImpl<Function>::v(), 
std::forward<Extra>(extra)...);
     return *this;
   }
 
@@ -540,7 +541,7 @@ class GlobalDef : public ReflectionDefBase {
    */
   template <typename Func, typename... Extra>
   GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
-    using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+    using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
     RegisterFunc(name, GetMethod(std::string(name), std::forward<Func>(func)),
                  FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
     return *this;
@@ -915,7 +916,7 @@ class ObjectDef : public ReflectionDefBase {
     // initialize default value to nullptr
     info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
     info.doc = TVMFFIByteArray{nullptr, 0};
-    info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
+    info.metadata_.emplace_back("type_schema", 
::tvm::ffi::details::TypeSchema<T>::v());
     // apply field info traits
     ((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
     // call register
@@ -927,7 +928,7 @@ class ObjectDef : public ReflectionDefBase {
   // register a method
   template <typename Func, typename... Extra>
   void RegisterMethod(const char* name, bool is_static, Func&& func, 
Extra&&... extra) {
-    using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+    using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
     MethodInfoBuilder info;
     info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
     info.doc = TVMFFIByteArray{nullptr, 0};
diff --git a/src/ffi/extra/reflection_extra.cc 
b/src/ffi/extra/reflection_extra.cc
index b5ced5c2..44e8ac3c 100644
--- a/src/ffi/extra/reflection_extra.cc
+++ b/src/ffi/extra/reflection_extra.cc
@@ -42,15 +42,7 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* 
ret) {
 
   TVM_FFI_ICHECK(args.size() % 2 == 1);
   const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
-
-  if (type_info->metadata == nullptr || type_info->metadata->creator == 
nullptr) {
-    TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
-                                << "` does not support reflection creation";
-  }
-  TVMFFIObjectHandle handle;
-  TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
-  ObjectPtr<Object> ptr =
-      
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+  ObjectPtr<Object> ptr = CreateEmptyObject(type_info);
 
   std::vector<String> keys;
   std::vector<bool> keys_found;
diff --git a/src/ffi/extra/serialization.cc b/src/ffi/extra/serialization.cc
index 80b96ec7..c1fb6211 100644
--- a/src/ffi/extra/serialization.cc
+++ b/src/ffi/extra/serialization.cc
@@ -371,15 +371,7 @@ class ObjectGraphDeserializer {
     }
     // otherwise, we go over the fields and create the data.
     const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
-    if (type_info->metadata == nullptr || type_info->metadata->creator == 
nullptr) {
-      TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
-                                  << "` does not support default constructor"
-                                  << ", so ToJSONGraph is not supported for 
this type";
-    }
-    TVMFFIObjectHandle handle;
-    TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
-    ObjectPtr<Object> ptr =
-        
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+    ObjectPtr<Object> ptr = CreateEmptyObject(type_info);
 
     auto decode_field_value = [&](const TVMFFIFieldInfo* field_info,
                                   const json::Value& data) -> Any {

Reply via email to