This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 5e564cd feat: add `DefaultFactory` support to field reflection (#446)
5e564cd is described below
commit 5e564cdfb932af63915fbeb5a5aa30671f55ae2c
Author: Junru Shao <[email protected]>
AuthorDate: Sat Feb 14 18:14:35 2026 -0800
feat: add `DefaultFactory` support to field reflection (#446)
`DefaultValue` stores a single static default shared across all
instances created via reflection. For mutable defaults (Array, Map,
etc.) this causes aliasing: every object receives the same underlying
container. `DefaultFactory` fixes this by storing a callable `() -> Any`
that is invoked each time a default is needed, producing a fresh value
per instance—mirroring Python dataclass `default_factory`.
Concrete changes:
- Rename `TVMFFIFieldInfo::default_value` → `default_value_or_factory`
to reflect that the slot now holds either a value or a factory.
- Add `kTVMFFIFieldFlagBitMaskDefaultFromFactory` (1 << 5) to
`TVMFFIFieldFlagBitMask`.
- Add `reflection::DefaultFactory` trait (registry.h), symmetric to
`DefaultValue`.
- Add `reflection::SetFieldToDefault` helper (accessor.h) that resolves
the default—calling the factory when the flag is set—so the three
consumption sites (creator.h, reflection_extra.cc, serialization.cc)
share one implementation.
- Propagate the rename through Rust (`c_api.rs`) and Cython (`base.pxi`)
bindings.
- Add `TestObjWithFactory` + three tests exercising flag inspection,
per-instance freshness, and explicit-value bypass.
---
include/tvm/ffi/c_api.h | 20 ++++++++--
include/tvm/ffi/reflection/accessor.h | 22 +++++++++++
include/tvm/ffi/reflection/creator.h | 2 +-
include/tvm/ffi/reflection/overload.h | 2 +-
include/tvm/ffi/reflection/registry.h | 33 +++++++++++++++-
python/tvm_ffi/cython/base.pxi | 3 +-
rust/tvm-ffi-sys/src/c_api.rs | 8 ++--
src/ffi/extra/reflection_extra.cc | 3 +-
src/ffi/extra/serialization.cc | 2 +-
src/ffi/object.cc | 7 ++--
tests/cpp/test_reflection.cc | 72 ++++++++++++++++++++++++++++++++++-
11 files changed, 156 insertions(+), 18 deletions(-)
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 5089b6e..67715ca 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -861,6 +861,15 @@ typedef enum {
* This is an optional meta-data for structural eq/hash.
*/
kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4,
+ /*!
+ * \brief The default_value_or_factory is a callable factory function () ->
Any.
+ *
+ * When this flag is set along with kTVMFFIFieldFlagBitMaskHasDefault,
+ * the default_value_or_factory field contains a Function that should be
+ * called with no arguments to produce the default value, rather than
+ * being used directly as the default value.
+ */
+ kTVMFFIFieldFlagBitMaskDefaultFromFactory = 1 << 5,
#ifdef __cplusplus
};
#else
@@ -960,10 +969,15 @@ typedef struct {
*/
TVMFFIFieldSetter setter;
/*!
- * \brief The default value of the field, this field hold AnyView,
- * valid when flags set kTVMFFIFieldFlagBitMaskHasDefault
+ * \brief The default value or default factory of the field.
+ *
+ * When flags has kTVMFFIFieldFlagBitMaskHasDefault set:
+ * - If kTVMFFIFieldFlagBitMaskDefaultFromFactory is NOT set,
+ * this holds the static default value as AnyView.
+ * - If kTVMFFIFieldFlagBitMaskDefaultFromFactory IS set,
+ * this holds a Function (() -> Any) that produces the default.
*/
- TVMFFIAny default_value;
+ TVMFFIAny default_value_or_factory;
/*!
* \brief Records the static type kind of the field.
*
diff --git a/include/tvm/ffi/reflection/accessor.h
b/include/tvm/ffi/reflection/accessor.h
index b49da51..68c8b0d 100644
--- a/include/tvm/ffi/reflection/accessor.h
+++ b/include/tvm/ffi/reflection/accessor.h
@@ -199,6 +199,28 @@ inline Function GetMethod(std::string_view type_key, const
char* method_name) {
return AnyView::CopyFromTVMFFIAny(info->method).cast<Function>();
}
+/*!
+ * \brief Set a field to its default value, calling the factory if applicable.
+ *
+ * When kTVMFFIFieldFlagBitMaskDefaultFromFactory is set, extracts the
+ * Function from default_value_or_factory, calls it with no arguments,
+ * and uses the result. Otherwise, passes default_value_or_factory directly
+ * to the setter.
+ *
+ * \param field_info The field info (must have
kTVMFFIFieldFlagBitMaskHasDefault set).
+ * \param field_addr The address of the field in the object.
+ */
+inline void SetFieldToDefault(const TVMFFIFieldInfo* field_info, void*
field_addr) {
+ if (field_info->flags & kTVMFFIFieldFlagBitMaskDefaultFromFactory) {
+ Function factory =
+
AnyView::CopyFromTVMFFIAny(field_info->default_value_or_factory).cast<Function>();
+ Any default_val = factory();
+ field_info->setter(field_addr, reinterpret_cast<const
TVMFFIAny*>(&default_val));
+ } else {
+ field_info->setter(field_addr, &(field_info->default_value_or_factory));
+ }
+}
+
/*!
* \brief Visit each field info of the type info and run callback.
*
diff --git a/include/tvm/ffi/reflection/creator.h
b/include/tvm/ffi/reflection/creator.h
index 774eb8b..a7e860c 100644
--- a/include/tvm/ffi/reflection/creator.h
+++ b/include/tvm/ffi/reflection/creator.h
@@ -79,7 +79,7 @@ class ObjectCreator {
field_info->setter(field_addr, reinterpret_cast<const
TVMFFIAny*>(&field_value));
++match_field_count;
} else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
- field_info->setter(field_addr, &(field_info->default_value));
+ SetFieldToDefault(field_info, field_addr);
} else {
TVM_FFI_THROW(TypeError) << "Required field `"
<< String(field_info->name.data,
field_info->name.size)
diff --git a/include/tvm/ffi/reflection/overload.h
b/include/tvm/ffi/reflection/overload.h
index a85174c..81ced69 100644
--- a/include/tvm/ffi/reflection/overload.h
+++ b/include/tvm/ffi/reflection/overload.h
@@ -449,7 +449,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
info.getter = ReflectionDefBase::FieldGetter<T>;
info.setter = ReflectionDefBase::FieldSetter<T>;
// initialize default value to nullptr
- info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+ info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
info.doc = TVMFFIByteArray{nullptr, 0};
info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
// apply field info traits
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index cc4ec50..4cd4663 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -152,7 +152,7 @@ class DefaultValue : public InfoTrait {
* \param info The field info.
*/
TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
- info->default_value = AnyView(value_).CopyToTVMFFIAny();
+ info->default_value_or_factory = AnyView(value_).CopyToTVMFFIAny();
info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
}
@@ -160,6 +160,35 @@ class DefaultValue : public InfoTrait {
Any value_;
};
+/*!
+ * \brief Trait that can be used to set field default factory.
+ *
+ * A default factory is a callable () -> Any that is invoked each time
+ * a default value is needed, producing a fresh value. This is important
+ * for mutable defaults (e.g., Array, Map) to avoid aliasing.
+ */
+class DefaultFactory : public InfoTrait {
+ public:
+ /*!
+ * \brief Constructor
+ * \param factory The factory function to be called to produce default
values.
+ */
+ explicit DefaultFactory(Function factory) : factory_(std::move(factory)) {}
+
+ /*!
+ * \brief Apply the default factory to the field info
+ * \param info The field info.
+ */
+ TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
+ info->default_value_or_factory = AnyView(factory_).CopyToTVMFFIAny();
+ info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
+ info->flags |= kTVMFFIFieldFlagBitMaskDefaultFromFactory;
+ }
+
+ private:
+ Function factory_;
+};
+
/*!
* \brief Trait that can be used to attach field flag
*/
@@ -653,7 +682,7 @@ class ObjectDef : public ReflectionDefBase {
info.getter = FieldGetter<T>;
info.setter = FieldSetter<T>;
// initialize default value to nullptr
- info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+ info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
info.doc = TVMFFIByteArray{nullptr, 0};
info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
// apply field info traits
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index a07e850..4512936 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -203,6 +203,7 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIFieldFlagBitMaskWritable = 1 << 0
kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1
kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2
+ kTVMFFIFieldFlagBitMaskDefaultFromFactory = 1 << 5
ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept
ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value)
noexcept
@@ -218,7 +219,7 @@ cdef extern from "tvm/ffi/c_api.h":
int64_t offset
TVMFFIFieldGetter getter
TVMFFIFieldSetter setter
- TVMFFIAny default_value
+ TVMFFIAny default_value_or_factory
int32_t field_static_type_index
ctypedef struct TVMFFIMethodInfo:
diff --git a/rust/tvm-ffi-sys/src/c_api.rs b/rust/tvm-ffi-sys/src/c_api.rs
index e0bf085..d94967b 100644
--- a/rust/tvm-ffi-sys/src/c_api.rs
+++ b/rust/tvm-ffi-sys/src/c_api.rs
@@ -286,9 +286,11 @@ pub struct TVMFFIFieldInfo {
/// The setter to access the field
/// The setter is set even if the field is readonly for serialization
pub setter: Option<TVMFFIFieldSetter>,
- /// The default value of the field, this field hold AnyView,
- /// valid when flags set kTVMFFIFieldFlagBitMaskHasDefault
- pub default_value: TVMFFIAny,
+ /// The default value or factory of the field, this field holds AnyView.
+ /// Valid when flags set kTVMFFIFieldFlagBitMaskHasDefault.
+ /// When kTVMFFIFieldFlagBitMaskDefaultFromFactory is also set,
+ /// this is a callable factory function () -> Any.
+ pub default_value_or_factory: TVMFFIAny,
/// Records the static type kind of the field
pub field_static_type_index: i32,
}
diff --git a/src/ffi/extra/reflection_extra.cc
b/src/ffi/extra/reflection_extra.cc
index 02d422d..6699416 100644
--- a/src/ffi/extra/reflection_extra.cc
+++ b/src/ffi/extra/reflection_extra.cc
@@ -22,6 +22,7 @@
* \brief Extra reflection registrations. *
*/
#include <tvm/ffi/reflection/access_path.h>
+#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
namespace tvm {
@@ -78,7 +79,7 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret)
{
field_info->setter(field_addr, reinterpret_cast<const
TVMFFIAny*>(&field_value));
keys_found[arg_index] = true;
} else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
- field_info->setter(field_addr, &(field_info->default_value));
+ reflection::SetFieldToDefault(field_info, field_addr);
} else {
TVM_FFI_THROW(TypeError) << "Required field `"
<< String(field_info->name.data,
field_info->name.size)
diff --git a/src/ffi/extra/serialization.cc b/src/ffi/extra/serialization.cc
index 1e5e98c..f639d87 100644
--- a/src/ffi/extra/serialization.cc
+++ b/src/ffi/extra/serialization.cc
@@ -396,7 +396,7 @@ class ObjectGraphDeserializer {
Any field_value = decode_field_value(field_info,
data_object[field_name]);
field_info->setter(field_addr, reinterpret_cast<const
TVMFFIAny*>(&field_value));
} else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
- field_info->setter(field_addr, &(field_info->default_value));
+ reflection::SetFieldToDefault(field_info, field_addr);
} else {
TVM_FFI_THROW(TypeError) << "Required field `"
<< String(field_info->name.data,
field_info->name.size)
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index 8c5c137..7e5e00e 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -211,10 +211,11 @@ class TypeTable {
field_data.doc = this->CopyString(info->doc);
field_data.metadata = this->CopyString(info->metadata);
if (info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
- field_data.default_value =
-
this->CopyAny(AnyView::CopyFromTVMFFIAny(info->default_value)).CopyToTVMFFIAny();
+ field_data.default_value_or_factory =
+
this->CopyAny(AnyView::CopyFromTVMFFIAny(info->default_value_or_factory))
+ .CopyToTVMFFIAny();
} else {
- field_data.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+ field_data.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
}
entry->type_fields_data.push_back(field_data);
// refresh ptr as the data can change
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index 82d73a2..e77368c 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -18,6 +18,7 @@
* under the License.
*/
#include <gtest/gtest.h>
+#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/access_path.h>
@@ -101,13 +102,14 @@ TEST(Reflection, FieldInfo) {
EXPECT_EQ(Bytes(info_int->doc).operator std::string(), "");
const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float",
"value");
- EXPECT_EQ(info_float->default_value.v_float64, 10.0);
+ EXPECT_EQ(info_float->default_value_or_factory.v_float64, 10.0);
EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskHasDefault);
EXPECT_FALSE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable);
EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value
field");
const TVMFFIFieldInfo* info_prim_expr_dtype =
reflection::GetFieldInfo("test.PrimExpr", "dtype");
- AnyView default_value =
AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value);
+ AnyView default_value =
+
AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value_or_factory);
EXPECT_EQ(default_value.cast<String>(), "float");
EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskHasDefault);
EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable);
@@ -295,6 +297,25 @@ TEST(Reflection, AccessPath) {
EXPECT_FALSE(root_parent.has_value());
}
+struct TestObjWithFactory : public Object {
+ Array<ObjectRef> items;
+ int64_t count;
+
+ explicit TestObjWithFactory(UnsafeInit) {}
+
+ [[maybe_unused]] static constexpr bool _type_mutable = true;
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.TestObjWithFactory",
TestObjWithFactory, Object);
+};
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<TestObjWithFactory>()
+ .def_ro("items", &TestObjWithFactory::items,
+ refl::DefaultFactory(
+ Function::FromTyped([]() -> Array<ObjectRef> { return
Array<ObjectRef>(); })))
+ .def_ro("count", &TestObjWithFactory::count,
refl::DefaultValue(static_cast<int64_t>(0)));
+}
+
struct TestObjWithAny : public Object {
Any value;
explicit TestObjWithAny(Any value) : value(std::move(value)) {}
@@ -350,4 +371,51 @@ TEST(Reflection, InitWithAnyView) {
ASSERT_TRUE(obj3.as<TestObjWithAnyView>() != nullptr);
EXPECT_EQ(obj3.as<TestObjWithAnyView>()->value.cast<String>(), "hello");
}
+TEST(Reflection, DefaultFactoryFlag) {
+ const TVMFFIFieldInfo* info_items =
reflection::GetFieldInfo("test.TestObjWithFactory", "items");
+ EXPECT_TRUE(info_items->flags & kTVMFFIFieldFlagBitMaskHasDefault);
+ EXPECT_TRUE(info_items->flags & kTVMFFIFieldFlagBitMaskDefaultFromFactory);
+
+ const TVMFFIFieldInfo* info_count =
reflection::GetFieldInfo("test.TestObjWithFactory", "count");
+ EXPECT_TRUE(info_count->flags & kTVMFFIFieldFlagBitMaskHasDefault);
+ EXPECT_FALSE(info_count->flags & kTVMFFIFieldFlagBitMaskDefaultFromFactory);
+}
+
+TEST(Reflection, DefaultFactoryCreation) {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectCreator creator("test.TestObjWithFactory");
+
+ // Create two objects without providing "items" - each should get a fresh
Array
+ Any obj1 = creator(Map<String, Any>({{"count", static_cast<int64_t>(42)}}));
+ Any obj2 = creator(Map<String, Any>({{"count", static_cast<int64_t>(99)}}));
+
+ auto* p1 = obj1.as<TestObjWithFactory>();
+ auto* p2 = obj2.as<TestObjWithFactory>();
+
+ ASSERT_NE(p1, nullptr);
+ ASSERT_NE(p2, nullptr);
+ EXPECT_EQ(p1->count, 42);
+ EXPECT_EQ(p2->count, 99);
+ // Both should have empty arrays
+ EXPECT_EQ(p1->items.size(), 0);
+ EXPECT_EQ(p2->items.size(), 0);
+ // Crucially, the arrays should be distinct objects (not aliased)
+ EXPECT_NE(p1->items.get(), p2->items.get());
+}
+
+TEST(Reflection, DefaultFactoryNotCalledWhenProvided) {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectCreator creator("test.TestObjWithFactory");
+
+ Array<ObjectRef> custom_items;
+ custom_items.push_back(TInt(1));
+ Any obj =
+ creator(Map<String, Any>({{"items", custom_items}, {"count",
static_cast<int64_t>(5)}}));
+
+ auto* p = obj.as<TestObjWithFactory>();
+ ASSERT_NE(p, nullptr);
+ EXPECT_EQ(p->items.size(), 1);
+ EXPECT_EQ(p->count, 5);
+}
+
} // namespace