This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 92ea12d7e2 [FFI] Variant specialize for all ObjectRef (#17943)
92ea12d7e2 is described below

commit 92ea12d7e2e9b8e585c94708d5b91b968e728eaf
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat May 10 13:03:54 2025 -0400

    [FFI] Variant specialize for all ObjectRef (#17943)
---
 ffi/include/tvm/ffi/any.h                         |  9 +++
 ffi/include/tvm/ffi/base_details.h                |  6 +-
 ffi/include/tvm/ffi/container/container_details.h |  8 ++
 ffi/include/tvm/ffi/container/variant.h           | 98 ++++++++++++++++++-----
 ffi/tests/cpp/test_any.cc                         |  1 +
 ffi/tests/cpp/test_map.cc                         |  2 +-
 ffi/tests/cpp/test_variant.cc                     | 27 +++++++
 7 files changed, 127 insertions(+), 24 deletions(-)

diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 8274904037..4ec72d6846 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -425,6 +425,15 @@ struct AnyUnsafe : public ObjectUnsafe {
     }
   }
 
+  template <typename T>
+  static TVM_FFI_INLINE T MoveFromAnyStorageAfterCheck(Any&& ref) {
+    if constexpr (!std::is_same_v<T, Any>) {
+      return TypeTraits<T>::MoveFromAnyStorageAfterCheck(&(ref.data_));
+    } else {
+      return std::move(ref);
+    }
+  }
+
   static TVM_FFI_INLINE Object* ObjectPtrFromAnyAfterCheck(const Any& ref) {
     return reinterpret_cast<Object*>(ref.data_.v_obj);
   }
diff --git a/ffi/include/tvm/ffi/base_details.h 
b/ffi/include/tvm/ffi/base_details.h
index 18cc3ecb72..eeb892eff6 100644
--- a/ffi/include/tvm/ffi/base_details.h
+++ b/ffi/include/tvm/ffi/base_details.h
@@ -123,9 +123,9 @@
  * This macro is used to clear the padding parts for hash and equality check
  * in 32bit platform.
  */
-#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result)                \
-  if constexpr (sizeof(result->v_obj) != sizeof(result->v_int64)) { \
-    result->v_int64 = 0;                                            \
+#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result)                    \
+  if constexpr (sizeof((result)->v_obj) != sizeof((result)->v_int64)) { \
+    (result)->v_int64 = 0;                                              \
   }
 
 namespace tvm {
diff --git a/ffi/include/tvm/ffi/container/container_details.h 
b/ffi/include/tvm/ffi/container/container_details.h
index 51e130f373..cfc5590f54 100644
--- a/ffi/include/tvm/ffi/container/container_details.h
+++ b/ffi/include/tvm/ffi/container/container_details.h
@@ -284,6 +284,14 @@ inline constexpr bool storage_enabled_v = 
std::is_same_v<T, Any> || TypeTraits<T
 template <typename... T>
 inline constexpr bool all_storage_enabled_v = (storage_enabled_v<T> && ...);
 
+/*!
+ * \brief Check if all T are compatible with Any.
+ *
+ * \tparam T The type to check.
+ * \return True if T is compatible with Any, false otherwise.
+ */
+template <typename... T>
+inline constexpr bool all_object_ref_v = (std::is_base_of_v<ObjectRef, T> && 
...);
 /**
  * \brief Check if Any storage of Derived can always be directly used as Base.
  *
diff --git a/ffi/include/tvm/ffi/container/variant.h 
b/ffi/include/tvm/ffi/container/variant.h
index f134be8331..c2b0688900 100644
--- a/ffi/include/tvm/ffi/container/variant.h
+++ b/ffi/include/tvm/ffi/container/variant.h
@@ -34,15 +34,73 @@
 
 namespace tvm {
 namespace ffi {
+namespace details {
+/*!
+ * \brief Base class for Variant.
+ *
+ * \tparam all_storage_object Whether all types are derived from ObjectRef.
+ */
+template <bool all_storage_object = false>
+class VariantBase {
+ public:
+  TVM_FFI_INLINE bool same_as(const VariantBase<all_storage_object>& other) 
const {
+    return data_.same_as(other.data_);
+  }
+
+ protected:
+  template <typename T>
+  explicit VariantBase(T other) : data_(std::move(other)) {}
+
+  TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); 
}
+
+  TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); }
+
+  TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); }
+
+  Any data_;
+};
+
+// Specialization for all object ref case, backed by ObjectRef.
+template <>
+class VariantBase<true> : public ObjectRef {
+ protected:
+  template <typename T>
+  explicit VariantBase(const T& other) : ObjectRef(other) {}
+  template <typename T>
+  explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {}
+  explicit VariantBase(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
+  explicit VariantBase(Any other)
+      : 
ObjectRef(details::AnyUnsafe::MoveFromAnyStorageAfterCheck<ObjectRef>(std::move(other)))
 {}
+
+  TVM_FFI_INLINE void SetData(ObjectPtr<Object> other) { data_ = 
std::move(other); }
+
+  TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); 
}
+
+  TVM_FFI_INLINE AnyView ToAnyView() const {
+    TVMFFIAny any_data;
+    if (data_ == nullptr) {
+      any_data.type_index = TypeIndex::kTVMFFINone;
+      any_data.v_int64 = 0;
+    } else {
+      TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+      any_data.type_index = data_->type_index();
+      any_data.v_obj = 
details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_);
+    }
+    return AnyView::CopyFromTVMFFIAny(any_data);
+  }
+};
+}  // namespace details
 
 /*!
  * \brief A typed variant container.
  *
- * A Variant is backed by Any container, with strong checks during 
construction.
+ * When all values are ObjectRef, Variant is backed by ObjectRef,
+ * otherwise it is backed by Any.
  */
 template <typename... V>
-class Variant {
+class Variant : public details::VariantBase<details::all_object_ref_v<V...>> {
  public:
+  using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
   static_assert(details::all_storage_enabled_v<V...>,
                 "All types used in Variant<...> must be compatible with Any");
   /*
@@ -54,31 +112,30 @@ class Variant {
   template <typename T>
   using enable_if_variant_contains_t = std::enable_if_t<variant_contains_v<T>>;
 
-  Variant(const Variant<V...>& other) : data_(other.data_) {}
-  Variant(Variant<V...>&& other) : data_(std::move(other.data_)) {}
+  Variant(const Variant<V...>& other) : TParent(other.data_) {}
+  Variant(Variant<V...>&& other) : TParent(std::move(other.data_)) {}
 
   TVM_FFI_INLINE Variant& operator=(const Variant<V...>& other) {
-    data_ = other.data_;
+    this->SetData(other.data_);
     return *this;
   }
 
   TVM_FFI_INLINE Variant& operator=(Variant<V...>&& other) {
-    data_ = std::move(other.data_);
+    this->SetData(std::move(other.data_));
     return *this;
   }
 
   template <typename T, typename = enable_if_variant_contains_t<T>>
-  Variant(T other) : data_(std::move(other)) {}  // NOLINT(*)
+  Variant(T other) : TParent(std::move(other)) {}  // NOLINT(*)
 
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE Variant& operator=(T other) {
-    data_ = std::move(other);
-    return *this;
+    return operator=(Variant(std::move(other)));
   }
 
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE std::optional<T> as() const {
-    return data_.as<T>();
+    return this->TParent::ToAnyView().template as<T>();
   }
 
   /*
@@ -89,29 +146,27 @@ class Variant {
    */
   template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, 
T>>>
   TVM_FFI_INLINE const T* as() const {
-    return data_.as<const T*>().value_or(nullptr);
+    return this->TParent::ToAnyView().template as<const 
T*>().value_or(nullptr);
   }
 
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE T get() const& {
-    return data_.template cast<T>();
+    return this->TParent::ToAnyView().template cast<T>();
   }
 
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE T get() && {
-    return std::move(data_).template cast<T>();
+    return std::move(*this).TParent::MoveToAny().template cast<T>();
   }
 
-  TVM_FFI_INLINE std::string GetTypeKey() const { return data_.GetTypeKey(); }
+  TVM_FFI_INLINE std::string GetTypeKey() const { return 
this->TParent::ToAnyView().GetTypeKey(); }
 
  private:
   friend struct TypeTraits<Variant<V...>>;
   friend struct ObjectPtrHash;
   friend struct ObjectPtrEqual;
   // constructor from any
-  explicit Variant(Any data) : data_(std::move(data)) {}
-  // internal data is backed by Any
-  Any data_;
+  explicit Variant(Any data) : TParent(std::move(data)) {}
   /*!
    * \brief Get the object pointer from the variant
    * \note This function is only available if all types used in Variant<...> 
are derived from
@@ -122,8 +177,11 @@ class Variant {
     static_assert(all_object_v,
                   "All types used in Variant<...> must be derived from 
ObjectRef "
                   "to enable ObjectPtrHash/ObjectPtrEqual");
-    return details::AnyUnsafe::ObjectPtrFromAnyAfterCheck(data_);
+    return this->data_.get();
   }
+  // rexpose to friend class
+  using TParent::MoveToAny;
+  using TParent::ToAnyView;
 };
 
 template <typename... V>
@@ -132,11 +190,11 @@ inline constexpr bool 
use_default_type_traits_v<Variant<V...>> = false;
 template <typename... V>
 struct TypeTraits<Variant<V...>> : public TypeTraitsBase {
   static TVM_FFI_INLINE void CopyToAnyView(const Variant<V...>& src, 
TVMFFIAny* result) {
-    *result = AnyView(src.data_).CopyToTVMFFIAny();
+    *result = src.ToAnyView().CopyToTVMFFIAny();
   }
 
   static TVM_FFI_INLINE void MoveToAny(Variant<V...> src, TVMFFIAny* result) {
-    *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
+    *result = 
details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny());
   }
 
   static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc
index 816ae28e0e..d84cc64ae4 100644
--- a/ffi/tests/cpp/test_any.cc
+++ b/ffi/tests/cpp/test_any.cc
@@ -337,6 +337,7 @@ TEST(Any, ObjectMove) {
   auto v0 = std::move(any1).cast<TPrimExpr>();
   EXPECT_EQ(v0->value, 3.14);
   EXPECT_EQ(v0.use_count(), 1);
+  EXPECT_TRUE(any1 == nullptr);
 }
 
 }  // namespace
diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc
index bd0b58b7c4..b7c977fd34 100644
--- a/ffi/tests/cpp/test_map.cc
+++ b/ffi/tests/cpp/test_map.cc
@@ -243,7 +243,7 @@ TEST(Map, AnyConvertCheck) {
       ::tvm::ffi::Error);
 }
 
-TEST(Map, ffi::FunctionGetItem) {
+TEST(Map, FunctionGetItem) {
   Function f = Function::FromTyped([](const MapObj* n, const Any& k) -> Any { 
return n->at(k); },
                                    "map_get_item");
   Map<String, int64_t> map{{"x", 1}, {"y", 2}};
diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc
index ee49ac75d1..17a1129087 100644
--- a/ffi/tests/cpp/test_variant.cc
+++ b/ffi/tests/cpp/test_variant.cc
@@ -134,4 +134,31 @@ TEST(Variant, Upcast) {
   EXPECT_EQ(a1[0].get<int>(), 1);
 }
 
+TEST(Variant, AllObjectRef) {
+  Variant<TInt, Array<TInt>> v0 = TInt(1);
+  EXPECT_EQ(v0.get<TInt>()->value, 1);
+  static_assert(std::is_base_of_v<ObjectRef, decltype(v0)>);
+  Any any0 = v0;
+  EXPECT_EQ(any0.cast<TInt>()->value, 1);
+  auto v2 = any0.cast<Variant<TInt, Array<TInt>>>();
+  EXPECT_TRUE(v0.same_as(v2));
+  // assignment operator
+  v0 = Array<TInt>({TInt(2), TInt(3)});
+  EXPECT_EQ(v0.get<Array<TInt>>().size(), 2);
+  EXPECT_EQ(v0.get<Array<TInt>>()[0]->value, 2);
+  EXPECT_EQ(v0.get<Array<TInt>>()[1]->value, 3);
+  EXPECT_EQ(sizeof(v0), sizeof(ObjectRef));
+}
+
+TEST(Variant, PODSameAs) {
+  Variant<String, int> v0 = 1;
+  Variant<String, int> v1 = 1;
+  EXPECT_TRUE(v0.same_as(v1));
+  String s = String("hello");
+  v0 = s;
+  v1 = s;
+  EXPECT_TRUE(v0.same_as(v1));
+  v1 = String("hello");
+  EXPECT_TRUE(!v0.same_as(v1));
+}
 }  // namespace

Reply via email to