This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ce8875ea04 [FFI] Update typeinfo to speedup parent reflection (#18083)
ce8875ea04 is described below

commit ce8875ea04837ace434303fce4c310ee2b3a442f
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jun 19 13:52:14 2025 -0400

    [FFI] Update typeinfo to speedup parent reflection (#18083)
    
    This PR updates the typeinfo to speedup parent reflection
    Also optimizes a few if constexpr branches to explicitly
    place else to eliminate branch early in compilation.
---
 ffi/include/tvm/ffi/c_api.h                 |  4 +--
 ffi/include/tvm/ffi/container/tuple.h       |  3 +-
 ffi/include/tvm/ffi/object.h                | 52 ++++++++++++++++-------------
 ffi/include/tvm/ffi/reflection/reflection.h | 25 ++++++++++++++
 ffi/include/tvm/ffi/string.h                |  9 +++++
 ffi/src/ffi/object.cc                       | 11 +++---
 ffi/tests/cpp/test_object.cc                |  4 +--
 ffi/tests/cpp/test_reflection.cc            | 30 ++++++++++++++---
 ffi/tests/cpp/testing_object.h              | 10 ++++++
 9 files changed, 109 insertions(+), 39 deletions(-)

diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 668d239ca1..545604a473 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -437,7 +437,7 @@ typedef struct {
 /*!
  * \brief Runtime type information for object type checking.
  */
-typedef struct {
+typedef struct TVMFFITypeInfo {
   /*!
    *\brief The runtime type index,
    * It can be allocated during runtime if the type is dynamic.
@@ -452,7 +452,7 @@ typedef struct {
    * \note To keep things simple, we do not allow multiple inheritance so the
    *       hieracy stays as a tree
    */
-  const int32_t* type_acenstors;
+  const struct TVMFFITypeInfo** type_acenstors;
   // The following fields are used for reflection
   /*! \brief Cached hash value of the type key, used for consistent structural 
hashing. */
   uint64_t type_key_hash;
diff --git a/ffi/include/tvm/ffi/container/tuple.h 
b/ffi/include/tvm/ffi/container/tuple.h
index 27f08e7fc9..e88768df6c 100644
--- a/ffi/include/tvm/ffi/container/tuple.h
+++ b/ffi/include/tvm/ffi/container/tuple.h
@@ -253,8 +253,9 @@ struct TypeTraits<Tuple<Types...>> : public 
ObjectRefTypeTraitsBase<Tuple<Types.
     }
     if constexpr (sizeof...(Rest) > 0) {
       return TryConvertElements<I + 1, Rest...>(std::move(arr));
+    } else {
+      return true;
     }
-    return true;
   }
 
   static TVM_FFI_INLINE std::string TypeStr() {
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index 21dd8f3f7f..d032cc13e4 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -693,34 +693,38 @@ template <typename TargetType>
 TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) {
   static_assert(std::is_base_of_v<Object, TargetType>);
   // Everything is a subclass of object.
-  if constexpr (std::is_same<TargetType, Object>::value) return true;
-
-  if constexpr (TargetType::_type_final) {
+  if constexpr (std::is_same<TargetType, Object>::value) {
+    return true;
+  } else if constexpr (TargetType::_type_final) {
     // if the target type is a final type
     // then we only need to check the equivalence.
     return object_type_index == TargetType::RuntimeTypeIndex();
-  }
-
-  // if target type is a non-leaf type
-  // Check if type index falls into the range of reserved slots.
-  int32_t target_type_index = TargetType::RuntimeTypeIndex();
-  int32_t begin = target_type_index;
-  // The condition will be optimized by constant-folding.
-  if constexpr (TargetType::_type_child_slots != 0) {
-    // total_slots = child_slots + 1 (including self)
-    int32_t end = begin + TargetType::_type_child_slots + 1;
-    if (object_type_index >= begin && object_type_index < end) return true;
   } else {
-    if (object_type_index == begin) return true;
-  }
-  if (!TargetType::_type_child_slots_can_overflow) return false;
-  // Invariance: parent index is always smaller than the child.
-  if (object_type_index < target_type_index) return false;
-  // Do a runtime lookup of type information
-  // the function checks that the info exists
-  const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index);
-  return (type_info->type_depth > TargetType::_type_depth &&
-          type_info->type_acenstors[TargetType::_type_depth] == 
target_type_index);
+    // Explicitly enclose in else to eliminate this branch early in 
compilation.
+    // if target type is a non-leaf type
+    // Check if type index falls into the range of reserved slots.
+    int32_t target_type_index = TargetType::RuntimeTypeIndex();
+    int32_t begin = target_type_index;
+    // The condition will be optimized by constant-folding.
+    if constexpr (TargetType::_type_child_slots != 0) {
+      // total_slots = child_slots + 1 (including self)
+      int32_t end = begin + TargetType::_type_child_slots + 1;
+      if (object_type_index >= begin && object_type_index < end) return true;
+    } else {
+      if (object_type_index == begin) return true;
+    }
+    if constexpr (TargetType::_type_child_slots_can_overflow) {
+      // Invariance: parent index is always smaller than the child.
+      if (object_type_index < target_type_index) return false;
+      // Do a runtime lookup of type information
+      // the function checks that the info exists
+      const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index);
+      return (type_info->type_depth > TargetType::_type_depth &&
+              type_info->type_acenstors[TargetType::_type_depth]->type_index 
== target_type_index);
+    } else {
+      return false;
+    }
+  }
 }
 
 /*!
diff --git a/ffi/include/tvm/ffi/reflection/reflection.h 
b/ffi/include/tvm/ffi/reflection/reflection.h
index 6187a74825..d53a4817ad 100644
--- a/ffi/include/tvm/ffi/reflection/reflection.h
+++ b/ffi/include/tvm/ffi/reflection/reflection.h
@@ -392,6 +392,31 @@ inline Function GetMethod(std::string_view type_key, const 
char* method_name) {
   return AnyView::CopyFromTVMFFIAny(info->method).cast<Function>();
 }
 
+/*!
+ * \brief Visit each field info of the type info and run callback.
+ *
+ * \tparam Callback The callback function type.
+ *
+ * \param type_info The type info.
+ * \param callback The callback function.
+ *
+ * \note This function calls both the child and parent type info.
+ */
+template <typename Callback>
+inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) {
+  // iterate through acenstors in parent to child order
+  // skip the first one since it is always the root object
+  for (int i = 1; i < type_info->type_depth; ++i) {
+    const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
+    for (int j = 0; j < parent_info->num_fields; ++j) {
+      callback(parent_info->fields + j);
+    }
+  }
+  for (int i = 0; i < type_info->num_fields; ++i) {
+    callback(type_info->fields + i);
+  }
+}
+
 }  // namespace reflection
 }  // namespace ffi
 }  // namespace tvm
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index dee2d89c08..734483e271 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -255,6 +255,15 @@ class String : public ObjectRef {
    */
   String(std::string&& other)  // NOLINT(*)
       : 
ObjectRef(make_object<details::BytesObjStdImpl<StringObj>>(std::move(other))) {}
+
+  /*!
+   * \brief constructor from TVMFFIByteArray
+   *
+   * \param other a TVMFFIByteArray.
+   */
+  explicit String(TVMFFIByteArray other)
+      : ObjectRef(details::MakeInplaceBytes<StringObj>(other.data, 
other.size)) {}
+
   /*!
    * \brief Swap this String with another string
    * \param other The other string
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index fa77e2b264..5becd89fb3 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -52,7 +52,7 @@ class TypeTable {
     /*! \brief stored type key */
     String type_key_data;
     /*! \brief acenstor information */
-    std::vector<int32_t> type_acenstors_data;
+    std::vector<const TVMFFITypeInfo*> type_acenstors_data;
     /*! \brief type fields informaton */
     std::vector<TVMFFIFieldInfo> type_fields_data;
     /*! \brief type methods informaton */
@@ -85,7 +85,7 @@ class TypeTable {
           type_acenstors_data[i] = parent->type_acenstors[i];
         }
         // set last type information to be parent
-        type_acenstors_data[parent->type_depth] = parent->type_index;
+        type_acenstors_data[parent->type_depth] = parent;
       }
       // initialize type info: no change to type_key and type_acenstors fields
       // after this line
@@ -234,7 +234,7 @@ class TypeTable {
     for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
       const Entry* ptr = it->get();
       if (ptr != nullptr && ptr->type_depth != 0) {
-        int parent_index = ptr->type_acenstors[ptr->type_depth - 1];
+        int parent_index = ptr->type_acenstors[ptr->type_depth - 
1]->type_index;
         num_children[parent_index] += num_children[ptr->type_index] + 1;
         if (expected_child_slots[ptr->type_index] + 1 < ptr->num_slots) {
           expected_child_slots[ptr->type_index] = ptr->num_slots - 1;
@@ -247,7 +247,7 @@ class TypeTable {
       if (ptr != nullptr && num_children[ptr->type_index] >= 
min_children_count) {
         std::cerr << '[' << ptr->type_index << "]\t" << 
ToStringView(ptr->type_key);
         if (ptr->type_depth != 0) {
-          int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1];
+          int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 
1]->type_index;
           std::cerr << "\tparent=" << 
ToStringView(type_table_[parent_index]->type_key);
         } else {
           std::cerr << "\tparent=root";
@@ -375,9 +375,8 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* 
ret) {
 
   // iterate through acenstors in parent to child order
   // skip the first one since it is always the root object
-  TVM_FFI_ICHECK(type_info->type_acenstors[0] == TypeIndex::kTVMFFIObject);
   for (int i = 1; i < type_info->type_depth; ++i) {
-    update_fields(TVMFFIGetTypeInfo(type_info->type_acenstors[i]));
+    update_fields(type_info->type_acenstors[i]);
   }
   update_fields(type_info);
 
diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc
index c370ff51a4..4b53a70b42 100644
--- a/ffi/tests/cpp/test_object.cc
+++ b/ffi/tests/cpp/test_object.cc
@@ -55,8 +55,8 @@ TEST(Object, TypeInfo) {
   EXPECT_TRUE(info != nullptr);
   EXPECT_EQ(info->type_index, TIntObj::RuntimeTypeIndex());
   EXPECT_EQ(info->type_depth, 2);
-  EXPECT_EQ(info->type_acenstors[0], Object::_type_index);
-  EXPECT_EQ(info->type_acenstors[1], TNumberObj::_type_index);
+  EXPECT_EQ(info->type_acenstors[0]->type_index, Object::_type_index);
+  EXPECT_EQ(info->type_acenstors[1]->type_index, TNumberObj::_type_index);
   EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin);
 }
 
diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc
index 450cb9dbcb..17494744ef 100644
--- a/ffi/tests/cpp/test_reflection.cc
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -18,6 +18,7 @@
  * under the License.
  */
 #include <gtest/gtest.h>
+#include <tvm/ffi/container/map.h>
 #include <tvm/ffi/object.h>
 #include <tvm/ffi/reflection/reflection.h>
 #include <tvm/ffi/string.h>
@@ -29,11 +30,20 @@ namespace {
 using namespace tvm::ffi;
 using namespace tvm::ffi::testing;
 
-struct A : public Object {
+struct TestObjA : public Object {
   int64_t x;
   int64_t y;
 
+  static constexpr const char* _type_key = "test.TestObjA";
   static constexpr bool _type_mutable = true;
+  TVM_FFI_DECLARE_BASE_OBJECT_INFO(TestObjA, Object);
+};
+
+struct TestObjADerived : public TestObjA {
+  int64_t z;
+
+  static constexpr const char* _type_key = "test.TestObjADerived";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjADerived, TestObjA);
 };
 
 TVM_FFI_STATIC_INIT_BLOCK({
@@ -56,12 +66,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
         return self->value - other;
       });
 
-  refl::ObjectDef<A>().def_ro("x", &A::x).def_rw("y", &A::y);
+  refl::ObjectDef<TestObjA>().def_ro("x", &TestObjA::x).def_rw("y", 
&TestObjA::y);
+  refl::ObjectDef<TestObjADerived>().def_ro("z", &TestObjADerived::z);
 });
 
 TEST(Reflection, GetFieldByteOffset) {
-  EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::x), 
sizeof(TVMFFIObject));
-  EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::y), 8 + 
sizeof(TVMFFIObject));
+  EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::x), 
sizeof(TVMFFIObject));
+  EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::y), 8 + 
sizeof(TVMFFIObject));
   EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TIntObj::value), 
sizeof(TVMFFIObject));
 }
 
@@ -131,4 +142,15 @@ TEST(Reflection, CallMethod) {
   EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast<double>(), -1.0);
 }
 
+TEST(Reflection, ForEachFieldInfo) {
+  const TypeInfo* info = 
TVMFFIGetTypeInfo(TestObjADerived::RuntimeTypeIndex());
+  Map<String, int> field_name_to_offset;
+  reflection::ForEachFieldInfo(info, [&](const TVMFFIFieldInfo* field_info) {
+    field_name_to_offset.Set(String(field_info->name), field_info->offset);
+  });
+  EXPECT_EQ(field_name_to_offset["x"], sizeof(TVMFFIObject));
+  EXPECT_EQ(field_name_to_offset["y"], 8 + sizeof(TVMFFIObject));
+  EXPECT_EQ(field_name_to_offset["z"], 16 + sizeof(TVMFFIObject));
+}
+
 }  // namespace
diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h
index 8a91848845..9c14f5590f 100644
--- a/ffi/tests/cpp/testing_object.h
+++ b/ffi/tests/cpp/testing_object.h
@@ -22,6 +22,7 @@
 
 #include <tvm/ffi/memory.h>
 #include <tvm/ffi/object.h>
+#include <tvm/ffi/reflection/reflection.h>
 #include <tvm/ffi/string.h>
 
 namespace tvm {
@@ -81,6 +82,15 @@ class TFloatObj : public TNumberObj {
 
   double Add(double other) const { return value + other; }
 
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<TFloatObj>()
+        .def_ro("value", &TFloatObj::value, "float value field", 
refl::DefaultValue(10.0))
+        .def("sub",
+             [](const TFloatObj* self, double other) -> double { return 
self->value - other; })
+        .def("add", &TFloatObj::Add, "add method");
+  }
+
   static constexpr const char* _type_key = "test.Float";
   TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj);
 };

Reply via email to