This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new e268eb1d feat(ffi)!: centralize object creation with
CreateEmptyObject/HasCreator (#501)
e268eb1d is described below
commit e268eb1d2d55ae97bdaab7ab52f381f436fd0c82
Author: Junru Shao <[email protected]>
AuthorDate: Tue Mar 10 13:48:04 2026 -0700
feat(ffi)!: centralize object creation with CreateEmptyObject/HasCreator
(#501)
## Summary
- Add `CreateEmptyObject()` and `HasCreator()` inline helpers in
`function.h` that unify the two-step "check creator, call creator"
pattern. These try the native `metadata->creator` fast path first, then
fall back to the `__ffi_new__` type attribute for Python-defined types.
- Deduplicate four call sites in `creator.h`, `init.h`,
`reflection_extra.cc`, and `serialization.cc` to use the new centralized
helpers.
- Qualify bare `details::` references to `::tvm::ffi::details::` in
reflection headers (`overload.h`, `registry.h`, `init.h`) to prevent
ADL/lookup ambiguity when included from other namespaces.
## Test plan
- [x] Existing C++ tests pass (object creation via reflection,
serialization roundtrip)
- [x] Existing Python tests pass (c_class / py_class decorator,
reflection-based construction)
- [x] CI lint (clang-format, clang-tidy) passes on changed headers
---
include/tvm/ffi/reflection/creator.h | 75 +++++++++++++++++++++++++++++------
include/tvm/ffi/reflection/init.h | 20 ++++------
include/tvm/ffi/reflection/overload.h | 13 +++---
include/tvm/ffi/reflection/registry.h | 21 +++++-----
src/ffi/extra/reflection_extra.cc | 10 +----
src/ffi/extra/serialization.cc | 10 +----
6 files changed, 90 insertions(+), 59 deletions(-)
diff --git a/include/tvm/ffi/reflection/creator.h
b/include/tvm/ffi/reflection/creator.h
index 300ad512..977e6ded 100644
--- a/include/tvm/ffi/reflection/creator.h
+++ b/include/tvm/ffi/reflection/creator.h
@@ -23,13 +23,72 @@
#ifndef TVM_FFI_REFLECTION_CREATOR_H_
#define TVM_FFI_REFLECTION_CREATOR_H_
-#include <tvm/ffi/any.h>
#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/string.h>
namespace tvm {
namespace ffi {
+/*!
+ * \brief Create an empty object via the type's native creator or
``__ffi_new__`` type attr.
+ *
+ * Falls back to the ``__ffi_new__`` type attribute (used by Python-defined
types)
+ * when the native ``metadata->creator`` is NULL.
+ *
+ * \param type_info The type info for the object to create.
+ * \return An owned ObjectPtr to the newly allocated (zero-initialized) object.
+ * \throws RuntimeError if neither creator nor __ffi_new__ is available.
+ */
+inline ObjectPtr<Object> CreateEmptyObject(const TVMFFITypeInfo* type_info) {
+ // Fast path: native C++ creator
+ if (type_info->metadata != nullptr && type_info->metadata->creator !=
nullptr) {
+ TVMFFIObjectHandle handle;
+ TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
+ return
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+ }
+ // Fallback: __ffi_new__ type attr (Python-defined types)
+ constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
+ const TVMFFITypeAttrColumn* column =
TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
+ if (column != nullptr) {
+ int32_t offset = type_info->type_index - column->begin_index;
+ if (offset >= 0 && offset < column->size) {
+ AnyView attr_view = AnyView::CopyFromTVMFFIAny(column->data[offset]);
+ if (auto opt_func = attr_view.try_cast<Function>()) {
+ ObjectRef obj_ref = (*opt_func)().cast<ObjectRef>();
+ return
details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(obj_ref));
+ }
+ }
+ }
+ TVM_FFI_THROW(RuntimeError) << "Type `" <<
TypeIndexToTypeKey(type_info->type_index)
+ << "` does not support reflection creation"
+ << " (no native creator or __ffi_new__ type
attr)";
+}
+
+/*!
+ * \brief Check whether a type supports reflection creation.
+ *
+ * Returns true if the type has a native creator or a ``__ffi_new__`` type
attr.
+ *
+ * \param type_info The type info to check.
+ * \return true if CreateEmptyObject would succeed.
+ */
+inline bool HasCreator(const TVMFFITypeInfo* type_info) {
+ if (type_info->metadata != nullptr && type_info->metadata->creator !=
nullptr) {
+ return true;
+ }
+ constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
+ const TVMFFITypeAttrColumn* column =
TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
+ if (column != nullptr) {
+ int32_t offset = type_info->type_index - column->begin_index;
+ if (offset >= 0 && offset < column->size &&
+ column->data[offset].type_index == TypeIndex::kTVMFFIFunction) {
+ return true;
+ }
+ }
+ return false;
+}
+
namespace reflection {
/*!
* \brief helper wrapper class of TVMFFITypeInfo to create object based on
reflection.
@@ -48,13 +107,8 @@ class ObjectCreator {
* \param type_info The type info.
*/
explicit ObjectCreator(const TVMFFITypeInfo* type_info) :
type_info_(type_info) {
- int32_t type_index = type_info->type_index;
- if (type_info->metadata == nullptr) {
- TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
- << "` does not have reflection registered";
- }
- if (type_info->metadata->creator == nullptr) {
- TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
+ if (!HasCreator(type_info)) {
+ TVM_FFI_THROW(RuntimeError) << "Type `" <<
TypeIndexToTypeKey(type_info->type_index)
<< "` does not support default constructor, "
<< "as a result cannot be created via
reflection";
}
@@ -66,10 +120,7 @@ class ObjectCreator {
* \return The created object.
*/
Any operator()(const Map<String, Any>& fields) const {
- TVMFFIObjectHandle handle;
- TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle));
- ObjectPtr<Object> ptr =
-
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+ ObjectPtr<Object> ptr = CreateEmptyObject(type_info_);
size_t match_field_count = 0;
ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) {
String field_name(field_info->name);
diff --git a/include/tvm/ffi/reflection/init.h
b/include/tvm/ffi/reflection/init.h
index 337753ce..8853f054 100644
--- a/include/tvm/ffi/reflection/init.h
+++ b/include/tvm/ffi/reflection/init.h
@@ -29,6 +29,7 @@
#include <tvm/ffi/function_details.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/reflection/creator.h>
#include <tvm/ffi/string.h>
#include <algorithm>
@@ -69,10 +70,8 @@ inline Function MakeInit(int32_t type_index) {
};
// ---- Pre-compute field analysis (once per type) -------------------------
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
- TVM_FFI_ICHECK(type_info->metadata != nullptr)
- << "Type `" << TypeIndexToTypeKey(type_index) << "` has no reflection
metadata";
- TVM_FFI_ICHECK(type_info->metadata->creator != nullptr)
- << "Type `" << TypeIndexToTypeKey(type_index) << "` has no creator";
+ TVM_FFI_ICHECK(HasCreator(type_info)) << "Type `" <<
TypeIndexToTypeKey(type_index)
+ << "` has no creator or __ffi_new__
for __ffi_init__";
auto info = std::make_shared<AutoInitInfo>();
info->type_key = std::string_view(type_info->type_key.data,
type_info->type_key.size);
@@ -101,16 +100,11 @@ inline Function MakeInit(int32_t type_index) {
// Eagerly resolve the KWARGS sentinel via global function registry.
ObjectRef kwargs_sentinel =
Function::GetGlobalRequired("ffi.GetKwargsObject")().cast<ObjectRef>();
- // Cache pointers for the lambda (avoid repeated lookups).
- TVMFFIObjectCreator creator = type_info->metadata->creator;
return Function::FromPacked(
- [info, kwargs_sentinel, creator](PackedArgs args, Any* rv) {
- // ---- 1. Create object via creator
------------------------------------
- TVMFFIObjectHandle handle;
- TVM_FFI_CHECK_SAFE_CALL(creator(&handle));
- ObjectPtr<Object> obj_ptr =
-
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+ [info, kwargs_sentinel, type_info](PackedArgs args, Any* rv) {
+ // ---- 1. Create object via CreateEmptyObject
--------------------------
+ ObjectPtr<Object> obj_ptr = CreateEmptyObject(type_info);
// ---- 2. Find KWARGS sentinel position
--------------------------------
int kwargs_pos = -1;
@@ -219,7 +213,7 @@ inline void RegisterAutoInit(int32_t type_index) {
info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod;
info.method = AnyView(auto_init_fn).CopyToTVMFFIAny();
static const std::string kMetadata =
- "{\"type_schema\":" +
std::string(details::TypeSchemaImpl<Function>::v()) +
+ "{\"type_schema\":" +
std::string(::tvm::ffi::details::TypeSchemaImpl<Function>::v()) +
",\"auto_init\":true}";
info.metadata = TVMFFIByteArray{kMetadata.c_str(), kMetadata.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index, &info));
diff --git a/include/tvm/ffi/reflection/overload.h
b/include/tvm/ffi/reflection/overload.h
index 8647c3c6..80e9bce9 100644
--- a/include/tvm/ffi/reflection/overload.h
+++ b/include/tvm/ffi/reflection/overload.h
@@ -404,14 +404,15 @@ class OverloadObjectDef : private ObjectDef<Class> {
template <typename Func>
static auto GetOverloadMethod(std::string name, Func&& func) {
using WrapFn = decltype(WrapFunction(std::forward<Func>(func)));
- using OverloadFn = details::OverloadedFunction<std::decay_t<WrapFn>>;
+ using OverloadFn =
::tvm::ffi::details::OverloadedFunction<std::decay_t<WrapFn>>;
return
ffi::Function::FromPackedInplace<OverloadFn>(WrapFunction(std::forward<Func>(func)),
std::move(name));
}
template <typename Func>
static auto NewOverload(std::string name, Func&& func) {
- return details::CreateNewOverload(WrapFunction(std::forward<Func>(func)),
std::move(name));
+ return
::tvm::ffi::details::CreateNewOverload(WrapFunction(std::forward<Func>(func)),
+ std::move(name));
}
template <typename... ExtraArgs>
@@ -452,7 +453,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
// initialize default value to nullptr
info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
info.doc = TVMFFIByteArray{nullptr, 0};
- info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
+ info.metadata_.emplace_back("type_schema",
::tvm::ffi::details::TypeSchema<T>::v());
// apply field info traits
((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
// call register
@@ -464,7 +465,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
// register a method
template <typename Func, typename... Extra>
void RegisterMethod(const char* name, bool is_static, Func&& func,
Extra&&... extra) {
- using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+ using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
MethodInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
@@ -478,7 +479,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
// if an overload method exists, register to existing overload function
if (const auto overload_it = registered_fields_.find(name);
overload_it != registered_fields_.end()) {
- details::OverloadBase* overload_ptr = overload_it->second;
+ ::tvm::ffi::details::OverloadBase* overload_ptr = overload_it->second;
return overload_ptr->Register(NewOverload(std::move(method_name),
std::forward<Func>(func)));
}
@@ -496,7 +497,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
}
- std::unordered_map<std::string, details::OverloadBase*> registered_fields_;
+ std::unordered_map<std::string, ::tvm::ffi::details::OverloadBase*>
registered_fields_;
};
} // namespace reflection
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index 719d6121..17c0078b 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -362,7 +362,8 @@ template <typename Class, typename T>
TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) {
int64_t field_offset_to_class =
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
- return field_offset_to_class -
details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
+ return field_offset_to_class -
+ ::tvm::ffi::details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
}
/// \cond Doxygen_Suppress
@@ -371,7 +372,7 @@ class ReflectionDefBase {
template <typename T>
static int FieldGetter(void* field, TVMFFIAny* result) {
TVM_FFI_SAFE_CALL_BEGIN();
- *result =
details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
+ *result =
::tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
TVM_FFI_SAFE_CALL_END();
}
@@ -390,7 +391,7 @@ class ReflectionDefBase {
static int ObjectCreatorDefault(TVMFFIObjectHandle* result) {
TVM_FFI_SAFE_CALL_BEGIN();
ObjectPtr<T> obj = make_object<T>();
- *result =
details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
+ *result =
::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
TVM_FFI_SAFE_CALL_END();
}
@@ -398,7 +399,7 @@ class ReflectionDefBase {
static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) {
TVM_FFI_SAFE_CALL_BEGIN();
ObjectPtr<T> obj = make_object<T>(UnsafeInit{});
- *result =
details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
+ *result =
::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
TVM_FFI_SAFE_CALL_END();
}
@@ -499,7 +500,7 @@ class GlobalDef : public ReflectionDefBase {
*/
template <typename Func, typename... Extra>
GlobalDef& def(const char* name, Func&& func, Extra&&... extra) {
- using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+ using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
RegisterFunc(name, ffi::Function::FromTyped(std::forward<Func>(func),
std::string(name)),
FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
return *this;
@@ -519,8 +520,8 @@ class GlobalDef : public ReflectionDefBase {
*/
template <typename Func, typename... Extra>
GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) {
- RegisterFunc(name, ffi::Function::FromPacked(func),
details::TypeSchemaImpl<Function>::v(),
- std::forward<Extra>(extra)...);
+ RegisterFunc(name, ffi::Function::FromPacked(func),
+ ::tvm::ffi::details::TypeSchemaImpl<Function>::v(),
std::forward<Extra>(extra)...);
return *this;
}
@@ -540,7 +541,7 @@ class GlobalDef : public ReflectionDefBase {
*/
template <typename Func, typename... Extra>
GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
- using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+ using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
RegisterFunc(name, GetMethod(std::string(name), std::forward<Func>(func)),
FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
return *this;
@@ -915,7 +916,7 @@ class ObjectDef : public ReflectionDefBase {
// initialize default value to nullptr
info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
info.doc = TVMFFIByteArray{nullptr, 0};
- info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
+ info.metadata_.emplace_back("type_schema",
::tvm::ffi::details::TypeSchema<T>::v());
// apply field info traits
((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
// call register
@@ -927,7 +928,7 @@ class ObjectDef : public ReflectionDefBase {
// register a method
template <typename Func, typename... Extra>
void RegisterMethod(const char* name, bool is_static, Func&& func,
Extra&&... extra) {
- using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+ using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
MethodInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
diff --git a/src/ffi/extra/reflection_extra.cc
b/src/ffi/extra/reflection_extra.cc
index b5ced5c2..44e8ac3c 100644
--- a/src/ffi/extra/reflection_extra.cc
+++ b/src/ffi/extra/reflection_extra.cc
@@ -42,15 +42,7 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any*
ret) {
TVM_FFI_ICHECK(args.size() % 2 == 1);
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
-
- if (type_info->metadata == nullptr || type_info->metadata->creator ==
nullptr) {
- TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
- << "` does not support reflection creation";
- }
- TVMFFIObjectHandle handle;
- TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
- ObjectPtr<Object> ptr =
-
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+ ObjectPtr<Object> ptr = CreateEmptyObject(type_info);
std::vector<String> keys;
std::vector<bool> keys_found;
diff --git a/src/ffi/extra/serialization.cc b/src/ffi/extra/serialization.cc
index 80b96ec7..c1fb6211 100644
--- a/src/ffi/extra/serialization.cc
+++ b/src/ffi/extra/serialization.cc
@@ -371,15 +371,7 @@ class ObjectGraphDeserializer {
}
// otherwise, we go over the fields and create the data.
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
- if (type_info->metadata == nullptr || type_info->metadata->creator ==
nullptr) {
- TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
- << "` does not support default constructor"
- << ", so ToJSONGraph is not supported for
this type";
- }
- TVMFFIObjectHandle handle;
- TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
- ObjectPtr<Object> ptr =
-
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+ ObjectPtr<Object> ptr = CreateEmptyObject(type_info);
auto decode_field_value = [&](const TVMFFIFieldInfo* field_info,
const json::Value& data) -> Any {