This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a06906cfd2 [REFACTOR] Upgrade NestedMsg<T> to use new ffi::Any 
mechanism (#18181)
a06906cfd2 is described below

commit a06906cfd273c9af657e0ada59ec33fe8b1ec3bc
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Aug 1 15:49:28 2025 -0400

    [REFACTOR] Upgrade NestedMsg<T> to use new ffi::Any mechanism (#18181)
    
    This PR upgrades NestedMsg<T> to use the new ffi::Any mechanism,
    which will enable us to get better support and enable NestedMsg
    for POD types.
---
 include/tvm/relax/nested_msg.h | 122 +++++++++++++++++++++++++++++++++--------
 tests/cpp/nested_msg_test.cc   |  23 ++++++--
 2 files changed, 116 insertions(+), 29 deletions(-)

diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h
index af2db582d6..8620ad80bd 100644
--- a/include/tvm/relax/nested_msg.h
+++ b/include/tvm/relax/nested_msg.h
@@ -33,6 +33,7 @@
 #include <tvm/relax/expr.h>
 #include <tvm/relax/struct_info.h>
 
+#include <string>
 #include <utility>
 #include <vector>
 
@@ -115,7 +116,7 @@ namespace relax {
  *       use this class or logic of a similar kind.
  */
 template <typename T>
-class NestedMsg : public ObjectRef {
+class NestedMsg {
  public:
   // default constructors.
   NestedMsg() = default;
@@ -123,12 +124,6 @@ class NestedMsg : public ObjectRef {
   NestedMsg(NestedMsg<T>&&) = default;
   NestedMsg<T>& operator=(const NestedMsg<T>&) = default;
   NestedMsg<T>& operator=(NestedMsg<T>&&) = default;
-  /*!
-   * \brief Construct from an ObjectPtr
-   *        whose type already satisfies the constraint
-   * \param ptr
-   */
-  explicit NestedMsg(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
   /*! \brief Nullopt handling */
   NestedMsg(std::nullopt_t) {}  // NOLINT(*)
   // nullptr handling.
@@ -140,16 +135,17 @@ class NestedMsg : public ObjectRef {
   }
   // normal value handling.
   NestedMsg(T other)  // NOLINT(*)
-      : ObjectRef(std::move(other)) {}
+      : data_(std::move(other)) {}
   NestedMsg<T>& operator=(T other) {
-    ObjectRef::operator=(std::move(other));
+    data_ = std::move(other);
     return *this;
   }
   // Array<NestedMsg<T>> handling
   NestedMsg(Array<NestedMsg<T>, void> other)  // NOLINT(*)
-      : ObjectRef(std::move(other)) {}
+      : data_(other) {}
+
   NestedMsg<T>& operator=(Array<NestedMsg<T>, void> other) {
-    ObjectRef::operator=(std::move(other));
+    data_ = std::move(other);
     return *this;
   }
 
@@ -170,13 +166,16 @@ class NestedMsg : public ObjectRef {
   bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
 
   /*! \return Whether the nested message is not-null leaf value */
-  bool IsLeaf() const { return data_ != nullptr && 
data_->IsInstance<LeafContainerType>(); }
+  bool IsLeaf() const {
+    return data_.type_index() != ffi::TypeIndex::kTVMFFINone &&
+           data_.type_index() != ffi::TypeIndex::kTVMFFIArray;
+  }
 
   /*! \return Whether the nested message is null */
-  bool IsNull() const { return data_ == nullptr; }
+  bool IsNull() const { return data_.type_index() == 
ffi::TypeIndex::kTVMFFINone; }
 
   /*! \return Whether the nested message is nested */
-  bool IsNested() const { return data_ != nullptr && 
data_->IsInstance<ffi::ArrayObj>(); }
+  bool IsNested() const { return data_.type_index() == 
ffi::TypeIndex::kTVMFFIArray; }
 
   /*!
    * \return The underlying leaf value.
@@ -184,7 +183,7 @@ class NestedMsg : public ObjectRef {
    */
   T LeafValue() const {
     ICHECK(IsLeaf());
-    return T(data_);
+    return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
   }
 
   /*!
@@ -192,16 +191,15 @@ class NestedMsg : public ObjectRef {
    * \note This checks if the underlying data type is array.
    */
   Array<NestedMsg<T>, void> NestedArray() const {
-    ICHECK(IsNested());
-    return Array<NestedMsg<T>, void>(data_);
+    return 
ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<Array<NestedMsg<T>, 
void>>(data_);
   }
 
-  using ContainerType = Object;
-  using LeafContainerType = typename T::ContainerType;
-
-  static_assert(std::is_base_of<ObjectRef, T>::value, "NestedMsg is only 
defined for ObjectRef.");
-
-  static constexpr bool _type_is_nullable = true;
+ private:
+  ffi::Any data_;
+  // private constructor
+  explicit NestedMsg(ffi::Any data) : data_(data) {}
+  template <typename, typename>
+  friend struct ffi::TypeTraits;
 };
 
 /*!
@@ -598,5 +596,83 @@ StructInfo TransformTupleLeaf(StructInfo sinfo, 
std::array<NestedMsg<T>, N> msgs
 }
 
 }  // namespace relax
+
+namespace ffi {
+
+template <typename T>
+inline constexpr bool use_default_type_traits_v<relax::NestedMsg<T>> = false;
+
+template <typename T>
+struct TypeTraits<relax::NestedMsg<T>> : public TypeTraitsBase {
+  TVM_FFI_INLINE static void CopyToAnyView(const relax::NestedMsg<T>& src, 
TVMFFIAny* result) {
+    *result = ffi::AnyView(src.data_).CopyToTVMFFIAny();
+  }
+
+  TVM_FFI_INLINE static void MoveToAny(relax::NestedMsg<T> src, TVMFFIAny* 
result) {
+    *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
+  }
+
+  TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
+    return TypeTraitsBase::GetMismatchTypeInfo(src);
+  }
+
+  static bool CheckAnyStrict(const TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFINone) {
+      return true;
+    }
+    if (TypeTraits<T>::CheckAnyStrict(src)) {
+      return true;
+    }
+    if (src->type_index == TypeIndex::kTVMFFIArray) {
+      const ffi::ArrayObj* array = reinterpret_cast<const 
ffi::ArrayObj*>(src->v_obj);
+      for (size_t i = 0; i < array->size(); ++i) {
+        const Any& any_v = (*array)[i];
+        if (!details::AnyUnsafe::CheckAnyStrict<relax::NestedMsg<T>>(any_v)) 
return false;
+      }
+    }
+    return true;
+  }
+
+  TVM_FFI_INLINE static relax::NestedMsg<T> CopyFromAnyViewAfterCheck(const 
TVMFFIAny* src) {
+    return relax::NestedMsg<T>(Any(AnyView::CopyFromTVMFFIAny(*src)));
+  }
+
+  TVM_FFI_INLINE static relax::NestedMsg<T> MoveFromAnyAfterCheck(TVMFFIAny* 
src) {
+    return 
relax::NestedMsg<T>(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src)));
+  }
+
+  static std::optional<relax::NestedMsg<T>> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    if (CheckAnyStrict(src)) {
+      return CopyFromAnyViewAfterCheck(src);
+    }
+    // slow path run conversion
+    if (src->type_index == TypeIndex::kTVMFFINone) {
+      return relax::NestedMsg<T>(std::nullopt);
+    }
+    if (auto opt_value = TypeTraits<T>::TryCastFromAnyView(src)) {
+      return relax::NestedMsg<T>(*std::move(opt_value));
+    }
+    if (src->type_index == TypeIndex::kTVMFFIArray) {
+      const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
+      Array<relax::NestedMsg<T>> result;
+      result.reserve(n->size());
+      for (size_t i = 0; i < n->size(); i++) {
+        const Any& any_v = (*n)[i];
+        if (auto opt_v = any_v.try_cast<relax::NestedMsg<T>>()) {
+          result.push_back(*std::move(opt_v));
+        } else {
+          return std::nullopt;
+        }
+      }
+      return relax::NestedMsg<T>(result);
+    }
+    return std::nullopt;
+  }
+
+  TVM_FFI_INLINE static std::string TypeStr() {
+    return "NestedMsg<" + details::Type2Str<T>::v() + ">";
+  }
+};
+}  // namespace ffi
 }  // namespace tvm
 #endif  // TVM_RELAX_NESTED_MSG_H_
diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc
index d552dae8f7..644a80664f 100644
--- a/tests/cpp/nested_msg_test.cc
+++ b/tests/cpp/nested_msg_test.cc
@@ -53,7 +53,7 @@ TEST(NestedMsg, Basic) {
   EXPECT_ANY_THROW(msg.LeafValue());
 
   auto arr = msg.NestedArray();
-  EXPECT_TRUE(arr[0].same_as(x));
+  EXPECT_TRUE(arr[0].LeafValue().same_as(x));
   EXPECT_TRUE(arr[1] == nullptr);
   EXPECT_TRUE(arr[1].IsNull());
 
@@ -72,13 +72,24 @@ TEST(NestedMsg, Basic) {
   EXPECT_TRUE(a0.IsNested());
   auto t0 = a0.NestedArray()[1];
   EXPECT_TRUE(t0.IsNested());
-  EXPECT_TRUE(t0.NestedArray()[2].same_as(y));
+  EXPECT_TRUE(t0.NestedArray()[2].LeafValue().same_as(y));
 
   // assign leaf
   a0 = x;
 
   EXPECT_TRUE(a0.IsLeaf());
-  EXPECT_TRUE(a0.same_as(x));
+  EXPECT_TRUE(a0.LeafValue().same_as(x));
+}
+
+TEST(NestedMsg, IntAndAny) {
+  NestedMsg<int64_t> msg({1, std::nullopt, 2});
+  Any any_msg = msg;
+  NestedMsg<int64_t> msg2 = any_msg.cast<NestedMsg<int64_t>>();
+
+  EXPECT_TRUE(msg2.IsNested());
+  EXPECT_EQ(msg2.NestedArray()[0].LeafValue(), 1);
+  EXPECT_TRUE(msg2.NestedArray()[1].IsNull());
+  EXPECT_EQ(msg2.NestedArray()[2].LeafValue(), 2);
 }
 
 TEST(NestedMsg, ForEachLeaf) {
@@ -174,13 +185,13 @@ TEST(NestedMsg, MapAndDecompose) {
 
   DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg<Integer> msg) {
     if (value.same_as(x)) {
-      EXPECT_TRUE(msg.same_as(c0));
+      EXPECT_TRUE(msg.LeafValue().same_as(c0));
       ++x_count;
     } else if (value.same_as(y)) {
-      EXPECT_TRUE(msg.same_as(c1));
+      EXPECT_TRUE(msg.LeafValue().same_as(c1));
       ++y_count;
     } else {
-      EXPECT_TRUE(msg.same_as(c2));
+      EXPECT_TRUE(msg.LeafValue().same_as(c2));
       ++z_count;
     }
   });

Reply via email to