This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 92ea12d7e2 [FFI] Variant specialize for all ObjectRef (#17943)
92ea12d7e2 is described below
commit 92ea12d7e2e9b8e585c94708d5b91b968e728eaf
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat May 10 13:03:54 2025 -0400
[FFI] Variant specialize for all ObjectRef (#17943)
---
ffi/include/tvm/ffi/any.h | 9 +++
ffi/include/tvm/ffi/base_details.h | 6 +-
ffi/include/tvm/ffi/container/container_details.h | 8 ++
ffi/include/tvm/ffi/container/variant.h | 98 ++++++++++++++++++-----
ffi/tests/cpp/test_any.cc | 1 +
ffi/tests/cpp/test_map.cc | 2 +-
ffi/tests/cpp/test_variant.cc | 27 +++++++
7 files changed, 127 insertions(+), 24 deletions(-)
diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 8274904037..4ec72d6846 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -425,6 +425,15 @@ struct AnyUnsafe : public ObjectUnsafe {
}
}
+ template <typename T>
+ static TVM_FFI_INLINE T MoveFromAnyStorageAfterCheck(Any&& ref) {
+ if constexpr (!std::is_same_v<T, Any>) {
+ return TypeTraits<T>::MoveFromAnyStorageAfterCheck(&(ref.data_));
+ } else {
+ return std::move(ref);
+ }
+ }
+
static TVM_FFI_INLINE Object* ObjectPtrFromAnyAfterCheck(const Any& ref) {
return reinterpret_cast<Object*>(ref.data_.v_obj);
}
diff --git a/ffi/include/tvm/ffi/base_details.h
b/ffi/include/tvm/ffi/base_details.h
index 18cc3ecb72..eeb892eff6 100644
--- a/ffi/include/tvm/ffi/base_details.h
+++ b/ffi/include/tvm/ffi/base_details.h
@@ -123,9 +123,9 @@
* This macro is used to clear the padding parts for hash and equality check
* in 32bit platform.
*/
-#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \
- if constexpr (sizeof(result->v_obj) != sizeof(result->v_int64)) { \
- result->v_int64 = 0; \
+#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \
+ if constexpr (sizeof((result)->v_obj) != sizeof((result)->v_int64)) { \
+ (result)->v_int64 = 0; \
}
namespace tvm {
diff --git a/ffi/include/tvm/ffi/container/container_details.h
b/ffi/include/tvm/ffi/container/container_details.h
index 51e130f373..cfc5590f54 100644
--- a/ffi/include/tvm/ffi/container/container_details.h
+++ b/ffi/include/tvm/ffi/container/container_details.h
@@ -284,6 +284,14 @@ inline constexpr bool storage_enabled_v =
std::is_same_v<T, Any> || TypeTraits<T
template <typename... T>
inline constexpr bool all_storage_enabled_v = (storage_enabled_v<T> && ...);
+/*!
+ * \brief Check if all T are compatible with Any.
+ *
+ * \tparam T The type to check.
+ * \return True if T is compatible with Any, false otherwise.
+ */
+template <typename... T>
+inline constexpr bool all_object_ref_v = (std::is_base_of_v<ObjectRef, T> &&
...);
/**
* \brief Check if Any storage of Derived can always be directly used as Base.
*
diff --git a/ffi/include/tvm/ffi/container/variant.h
b/ffi/include/tvm/ffi/container/variant.h
index f134be8331..c2b0688900 100644
--- a/ffi/include/tvm/ffi/container/variant.h
+++ b/ffi/include/tvm/ffi/container/variant.h
@@ -34,15 +34,73 @@
namespace tvm {
namespace ffi {
+namespace details {
+/*!
+ * \brief Base class for Variant.
+ *
+ * \tparam all_storage_object Whether all types are derived from ObjectRef.
+ */
+template <bool all_storage_object = false>
+class VariantBase {
+ public:
+ TVM_FFI_INLINE bool same_as(const VariantBase<all_storage_object>& other)
const {
+ return data_.same_as(other.data_);
+ }
+
+ protected:
+ template <typename T>
+ explicit VariantBase(T other) : data_(std::move(other)) {}
+
+ TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data);
}
+
+ TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); }
+
+ TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); }
+
+ Any data_;
+};
+
+// Specialization for all object ref case, backed by ObjectRef.
+template <>
+class VariantBase<true> : public ObjectRef {
+ protected:
+ template <typename T>
+ explicit VariantBase(const T& other) : ObjectRef(other) {}
+ template <typename T>
+ explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {}
+ explicit VariantBase(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
+ explicit VariantBase(Any other)
+ :
ObjectRef(details::AnyUnsafe::MoveFromAnyStorageAfterCheck<ObjectRef>(std::move(other)))
{}
+
+ TVM_FFI_INLINE void SetData(ObjectPtr<Object> other) { data_ =
std::move(other); }
+
+ TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_)));
}
+
+ TVM_FFI_INLINE AnyView ToAnyView() const {
+ TVMFFIAny any_data;
+ if (data_ == nullptr) {
+ any_data.type_index = TypeIndex::kTVMFFINone;
+ any_data.v_int64 = 0;
+ } else {
+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+ any_data.type_index = data_->type_index();
+ any_data.v_obj =
details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_);
+ }
+ return AnyView::CopyFromTVMFFIAny(any_data);
+ }
+};
+} // namespace details
/*!
* \brief A typed variant container.
*
- * A Variant is backed by Any container, with strong checks during
construction.
+ * When all values are ObjectRef, Variant is backed by ObjectRef,
+ * otherwise it is backed by Any.
*/
template <typename... V>
-class Variant {
+class Variant : public details::VariantBase<details::all_object_ref_v<V...>> {
public:
+ using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
static_assert(details::all_storage_enabled_v<V...>,
"All types used in Variant<...> must be compatible with Any");
/*
@@ -54,31 +112,30 @@ class Variant {
template <typename T>
using enable_if_variant_contains_t = std::enable_if_t<variant_contains_v<T>>;
- Variant(const Variant<V...>& other) : data_(other.data_) {}
- Variant(Variant<V...>&& other) : data_(std::move(other.data_)) {}
+ Variant(const Variant<V...>& other) : TParent(other.data_) {}
+ Variant(Variant<V...>&& other) : TParent(std::move(other.data_)) {}
TVM_FFI_INLINE Variant& operator=(const Variant<V...>& other) {
- data_ = other.data_;
+ this->SetData(other.data_);
return *this;
}
TVM_FFI_INLINE Variant& operator=(Variant<V...>&& other) {
- data_ = std::move(other.data_);
+ this->SetData(std::move(other.data_));
return *this;
}
template <typename T, typename = enable_if_variant_contains_t<T>>
- Variant(T other) : data_(std::move(other)) {} // NOLINT(*)
+ Variant(T other) : TParent(std::move(other)) {} // NOLINT(*)
template <typename T, typename = enable_if_variant_contains_t<T>>
TVM_FFI_INLINE Variant& operator=(T other) {
- data_ = std::move(other);
- return *this;
+ return operator=(Variant(std::move(other)));
}
template <typename T, typename = enable_if_variant_contains_t<T>>
TVM_FFI_INLINE std::optional<T> as() const {
- return data_.as<T>();
+ return this->TParent::ToAnyView().template as<T>();
}
/*
@@ -89,29 +146,27 @@ class Variant {
*/
template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object,
T>>>
TVM_FFI_INLINE const T* as() const {
- return data_.as<const T*>().value_or(nullptr);
+ return this->TParent::ToAnyView().template as<const
T*>().value_or(nullptr);
}
template <typename T, typename = enable_if_variant_contains_t<T>>
TVM_FFI_INLINE T get() const& {
- return data_.template cast<T>();
+ return this->TParent::ToAnyView().template cast<T>();
}
template <typename T, typename = enable_if_variant_contains_t<T>>
TVM_FFI_INLINE T get() && {
- return std::move(data_).template cast<T>();
+ return std::move(*this).TParent::MoveToAny().template cast<T>();
}
- TVM_FFI_INLINE std::string GetTypeKey() const { return data_.GetTypeKey(); }
+ TVM_FFI_INLINE std::string GetTypeKey() const { return
this->TParent::ToAnyView().GetTypeKey(); }
private:
friend struct TypeTraits<Variant<V...>>;
friend struct ObjectPtrHash;
friend struct ObjectPtrEqual;
// constructor from any
- explicit Variant(Any data) : data_(std::move(data)) {}
- // internal data is backed by Any
- Any data_;
+ explicit Variant(Any data) : TParent(std::move(data)) {}
/*!
* \brief Get the object pointer from the variant
* \note This function is only available if all types used in Variant<...>
are derived from
@@ -122,8 +177,11 @@ class Variant {
static_assert(all_object_v,
"All types used in Variant<...> must be derived from
ObjectRef "
"to enable ObjectPtrHash/ObjectPtrEqual");
- return details::AnyUnsafe::ObjectPtrFromAnyAfterCheck(data_);
+ return this->data_.get();
}
+ // rexpose to friend class
+ using TParent::MoveToAny;
+ using TParent::ToAnyView;
};
template <typename... V>
@@ -132,11 +190,11 @@ inline constexpr bool
use_default_type_traits_v<Variant<V...>> = false;
template <typename... V>
struct TypeTraits<Variant<V...>> : public TypeTraitsBase {
static TVM_FFI_INLINE void CopyToAnyView(const Variant<V...>& src,
TVMFFIAny* result) {
- *result = AnyView(src.data_).CopyToTVMFFIAny();
+ *result = src.ToAnyView().CopyToTVMFFIAny();
}
static TVM_FFI_INLINE void MoveToAny(Variant<V...> src, TVMFFIAny* result) {
- *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
+ *result =
details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny());
}
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc
index 816ae28e0e..d84cc64ae4 100644
--- a/ffi/tests/cpp/test_any.cc
+++ b/ffi/tests/cpp/test_any.cc
@@ -337,6 +337,7 @@ TEST(Any, ObjectMove) {
auto v0 = std::move(any1).cast<TPrimExpr>();
EXPECT_EQ(v0->value, 3.14);
EXPECT_EQ(v0.use_count(), 1);
+ EXPECT_TRUE(any1 == nullptr);
}
} // namespace
diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc
index bd0b58b7c4..b7c977fd34 100644
--- a/ffi/tests/cpp/test_map.cc
+++ b/ffi/tests/cpp/test_map.cc
@@ -243,7 +243,7 @@ TEST(Map, AnyConvertCheck) {
::tvm::ffi::Error);
}
-TEST(Map, ffi::FunctionGetItem) {
+TEST(Map, FunctionGetItem) {
Function f = Function::FromTyped([](const MapObj* n, const Any& k) -> Any {
return n->at(k); },
"map_get_item");
Map<String, int64_t> map{{"x", 1}, {"y", 2}};
diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc
index ee49ac75d1..17a1129087 100644
--- a/ffi/tests/cpp/test_variant.cc
+++ b/ffi/tests/cpp/test_variant.cc
@@ -134,4 +134,31 @@ TEST(Variant, Upcast) {
EXPECT_EQ(a1[0].get<int>(), 1);
}
+TEST(Variant, AllObjectRef) {
+ Variant<TInt, Array<TInt>> v0 = TInt(1);
+ EXPECT_EQ(v0.get<TInt>()->value, 1);
+ static_assert(std::is_base_of_v<ObjectRef, decltype(v0)>);
+ Any any0 = v0;
+ EXPECT_EQ(any0.cast<TInt>()->value, 1);
+ auto v2 = any0.cast<Variant<TInt, Array<TInt>>>();
+ EXPECT_TRUE(v0.same_as(v2));
+ // assignment operator
+ v0 = Array<TInt>({TInt(2), TInt(3)});
+ EXPECT_EQ(v0.get<Array<TInt>>().size(), 2);
+ EXPECT_EQ(v0.get<Array<TInt>>()[0]->value, 2);
+ EXPECT_EQ(v0.get<Array<TInt>>()[1]->value, 3);
+ EXPECT_EQ(sizeof(v0), sizeof(ObjectRef));
+}
+
+TEST(Variant, PODSameAs) {
+ Variant<String, int> v0 = 1;
+ Variant<String, int> v1 = 1;
+ EXPECT_TRUE(v0.same_as(v1));
+ String s = String("hello");
+ v0 = s;
+ v1 = s;
+ EXPECT_TRUE(v0.same_as(v1));
+ v1 = String("hello");
+ EXPECT_TRUE(!v0.same_as(v1));
+}
} // namespace