This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f5bb200f1 [FFI][REFACTOR] Enhance reflection (#18059)
0f5bb200f1 is described below

commit 0f5bb200f129ec4c0a235e065f452ca524d19bd8
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 15 07:46:42 2025 -0400

    [FFI][REFACTOR] Enhance reflection (#18059)
    
    This PR enhances and refactors the reflection module to support
    nanobind/pybind style reflection definition.
    
    This is one step to upgrade the overall reflection
    mechanism in the project.
---
 ffi/include/tvm/ffi/c_api.h                 | 100 ++++++++---
 ffi/include/tvm/ffi/function.h              |   8 +-
 ffi/include/tvm/ffi/object.h                |   8 +-
 ffi/include/tvm/ffi/reflection/reflection.h | 265 +++++++++++++++++++++++++---
 ffi/include/tvm/ffi/string.h                |   2 -
 ffi/include/tvm/ffi/type_traits.h           |  18 +-
 ffi/src/ffi/object.cc                       |  48 ++++-
 ffi/tests/cpp/test_any.cc                   |   3 +
 ffi/tests/cpp/test_array.cc                 |   1 +
 ffi/tests/cpp/test_reflection.cc            |  89 +++++++++-
 ffi/tests/cpp/test_tuple.cc                 |   1 +
 ffi/tests/cpp/test_variant.cc               |   1 +
 ffi/tests/cpp/testing_object.h              |  10 +-
 13 files changed, 475 insertions(+), 79 deletions(-)

diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 7cf7543f48..a43f2e83b5 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -303,12 +303,54 @@ typedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* 
result);
  */
 typedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value);
 
+/*!
+ * \brief bitmask of the field.
+ */
+#ifdef __cplusplus
+enum TVMFFIFieldFlagBitMask : int32_t {
+#else
+typedef enum {
+#endif
+  /*! \brief The field is writable. */
+  TVMFFIFieldFlagBitMaskWritable = 1 << 0,
+  /*! \brief The field has default value. */
+  TVMFFIFieldFlagBitMaskHasDefault = 1 << 1,
+  /*! \brief The field is a static method. */
+  TVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2,
+#ifdef __cplusplus
+};
+#else
+} TVMFFIFieldFlagBitMask;
+#endif
+
 /*!
  * \brief Information support for optional object reflection.
  */
 typedef struct {
   /*! \brief The name of the field. */
   TVMFFIByteArray name;
+  /*! \brief The docstring about the field. */
+  TVMFFIByteArray doc;
+  /*!
+   * \brief bitmask flags of the field.
+   */
+  int64_t flags;
+  /*!
+   * \brief Byte offset of the field.
+   */
+  int64_t byte_offset;
+  /*! \brief The getter to access the field. */
+  TVMFFIFieldGetter getter;
+  /*!
+   * \brief The setter to access the field.
+   * \note The setter is set even if the field is readonly for serialization.
+   */
+  TVMFFIFieldSetter setter;
+  /*!
+   * \brief The default value of the field, this field hold AnyView,
+   *        valid when flags set TVMFFIFieldFlagBitMaskHasDefault
+   */
+  TVMFFIAny default_value;
   /*!
    * \brief Records the static type kind of the field.
    *
@@ -317,28 +359,18 @@ typedef struct {
    *  - TVMFFITypeIndex::kTVMFFIObject for general objects
    *    - The value is nullable when kTVMFFIObject is chosen
    * - static object type kinds such as Map, Dict, String
-   * - POD type index
+   * - POD type index, note it does not give information about storage size of 
the field.
    * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info
    *   about the field.
    *
-   * \note This information is helpful in designing serializer
-   * of the field. As it helps to narrow down the type of the
-   * object. It also helps to provide opportunities to enable
-   * short-cut access to the field.
+   * When the value is a type index of Object type, the field is storaged as 
an ObjectRef.
+   *
+   * \note This information maybe helpful in designing serializer.
+   * As it helps to narrow down the field type so we don't have to
+   * print type_key for cases like POD types.
+   * It also helps to provide opportunities to enable short-cut getter to 
ObjectRef fields.
    */
   int32_t field_static_type_index;
-  /*!
-   * \brief Mark whether field is readonly.
-   */
-  int32_t readonly;
-  /*!
-   * \brief Byte offset of the field.
-   */
-  int64_t byte_offset;
-  /*! \brief The getter to access the field. */
-  TVMFFIFieldGetter getter;
-  /*! \brief The setter to access the field. */
-  TVMFFIFieldSetter setter;
 } TVMFFIFieldInfo;
 
 /*!
@@ -347,11 +379,15 @@ typedef struct {
 typedef struct {
   /*! \brief The name of the field. */
   TVMFFIByteArray name;
+  /*! \brief The docstring about the method. */
+  TVMFFIByteArray doc;
+  /*! \brief bitmask flags of the method. */
+  int64_t flags;
   /*!
-   * \brief The method wrapped as Function
-   * \note The first argument to the method is always the self.
+   * \brief The method wrapped as ffi::Function, stored as AnyView.
+   * \note The first argument to the method is always the self for instance 
methods.
    */
-  TVMFFIObjectHandle method;
+  TVMFFIAny method;
 } TVMFFIMethodInfo;
 
 /*!
@@ -379,6 +415,7 @@ typedef struct {
   int32_t num_fields;
   /*! \brief number of reflection acccesible methods. */
   int32_t num_methods;
+
   /*! \brief The reflection field information. */
   TVMFFIFieldInfo* fields;
   /*! \brief The reflection method. */
@@ -522,12 +559,29 @@ TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const 
TVMFFIByteArray* name, void* symbol)
 // Section: Type reflection support APIs
 //------------------------------------------------------------
 /*!
- * \brief Register type field information for rutnime reflection.
+ * \brief Register type field information for runtime reflection.
  * \param type_index The type index
  * \param info The field info to be registered.
  * \return 0 when success, nonzero when failure happens
  */
-TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t type_index, const 
TVMFFIFieldInfo* info);
+TVM_FFI_DLL int TVMFFITypeRegisterField(int32_t type_index, const 
TVMFFIFieldInfo* info);
+
+/*!
+ * \brief Register type method information for runtime reflection.
+ * \param type_index The type index
+ * \param info The method info to be registered.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const 
TVMFFIMethodInfo* info);
+
+/*!
+ * \brief Get dynamic type info by type index.
+ *
+ * \param type_index The type index
+ * \param result The output type information
+ * \return The type info
+ */
+TVM_FFI_DLL const TVMFFITypeInfo* TVMFFITypeGetMethod(int32_t type_index);
 
 //------------------------------------------------------------
 // Section: DLPack support APIs
@@ -638,7 +692,7 @@ TVM_FFI_DLL const TVMFFIByteArray* TVMFFITraceback(const 
char* filename, int lin
  *
  * \return 0 if success, -1 if error occured
  */
-TVM_FFI_DLL int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* type_key,
+TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key,
                                               int32_t static_type_index, 
int32_t type_depth,
                                               int32_t num_child_slots,
                                               int32_t child_slots_can_overflow,
diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h
index 128c67830e..eb3f390bcd 100644
--- a/ffi/include/tvm/ffi/function.h
+++ b/ffi/include/tvm/ffi/function.h
@@ -836,14 +836,14 @@ class Function::Registry {
     if constexpr (std::is_base_of_v<ObjectRef, T>) {
       auto fwrap = [f](T target, Args... params) -> R {
         // call method pointer
-        return (target.*f)(params...);
+        return (target.*f)(std::forward<Args>(params)...);
       };
       return Register(ffi::Function::FromTyped(fwrap, name_));
     }
     if constexpr (std::is_base_of_v<Object, T>) {
       auto fwrap = [f](const T* target, Args... params) -> R {
         // call method pointer
-        return (const_cast<T*>(target)->*f)(params...);
+        return (const_cast<T*>(target)->*f)(std::forward<Args>(params)...);
       };
       return Register(ffi::Function::FromTyped(fwrap, name_));
     }
@@ -857,14 +857,14 @@ class Function::Registry {
     if constexpr (std::is_base_of_v<ObjectRef, T>) {
       auto fwrap = [f](const T target, Args... params) -> R {
         // call method pointer
-        return (target.*f)(params...);
+        return (target.*f)(std::forward<Args>(params)...);
       };
       return Register(ffi::Function::FromTyped(fwrap, name_));
     }
     if constexpr (std::is_base_of_v<Object, T>) {
       auto fwrap = [f](const T* target, Args... params) -> R {
         // call method pointer
-        return (target->*f)(params...);
+        return (target->*f)(std::forward<Args>(params)...);
       };
       return Register(ffi::Function::FromTyped(fwrap, name_));
     }
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index e86689ebe2..72e6f0a1f8 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -103,6 +103,9 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t 
object_type_index);
  *       This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO
  *       It is still OK to sub-class a terminal object type T and construct it 
using make_object.
  *       But IsInstance check will only show that the object type is T(instead 
of the sub-class).
+ * - _type_mutable:
+ *      Whether we would like to expose cast to non-constant pointer
+ *      ObjectType* from Any/AnyView. By default, we set to false so it is not 
exposed.
  *
  * The following two fields are necessary for base classes that can be 
sub-classed.
  *
@@ -191,6 +194,7 @@ class Object {
 
   // Default object type properties for sub-classes
   static constexpr bool _type_final = false;
+  static constexpr bool _type_mutable = false;
   static constexpr uint32_t _type_child_slots = 0;
   static constexpr bool _type_child_slots_can_overflow = true;
   // NOTE: static type index field of the class
@@ -546,7 +550,7 @@ struct ObjectPtrEqual {
                   "Need to set _type_child_slots when parent specifies it.");  
               \
     TVMFFIByteArray type_key{TypeName::_type_key,                              
               \
                              
std::char_traits<char>::length(TypeName::_type_key)};            \
-    static int32_t tindex = TVMFFIGetOrAllocTypeIndex(                         
               \
+    static int32_t tindex = TVMFFITypeGetOrAllocIndex(                         
               \
         &type_key, TypeName::_type_index, TypeName::_type_depth, 
TypeName::_type_child_slots, \
         TypeName::_type_child_slots_can_overflow, 
ParentType::_GetOrAllocRuntimeTypeIndex()); \
     return tindex;                                                             
               \
@@ -576,7 +580,7 @@ struct ObjectPtrEqual {
                   "Need to set _type_child_slots when parent specifies it.");  
               \
     TVMFFIByteArray type_key{TypeName::_type_key,                              
               \
                              
std::char_traits<char>::length(TypeName::_type_key)};            \
-    static int32_t tindex = TVMFFIGetOrAllocTypeIndex(                         
               \
+    static int32_t tindex = TVMFFITypeGetOrAllocIndex(                         
               \
         &type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots,     
               \
         TypeName::_type_child_slots_can_overflow, 
ParentType::_GetOrAllocRuntimeTypeIndex()); \
     return tindex;                                                             
               \
diff --git a/ffi/include/tvm/ffi/reflection/reflection.h 
b/ffi/include/tvm/ffi/reflection/reflection.h
index 766b9b8099..a5ab6f4fe8 100644
--- a/ffi/include/tvm/ffi/reflection/reflection.h
+++ b/ffi/include/tvm/ffi/reflection/reflection.h
@@ -29,10 +29,31 @@
 #include <tvm/ffi/type_traits.h>
 
 #include <string>
+#include <utility>
 
 namespace tvm {
 namespace ffi {
-namespace details {
+/*! \brief Reflection namespace */
+namespace reflection {
+
+/*! \brief Trait that can be used to set field info */
+struct FieldInfoTrait {};
+
+/*!
+ * \brief Trait that can be used to set field default value
+ */
+class DefaultValue : public FieldInfoTrait {
+ public:
+  explicit DefaultValue(Any value) : value_(value) {}
+
+  void Apply(TVMFFIFieldInfo* info) const {
+    info->default_value = AnyView(value_).CopyToTVMFFIAny();
+    info->flags |= TVMFFIFieldFlagBitMaskHasDefault;
+  }
+
+ private:
+  Any value_;
+};
 
 /*!
  * \brief Get the byte offset of a class member field.
@@ -50,37 +71,108 @@ inline int64_t GetFieldByteOffsetToObject(T 
Class::*field_ptr) {
   return field_offset_to_class - 
details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
 }
 
-struct ReflectionDefFinish {};
-
 class ReflectionDef {
  public:
-  explicit ReflectionDef(int32_t type_index) : type_index_(type_index) {}
+  explicit ReflectionDef(int32_t type_index, const char* type_key)
+      : type_index_(type_index), type_key_(type_key) {}
+
+  /*!
+   * \brief Define a readonly field.
+   *
+   * \tparam Class The class type.
+   * \tparam T The field type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the field.
+   * \param field_ptr The pointer to the field.
+   * \param extra The extra arguments that can be docstring or default value.
+   *
+   * \return The reflection definition.
+   */
+  template <typename Class, typename T, typename... Extra>
+  ReflectionDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... 
extra) {
+    RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
+    return *this;
+  }
 
-  template <typename Class, typename T>
-  ReflectionDef& def_readonly(const char* name, T Class::*field_ptr) {
-    RegisterField(name, field_ptr, true);
+  /*!
+   * \brief Define a read-write field.
+   *
+   * \tparam Class The class type.
+   * \tparam T The field type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the field.
+   * \param field_ptr The pointer to the field.
+   * \param extra The extra arguments that can be docstring or default value.
+   *
+   * \return The reflection definition.
+   */
+  template <typename Class, typename T, typename... Extra>
+  ReflectionDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... 
extra) {
+    RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
     return *this;
   }
 
-  template <typename Class, typename T>
-  ReflectionDef& def_readwrite(const char* name, T Class::*field_ptr) {
-    RegisterField(name, field_ptr, false);
+  /*!
+   * \brief Define a method.
+   *
+   * \tparam Func The function type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the method.
+   * \param func The function to be registered.
+   * \param extra The extra arguments that can be docstring.
+   *
+   * \return The reflection definition.
+   */
+  template <typename Func, typename... Extra>
+  ReflectionDef& def(const char* name, Func&& func, Extra&&... extra) {
+    RegisterMethod(name, false, std::forward<Func>(func), 
std::forward<Extra>(extra)...);
+    return *this;
+  }
+
+  /*!
+   * \brief Define a static method.
+   *
+   * \tparam Func The function type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the method.
+   * \param func The function to be registered.
+   * \param extra The extra arguments that can be docstring.
+   *
+   * \return The reflection definition.
+   */
+  template <typename Func, typename... Extra>
+  ReflectionDef& def_static(const char* name, Func&& func, Extra&&... extra) {
+    RegisterMethod(name, true, std::forward<Func>(func), 
std::forward<Extra>(extra)...);
     return *this;
   }
 
  private:
-  template <typename Class, typename T>
-  void RegisterField(const char* name, T Class::*field_ptr, bool readonly) {
+  template <typename Class, typename T, typename... ExtraArgs>
+  void RegisterField(const char* name, T Class::*field_ptr, bool writable,
+                     ExtraArgs&&... extra_args) {
     TVMFFIFieldInfo info;
     info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
     info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;
     // store byte offset and setter, getter
     // so the same setter can be reused for all the same type
     info.byte_offset = GetFieldByteOffsetToObject<Class, T>(field_ptr);
-    info.readonly = readonly;
+    info.flags = 0;
+    if (writable) {
+      info.flags |= TVMFFIFieldFlagBitMaskWritable;
+    }
     info.getter = FieldGetter<T>;
     info.setter = FieldSetter<T>;
-    TVM_FFI_CHECK_SAFE_CALL(TVMFFIRegisterTypeField(type_index_, &info));
+    // initialize default value to nullptr
+    info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+    info.doc = TVMFFIByteArray{nullptr, 0};
+    // apply field info traits
+    ((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
+    // call register
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
   }
 
   template <typename T>
@@ -97,14 +189,70 @@ class ReflectionDef {
     TVM_FFI_SAFE_CALL_END();
   }
 
+  template <typename T>
+  static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) {
+    if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
+      value.Apply(info);
+    }
+    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
+      info->doc = TVMFFIByteArray{value, 
std::char_traits<char>::length(value)};
+    }
+  }
+
+  // register a method
+  template <typename Func, typename... Extra>
+  void RegisterMethod(const char* name, bool is_static, Func&& func, 
Extra&&... extra) {
+    TVMFFIMethodInfo info;
+    info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
+    info.doc = TVMFFIByteArray{nullptr, 0};
+    info.flags = 0;
+    if (is_static) {
+      info.flags |= TVMFFIFieldFlagBitMaskIsStaticMethod;
+    }
+    // obtain the method function
+    Function method = GetMethod(std::string(type_key_) + "." + name, 
std::forward<Func>(func));
+    info.method = AnyView(method).CopyToTVMFFIAny();
+    // apply method info traits
+    ((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
+  }
+
+  template <typename T>
+  static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) {
+    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
+      info->doc = TVMFFIByteArray{value, 
std::char_traits<char>::length(value)};
+    }
+  }
+
+  template <typename Class, typename R, typename... Args>
+  static Function GetMethod(std::string name, R (Class::*func)(Args...)) {
+    auto fwrap = [func](const Class* target, Args... params) -> R {
+      return 
(const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
+    };
+    return ffi::Function::FromTyped(fwrap, name);
+  }
+
+  template <typename Class, typename R, typename... Args>
+  static Function GetMethod(std::string name, R (Class::*func)(Args...) const) 
{
+    auto fwrap = [func](const Class* target, Args... params) -> R {
+      return (target->*func)(std::forward<Args>(params)...);
+    };
+    return ffi::Function::FromTyped(fwrap, name);
+  }
+
+  template <typename Func>
+  static Function GetMethod(std::string name, Func&& func) {
+    return ffi::Function::FromTyped(std::forward<Func>(func), name);
+  }
+
   int32_t type_index_;
+  const char* type_key_;
 };
 
 /*!
  * \brief helper function to get reflection field info by type key and field 
name
  */
-inline const TVMFFIFieldInfo* GetReflectionFieldInfo(std::string_view type_key,
-                                                     const char* field_name) {
+inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const 
char* field_name) {
   int32_t type_index;
   TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()};
   TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
@@ -115,14 +263,18 @@ inline const TVMFFIFieldInfo* 
GetReflectionFieldInfo(std::string_view type_key,
     }
   }
   TVM_FFI_THROW(RuntimeError) << "Cannot find field " << field_name << " in " 
<< type_key;
+  TVM_FFI_UNREACHABLE();
 }
 
 /*!
  * \brief helper wrapper class to obtain a getter.
  */
-class ReflectionFieldGetter {
+class FieldGetter {
  public:
-  explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) : 
field_info_(field_info) {}
+  explicit FieldGetter(const TVMFFIFieldInfo* field_info) : 
field_info_(field_info) {}
+
+  explicit FieldGetter(std::string_view type_key, const char* field_name)
+      : FieldGetter(GetFieldInfo(type_key, field_name)) {}
 
   Any operator()(const Object* obj_ptr) const {
     Any result;
@@ -140,16 +292,81 @@ class ReflectionFieldGetter {
   const TVMFFIFieldInfo* field_info_;
 };
 
-#define TVM_FFI_REFLECTION_REG_VAR_DEF \
-  static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::details::ReflectionDef& 
__TVMFFIReflectionReg
+/*!
+ * \brief helper wrapper class to obtain a setter.
+ */
+class FieldSetter {
+ public:
+  explicit FieldSetter(const TVMFFIFieldInfo* field_info) : 
field_info_(field_info) {}
+
+  explicit FieldSetter(std::string_view type_key, const char* field_name)
+      : FieldSetter(GetFieldInfo(type_key, field_name)) {}
+
+  void operator()(const Object* obj_ptr, AnyView value) const {
+    const void* addr = reinterpret_cast<const char*>(obj_ptr) + 
field_info_->byte_offset;
+    TVM_FFI_CHECK_SAFE_CALL(
+        field_info_->setter(const_cast<void*>(addr), reinterpret_cast<const 
TVMFFIAny*>(&value)));
+  }
+
+  void operator()(const ObjectPtr<Object>& obj_ptr, AnyView value) const {
+    operator()(obj_ptr.get(), value);
+  }
+
+  void operator()(const ObjectRef& obj, AnyView value) const { 
operator()(obj.get(), value); }
+
+ private:
+  const TVMFFIFieldInfo* field_info_;
+};
+
+/*!
+ * \brief helper function to get reflection method info by type key and method 
name
+ *
+ * \param type_key The type key.
+ * \param method_name The name of the method.
+ * \return The method info.
+ */
+inline const TVMFFIMethodInfo* GetMethodInfo(std::string_view type_key, const 
char* method_name) {
+  int32_t type_index;
+  TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()};
+  TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
+  const TypeInfo* info = TVMFFIGetTypeInfo(type_index);
+  for (int32_t i = 0; i < info->num_methods; ++i) {
+    if (std::strncmp(info->methods[i].name.data, method_name, 
info->methods[i].name.size) == 0) {
+      return &(info->methods[i]);
+    }
+  }
+  TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in 
" << type_key;
+  TVM_FFI_UNREACHABLE();
+}
+
+/*!
+ * \brief helper function to get reflection method function by method info
+ *
+ * \param type_key The type key.
+ * \param method_name The name of the method.
+ * \return The method function.
+ */
+inline Function GetMethod(std::string_view type_key, const char* method_name) {
+  const TVMFFIMethodInfo* info = GetMethodInfo(type_key, method_name);
+  return AnyView::CopyFromTVMFFIAny(info->method).cast<Function>();
+}
+
+#define TVM_FFI_REFLECTION_REG_VAR_DEF                                         
 \
+  static inline TVM_FFI_ATTRIBUTE_UNUSED 
::tvm::ffi::reflection::ReflectionDef& \
+      __TVMFFIReflectionReg
 
 /*!
  * helper macro to define a reflection definition for an object
  */
-#define TVM_FFI_REFLECTION_DEF(TypeName)                            \
-  TVM_FFI_STR_CONCAT(TVM_FFI_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
-      
::tvm::ffi::details::ReflectionDef(TypeName::_GetOrAllocRuntimeTypeIndex())
-}  // namespace details
+#define TVM_FFI_REFLECTION_DEF(TypeName)                                       
      \
+  TVM_FFI_STR_CONCAT(TVM_FFI_REFLECTION_REG_VAR_DEF, __COUNTER__) =            
      \
+      
::tvm::ffi::reflection::ReflectionDef(TypeName::_GetOrAllocRuntimeTypeIndex(), \
+                                            TypeName::_type_key)
+
+}  // namespace reflection
+
+/*! \brief Shortcut to the reflection namespace */
+namespace refl = reflection;
 }  // namespace ffi
 }  // namespace tvm
 #endif  // TVM_FFI_REFLECTION_REFLECTION_H_
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index c3eceff905..19df2e8e3d 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -401,7 +401,6 @@ TVM_FFI_INLINE std::string_view 
ToStringView(TVMFFIByteArray str) {
 template <int N>
 struct TypeTraits<char[N]> : public TypeTraitsBase {
   // NOTE: only enable implicit conversion into AnyView
-  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr;
   static constexpr bool storage_enabled = false;
 
   static TVM_FFI_INLINE void CopyToAnyView(const char src[N], TVMFFIAny* 
result) {
@@ -417,7 +416,6 @@ struct TypeTraits<char[N]> : public TypeTraitsBase {
 
 template <>
 struct TypeTraits<const char*> : public TypeTraitsBase {
-  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr;
   static constexpr bool storage_enabled = false;
 
   static TVM_FFI_INLINE void CopyToAnyView(const char* src, TVMFFIAny* result) 
{
diff --git a/ffi/include/tvm/ffi/type_traits.h 
b/ffi/include/tvm/ffi/type_traits.h
index 02c9a90edc..5c291b5535 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -84,6 +84,7 @@ inline constexpr bool use_default_type_traits_v = true;
 struct TypeTraitsBase {
   static constexpr bool convert_enabled = true;
   static constexpr bool storage_enabled = true;
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny;
   // get mismatched type when result mismatches the trait.
   // this function is called after TryCastFromAnyView fails
   // to get more detailed type information in runtime
@@ -588,17 +589,18 @@ struct ObjectRefWithFallbackTraitsBase : public 
ObjectRefTypeTraitsBase<ObjectRe
 
 // Traits for weak pointer of object
 // NOTE: we require the weak pointer cast from
+
 template <typename TObject>
-struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object, 
TObject>>>
+struct TypeTraits<TObject*, std::enable_if_t<std::is_base_of_v<Object, 
TObject>>>
     : public TypeTraitsBase {
-  static TVM_FFI_INLINE void CopyToAnyView(const TObject* src, TVMFFIAny* 
result) {
+  static TVM_FFI_INLINE void CopyToAnyView(TObject* src, TVMFFIAny* result) {
     TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src);
     result->type_index = obj_ptr->type_index;
     TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
     result->v_obj = obj_ptr;
   }
 
-  static TVM_FFI_INLINE void MoveToAny(const TObject* src, TVMFFIAny* result) {
+  static TVM_FFI_INLINE void MoveToAny(TObject* src, TVMFFIAny* result) {
     TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src);
     result->type_index = obj_ptr->type_index;
     TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
@@ -612,11 +614,17 @@ struct TypeTraits<const TObject*, 
std::enable_if_t<std::is_base_of_v<Object, TOb
            details::IsObjectInstance<TObject>(src->type_index);
   }
 
-  static TVM_FFI_INLINE const TObject* CopyFromAnyViewAfterCheck(const 
TVMFFIAny* src) {
+  static TVM_FFI_INLINE TObject* CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    if constexpr (!std::is_const_v<TObject>) {
+      static_assert(TObject::_type_mutable, "TObject must be mutable to enable 
cast from Any");
+    }
     return details::ObjectUnsafe::RawObjectPtrFromUnowned<TObject>(src->v_obj);
   }
 
-  static TVM_FFI_INLINE std::optional<const TObject*> TryCastFromAnyView(const 
TVMFFIAny* src) {
+  static TVM_FFI_INLINE std::optional<TObject*> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    if constexpr (!std::is_const_v<TObject>) {
+      static_assert(TObject::_type_mutable, "TObject must be mutable to enable 
cast from Any");
+    }
     if (CheckAnyStrict(src)) return CopyFromAnyViewAfterCheck(src);
     return std::nullopt;
   }
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 63ec68790e..6ce149a7c9 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -54,6 +54,8 @@ class TypeTable {
     std::vector<int32_t> type_acenstors_data;
     /*! \brief type fields informaton */
     std::vector<TVMFFIFieldInfo> type_fields_data;
+    /*! \brief type methods informaton */
+    std::vector<TVMFFIMethodInfo> type_methods_data;
     // NOTE: the indices in [index, index + num_reserved_slots) are
     // reserved for the child-class of this type.
     /*! \brief Total number of slots reserved for the type and its children. */
@@ -186,12 +188,29 @@ class TypeTable {
     Entry* entry = GetTypeEntry(type_index);
     TVMFFIFieldInfo field_data = *info;
     field_data.name = this->CopyString(info->name);
+    field_data.doc = this->CopyString(info->doc);
+    if (info->flags & TVMFFIFieldFlagBitMaskHasDefault) {
+      field_data.default_value =
+          
this->CopyAny(AnyView::CopyFromTVMFFIAny(info->default_value)).CopyToTVMFFIAny();
+    } else {
+      field_data.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+    }
     entry->type_fields_data.push_back(field_data);
     // refresh ptr as the data can change
     entry->fields = entry->type_fields_data.data();
     entry->num_fields = static_cast<int32_t>(entry->type_fields_data.size());
   }
 
+  void RegisterTypeMethod(int32_t type_index, const TVMFFIMethodInfo* info) {
+    Entry* entry = GetTypeEntry(type_index);
+    TVMFFIMethodInfo method_data = *info;
+    method_data.name = this->CopyString(info->name);
+    method_data.doc = this->CopyString(info->doc);
+    method_data.method = 
this->CopyAny(AnyView::CopyFromTVMFFIAny(info->method)).CopyToTVMFFIAny();
+    entry->type_methods_data.push_back(method_data);
+    entry->methods = entry->type_methods_data.data();
+    entry->num_methods = static_cast<int32_t>(entry->type_methods_data.size());
+  }
   void Dump(int min_children_count) {
     std::vector<int> num_children(type_table_.size(), 0);
     // expected child slots compute the expected slots
@@ -262,16 +281,25 @@ class TypeTable {
   }
 
   TVMFFIByteArray CopyString(TVMFFIByteArray str) {
-    std::unique_ptr<std::string> val = std::make_unique<std::string>(str.data, 
str.size);
-    TVMFFIByteArray c_val{val->data(), val->length()};
-    string_pool_.emplace_back(std::move(val));
+    if (str.size == 0) {
+      return TVMFFIByteArray{nullptr, 0};
+    }
+    String val = String(str.data, str.size);
+    TVMFFIByteArray c_val{val.data(), val.length()};
+    any_pool_.emplace_back(std::move(val));
     return c_val;
   }
 
+  AnyView CopyAny(Any val) {
+    AnyView view = AnyView(val);
+    any_pool_.emplace_back(std::move(val));
+    return view;
+  }
+
   int32_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin};
   std::vector<std::unique_ptr<Entry>> type_table_;
   std::unordered_map<std::string, int32_t> type_key2index_;
-  std::vector<std::unique_ptr<std::string>> string_pool_;
+  std::vector<Any> any_pool_;
 };
 }  // namespace ffi
 }  // namespace tvm
@@ -288,13 +316,19 @@ int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, 
int32_t* out_tindex) {
   TVM_FFI_SAFE_CALL_END();
 }
 
-int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) {
+int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info) {
   TVM_FFI_SAFE_CALL_BEGIN();
   tvm::ffi::TypeTable::Global()->RegisterTypeField(type_index, info);
   TVM_FFI_SAFE_CALL_END();
 }
 
-int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* type_key, int32_t 
static_type_index,
+int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info) 
{
+  TVM_FFI_SAFE_CALL_BEGIN();
+  tvm::ffi::TypeTable::Global()->RegisterTypeMethod(type_index, info);
+  TVM_FFI_SAFE_CALL_END();
+}
+
+int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, int32_t 
static_type_index,
                                   int32_t type_depth, int32_t num_child_slots,
                                   int32_t child_slots_can_overflow, int32_t 
parent_type_index) {
   TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
@@ -302,7 +336,7 @@ int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* 
type_key, int32_t stati
   return tvm::ffi::TypeTable::Global()->GetOrAllocTypeIndex(
       s_type_key, static_type_index, type_depth, num_child_slots, 
child_slots_can_overflow,
       parent_type_index);
-  TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetOrAllocTypeIndex);
+  TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFITypeGetOrAllocIndex);
 }
 
 const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) {
diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc
index 3ad81cd118..eea18c7c64 100644
--- a/ffi/tests/cpp/test_any.cc
+++ b/ffi/tests/cpp/test_any.cc
@@ -232,6 +232,9 @@ TEST(Any, Object) {
   EXPECT_EQ(v1.use_count(), 3);
   EXPECT_TRUE(any2.as<TInt>().has_value());
 
+  any2 = const_cast<TIntObj*>(v1_ptr);
+  EXPECT_TRUE(any2.as<TInt>().has_value());
+
   // convert to raw opaque ptr
   void* raw_v1_ptr = const_cast<TIntObj*>(v1_ptr);
   any2 = raw_v1_ptr;
diff --git a/ffi/tests/cpp/test_array.cc b/ffi/tests/cpp/test_array.cc
index cb42f32c6c..321af7ae16 100644
--- a/ffi/tests/cpp/test_array.cc
+++ b/ffi/tests/cpp/test_array.cc
@@ -18,6 +18,7 @@
  */
 #include <gtest/gtest.h>
 #include <tvm/ffi/container/array.h>
+#include <tvm/ffi/function.h>
 
 #include "./testing_object.h"
 
diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc
index fec167d257..76e8d35a99 100644
--- a/ffi/tests/cpp/test_reflection.cc
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -20,33 +20,108 @@
 #include <gtest/gtest.h>
 #include <tvm/ffi/object.h>
 #include <tvm/ffi/reflection/reflection.h>
+#include <tvm/ffi/string.h>
 
 #include "./testing_object.h"
 
 namespace {
-
 using namespace tvm::ffi;
 using namespace tvm::ffi::testing;
 
+TVM_FFI_REFLECTION_DEF(TFloatObj)
+    .def_rw("value", &TFloatObj::value, "float value field", 
refl::DefaultValue(10.0))
+    .def("sub", [](const TFloatObj* self, double other) -> double { return 
self->value - other; })
+    .def("add", &TFloatObj::Add, "add method");
+
+TVM_FFI_REFLECTION_DEF(TIntObj)
+    .def_ro("value", &TIntObj::value)
+    .def_static("static_add", &TInt::StaticAdd, "static add method");
+
+TVM_FFI_REFLECTION_DEF(TPrimExprObj)
+    .def_ro("dtype", &TPrimExprObj::dtype, "dtype field", 
refl::DefaultValue("float"))
+    .def_ro("value", &TPrimExprObj::value, "value field", 
refl::DefaultValue(0))
+    .def("sub", [](TPrimExprObj* self, double other) -> double {
+      // this is ok because TPrimExprObj is declared asmutable
+      return self->value - other;
+    });
+
 struct A : public Object {
   int64_t x;
   int64_t y;
 };
 
+TVM_FFI_REFLECTION_DEF(A).def_ro("x", &A::x).def_rw("y", &A::y);
+
 TEST(Reflection, GetFieldByteOffset) {
-  EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::x), sizeof(TVMFFIObject));
-  EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::y), 8 + 
sizeof(TVMFFIObject));
-  EXPECT_EQ(details::GetFieldByteOffsetToObject(&TIntObj::value), 
sizeof(TVMFFIObject));
+  EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::x), 
sizeof(TVMFFIObject));
+  EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::y), 8 + 
sizeof(TVMFFIObject));
+  EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TIntObj::value), 
sizeof(TVMFFIObject));
 }
 
 TEST(Reflection, FieldGetter) {
   ObjectRef a = TInt(10);
-  details::ReflectionFieldGetter 
getter(details::GetReflectionFieldInfo("test.Int", "value"));
+  reflection::FieldGetter getter("test.Int", "value");
   EXPECT_EQ(getter(a).cast<int>(), 10);
 
   ObjectRef b = TFloat(10.0);
-  details::ReflectionFieldGetter getter_float(
-      details::GetReflectionFieldInfo("test.Float", "value"));
+  reflection::FieldGetter getter_float("test.Float", "value");
   EXPECT_EQ(getter_float(b).cast<double>(), 10.0);
 }
+
+TEST(Reflection, FieldSetter) {
+  ObjectRef a = TFloat(10.0);
+  reflection::FieldSetter setter("test.Float", "value");
+  setter(a, 20.0);
+  EXPECT_EQ(a.as<TFloatObj>()->value, 20.0);
+}
+
+TEST(Reflection, FieldInfo) {
+  const TVMFFIFieldInfo* info_int = reflection::GetFieldInfo("test.Int", 
"value");
+  EXPECT_FALSE(info_int->flags & TVMFFIFieldFlagBitMaskHasDefault);
+  EXPECT_FALSE(info_int->flags & TVMFFIFieldFlagBitMaskWritable);
+  EXPECT_EQ(Bytes(info_int->doc).operator std::string(), "");
+
+  const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float", 
"value");
+  EXPECT_EQ(info_float->default_value.v_float64, 10.0);
+  EXPECT_TRUE(info_float->flags & TVMFFIFieldFlagBitMaskHasDefault);
+  EXPECT_TRUE(info_float->flags & TVMFFIFieldFlagBitMaskWritable);
+  EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value 
field");
+
+  const TVMFFIFieldInfo* info_prim_expr_dtype = 
reflection::GetFieldInfo("test.PrimExpr", "dtype");
+  AnyView default_value = 
AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value);
+  EXPECT_EQ(default_value.cast<String>(), "float");
+  EXPECT_EQ(default_value.as<String>().value().use_count(), 2);
+  EXPECT_TRUE(info_prim_expr_dtype->flags & TVMFFIFieldFlagBitMaskHasDefault);
+  EXPECT_FALSE(info_prim_expr_dtype->flags & TVMFFIFieldFlagBitMaskWritable);
+  EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype 
field");
+}
+
+TEST(Reflection, MethodInfo) {
+  const TVMFFIMethodInfo* info_int_static_add = 
reflection::GetMethodInfo("test.Int", "static_add");
+  EXPECT_TRUE(info_int_static_add->flags & 
TVMFFIFieldFlagBitMaskIsStaticMethod);
+  EXPECT_EQ(Bytes(info_int_static_add->doc).operator std::string(), "static 
add method");
+
+  const TVMFFIMethodInfo* info_float_add = 
reflection::GetMethodInfo("test.Float", "add");
+  EXPECT_FALSE(info_float_add->flags & TVMFFIFieldFlagBitMaskIsStaticMethod);
+  EXPECT_EQ(Bytes(info_float_add->doc).operator std::string(), "add method");
+
+  const TVMFFIMethodInfo* info_float_sub = 
reflection::GetMethodInfo("test.Float", "sub");
+  EXPECT_FALSE(info_float_sub->flags & TVMFFIFieldFlagBitMaskIsStaticMethod);
+  EXPECT_EQ(Bytes(info_float_sub->doc).operator std::string(), "");
+}
+
+TEST(Reflection, CallMethod) {
+  Function static_int_add = reflection::GetMethod("test.Int", "static_add");
+  EXPECT_EQ(static_int_add(TInt(1), TInt(2)).cast<TInt>()->value, 3);
+
+  Function float_add = reflection::GetMethod("test.Float", "add");
+  EXPECT_EQ(float_add(TFloat(1), 2.0).cast<double>(), 3.0);
+
+  Function float_sub = reflection::GetMethod("test.Float", "sub");
+  EXPECT_EQ(float_sub(TFloat(1), 2.0).cast<double>(), -1.0);
+
+  Function prim_expr_sub = reflection::GetMethod("test.PrimExpr", "sub");
+  EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast<double>(), -1.0);
+}
+
 }  // namespace
diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc
index e0f69d8200..5735e86eca 100644
--- a/ffi/tests/cpp/test_tuple.cc
+++ b/ffi/tests/cpp/test_tuple.cc
@@ -18,6 +18,7 @@
  */
 #include <gtest/gtest.h>
 #include <tvm/ffi/container/tuple.h>
+#include <tvm/ffi/function.h>
 
 #include "./testing_object.h"
 
diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc
index 451913c992..b140e7db6e 100644
--- a/ffi/tests/cpp/test_variant.cc
+++ b/ffi/tests/cpp/test_variant.cc
@@ -20,6 +20,7 @@
 #include <tvm/ffi/any.h>
 #include <tvm/ffi/container/array.h>
 #include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/function.h>
 #include <tvm/ffi/memory.h>
 
 #include "./testing_object.h"
diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h
index 69a91efc46..8a91848845 100644
--- a/ffi/tests/cpp/testing_object.h
+++ b/ffi/tests/cpp/testing_object.h
@@ -22,7 +22,6 @@
 
 #include <tvm/ffi/memory.h>
 #include <tvm/ffi/object.h>
-#include <tvm/ffi/reflection/reflection.h>
 #include <tvm/ffi/string.h>
 
 namespace tvm {
@@ -65,12 +64,12 @@ class TIntObj : public TNumberObj {
   TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj);
 };
 
-TVM_FFI_REFLECTION_DEF(TIntObj).def_readonly("value", &TIntObj::value);
-
 class TInt : public TNumber {
  public:
   explicit TInt(int64_t value) { data_ = make_object<TIntObj>(value); }
 
+  static TInt StaticAdd(TInt lhs, TInt rhs) { return TInt(lhs->value + 
rhs->value); }
+
   TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TInt, TNumber, TIntObj);
 };
 
@@ -80,12 +79,12 @@ class TFloatObj : public TNumberObj {
 
   TFloatObj(double value) : value(value) {}
 
+  double Add(double other) const { return value + other; }
+
   static constexpr const char* _type_key = "test.Float";
   TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj);
 };
 
-TVM_FFI_REFLECTION_DEF(TFloatObj).def_readonly("value", &TFloatObj::value);
-
 class TFloat : public TNumber {
  public:
   explicit TFloat(double value) { data_ = make_object<TFloatObj>(value); }
@@ -102,6 +101,7 @@ class TPrimExprObj : public Object {
   TPrimExprObj(std::string dtype, double value) : dtype(dtype), value(value) {}
 
   static constexpr const char* _type_key = "test.PrimExpr";
+  static constexpr bool _type_mutable = true;
   TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TPrimExprObj, Object);
 };
 


Reply via email to