This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a06906cfd2 [REFACTOR] Upgrade NestedMsg<T> to use new ffi::Any
mechanism (#18181)
a06906cfd2 is described below
commit a06906cfd273c9af657e0ada59ec33fe8b1ec3bc
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Aug 1 15:49:28 2025 -0400
[REFACTOR] Upgrade NestedMsg<T> to use new ffi::Any mechanism (#18181)
This PR upgrades NestedMsg<T> to use the new ffi::Any mechanism,
which will enable us to get better support and enable NestedMsg
for POD types.
---
include/tvm/relax/nested_msg.h | 122 +++++++++++++++++++++++++++++++++--------
tests/cpp/nested_msg_test.cc | 23 ++++++--
2 files changed, 116 insertions(+), 29 deletions(-)
diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h
index af2db582d6..8620ad80bd 100644
--- a/include/tvm/relax/nested_msg.h
+++ b/include/tvm/relax/nested_msg.h
@@ -33,6 +33,7 @@
#include <tvm/relax/expr.h>
#include <tvm/relax/struct_info.h>
+#include <string>
#include <utility>
#include <vector>
@@ -115,7 +116,7 @@ namespace relax {
* use this class or logic of a similar kind.
*/
template <typename T>
-class NestedMsg : public ObjectRef {
+class NestedMsg {
public:
// default constructors.
NestedMsg() = default;
@@ -123,12 +124,6 @@ class NestedMsg : public ObjectRef {
NestedMsg(NestedMsg<T>&&) = default;
NestedMsg<T>& operator=(const NestedMsg<T>&) = default;
NestedMsg<T>& operator=(NestedMsg<T>&&) = default;
- /*!
- * \brief Construct from an ObjectPtr
- * whose type already satisfies the constraint
- * \param ptr
- */
- explicit NestedMsg(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief Nullopt handling */
NestedMsg(std::nullopt_t) {} // NOLINT(*)
// nullptr handling.
@@ -140,16 +135,17 @@ class NestedMsg : public ObjectRef {
}
// normal value handling.
NestedMsg(T other) // NOLINT(*)
- : ObjectRef(std::move(other)) {}
+ : data_(std::move(other)) {}
NestedMsg<T>& operator=(T other) {
- ObjectRef::operator=(std::move(other));
+ data_ = std::move(other);
return *this;
}
// Array<NestedMsg<T>> handling
NestedMsg(Array<NestedMsg<T>, void> other) // NOLINT(*)
- : ObjectRef(std::move(other)) {}
+ : data_(other) {}
+
NestedMsg<T>& operator=(Array<NestedMsg<T>, void> other) {
- ObjectRef::operator=(std::move(other));
+ data_ = std::move(other);
return *this;
}
@@ -170,13 +166,16 @@ class NestedMsg : public ObjectRef {
bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
/*! \return Whether the nested message is not-null leaf value */
- bool IsLeaf() const { return data_ != nullptr &&
data_->IsInstance<LeafContainerType>(); }
+ bool IsLeaf() const {
+ return data_.type_index() != ffi::TypeIndex::kTVMFFINone &&
+ data_.type_index() != ffi::TypeIndex::kTVMFFIArray;
+ }
/*! \return Whether the nested message is null */
- bool IsNull() const { return data_ == nullptr; }
+ bool IsNull() const { return data_.type_index() ==
ffi::TypeIndex::kTVMFFINone; }
/*! \return Whether the nested message is nested */
- bool IsNested() const { return data_ != nullptr &&
data_->IsInstance<ffi::ArrayObj>(); }
+ bool IsNested() const { return data_.type_index() ==
ffi::TypeIndex::kTVMFFIArray; }
/*!
* \return The underlying leaf value.
@@ -184,7 +183,7 @@ class NestedMsg : public ObjectRef {
*/
T LeafValue() const {
ICHECK(IsLeaf());
- return T(data_);
+ return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
}
/*!
@@ -192,16 +191,15 @@ class NestedMsg : public ObjectRef {
* \note This checks if the underlying data type is array.
*/
Array<NestedMsg<T>, void> NestedArray() const {
- ICHECK(IsNested());
- return Array<NestedMsg<T>, void>(data_);
+ return
ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<Array<NestedMsg<T>,
void>>(data_);
}
- using ContainerType = Object;
- using LeafContainerType = typename T::ContainerType;
-
- static_assert(std::is_base_of<ObjectRef, T>::value, "NestedMsg is only
defined for ObjectRef.");
-
- static constexpr bool _type_is_nullable = true;
+ private:
+ ffi::Any data_;
+ // private constructor
+ explicit NestedMsg(ffi::Any data) : data_(data) {}
+ template <typename, typename>
+ friend struct ffi::TypeTraits;
};
/*!
@@ -598,5 +596,83 @@ StructInfo TransformTupleLeaf(StructInfo sinfo,
std::array<NestedMsg<T>, N> msgs
}
} // namespace relax
+
+namespace ffi {
+
+template <typename T>
+inline constexpr bool use_default_type_traits_v<relax::NestedMsg<T>> = false;
+
+template <typename T>
+struct TypeTraits<relax::NestedMsg<T>> : public TypeTraitsBase {
+ TVM_FFI_INLINE static void CopyToAnyView(const relax::NestedMsg<T>& src,
TVMFFIAny* result) {
+ *result = ffi::AnyView(src.data_).CopyToTVMFFIAny();
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(relax::NestedMsg<T> src, TVMFFIAny*
result) {
+ *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
+ }
+
+ TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
+ return TypeTraitsBase::GetMismatchTypeInfo(src);
+ }
+
+ static bool CheckAnyStrict(const TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFINone) {
+ return true;
+ }
+ if (TypeTraits<T>::CheckAnyStrict(src)) {
+ return true;
+ }
+ if (src->type_index == TypeIndex::kTVMFFIArray) {
+ const ffi::ArrayObj* array = reinterpret_cast<const
ffi::ArrayObj*>(src->v_obj);
+ for (size_t i = 0; i < array->size(); ++i) {
+ const Any& any_v = (*array)[i];
+ if (!details::AnyUnsafe::CheckAnyStrict<relax::NestedMsg<T>>(any_v))
return false;
+ }
+ }
+ return true;
+ }
+
+ TVM_FFI_INLINE static relax::NestedMsg<T> CopyFromAnyViewAfterCheck(const
TVMFFIAny* src) {
+ return relax::NestedMsg<T>(Any(AnyView::CopyFromTVMFFIAny(*src)));
+ }
+
+ TVM_FFI_INLINE static relax::NestedMsg<T> MoveFromAnyAfterCheck(TVMFFIAny*
src) {
+ return
relax::NestedMsg<T>(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src)));
+ }
+
+ static std::optional<relax::NestedMsg<T>> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (CheckAnyStrict(src)) {
+ return CopyFromAnyViewAfterCheck(src);
+ }
+ // slow path run conversion
+ if (src->type_index == TypeIndex::kTVMFFINone) {
+ return relax::NestedMsg<T>(std::nullopt);
+ }
+ if (auto opt_value = TypeTraits<T>::TryCastFromAnyView(src)) {
+ return relax::NestedMsg<T>(*std::move(opt_value));
+ }
+ if (src->type_index == TypeIndex::kTVMFFIArray) {
+ const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
+ Array<relax::NestedMsg<T>> result;
+ result.reserve(n->size());
+ for (size_t i = 0; i < n->size(); i++) {
+ const Any& any_v = (*n)[i];
+ if (auto opt_v = any_v.try_cast<relax::NestedMsg<T>>()) {
+ result.push_back(*std::move(opt_v));
+ } else {
+ return std::nullopt;
+ }
+ }
+ return relax::NestedMsg<T>(result);
+ }
+ return std::nullopt;
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ return "NestedMsg<" + details::Type2Str<T>::v() + ">";
+ }
+};
+} // namespace ffi
} // namespace tvm
#endif // TVM_RELAX_NESTED_MSG_H_
diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc
index d552dae8f7..644a80664f 100644
--- a/tests/cpp/nested_msg_test.cc
+++ b/tests/cpp/nested_msg_test.cc
@@ -53,7 +53,7 @@ TEST(NestedMsg, Basic) {
EXPECT_ANY_THROW(msg.LeafValue());
auto arr = msg.NestedArray();
- EXPECT_TRUE(arr[0].same_as(x));
+ EXPECT_TRUE(arr[0].LeafValue().same_as(x));
EXPECT_TRUE(arr[1] == nullptr);
EXPECT_TRUE(arr[1].IsNull());
@@ -72,13 +72,24 @@ TEST(NestedMsg, Basic) {
EXPECT_TRUE(a0.IsNested());
auto t0 = a0.NestedArray()[1];
EXPECT_TRUE(t0.IsNested());
- EXPECT_TRUE(t0.NestedArray()[2].same_as(y));
+ EXPECT_TRUE(t0.NestedArray()[2].LeafValue().same_as(y));
// assign leaf
a0 = x;
EXPECT_TRUE(a0.IsLeaf());
- EXPECT_TRUE(a0.same_as(x));
+ EXPECT_TRUE(a0.LeafValue().same_as(x));
+}
+
+TEST(NestedMsg, IntAndAny) {
+ NestedMsg<int64_t> msg({1, std::nullopt, 2});
+ Any any_msg = msg;
+ NestedMsg<int64_t> msg2 = any_msg.cast<NestedMsg<int64_t>>();
+
+ EXPECT_TRUE(msg2.IsNested());
+ EXPECT_EQ(msg2.NestedArray()[0].LeafValue(), 1);
+ EXPECT_TRUE(msg2.NestedArray()[1].IsNull());
+ EXPECT_EQ(msg2.NestedArray()[2].LeafValue(), 2);
}
TEST(NestedMsg, ForEachLeaf) {
@@ -174,13 +185,13 @@ TEST(NestedMsg, MapAndDecompose) {
DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg<Integer> msg) {
if (value.same_as(x)) {
- EXPECT_TRUE(msg.same_as(c0));
+ EXPECT_TRUE(msg.LeafValue().same_as(c0));
++x_count;
} else if (value.same_as(y)) {
- EXPECT_TRUE(msg.same_as(c1));
+ EXPECT_TRUE(msg.LeafValue().same_as(c1));
++y_count;
} else {
- EXPECT_TRUE(msg.same_as(c2));
+ EXPECT_TRUE(msg.LeafValue().same_as(c2));
++z_count;
}
});