This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c7c5150d2 [FFI] Introduce FFI reflection support in python (#18065)
3c7c5150d2 is described below

commit 3c7c5150d2285a8899c5ecd1a71400b3955de6ce
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Jun 17 08:25:53 2025 -0400

    [FFI] Introduce FFI reflection support in python (#18065)
    
    This PR brings up new reflection support in python.
    The new reflection now directly attaches property
    and methods to the class object themselves, making more
    efficient accessing than old mechanism.
    
    It will also support broader set of value types
    that are compatible with the FFI system.
    
    For now the old mechanism and new mechanism will co-exist,
    and we will phase out old mechanism as we migrate most needed
    features into new one.
---
 ffi/include/tvm/ffi/memory.h                |   8 +-
 ffi/include/tvm/ffi/reflection/reflection.h | 159 +++++++++++++++++-----------
 ffi/include/tvm/ffi/string.h                |  12 +++
 ffi/src/ffi/object.cc                       |  77 ++++++++++++++
 ffi/src/ffi/testing.cc                      |  42 ++++++++
 ffi/tests/cpp/test_reflection.cc            |  10 +-
 python/tvm/ffi/__init__.py                  |   1 +
 python/tvm/ffi/cython/base.pxi              |  47 ++++++++
 python/tvm/ffi/cython/error.pxi             |   1 +
 python/tvm/ffi/cython/function.pxi          |  95 +++++++++++++++++
 python/tvm/ffi/cython/ndarray.pxi           |   2 +-
 python/tvm/ffi/registry.py                  |   1 +
 python/tvm/ffi/testing.py                   |  63 +++++++++++
 tests/python/ffi/test_object.py             |  70 ++++++++++++
 14 files changed, 519 insertions(+), 69 deletions(-)

diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h
index eb317d2bbd..02537df79c 100644
--- a/ffi/include/tvm/ffi/memory.h
+++ b/ffi/include/tvm/ffi/memory.h
@@ -70,7 +70,7 @@ class ObjAllocatorBase {
    * \param args The arguments.
    */
   template <typename T, typename... Args>
-  inline ObjectPtr<T> make_object(Args&&... args) {
+  ObjectPtr<T> make_object(Args&&... args) {
     using Handler = typename Derived::template Handler<T>;
     static_assert(std::is_base_of<Object, T>::value, "make can only be used to 
create Object");
     T* ptr = Handler::New(static_cast<Derived*>(this), 
std::forward<Args>(args)...);
@@ -89,7 +89,7 @@ class ObjAllocatorBase {
    * \param args The arguments.
    */
   template <typename ArrayType, typename ElemType, typename... Args>
-  inline ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... 
args) {
+  ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... args) {
     using Handler = typename Derived::template ArrayHandler<ArrayType, 
ElemType>;
     static_assert(std::is_base_of<Object, ArrayType>::value,
                   "make_inplace_array can only be used to create Object");
@@ -109,7 +109,9 @@ class SimpleObjAllocator : public 
ObjAllocatorBase<SimpleObjAllocator> {
   template <typename T>
   class Handler {
    public:
-    using StorageType = typename std::aligned_storage<sizeof(T), 
alignof(T)>::type;
+    struct alignas(T) StorageType {
+      char data[sizeof(T)];
+    };
 
     template <typename... Args>
     static T* New(SimpleObjAllocator*, Args&&... args) {
diff --git a/ffi/include/tvm/ffi/reflection/reflection.h 
b/ffi/include/tvm/ffi/reflection/reflection.h
index bd2f5cb9c7..6187a74825 100644
--- a/ffi/include/tvm/ffi/reflection/reflection.h
+++ b/ffi/include/tvm/ffi/reflection/reflection.h
@@ -46,7 +46,7 @@ class DefaultValue : public FieldInfoTrait {
  public:
   explicit DefaultValue(Any value) : value_(value) {}
 
-  void Apply(TVMFFIFieldInfo* info) const {
+  TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
     info->default_value = AnyView(value_).CopyToTVMFFIAny();
     info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
   }
@@ -65,16 +65,89 @@ class DefaultValue : public FieldInfoTrait {
  * \returns The byteoffset
  */
 template <typename Class, typename T>
-inline int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) {
+TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) {
   int64_t field_offset_to_class =
       reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
   return field_offset_to_class - 
details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
 }
 
+class ReflectionDefBase {
+ protected:
+  template <typename T>
+  static int FieldGetter(void* field, TVMFFIAny* result) {
+    TVM_FFI_SAFE_CALL_BEGIN();
+    *result = 
details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
+    TVM_FFI_SAFE_CALL_END();
+  }
+
+  template <typename T>
+  static int FieldSetter(void* field, const TVMFFIAny* value) {
+    TVM_FFI_SAFE_CALL_BEGIN();
+    *reinterpret_cast<T*>(field) = 
AnyView::CopyFromTVMFFIAny(*value).cast<T>();
+    TVM_FFI_SAFE_CALL_END();
+  }
+
+  template <typename T>
+  static int ObjectCreatorDefault(TVMFFIObjectHandle* result) {
+    TVM_FFI_SAFE_CALL_BEGIN();
+    ObjectPtr<T> obj = make_object<T>();
+    *result = 
details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
+    TVM_FFI_SAFE_CALL_END();
+  }
+
+  template <typename T>
+  static TVM_FFI_INLINE void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const 
T& value) {
+    if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
+      value.Apply(info);
+    }
+    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
+      info->doc = TVMFFIByteArray{value, 
std::char_traits<char>::length(value)};
+    }
+  }
+
+  template <typename T>
+  static TVM_FFI_INLINE void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, 
const T& value) {
+    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
+      info->doc = TVMFFIByteArray{value, 
std::char_traits<char>::length(value)};
+    }
+  }
+
+  template <typename T>
+  static TVM_FFI_INLINE void ApplyExtraInfoTrait(TVMFFITypeExtraInfo* info, 
const T& value) {
+    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
+      info->doc = TVMFFIByteArray{value, 
std::char_traits<char>::length(value)};
+    }
+  }
+  template <typename Class, typename R, typename... Args>
+  static TVM_FFI_INLINE Function GetMethod(std::string name, R 
(Class::*func)(Args...)) {
+    auto fwrap = [func](const Class* target, Args... params) -> R {
+      return 
(const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
+    };
+    return ffi::Function::FromTyped(fwrap, name);
+  }
+
+  template <typename Class, typename R, typename... Args>
+  static TVM_FFI_INLINE Function GetMethod(std::string name, R 
(Class::*func)(Args...) const) {
+    auto fwrap = [func](const Class* target, Args... params) -> R {
+      return (target->*func)(std::forward<Args>(params)...);
+    };
+    return ffi::Function::FromTyped(fwrap, name);
+  }
+
+  template <typename Class, typename Func>
+  static TVM_FFI_INLINE Function GetMethod(std::string name, Func&& func) {
+    return ffi::Function::FromTyped(std::forward<Func>(func), name);
+  }
+};
+
 template <typename Class>
-class ObjectDef {
+class ObjectDef : public ReflectionDefBase {
  public:
-  ObjectDef() : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), 
type_key_(Class::_type_key) {}
+  template <typename... ExtraArgs>
+  explicit ObjectDef(ExtraArgs&&... extra_args)
+      : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), 
type_key_(Class::_type_key) {
+    RegisterExtraInfo(std::forward<ExtraArgs>(extra_args)...);
+  }
 
   /*!
    * \brief Define a readonly field.
@@ -90,7 +163,7 @@ class ObjectDef {
    * \return The reflection definition.
    */
   template <typename T, typename... Extra>
-  ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) {
+  TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T Class::*field_ptr, 
Extra&&... extra) {
     RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
     return *this;
   }
@@ -109,7 +182,8 @@ class ObjectDef {
    * \return The reflection definition.
    */
   template <typename T, typename... Extra>
-  ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) {
+  TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T Class::*field_ptr, 
Extra&&... extra) {
+    static_assert(Class::_type_mutable, "Only mutable classes are supported 
for writable fields");
     RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
     return *this;
   }
@@ -127,7 +201,7 @@ class ObjectDef {
    * \return The reflection definition.
    */
   template <typename Func, typename... Extra>
-  ObjectDef& def(const char* name, Func&& func, Extra&&... extra) {
+  TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... 
extra) {
     RegisterMethod(name, false, std::forward<Func>(func), 
std::forward<Extra>(extra)...);
     return *this;
   }
@@ -145,12 +219,26 @@ class ObjectDef {
    * \return The reflection definition.
    */
   template <typename Func, typename... Extra>
-  ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) {
+  TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, 
Extra&&... extra) {
     RegisterMethod(name, true, std::forward<Func>(func), 
std::forward<Extra>(extra)...);
     return *this;
   }
 
  private:
+  template <typename... ExtraArgs>
+  void RegisterExtraInfo(ExtraArgs&&... extra_args) {
+    TVMFFITypeExtraInfo info;
+    info.total_size = sizeof(Class);
+    info.creator = nullptr;
+    info.doc = TVMFFIByteArray{nullptr, 0};
+    if constexpr (std::is_default_constructible_v<Class>) {
+      info.creator = ObjectCreatorDefault<Class>;
+    }
+    // apply extra info traits
+    ((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info));
+  }
+
   template <typename T, typename... ExtraArgs>
   void RegisterField(const char* name, T Class::*field_ptr, bool writable,
                      ExtraArgs&&... extra_args) {
@@ -178,30 +266,6 @@ class ObjectDef {
     TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
   }
 
-  template <typename T>
-  static int FieldGetter(void* field, TVMFFIAny* result) {
-    TVM_FFI_SAFE_CALL_BEGIN();
-    *result = 
details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
-    TVM_FFI_SAFE_CALL_END();
-  }
-
-  template <typename T>
-  static int FieldSetter(void* field, const TVMFFIAny* value) {
-    TVM_FFI_SAFE_CALL_BEGIN();
-    *reinterpret_cast<T*>(field) = 
AnyView::CopyFromTVMFFIAny(*value).cast<T>();
-    TVM_FFI_SAFE_CALL_END();
-  }
-
-  template <typename T>
-  static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) {
-    if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
-      value.Apply(info);
-    }
-    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
-      info->doc = TVMFFIByteArray{value, 
std::char_traits<char>::length(value)};
-    }
-  }
-
   // register a method
   template <typename Func, typename... Extra>
   void RegisterMethod(const char* name, bool is_static, Func&& func, 
Extra&&... extra) {
@@ -214,41 +278,14 @@ class ObjectDef {
       info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
     }
     // obtain the method function
-    Function method = GetMethod(std::string(type_key_) + "." + name, 
std::forward<Func>(func));
+    Function method =
+        GetMethod<Class>(std::string(type_key_) + "." + name, 
std::forward<Func>(func));
     info.method = AnyView(method).CopyToTVMFFIAny();
     // apply method info traits
     ((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
     TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
   }
 
-  template <typename T>
-  static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) {
-    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
-      info->doc = TVMFFIByteArray{value, 
std::char_traits<char>::length(value)};
-    }
-  }
-
-  template <typename R, typename... Args>
-  static Function GetMethod(std::string name, R (Class::*func)(Args...)) {
-    auto fwrap = [func](const Class* target, Args... params) -> R {
-      return 
(const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
-    };
-    return ffi::Function::FromTyped(fwrap, name);
-  }
-
-  template <typename R, typename... Args>
-  static Function GetMethod(std::string name, R (Class::*func)(Args...) const) 
{
-    auto fwrap = [func](const Class* target, Args... params) -> R {
-      return (target->*func)(std::forward<Args>(params)...);
-    };
-    return ffi::Function::FromTyped(fwrap, name);
-  }
-
-  template <typename Func>
-  static Function GetMethod(std::string name, Func&& func) {
-    return ffi::Function::FromTyped(std::forward<Func>(func), name);
-  }
-
   int32_t type_index_;
   const char* type_key_;
 };
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index 19df2e8e3d..dee2d89c08 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -306,6 +306,18 @@ class String : public ObjectRef {
     return Bytes::memncmp(data(), other, size(), std::strlen(other));
   }
 
+  /*!
+   * \brief Compares this to other
+   *
+   * \param other The TVMFFIByteArray to compare with.
+   *
+   * \return zero if both char sequences compare equal. negative if this appear
+   * before other, positive otherwise.
+   */
+  int compare(const TVMFFIByteArray& other) const {
+    return Bytes::memncmp(data(), other.data, size(), other.size);
+  }
+
   /*!
    * \brief Returns a pointer to the char array in the string.
    *
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 793d3e2728..fa77e2b264 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -315,6 +315,83 @@ class TypeTable {
   Map<String, int64_t> type_key2index_;
   std::vector<Any> any_pool_;
 };
+
+void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
+  String type_key = args[0].cast<String>();
+  TVM_FFI_ICHECK(args.size() % 2 == 1);
+
+  int32_t type_index;
+  TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), 
type_key.size()};
+  TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
+  const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
+  if (type_info == nullptr) {
+    TVM_FFI_THROW(RuntimeError) << "Cannot find type `" << type_key << "`";
+  }
+
+  if (type_info->extra_info == nullptr || type_info->extra_info->creator == 
nullptr) {
+    TVM_FFI_THROW(RuntimeError) << "Type `" << type_key << "` does not support 
reflection creation";
+  }
+  TVMFFIObjectHandle handle;
+  TVM_FFI_CHECK_SAFE_CALL(type_info->extra_info->creator(&handle));
+  ObjectPtr<Object> ptr =
+      
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+
+  std::vector<String> keys;
+  std::vector<bool> keys_found;
+
+  for (int i = 1; i < args.size(); i += 2) {
+    keys.push_back(args[i].cast<String>());
+  }
+  keys_found.resize(keys.size(), false);
+
+  auto search_field = [&](const TVMFFIByteArray& field_name) {
+    for (size_t i = 0; i < keys.size(); ++i) {
+      if (keys_found[i]) continue;
+      if (keys[i].compare(field_name) == 0) {
+        return i;
+      }
+    }
+    return keys.size();
+  };
+
+  auto update_fields = [&](const TVMFFITypeInfo* tinfo) {
+    for (int i = 0; i < tinfo->num_fields; ++i) {
+      const TVMFFIFieldInfo* field_info = tinfo->fields + i;
+      size_t arg_index = search_field(field_info->name);
+      void* field_addr = reinterpret_cast<char*>(ptr.get()) + 
field_info->offset;
+      if (arg_index < keys.size()) {
+        AnyView field_value = args[arg_index * 2 + 2];
+        field_info->setter(field_addr, reinterpret_cast<const 
TVMFFIAny*>(&field_value));
+        keys_found[arg_index] = true;
+      } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
+        field_info->setter(field_addr, &(field_info->default_value));
+      } else {
+        TVM_FFI_THROW(TypeError) << "Required field `"
+                                 << String(field_info->name.data, 
field_info->name.size)
+                                 << "` not set in type `" << type_key << "`";
+      }
+    }
+  };
+
+  // iterate through acenstors in parent to child order
+  // skip the first one since it is always the root object
+  TVM_FFI_ICHECK(type_info->type_acenstors[0] == TypeIndex::kTVMFFIObject);
+  for (int i = 1; i < type_info->type_depth; ++i) {
+    update_fields(TVMFFIGetTypeInfo(type_info->type_acenstors[i]));
+  }
+  update_fields(type_info);
+
+  for (size_t i = 0; i < keys.size(); ++i) {
+    if (!keys_found[i]) {
+      TVM_FFI_THROW(TypeError) << "Type `" << type_key << "` does not have 
field `" << keys[i]
+                               << "`";
+    }
+  }
+  *ret = ObjectRef(ptr);
+}
+
+TVM_FFI_REGISTER_GLOBAL("ffi.MakeObjectFromPackedArgs").set_body_packed(MakeObjectFromPackedArgs);
+
 }  // namespace ffi
 }  // namespace tvm
 
diff --git a/ffi/src/ffi/testing.cc b/ffi/src/ffi/testing.cc
index 050ac28c47..6bc7968eab 100644
--- a/ffi/src/ffi/testing.cc
+++ b/ffi/src/ffi/testing.cc
@@ -17,7 +17,10 @@
  * under the License.
  */
 // This file is used for testing the FFI API.
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
 #include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/reflection.h>
 
 #include <chrono>
 #include <iostream>
@@ -26,6 +29,45 @@
 namespace tvm {
 namespace ffi {
 
+class TestObjectBase : public Object {
+ public:
+  int64_t v_i64;
+  double v_f64;
+  String v_str;
+
+  int64_t AddI64(int64_t other) const { return v_i64 + other; }
+
+  // declare as one slot, with float as overflow
+  static constexpr bool _type_mutable = true;
+  static constexpr uint32_t _type_child_slots = 1;
+  static constexpr const char* _type_key = "testing.TestObjectBase";
+  TVM_FFI_DECLARE_BASE_OBJECT_INFO(TestObjectBase, Object);
+};
+
+class TestObjectDerived : public TestObjectBase {
+ public:
+  Map<Any, Any> v_map;
+  Array<Any> v_array;
+
+  // declare as one slot, with float as overflow
+  static constexpr const char* _type_key = "testing.TestObjectDerived";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjectDerived, TestObjectBase);
+};
+
+TVM_FFI_STATIC_INIT_BLOCK({
+  namespace refl = tvm::ffi::reflection;
+
+  refl::ObjectDef<TestObjectBase>()
+      .def_rw("v_i64", &TestObjectBase::v_i64, refl::DefaultValue(10), "i64 
field")
+      .def_ro("v_f64", &TestObjectBase::v_f64, refl::DefaultValue(10.0))
+      .def_rw("v_str", &TestObjectBase::v_str, refl::DefaultValue("hello"))
+      .def("add_i64", &TestObjectBase::AddI64, "add_i64 method");
+
+  refl::ObjectDef<TestObjectDerived>()
+      .def_ro("v_map", &TestObjectDerived::v_map)
+      .def_ro("v_array", &TestObjectDerived::v_array);
+});
+
 void TestRaiseError(String kind, String msg) {
   throw ffi::Error(kind, msg, TVM_FFI_TRACEBACK_HERE);
 }
diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc
index 64b3a6f590..450cb9dbcb 100644
--- a/ffi/tests/cpp/test_reflection.cc
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -32,13 +32,15 @@ using namespace tvm::ffi::testing;
 struct A : public Object {
   int64_t x;
   int64_t y;
+
+  static constexpr bool _type_mutable = true;
 };
 
 TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
 
   refl::ObjectDef<TFloatObj>()
-      .def_rw("value", &TFloatObj::value, "float value field", 
refl::DefaultValue(10.0))
+      .def_ro("value", &TFloatObj::value, "float value field", 
refl::DefaultValue(10.0))
       .def("sub", [](const TFloatObj* self, double other) -> double { return 
self->value - other; })
       .def("add", &TFloatObj::Add, "add method");
 
@@ -47,7 +49,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
       .def_static("static_add", &TInt::StaticAdd, "static add method");
 
   refl::ObjectDef<TPrimExprObj>()
-      .def_ro("dtype", &TPrimExprObj::dtype, "dtype field", 
refl::DefaultValue("float"))
+      .def_rw("dtype", &TPrimExprObj::dtype, "dtype field", 
refl::DefaultValue("float"))
       .def_ro("value", &TPrimExprObj::value, "value field", 
refl::DefaultValue(0))
       .def("sub", [](TPrimExprObj* self, double other) -> double {
         // this is ok because TPrimExprObj is declared asmutable
@@ -89,7 +91,7 @@ TEST(Reflection, FieldInfo) {
   const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float", 
"value");
   EXPECT_EQ(info_float->default_value.v_float64, 10.0);
   EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskHasDefault);
-  EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable);
+  EXPECT_FALSE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable);
   EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value 
field");
 
   const TVMFFIFieldInfo* info_prim_expr_dtype = 
reflection::GetFieldInfo("test.PrimExpr", "dtype");
@@ -97,7 +99,7 @@ TEST(Reflection, FieldInfo) {
   EXPECT_EQ(default_value.cast<String>(), "float");
   EXPECT_EQ(default_value.as<String>().value().use_count(), 2);
   EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskHasDefault);
-  EXPECT_FALSE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable);
+  EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable);
   EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype 
field");
 }
 
diff --git a/python/tvm/ffi/__init__.py b/python/tvm/ffi/__init__.py
index 0a8b223405..b507064e34 100644
--- a/python/tvm/ffi/__init__.py
+++ b/python/tvm/ffi/__init__.py
@@ -30,6 +30,7 @@ from .ndarray import Device, device
 from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, 
hexagon, webgpu
 from .ndarray import from_dlpack, NDArray, Shape
 from .container import Array, Map
+from . import testing
 
 
 __all__ = [
diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi
index e18d52fc8d..50831be462 100644
--- a/python/tvm/ffi/cython/base.pxi
+++ b/python/tvm/ffi/cython/base.pxi
@@ -134,6 +134,52 @@ cdef extern from "tvm/ffi/c_api.h":
         void* handle, const TVMFFIAny* args, int32_t num_args,
         TVMFFIAny* result) noexcept
 
+    cdef enum TVMFFIFieldFlagBitMask:
+        kTVMFFIFieldFlagBitMaskWritable = 1 << 0
+        kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1
+        kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2
+
+    ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept;
+    ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) 
noexcept;
+    ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept;
+
+    ctypedef struct TVMFFIFieldInfo:
+        TVMFFIByteArray name
+        TVMFFIByteArray doc
+        TVMFFIByteArray type_schema
+        int64_t flags
+        int64_t size
+        int64_t alignment
+        int64_t offset
+        TVMFFIFieldGetter getter
+        TVMFFIFieldSetter setter
+        TVMFFIAny default_value
+        int32_t field_static_type_index
+
+    ctypedef struct TVMFFIMethodInfo:
+        TVMFFIByteArray name
+        TVMFFIByteArray doc
+        TVMFFIByteArray type_schema
+        int64_t flags
+        TVMFFIAny method
+
+    ctypedef struct TVMFFITypeExtraInfo:
+        TVMFFIByteArray doc
+        TVMFFIObjectCreator creator
+        int64_t total_size
+
+    ctypedef struct TVMFFITypeInfo:
+        int32_t type_index
+        int32_t type_depth
+        TVMFFIByteArray type_key
+        const int32_t* type_acenstors
+        uint64_t type_key_hash
+        int32_t num_fields
+        int32_t num_methods
+        const TVMFFIFieldInfo* fields
+        const TVMFFIMethodInfo* methods
+        const TVMFFITypeExtraInfo* extra_info
+
     int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil
     int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil
     int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t 
num_args,
@@ -161,6 +207,7 @@ cdef extern from "tvm/ffi/c_api.h":
     int TVMFFINDArrayToDLPack(TVMFFIObjectHandle src, DLManagedTensor** out) 
nogil
     int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle src,
                                         DLManagedTensorVersioned** out) nogil
+    const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil
     TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil
     TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil
     TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil
diff --git a/python/tvm/ffi/cython/error.pxi b/python/tvm/ffi/cython/error.pxi
index 3a19573b8f..8da630873e 100644
--- a/python/tvm/ffi/cython/error.pxi
+++ b/python/tvm/ffi/cython/error.pxi
@@ -113,6 +113,7 @@ cdef class Error(Object):
     def traceback(self):
         return 
bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).traceback))
 
+
 _register_object_by_index(kTVMFFIError, Error)
 
 
diff --git a/python/tvm/ffi/cython/function.pxi 
b/python/tvm/ffi/cython/function.pxi
index 294a1246b2..640fff7af5 100644
--- a/python/tvm/ffi/cython/function.pxi
+++ b/python/tvm/ffi/cython/function.pxi
@@ -230,6 +230,101 @@ class Function(Object):
 _register_object_by_index(kTVMFFIFunction, Function)
 
 
+cdef class FieldGetter:
+    cdef TVMFFIFieldGetter getter
+    cdef int64_t offset
+
+    def __call__(self, Object obj):
+        cdef TVMFFIAny result
+        cdef int c_api_ret_code
+        cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
+        result.type_index = kTVMFFINone
+        result.v_int64 = 0
+        c_api_ret_code = self.getter(field_ptr, &result)
+        CHECK_CALL(c_api_ret_code)
+        return make_ret(result)
+
+
+cdef class FieldSetter:
+    cdef TVMFFIFieldSetter setter
+    cdef int64_t offset
+
+    def __call__(self, Object obj, value):
+        cdef TVMFFIAny[1] packed_args
+        cdef int c_api_ret_code
+        cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
+        cdef int nargs = 1
+        temp_args = []
+        make_args((value,), &packed_args[0], temp_args)
+        c_api_ret_code = self.setter(field_ptr, &packed_args[0])
+        # NOTE: logic is same as check_call
+        # directly inline here to simplify traceback
+        if c_api_ret_code == 0:
+            return
+        elif c_api_ret_code == -2:
+            raise_existing_error()
+        raise move_from_last_error().py_error()
+
+
+cdef _get_method_from_method_info(const TVMFFIMethodInfo* method):
+    cdef TVMFFIAny result
+    CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result))
+    return make_ret(result)
+
+
+def _add_class_attrs_by_reflection(int type_index, object cls):
+    """Decorate the class attrs by reflection"""
+    cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index)
+    cdef const TVMFFIFieldInfo* field
+    cdef const TVMFFIMethodInfo* method
+    cdef int num_fields = info.num_fields
+    cdef int num_methods = info.num_methods
+
+    for i in range(num_fields):
+        # attach fields to the class
+        field = &(info.fields[i])
+        getter = FieldGetter.__new__(FieldGetter)
+        (<FieldGetter>getter).getter = field.getter
+        (<FieldGetter>getter).offset = field.offset
+        setter = FieldSetter.__new__(FieldSetter)
+        (<FieldSetter>setter).setter = field.setter
+        (<FieldSetter>setter).offset = field.offset
+        if (field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0:
+            setter = None
+        doc = (
+            py_str(PyBytes_FromStringAndSize(field.doc.data, field.doc.size))
+            if field.doc.size != 0
+            else None
+        )
+        name = py_str(PyBytes_FromStringAndSize(field.name.data, 
field.name.size))
+        setattr(cls, name, property(getter, setter, doc=doc))
+
+    for i in range(num_methods):
+        # attach methods to the class
+        method = &(info.methods[i])
+        name = py_str(PyBytes_FromStringAndSize(method.name.data, 
method.name.size))
+        doc = (
+            py_str(PyBytes_FromStringAndSize(method.doc.data, method.doc.size))
+            if method.doc.size != 0
+            else None
+        )
+        method_func = _get_method_from_method_info(method)
+
+        if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod:
+            method_pyfunc = staticmethod(method_func)
+        else:
+            def method_pyfunc(self, *args):
+                return method_func(self, *args)
+
+        if doc is not None:
+            method_pyfunc.__doc__ = doc
+            method_pyfunc.__name__ = name
+
+        setattr(cls, name, method_pyfunc)
+
+    return cls
+
+
 def _register_global_func(name, pyfunc, override):
     cdef TVMFFIObjectHandle chandle
     cdef int c_api_ret_code
diff --git a/python/tvm/ffi/cython/ndarray.pxi 
b/python/tvm/ffi/cython/ndarray.pxi
index b8534b41b3..9dfe1222dc 100644
--- a/python/tvm/ffi/cython/ndarray.pxi
+++ b/python/tvm/ffi/cython/ndarray.pxi
@@ -23,7 +23,6 @@ _CLASS_NDARRAY = None
 def _set_class_ndarray(cls):
     global _CLASS_NDARRAY
     _CLASS_NDARRAY = cls
-    _register_object_by_index(kTVMFFINDArray, cls)
 
 
 cdef const char* _c_str_dltensor = "dltensor"
@@ -268,6 +267,7 @@ cdef class NDArray(Object):
 
 
 _set_class_ndarray(NDArray)
+_register_object_by_index(kTVMFFINDArray, NDArray)
 
 
 cdef inline object make_ret_dltensor(TVMFFIAny result):
diff --git a/python/tvm/ffi/registry.py b/python/tvm/ffi/registry.py
index 58df08d90c..9302b25173 100644
--- a/python/tvm/ffi/registry.py
+++ b/python/tvm/ffi/registry.py
@@ -50,6 +50,7 @@ def register_object(type_key=None):
             if _SKIP_UNKNOWN_OBJECTS:
                 return cls
             raise ValueError("Cannot find object type index for %s" % 
object_name)
+        core._add_class_attrs_by_reflection(type_index, cls)
         core._register_object_by_index(type_index, cls)
         return cls
 
diff --git a/python/tvm/ffi/testing.py b/python/tvm/ffi/testing.py
new file mode 100644
index 0000000000..843a10c896
--- /dev/null
+++ b/python/tvm/ffi/testing.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Testing utilities."""
+
+from . import _ffi_api
+from .core import Object
+from .registry import register_object
+
+
+@register_object("testing.TestObjectBase")
+class TestObjectBase(Object):
+    """
+    Test object base class.
+    """
+
+
+@register_object("testing.TestObjectDerived")
+class TestObjectDerived(TestObjectBase):
+    """
+    Test object derived class.
+    """
+
+
+def create_object(type_key: str, **kwargs) -> Object:
+    """
+    Make an object by reflection.
+
+    Parameters
+    ----------
+    type_key : str
+        The type key of the object.
+    kwargs : dict
+        The keyword arguments to the object.
+
+    Returns
+    -------
+    obj : object
+        The created object.
+
+    Note
+    ----
+    This function is only used for testing purposes and should
+    not be used in other cases.
+    """
+    args = [type_key]
+    for k, v in kwargs.items():
+        args.append(k)
+        args.append(v)
+    return _ffi_api.MakeObjectFromPackedArgs(*args)
diff --git a/tests/python/ffi/test_object.py b/tests/python/ffi/test_object.py
new file mode 100644
index 0000000000..d333cbca08
--- /dev/null
+++ b/tests/python/ffi/test_object.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from tvm import ffi as tvm_ffi
+
+
+def test_make_object():
+    # with default values
+    obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase")
+    assert obj0.v_i64 == 10
+    assert obj0.v_f64 == 10.0
+    assert obj0.v_str == "hello"
+
+
+def test_method():
+    obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
+    assert obj0.add_i64(1) == 13
+    assert type(obj0).add_i64.__doc__ == "add_i64 method"
+    assert type(obj0).v_i64.__doc__ == "i64 field"
+
+
+def test_setter():
+    # test setter
+    obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10, 
v_str="hello")
+    assert obj0.v_i64 == 10
+    obj0.v_i64 = 11
+    assert obj0.v_i64 == 11
+    obj0.v_str = "world"
+    assert obj0.v_str == "world"
+
+    with pytest.raises(TypeError):
+        obj0.v_str = 1
+
+    with pytest.raises(TypeError):
+        obj0.v_i64 = "hello"
+
+
+def test_derived_object():
+    with pytest.raises(TypeError):
+        obj0 = tvm_ffi.testing.create_object("testing.TestObjectDerived")
+
+    v_map = tvm_ffi.convert({"a": 1})
+    v_array = tvm_ffi.convert([1, 2, 3])
+
+    obj0 = tvm_ffi.testing.create_object(
+        "testing.TestObjectDerived", v_i64=20, v_map=v_map, v_array=v_array
+    )
+    assert obj0.v_map.same_as(v_map)
+    assert obj0.v_array.same_as(v_array)
+    assert obj0.v_i64 == 20
+    assert obj0.v_f64 == 10.0
+    assert obj0.v_str == "hello"
+
+    obj0.v_i64 = 21
+    assert obj0.v_i64 == 21

Reply via email to