This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 5e564cd  feat: add `DefaultFactory` support to field reflection (#446)
5e564cd is described below

commit 5e564cdfb932af63915fbeb5a5aa30671f55ae2c
Author: Junru Shao <[email protected]>
AuthorDate: Sat Feb 14 18:14:35 2026 -0800

    feat: add `DefaultFactory` support to field reflection (#446)
    
    `DefaultValue` stores a single static default shared across all
    instances created via reflection. For mutable defaults (Array, Map,
    etc.) this causes aliasing: every object receives the same underlying
    container. `DefaultFactory` fixes this by storing a callable `() -> Any`
    that is invoked each time a default is needed, producing a fresh value
    per instance—mirroring Python dataclass `default_factory`.
    
    Concrete changes:
    
    - Rename `TVMFFIFieldInfo::default_value` → `default_value_or_factory`
    to reflect that the slot now holds either a value or a factory.
    - Add `kTVMFFIFieldFlagBitMaskDefaultFromFactory` (1 << 5) to
    `TVMFFIFieldFlagBitMask`.
    - Add `reflection::DefaultFactory` trait (registry.h), symmetric to
    `DefaultValue`.
    - Add `reflection::SetFieldToDefault` helper (accessor.h) that resolves
    the default—calling the factory when the flag is set—so the three
    consumption sites (creator.h, reflection_extra.cc, serialization.cc)
    share one implementation.
    - Propagate the rename through Rust (`c_api.rs`) and Cython (`base.pxi`)
    bindings.
    - Add `TestObjWithFactory` + three tests exercising flag inspection,
    per-instance freshness, and explicit-value bypass.
---
 include/tvm/ffi/c_api.h               | 20 ++++++++--
 include/tvm/ffi/reflection/accessor.h | 22 +++++++++++
 include/tvm/ffi/reflection/creator.h  |  2 +-
 include/tvm/ffi/reflection/overload.h |  2 +-
 include/tvm/ffi/reflection/registry.h | 33 +++++++++++++++-
 python/tvm_ffi/cython/base.pxi        |  3 +-
 rust/tvm-ffi-sys/src/c_api.rs         |  8 ++--
 src/ffi/extra/reflection_extra.cc     |  3 +-
 src/ffi/extra/serialization.cc        |  2 +-
 src/ffi/object.cc                     |  7 ++--
 tests/cpp/test_reflection.cc          | 72 ++++++++++++++++++++++++++++++++++-
 11 files changed, 156 insertions(+), 18 deletions(-)

diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 5089b6e..67715ca 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -861,6 +861,15 @@ typedef enum {
    * This is an optional meta-data for structural eq/hash.
    */
   kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4,
+  /*!
+   * \brief The default_value_or_factory is a callable factory function () -> 
Any.
+   *
+   * When this flag is set along with kTVMFFIFieldFlagBitMaskHasDefault,
+   * the default_value_or_factory field contains a Function that should be
+   * called with no arguments to produce the default value, rather than
+   * being used directly as the default value.
+   */
+  kTVMFFIFieldFlagBitMaskDefaultFromFactory = 1 << 5,
 #ifdef __cplusplus
 };
 #else
@@ -960,10 +969,15 @@ typedef struct {
    */
   TVMFFIFieldSetter setter;
   /*!
-   * \brief The default value of the field, this field hold AnyView,
-   *        valid when flags set kTVMFFIFieldFlagBitMaskHasDefault
+   * \brief The default value or default factory of the field.
+   *
+   * When flags has kTVMFFIFieldFlagBitMaskHasDefault set:
+   * - If kTVMFFIFieldFlagBitMaskDefaultFromFactory is NOT set,
+   *   this holds the static default value as AnyView.
+   * - If kTVMFFIFieldFlagBitMaskDefaultFromFactory IS set,
+   *   this holds a Function (() -> Any) that produces the default.
    */
-  TVMFFIAny default_value;
+  TVMFFIAny default_value_or_factory;
   /*!
    * \brief Records the static type kind of the field.
    *
diff --git a/include/tvm/ffi/reflection/accessor.h 
b/include/tvm/ffi/reflection/accessor.h
index b49da51..68c8b0d 100644
--- a/include/tvm/ffi/reflection/accessor.h
+++ b/include/tvm/ffi/reflection/accessor.h
@@ -199,6 +199,28 @@ inline Function GetMethod(std::string_view type_key, const 
char* method_name) {
   return AnyView::CopyFromTVMFFIAny(info->method).cast<Function>();
 }
 
+/*!
+ * \brief Set a field to its default value, calling the factory if applicable.
+ *
+ * When kTVMFFIFieldFlagBitMaskDefaultFromFactory is set, extracts the
+ * Function from default_value_or_factory, calls it with no arguments,
+ * and uses the result. Otherwise, passes default_value_or_factory directly
+ * to the setter.
+ *
+ * \param field_info The field info (must have 
kTVMFFIFieldFlagBitMaskHasDefault set).
+ * \param field_addr The address of the field in the object.
+ */
+inline void SetFieldToDefault(const TVMFFIFieldInfo* field_info, void* 
field_addr) {
+  if (field_info->flags & kTVMFFIFieldFlagBitMaskDefaultFromFactory) {
+    Function factory =
+        
AnyView::CopyFromTVMFFIAny(field_info->default_value_or_factory).cast<Function>();
+    Any default_val = factory();
+    field_info->setter(field_addr, reinterpret_cast<const 
TVMFFIAny*>(&default_val));
+  } else {
+    field_info->setter(field_addr, &(field_info->default_value_or_factory));
+  }
+}
+
 /*!
  * \brief Visit each field info of the type info and run callback.
  *
diff --git a/include/tvm/ffi/reflection/creator.h 
b/include/tvm/ffi/reflection/creator.h
index 774eb8b..a7e860c 100644
--- a/include/tvm/ffi/reflection/creator.h
+++ b/include/tvm/ffi/reflection/creator.h
@@ -79,7 +79,7 @@ class ObjectCreator {
         field_info->setter(field_addr, reinterpret_cast<const 
TVMFFIAny*>(&field_value));
         ++match_field_count;
       } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
-        field_info->setter(field_addr, &(field_info->default_value));
+        SetFieldToDefault(field_info, field_addr);
       } else {
         TVM_FFI_THROW(TypeError) << "Required field `"
                                  << String(field_info->name.data, 
field_info->name.size)
diff --git a/include/tvm/ffi/reflection/overload.h 
b/include/tvm/ffi/reflection/overload.h
index a85174c..81ced69 100644
--- a/include/tvm/ffi/reflection/overload.h
+++ b/include/tvm/ffi/reflection/overload.h
@@ -449,7 +449,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
     info.getter = ReflectionDefBase::FieldGetter<T>;
     info.setter = ReflectionDefBase::FieldSetter<T>;
     // initialize default value to nullptr
-    info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+    info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
     info.doc = TVMFFIByteArray{nullptr, 0};
     info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
     // apply field info traits
diff --git a/include/tvm/ffi/reflection/registry.h 
b/include/tvm/ffi/reflection/registry.h
index cc4ec50..4cd4663 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -152,7 +152,7 @@ class DefaultValue : public InfoTrait {
    * \param info The field info.
    */
   TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
-    info->default_value = AnyView(value_).CopyToTVMFFIAny();
+    info->default_value_or_factory = AnyView(value_).CopyToTVMFFIAny();
     info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
   }
 
@@ -160,6 +160,35 @@ class DefaultValue : public InfoTrait {
   Any value_;
 };
 
+/*!
+ * \brief Trait that can be used to set field default factory.
+ *
+ * A default factory is a callable () -> Any that is invoked each time
+ * a default value is needed, producing a fresh value. This is important
+ * for mutable defaults (e.g., Array, Map) to avoid aliasing.
+ */
+class DefaultFactory : public InfoTrait {
+ public:
+  /*!
+   * \brief Constructor
+   * \param factory The factory function to be called to produce default 
values.
+   */
+  explicit DefaultFactory(Function factory) : factory_(std::move(factory)) {}
+
+  /*!
+   * \brief Apply the default factory to the field info
+   * \param info The field info.
+   */
+  TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
+    info->default_value_or_factory = AnyView(factory_).CopyToTVMFFIAny();
+    info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
+    info->flags |= kTVMFFIFieldFlagBitMaskDefaultFromFactory;
+  }
+
+ private:
+  Function factory_;
+};
+
 /*!
  * \brief Trait that can be used to attach field flag
  */
@@ -653,7 +682,7 @@ class ObjectDef : public ReflectionDefBase {
     info.getter = FieldGetter<T>;
     info.setter = FieldSetter<T>;
     // initialize default value to nullptr
-    info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+    info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
     info.doc = TVMFFIByteArray{nullptr, 0};
     info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
     // apply field info traits
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index a07e850..4512936 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -203,6 +203,7 @@ cdef extern from "tvm/ffi/c_api.h":
         kTVMFFIFieldFlagBitMaskWritable = 1 << 0
         kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1
         kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2
+        kTVMFFIFieldFlagBitMaskDefaultFromFactory = 1 << 5
 
     ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept
     ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) 
noexcept
@@ -218,7 +219,7 @@ cdef extern from "tvm/ffi/c_api.h":
         int64_t offset
         TVMFFIFieldGetter getter
         TVMFFIFieldSetter setter
-        TVMFFIAny default_value
+        TVMFFIAny default_value_or_factory
         int32_t field_static_type_index
 
     ctypedef struct TVMFFIMethodInfo:
diff --git a/rust/tvm-ffi-sys/src/c_api.rs b/rust/tvm-ffi-sys/src/c_api.rs
index e0bf085..d94967b 100644
--- a/rust/tvm-ffi-sys/src/c_api.rs
+++ b/rust/tvm-ffi-sys/src/c_api.rs
@@ -286,9 +286,11 @@ pub struct TVMFFIFieldInfo {
     /// The setter to access the field
     /// The setter is set even if the field is readonly for serialization
     pub setter: Option<TVMFFIFieldSetter>,
-    /// The default value of the field, this field hold AnyView,
-    /// valid when flags set kTVMFFIFieldFlagBitMaskHasDefault
-    pub default_value: TVMFFIAny,
+    /// The default value or factory of the field, this field holds AnyView.
+    /// Valid when flags set kTVMFFIFieldFlagBitMaskHasDefault.
+    /// When kTVMFFIFieldFlagBitMaskDefaultFromFactory is also set,
+    /// this is a callable factory function () -> Any.
+    pub default_value_or_factory: TVMFFIAny,
     /// Records the static type kind of the field
     pub field_static_type_index: i32,
 }
diff --git a/src/ffi/extra/reflection_extra.cc 
b/src/ffi/extra/reflection_extra.cc
index 02d422d..6699416 100644
--- a/src/ffi/extra/reflection_extra.cc
+++ b/src/ffi/extra/reflection_extra.cc
@@ -22,6 +22,7 @@
  * \brief Extra reflection registrations. *
  */
 #include <tvm/ffi/reflection/access_path.h>
+#include <tvm/ffi/reflection/accessor.h>
 #include <tvm/ffi/reflection/registry.h>
 
 namespace tvm {
@@ -78,7 +79,7 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) 
{
         field_info->setter(field_addr, reinterpret_cast<const 
TVMFFIAny*>(&field_value));
         keys_found[arg_index] = true;
       } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
-        field_info->setter(field_addr, &(field_info->default_value));
+        reflection::SetFieldToDefault(field_info, field_addr);
       } else {
         TVM_FFI_THROW(TypeError) << "Required field `"
                                  << String(field_info->name.data, 
field_info->name.size)
diff --git a/src/ffi/extra/serialization.cc b/src/ffi/extra/serialization.cc
index 1e5e98c..f639d87 100644
--- a/src/ffi/extra/serialization.cc
+++ b/src/ffi/extra/serialization.cc
@@ -396,7 +396,7 @@ class ObjectGraphDeserializer {
         Any field_value = decode_field_value(field_info, 
data_object[field_name]);
         field_info->setter(field_addr, reinterpret_cast<const 
TVMFFIAny*>(&field_value));
       } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
-        field_info->setter(field_addr, &(field_info->default_value));
+        reflection::SetFieldToDefault(field_info, field_addr);
       } else {
         TVM_FFI_THROW(TypeError) << "Required field `"
                                  << String(field_info->name.data, 
field_info->name.size)
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index 8c5c137..7e5e00e 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -211,10 +211,11 @@ class TypeTable {
     field_data.doc = this->CopyString(info->doc);
     field_data.metadata = this->CopyString(info->metadata);
     if (info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
-      field_data.default_value =
-          
this->CopyAny(AnyView::CopyFromTVMFFIAny(info->default_value)).CopyToTVMFFIAny();
+      field_data.default_value_or_factory =
+          
this->CopyAny(AnyView::CopyFromTVMFFIAny(info->default_value_or_factory))
+              .CopyToTVMFFIAny();
     } else {
-      field_data.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+      field_data.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
     }
     entry->type_fields_data.push_back(field_data);
     // refresh ptr as the data can change
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index 82d73a2..e77368c 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -18,6 +18,7 @@
  * under the License.
  */
 #include <gtest/gtest.h>
+#include <tvm/ffi/container/array.h>
 #include <tvm/ffi/container/map.h>
 #include <tvm/ffi/object.h>
 #include <tvm/ffi/reflection/access_path.h>
@@ -101,13 +102,14 @@ TEST(Reflection, FieldInfo) {
   EXPECT_EQ(Bytes(info_int->doc).operator std::string(), "");
 
   const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float", 
"value");
-  EXPECT_EQ(info_float->default_value.v_float64, 10.0);
+  EXPECT_EQ(info_float->default_value_or_factory.v_float64, 10.0);
   EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskHasDefault);
   EXPECT_FALSE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable);
   EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value 
field");
 
   const TVMFFIFieldInfo* info_prim_expr_dtype = 
reflection::GetFieldInfo("test.PrimExpr", "dtype");
-  AnyView default_value = 
AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value);
+  AnyView default_value =
+      
AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value_or_factory);
   EXPECT_EQ(default_value.cast<String>(), "float");
   EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskHasDefault);
   EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable);
@@ -295,6 +297,25 @@ TEST(Reflection, AccessPath) {
   EXPECT_FALSE(root_parent.has_value());
 }
 
+struct TestObjWithFactory : public Object {
+  Array<ObjectRef> items;
+  int64_t count;
+
+  explicit TestObjWithFactory(UnsafeInit) {}
+
+  [[maybe_unused]] static constexpr bool _type_mutable = true;
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.TestObjWithFactory", 
TestObjWithFactory, Object);
+};
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::ObjectDef<TestObjWithFactory>()
+      .def_ro("items", &TestObjWithFactory::items,
+              refl::DefaultFactory(
+                  Function::FromTyped([]() -> Array<ObjectRef> { return 
Array<ObjectRef>(); })))
+      .def_ro("count", &TestObjWithFactory::count, 
refl::DefaultValue(static_cast<int64_t>(0)));
+}
+
 struct TestObjWithAny : public Object {
   Any value;
   explicit TestObjWithAny(Any value) : value(std::move(value)) {}
@@ -350,4 +371,51 @@ TEST(Reflection, InitWithAnyView) {
   ASSERT_TRUE(obj3.as<TestObjWithAnyView>() != nullptr);
   EXPECT_EQ(obj3.as<TestObjWithAnyView>()->value.cast<String>(), "hello");
 }
+TEST(Reflection, DefaultFactoryFlag) {
+  const TVMFFIFieldInfo* info_items = 
reflection::GetFieldInfo("test.TestObjWithFactory", "items");
+  EXPECT_TRUE(info_items->flags & kTVMFFIFieldFlagBitMaskHasDefault);
+  EXPECT_TRUE(info_items->flags & kTVMFFIFieldFlagBitMaskDefaultFromFactory);
+
+  const TVMFFIFieldInfo* info_count = 
reflection::GetFieldInfo("test.TestObjWithFactory", "count");
+  EXPECT_TRUE(info_count->flags & kTVMFFIFieldFlagBitMaskHasDefault);
+  EXPECT_FALSE(info_count->flags & kTVMFFIFieldFlagBitMaskDefaultFromFactory);
+}
+
+TEST(Reflection, DefaultFactoryCreation) {
+  namespace refl = tvm::ffi::reflection;
+  refl::ObjectCreator creator("test.TestObjWithFactory");
+
+  // Create two objects without providing "items" - each should get a fresh 
Array
+  Any obj1 = creator(Map<String, Any>({{"count", static_cast<int64_t>(42)}}));
+  Any obj2 = creator(Map<String, Any>({{"count", static_cast<int64_t>(99)}}));
+
+  auto* p1 = obj1.as<TestObjWithFactory>();
+  auto* p2 = obj2.as<TestObjWithFactory>();
+
+  ASSERT_NE(p1, nullptr);
+  ASSERT_NE(p2, nullptr);
+  EXPECT_EQ(p1->count, 42);
+  EXPECT_EQ(p2->count, 99);
+  // Both should have empty arrays
+  EXPECT_EQ(p1->items.size(), 0);
+  EXPECT_EQ(p2->items.size(), 0);
+  // Crucially, the arrays should be distinct objects (not aliased)
+  EXPECT_NE(p1->items.get(), p2->items.get());
+}
+
+TEST(Reflection, DefaultFactoryNotCalledWhenProvided) {
+  namespace refl = tvm::ffi::reflection;
+  refl::ObjectCreator creator("test.TestObjWithFactory");
+
+  Array<ObjectRef> custom_items;
+  custom_items.push_back(TInt(1));
+  Any obj =
+      creator(Map<String, Any>({{"items", custom_items}, {"count", 
static_cast<int64_t>(5)}}));
+
+  auto* p = obj.as<TestObjWithFactory>();
+  ASSERT_NE(p, nullptr);
+  EXPECT_EQ(p->items.size(), 1);
+  EXPECT_EQ(p->count, 5);
+}
+
 }  // namespace

Reply via email to