This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 61249b46e790efbd0389ad1af0cc5148ac4d5e2a
Author: tqchen <[email protected]>
AuthorDate: Fri Dec 27 08:02:30 2024 +0800

    [FFI] Initial reflection support
---
 ffi/cmake/Utils/CxxWarning.cmake       |   2 +-
 ffi/include/tvm/ffi/c_api.h            | 209 +++++++++++++++++++++++++--------
 ffi/include/tvm/ffi/function_details.h |   6 +-
 ffi/include/tvm/ffi/object.h           |  31 +----
 ffi/include/tvm/ffi/reflection.h       | 182 ++++++++++++++++++++++++++++
 ffi/include/tvm/ffi/type_traits.h      |   9 ++
 ffi/src/ffi/object.cc                  |  49 +++++++-
 ffi/tests/cpp/test_reflection.cc       |  49 ++++++++
 ffi/tests/cpp/testing_object.h         |   5 +-
 9 files changed, 455 insertions(+), 87 deletions(-)

diff --git a/ffi/cmake/Utils/CxxWarning.cmake b/ffi/cmake/Utils/CxxWarning.cmake
index 50ee5b616d..efa9781e55 100644
--- a/ffi/cmake/Utils/CxxWarning.cmake
+++ b/ffi/cmake/Utils/CxxWarning.cmake
@@ -1,7 +1,7 @@
 function(add_cxx_warning target_name)
   # GNU, Clang, or AppleClang
   if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang")
-    target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" 
"-Wpedantic")
+    target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" 
"-Wpedantic" "-Wno-unused-parameter")
     return()
   endif()
   # MSVC
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index c580410581..18c44e40da 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -51,27 +51,40 @@ enum TVMFFITypeIndex : int32_t {
 #else
 typedef enum {
 #endif
-  // [Section] On-stack POD Types: [0, kTVMFFIStaticObjectBegin)
+  // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin)
   // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array,
   // which is not owned by TVMFFIAny. It is required that the following
   // invariant holds:
   // - `Any::type_index` is never `kTVMFFIRawStr`
   // - `AnyView::type_index` can be `kTVMFFIRawStr`
+  //
+  // NOTE: kTVMFFIAny is a root type of everything
+  // we include it so TypeIndex captures all possible runtime values.
+  // `kTVMFFIAny` code will never appear in Any::type_index.
+  // However, it may appear in field annotations during reflection.
+  //
+  kTVMFFIAny = -1,
   kTVMFFINone = 0,
   kTVMFFIInt = 1,
-  kTVMFFIFloat = 2,
-  kTVMFFIOpaquePtr = 3,
-  kTVMFFIDataType = 4,
-  kTVMFFIDevice = 5,
-  kTVMFFIRawStr = 6,
+  kTVMFFIBool = 2,
+  kTVMFFIFloat = 3,
+  kTVMFFIOpaquePtr = 4,
+  kTVMFFIDataType = 5,
+  kTVMFFIDevice = 6,
+  kTVMFFIDLTensorPtr = 7,
+  kTVMFFIRawStr = 8,
   // [Section] Static Boxed: [kTVMFFIStaticObjectBegin, kTVMFFIDynObjectBegin)
+  // roughly order in terms of their ptential dependencies
   kTVMFFIStaticObjectBegin = 64,
   kTVMFFIObject = 64,
-  kTVMFFIArray = 65,
-  kTVMFFIMap = 66,
-  kTVMFFIError = 67,
-  kTVMFFIFunc = 68,
-  kTVMFFIStr = 69,
+  kTVMFFIStr = 65,
+  kTVMFFIError = 66,
+  kTVMFFIFunc = 67,
+  kTVMFFIArray = 68,
+  kTVMFFIMap = 69,
+  kTVMFFIShapeTuple = 70,
+  kTVMFFINDArray = 71,
+  kTVMFFIRuntimeModule = 72,
   // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo)
   // kTVMFFIDynObject is used to indicate that the type index
   // is dynamic and needs to be looked up at runtime
@@ -113,7 +126,10 @@ typedef struct TVMFFIAny {
    * \note The type index of Object and Any are shared in FFI.
    */
   int32_t type_index;
-  /*! \brief length for on-stack Any object, such as small-string */
+  /*!
+   * \brief length for on-stack Any object, such as small-string
+   * \note This field is reserved for future compact.
+   */
   int32_t small_len;
   union {                  // 8 bytes
     int64_t v_int64;       // integers
@@ -134,6 +150,99 @@ typedef struct {
   const char* bytes;
 } TVMFFIByteArray;
 
+/*!
+ * \brief Type that defines C-style safe call convention
+ *
+ * Safe call explicitly catches exception on function boundary.
+ *
+ * \param self The function handle
+ * \param num_args Number if input arguments
+ * \param args The input arguments to the call.
+ * \param result Store output result
+ *
+ * \return The call return 0 if call is successful.
+ *  It returns non-zero value if there is an error.
+ *
+ *  Possible return error of the API functions:
+ *  * 0: success
+ *  * -1: error happens, can be retrieved by TVMFFIGetLastError
+ *  * -2: a frontend error occurred and recorded in the frontend.
+ *
+ * \note We decided to leverage TVMFFIGetLastError and TVMFFISetLastError
+ *  for C function error propagation. This design choice, while
+ *  introducing a dependency for TLS runtime, simplifies error
+ *  propgation in chains of calls in compiler codegen.
+ *  As we do not need to propagate error through argument but simply
+ *  set them in the runtime environment.
+ */
+typedef int (*TVMFFISafeCallType)(void* self, int32_t num_args, const 
TVMFFIAny* args,
+                                  TVMFFIAny* result);
+
+/*!
+ * \brief Getter that can take address of a field and set the result.
+ * \param field The raw address of the field.
+ * \param result Stores the result.
+ */
+typedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result);
+
+/*!
+ * \brief Getter that can take address of a field and set to value.
+ * \param field The raw address of the field.
+ * \param value The value to set.
+ */
+typedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value);
+
+/*!
+ * \brief Information support for optional object reflection.
+ */
+typedef struct {
+  /*! \brief The name of the field. */
+  const char* name;
+  /*!
+   * \brief Records the static type kind of the field.
+   *
+   * Possible values:
+   *
+   *  - TVMFFITypeIndex::kTVMFFIObject for general objects
+   *    - The value is nullable when kTVMFFIObject is chosen
+   * - static object type kinds such as Map, Dict, String
+   * - POD type index
+   * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info
+   *   about the field.
+   *
+   * \note This information is helpful in designing serializer
+   * of the field. As it helps to narrow down the type of the
+   * object. It also helps to provide opportunities to enable
+   * short-cut access to the field.
+   */
+  int32_t field_static_type_index;
+  /*!
+   * \brief Mark whether field is readonly.
+   */
+  int32_t readonly;
+  /*!
+   * \brief Byte offset of the field.
+   */
+  int64_t byte_offset;
+  /*! \brief The getter to access the field. */
+  TVMFFIFieldGetter getter;
+  /*! \brief The setter to access the field. */
+  TVMFFIFieldSetter setter;
+} TVMFFIFieldInfo;
+
+/*!
+ * \brief Method information that can appear in reflection table.
+ */
+typedef struct {
+  /*! \brief The name of the field. */
+  const char* name;
+  /*!
+   * \brief The method wrapped as Function
+   * \note The first argument to the method is always the self.
+   */
+  TVMFFIObjectHandle method;
+} TVMFFIMethodInfo;
+
 /*!
  * \brief Runtime type information for object type checking.
  */
@@ -155,38 +264,27 @@ typedef struct {
    *       hieracy stays as a tree
    */
   const int32_t* type_acenstors;
+  /*! \brief number of reflection accessible fields. */
+  int32_t num_fields;
+  /*! \brief number of reflection acccesible methods. */
+  int32_t num_methods;
+  /*! \brief The reflection field information. */
+  TVMFFIFieldInfo* fields;
+  /*! \brief The reflection method. */
+  TVMFFIMethodInfo* methods;
 } TVMFFITypeInfo;
 
 //------------------------------------------------------------
 // Section: User APIs to interact with the FFI
 //------------------------------------------------------------
 /*!
- * \brief Type that defines C-style safe call convention
- *
- * Safe call explicitly catches exception on function boundary.
- *
- * \param func The function handle
- * \param num_args Number if input arguments
- * \param args The input arguments to the call.
- * \param result Store output result
- *
- * \return The call return 0 if call is successful.
- *  It returns non-zero value if there is an error.
- *
- *  Possible return error of the API functions:
- *  * 0: success
- *  * -1: error happens, can be retrieved by TVMFFIGetLastError
- *  * -2: a frontend error occurred and recorded in the frontend.
- *
- * \note We decided to leverage TVMFFIGetLastError and TVMFFISetLastError
- *  for C function error propagation. This design choice, while
- *  introducing a dependency for TLS runtime, simplifies error
- *  propgation in chains of calls in compiler codegen.
- *  As we do not need to propagate error through argument but simply
- *  set them in the runtime environment.
+ * \brief Free an object handle by decreasing reference
+ * \param obj The object handle.
+ * \note Internally we decrease the reference counter of the object.
+ *       The object will be freed when every reference to the object are 
removed.
+ * \return 0 when success, nonzero when failure happens
  */
-typedef int (*TVMFFISafeCallType)(void* func, int32_t num_args, const 
TVMFFIAny* args,
-                                  TVMFFIAny* result);
+TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj);
 
 /*!
  * \brief Create a FFIFunc by passing in callbacks from C callback.
@@ -223,15 +321,6 @@ TVM_FFI_DLL int TVMFFIFuncSetGlobal(const char* name, 
TVMFFIObjectHandle f, int
  */
 TVM_FFI_DLL int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out);
 
-/*!
- * \brief Free an object handle by decreasing reference
- * \param obj The object handle.
- * \note Internally we decrease the reference counter of the object.
- *       The object will be freed when every reference to the object are 
removed.
- * \return 0 when success, nonzero when failure happens
- */
-TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj);
-
 /*!
  * \brief Move the last error from the environment to result.
  *
@@ -249,6 +338,30 @@ TVM_FFI_DLL void TVMFFIMoveFromLastError(TVMFFIAny* 
result);
  */
 TVM_FFI_DLL void TVMFFISetLastError(const TVMFFIAny* error_view);
 
+/*!
+ * \brief Convert type key to type index.
+ * \param type_key The key of the type.
+ * \param out_tindex the corresponding type index.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFITypeKey2Index(const char* type_key, int32_t* out_tindex);
+
+/*!
+ * \brief Register type field information for rutnime reflection.
+ * \param type_index The type index
+ * \param info The field info to be registered.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t type_index, const 
TVMFFIFieldInfo* info);
+
+/*!
+ * \brief Register type method information for rutnime reflection.
+ * \param type_index The type index
+ * \param info The method info to be registered.
+ * \return 0 when success, nonzero when failure happens 
+ */
+TVM_FFI_DLL int TVMFFIRegisterTypeMethod(int32_t type_index, const 
TVMFFIMethodInfo* info);
+
 //------------------------------------------------------------
 // Section: Backend noexcept functions for internal use
 //
@@ -292,13 +405,13 @@ TVM_FFI_DLL int32_t TVMFFIGetOrAllocTypeIndex(const char* 
type_key, int32_t stat
                                               int32_t type_depth, int32_t 
num_child_slots,
                                               int32_t child_slots_can_overflow,
                                               int32_t parent_type_index);
+
 /*!
  * \brief Get dynamic type info by type index.
  *
  * \param type_index The type index
  * \param result The output type information
- *
- * \return 0 when success, nonzero when failure happens
+ * \return The type info
  */
 TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index);
 
diff --git a/ffi/include/tvm/ffi/function_details.h 
b/ffi/include/tvm/ffi/function_details.h
index 56b7ca7df3..0f6701e509 100644
--- a/ffi/include/tvm/ffi/function_details.h
+++ b/ffi/include/tvm/ffi/function_details.h
@@ -161,7 +161,7 @@ class MovableArgValueWithContext {
   template <typename Type>
   TVM_FFI_INLINE operator Type() {
     using TypeWithoutCR = std::remove_const_t<std::remove_reference_t<Type>>;
-    std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]);
+  std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]);
     if (opt.has_value()) {
       return std::move(*opt);
     }
@@ -210,8 +210,8 @@ struct unpack_call_dispatcher<R, 0, index, F> {
 template <int index, typename F>
 struct unpack_call_dispatcher<void, 0, index, F> {
   template <typename... Args>
-  TVM_FFI_INLINE static void run(const std::string* optional_name, 
FGetFuncSignature f_sig,
-                                 const F& f, int32_t num_args, const AnyView* 
args, Any* rv,
+  TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature ,
+                                 const F& , int32_t , const AnyView* , Any* ,
                                  Args&&... unpacked_args) {
     f(std::forward<Args>(unpacked_args)...);
   }
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index 6a202ba1f5..70b68f6fc6 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -39,7 +39,7 @@ using TypeInfo = TVMFFITypeInfo;
 namespace details {
 // Helper to perform
 // unsafe operations related to object
-struct ObjectUnsafe;
+class ObjectUnsafe;
 
 /*!
  * Check if the type_index is an instance of TargetObjectType.
@@ -445,32 +445,6 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
   static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \
   TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType)
 
-/*!
- * \brief helper macro to declare a base object type that can be inherited.
- * \param TypeName The name of the current type.
- * \param ParentType The name of the ParentType
- */
-#define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)                 
               \
-  TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType);                            
               \
-  static int32_t _GetOrAllocRuntimeTypeIndex() {                               
               \
-    static int32_t tindex = TVMFFIGetOrAllocTypeIndex(                         
               \
-        TypeName::_type_key, -1, TypeName::_type_depth, 
TypeName::_type_child_slots,          \
-        TypeName::_type_child_slots_can_overflow, 
ParentType::_GetOrAllocRuntimeTypeIndex()); \
-    return tindex;                                                             
               \
-  }                                                                            
               \
-  static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); }  
               \
-  static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex()
-
-/*!
- * \brief helper macro to declare type information in a final class.
- * \param TypeName The name of the current type.
- * \param ParentType The name of the ParentType
- */
-#define TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
-  static const constexpr int _type_child_slots = 0;             \
-  static const constexpr bool _type_final = true;               \
-  TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
-
 /*
  * \brief Define object reference methods.
  * \param TypeName The object type name
@@ -539,7 +513,8 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t 
object_type_index) {
  * \note These functions are only supposed to be used by internal
  * implementations and not external users of the tvm::ffi
  */
-struct ObjectUnsafe {
+class ObjectUnsafe {
+ public:
   // NOTE: get ffi header from an object
   static TVM_FFI_INLINE TVMFFIObject* GetHeader(const Object* src) {
     return const_cast<TVMFFIObject*>(&(src->header_));
diff --git a/ffi/include/tvm/ffi/reflection.h b/ffi/include/tvm/ffi/reflection.h
new file mode 100644
index 0000000000..51573d1bc2
--- /dev/null
+++ b/ffi/include/tvm/ffi/reflection.h
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/reflection.h
+ * \brief Base reflection support to access object fields.
+ */
+#ifndef TVM_FFI_REFLECTION_H_
+#define TVM_FFI_REFLECTION_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/type_traits.h>
+
+namespace tvm {
+namespace ffi {
+namespace details {
+
+template <typename T, typename = void>
+struct Type2FieldStaticTypeIndex {
+  static constexpr int32_t value = TypeIndex::kTVMFFIAny;
+};
+
+template <typename T>
+struct Type2FieldStaticTypeIndex<T, std::enable_if_t<TypeTraits<T>::enabled>> {
+  static constexpr int32_t value = TypeTraits<T>::field_static_type_index;
+};
+
+/*!
+ * \brief Get the byte offset of a class member field.
+ *
+ * \tparam The original class.
+ * \tparam T the field type.
+ *
+ * \param field_ptr A class member pointer
+ * \returns The byteoffset
+ */
+template <typename Class, typename T>
+inline int64_t GetFieldByteOffset(T Class::*field_ptr) {
+  return 
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
+}
+
+class ReflectionDef {
+ public:
+  explicit ReflectionDef(int32_t type_index) : type_index_(type_index) {}
+
+  template <typename Class, typename T>
+  ReflectionDef& def_readonly(const char* name, T Class::*field_ptr) {
+    RegisterField(name, field_ptr, true);
+    return *this;
+  }
+
+  template <typename Class, typename T>
+  ReflectionDef& def_readwrite(const char* name, T Class::*field_ptr) {
+    RegisterField(name, field_ptr, false);
+    return *this;
+  }
+
+  operator int32_t() const { return type_index_; }
+
+ private:
+  template <typename Class, typename T>
+  void RegisterField(const char* name, T Class::*field_ptr, bool readonly) {
+    TVMFFIFieldInfo info;
+    info.name = name;
+    info.field_static_type_index = Type2FieldStaticTypeIndex<T>::value;
+    // store byte offset and setter, getter
+    // so the same setter can be reused for all the same type
+    info.byte_offset = GetFieldByteOffset<Class, T>(field_ptr);
+    info.readonly = readonly;
+    info.getter = FieldGetter<T>;
+    info.setter = FieldSetter<T>;
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFIRegisterTypeField(type_index_, &info));
+  }
+
+  template <typename T>
+  static int FieldGetter(void* field, TVMFFIAny* result) {
+    TVM_FFI_SAFE_CALL_BEGIN();
+    Any(*reinterpret_cast<T*>(field)).MoveToTVMFFIAny(result);
+    TVM_FFI_SAFE_CALL_END();
+  }
+
+  template <typename T>
+  static int FieldSetter(void* field, const TVMFFIAny* value) {
+    TVM_FFI_SAFE_CALL_BEGIN();
+    *reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value);
+    TVM_FFI_SAFE_CALL_END();
+  }
+
+  int32_t type_index_;
+};
+
+/*!
+ * \brief helper function to get reflection field info by type key and field 
name
+ */
+inline const TVMFFIFieldInfo* GetReflectionFieldInfo(const char* type_key, 
const char* field_name) {
+  int32_t type_index;
+  TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKey2Index(type_key, &type_index));
+  const TypeInfo* info = TVMFFIGetTypeInfo(type_index);
+  for (int32_t i = 0; i < info->num_fields; ++i) {
+    if (std::strcmp(info->fields[i].name, field_name) == 0) {
+      return &(info->fields[i]);
+    }
+  }
+  TVM_FFI_THROW(RuntimeError) << "Cannot find field " << field_name << " in " 
<< type_key;
+}
+
+/*!
+ * \brief helper wrapper class to obtain a getter.
+ */
+class ReflectionFieldGetter {
+ public:
+  explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) : 
field_info_(field_info) {
+  }
+
+  Any operator()(const Object* obj_ptr) const { 
+    Any result;
+    const void* addr = reinterpret_cast<const char*>(obj_ptr) + 
field_info_->byte_offset;
+    TVM_FFI_CHECK_SAFE_CALL(field_info_->getter(const_cast<void*>(addr), 
reinterpret_cast<TVMFFIAny*>(&result)));
+    return result;
+  }
+
+  Any operator()(const ObjectPtr<Object>& obj_ptr) const { 
+    return operator()(obj_ptr.get());
+  }
+
+  Any operator()(const ObjectRef& obj) const { 
+    return operator()(obj.get());
+  }  
+
+ private:
+  const TVMFFIFieldInfo* field_info_;
+};
+
+}  // namespace details
+
+/*!
+ * \brief helper macro to declare a base object type that can be inherited.
+ * \param TypeName The name of the current type.
+ * \param ParentType The name of the ParentType
+ */
+#define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)                 
               \
+  TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType);                            
               \
+  static int32_t _GetOrAllocRuntimeTypeIndex() {                               
               \
+    static int32_t tindex = TVMFFIGetOrAllocTypeIndex(                         
               \
+        TypeName::_type_key, -1, TypeName::_type_depth, 
TypeName::_type_child_slots,          \
+        TypeName::_type_child_slots_can_overflow, 
ParentType::_GetOrAllocRuntimeTypeIndex()); \
+    return tindex;                                                             
               \
+  }                                                                            
               \
+  static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); }  
               \
+  static inline int32_t _type_index =                                          
               \
+      ::tvm::ffi::details::ReflectionDef(_GetOrAllocRuntimeTypeIndex())
+
+/*!
+ * \brief helper macro to declare type information in a final class.
+ * \param TypeName The name of the current type.
+ * \param ParentType The name of the ParentType
+ */
+#define TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
+  static const constexpr int _type_child_slots = 0;             \
+  static const constexpr bool _type_final = true;               \
+  TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
+
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_REFLECTION_H_
diff --git a/ffi/include/tvm/ffi/type_traits.h 
b/ffi/include/tvm/ffi/type_traits.h
index a46d04ad83..cf88eb7319 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -107,6 +107,8 @@ struct TypeTraitsBase {
 // None
 template <>
 struct TypeTraits<std::nullptr_t> : public TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFINone;
+
   static TVM_FFI_INLINE void CopyToAnyView(const std::nullptr_t&, TVMFFIAny* 
result) {
     result->type_index = TypeIndex::kTVMFFINone;
     // invariant: the pointer field also equals nullptr
@@ -142,6 +144,8 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase {
 // Integer POD values
 template <typename Int>
 struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public 
TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt;
+
   static TVM_FFI_INLINE void CopyToAnyView(const Int& src, TVMFFIAny* result) {
     result->type_index = TypeIndex::kTVMFFIInt;
     result->v_int64 = static_cast<int64_t>(src);
@@ -171,6 +175,8 @@ struct TypeTraits<Int, 
std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
 template <typename Float>
 struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>>
     : public TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFloat;
+
   static TVM_FFI_INLINE void CopyToAnyView(const Float& src, TVMFFIAny* 
result) {
     result->type_index = TypeIndex::kTVMFFIFloat;
     result->v_float64 = static_cast<double>(src);
@@ -205,6 +211,8 @@ struct TypeTraits<Float, 
std::enable_if_t<std::is_floating_point_v<Float>>>
 // void*
 template <>
 struct TypeTraits<void*> : public TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = 
TypeIndex::kTVMFFIOpaquePtr;
+
   static TVM_FFI_INLINE void CopyToAnyView(void* src, TVMFFIAny* result) {
     result->type_index = TypeIndex::kTVMFFIOpaquePtr;
     // maintain padding zero in 32bit platform
@@ -241,6 +249,7 @@ template <typename TObjRef>
 struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, 
TObjRef> &&
                                             
use_default_type_traits_v<TObjRef>>>
     : public TypeTraitsBase {
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIObject;
   using ContainerType = typename TObjRef::ContainerType;
 
   static TVM_FFI_INLINE void CopyToAnyView(const TObjRef& src, TVMFFIAny* 
result) {
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 8941753df8..a0fe306e78 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -52,6 +52,8 @@ class TypeTable {
     std::string type_key_data;
     /*! \brief acenstor information */
     std::vector<int32_t> type_acenstors_data;
+    /*! \brief type fields informaton */
+    std::vector<TVMFFIFieldInfo> type_fields_data;
     // NOTE: the indices in [index, index + num_reserved_slots) are
     // reserved for the child-class of this type.
     /*! \brief Total number of slots reserved for the type and its children. */
@@ -87,6 +89,11 @@ class TypeTable {
       this->type_key = this->type_key_data.c_str();
       this->type_key_hash = std::hash<std::string>()(this->type_key_data);
       this->type_acenstors = type_acenstors_data.data();
+      // initialize the reflection information
+      this->num_fields = 0;
+      this->num_methods = 0;
+      this->fields = nullptr;
+      this->methods = nullptr;
     }
   };
 
@@ -165,13 +172,23 @@ class TypeTable {
     return it->second;
   }
 
-  const TypeInfo* GetTypeInfo(int32_t type_index) {
-    const TypeInfo* info = nullptr;
+  Entry* GetTypeEntry(int32_t type_index) { 
+    Entry* entry = nullptr;
     if (type_index >= 0 && static_cast<size_t>(type_index) < 
type_table_.size()) {
-      info = type_table_[type_index].get();
+      entry = type_table_[type_index].get();
     }
-    TVM_FFI_ICHECK(info != nullptr) << "Cannot find type info for type_index=" 
<< type_index;
-    return info;
+    TVM_FFI_ICHECK(entry != nullptr) << "Cannot find type info for 
type_index=" << type_index;
+    return entry;
+  }
+
+  void RegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) {
+    Entry* entry = GetTypeEntry(type_index);
+    TVMFFIFieldInfo field_data = *info;
+    field_data.name = this->CopyString(info->name);
+    entry->type_fields_data.push_back(field_data);
+    // refresh ptr as the data can change
+    entry->fields = entry->type_fields_data.data();
+    entry->num_fields = static_cast<int32_t>(entry->type_fields_data.size());
   }
 
   void Dump(int min_children_count) {
@@ -217,9 +234,17 @@ class TypeTable {
                               -1);
   }
 
+  const char* CopyString(const char* c_str) {
+    std::unique_ptr<std::string> val = std::make_unique<std::string>(c_str);
+    const char* c_val = val->c_str();
+    string_pool_.emplace_back(std::move(val));
+    return c_val;
+  }
+
   int32_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin};
   std::vector<std::unique_ptr<Entry>> type_table_;
   std::unordered_map<std::string, int32_t> type_key2index_;
+  std::vector<std::unique_ptr<std::string>> string_pool_;
 };
 }  // namespace ffi
 }  // namespace tvm
@@ -230,6 +255,18 @@ int TVMFFIObjectFree(TVMFFIObjectHandle handle) {
   TVM_FFI_SAFE_CALL_END();
 }
 
+int TVMFFITypeKey2Index(const char* type_key, int32_t* out_tindex) {
+  TVM_FFI_SAFE_CALL_BEGIN();
+  out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKey2Index(type_key);
+  TVM_FFI_SAFE_CALL_END();
+}
+
+int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) {
+  TVM_FFI_SAFE_CALL_BEGIN();
+  tvm::ffi::TypeTable::Global()->RegisterTypeField(type_index, info);
+  TVM_FFI_SAFE_CALL_END();
+}
+
 int32_t TVMFFIGetOrAllocTypeIndex(const char* type_key, int32_t 
static_type_index,
                                   int32_t type_depth, int32_t num_child_slots,
                                   int32_t child_slots_can_overflow, int32_t 
parent_type_index) {
@@ -242,6 +279,6 @@ int32_t TVMFFIGetOrAllocTypeIndex(const char* type_key, 
int32_t static_type_inde
 
 const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) {
   TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
-  return tvm::ffi::TypeTable::Global()->GetTypeInfo(type_index);
+  return tvm::ffi::TypeTable::Global()->GetTypeEntry(type_index);
   TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeInfo);
 }
diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc
new file mode 100644
index 0000000000..c7901a4ca3
--- /dev/null
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -0,0 +1,49 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/reflection.h>
+#include "./testing_object.h"
+
+namespace {
+
+using namespace tvm::ffi;
+using namespace tvm::ffi::testing;
+
+struct A {
+  ObjectRef obj;
+  int32_t x;
+  int32_t y;
+};
+
+TEST(Reflection, GetFieldByteOffset) {
+  EXPECT_EQ(details::GetFieldByteOffset(&A::x), 8);
+  EXPECT_EQ(details::GetFieldByteOffset(&A::y), 12);
+}
+
+
+TEST(Reflection, FieldGetter) {
+  ObjectRef a = TInt(10);
+  details::ReflectionFieldGetter getter(
+    details::GetReflectionFieldInfo("test.Int", "value")
+  );
+  EXPECT_EQ(getter(a).operator int(), 10);
+}
+}  // namespace
diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h
index e660b2751e..a4d6b1353f 100644
--- a/ffi/tests/cpp/testing_object.h
+++ b/ffi/tests/cpp/testing_object.h
@@ -19,8 +19,10 @@
 
 #ifndef TVM_FFI_TESTING_OBJECT_H_
 #define TVM_FFI_TESTING_OBJECT_H_
+
 #include <tvm/ffi/memory.h>
 #include <tvm/ffi/object.h>
+#include <tvm/ffi/reflection.h>
 
 namespace tvm {
 namespace ffi {
@@ -46,7 +48,8 @@ class TIntObj : public TNumberObj {
   TIntObj(int64_t value) : value(value) {}
 
   static constexpr const char* _type_key = "test.Int";
-  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj);
+
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj).def_readonly("value", 
&TIntObj::value);
 };
 
 class TInt : public TNumber {

Reply via email to