This is an automated email from the ASF dual-hosted git repository.

MasterJH5574 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 84ee1a07 [CORE] Add type subsumption and object-ref containment 
relations (#647)
84ee1a07 is described below

commit 84ee1a07f85645886d2eebc9ddd0ddf9488fd38a
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 28 18:51:06 2026 -0400

    [CORE] Add type subsumption and object-ref containment relations (#647)
---
 include/tvm/ffi/cast.h              | 31 +++++++++++++++++--------------
 include/tvm/ffi/container/array.h   | 15 ++++++++-------
 include/tvm/ffi/container/dict.h    | 21 +++++++++------------
 include/tvm/ffi/container/list.h    | 15 ++++++++-------
 include/tvm/ffi/container/map.h     | 21 +++++++++------------
 include/tvm/ffi/container/tuple.h   | 17 ++++++++++-------
 include/tvm/ffi/container/variant.h |  9 +++++----
 include/tvm/ffi/object.h            | 19 +++++++++++++++++++
 include/tvm/ffi/type_traits.h       | 29 ++++++++++++++++-------------
 tests/cpp/test_array.cc             |  5 +++--
 tests/cpp/test_map.cc               |  2 +-
 tests/cpp/test_object.cc            | 36 ++++++++++++++++++++++++++++++++++++
 tests/cpp/test_tuple.cc             |  8 +++++---
 tests/cpp/test_variant.cc           |  2 +-
 14 files changed, 147 insertions(+), 83 deletions(-)

diff --git a/include/tvm/ffi/cast.h b/include/tvm/ffi/cast.h
index 655f1891..44228fe2 100644
--- a/include/tvm/ffi/cast.h
+++ b/include/tvm/ffi/cast.h
@@ -27,6 +27,8 @@
 #include <tvm/ffi/object.h>
 #include <tvm/ffi/optional.h>
 
+#include <type_traits>
+
 namespace tvm {
 namespace ffi {
 
@@ -44,23 +46,24 @@ namespace ffi {
  */
 template <typename RefType, typename ObjectType>
 inline RefType GetRef(const ObjectType* ptr) {
-  using ContainerType = typename RefType::ContainerType;
-  static_assert(RefType::_type_container_is_exact,
-                "GetRef requires RefType::ContainerType to exactly describe 
all objects the ref "
-                "can hold; use ObjectRef::as<RefType>() for richer 
TypeTraits-based refs");
-  static_assert(std::is_base_of_v<ContainerType, ObjectType>,
-                "Can only cast to the ref of same container type");
-
-  if constexpr (is_optional_type_v<RefType> || RefType::_type_is_nullable) {
-    if (ptr == nullptr) {
-      return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(nullptr);
+  if constexpr (object_ref_contains_v<RefType, ObjectType>) {
+    if constexpr (is_optional_type_v<RefType> || RefType::_type_is_nullable) {
+      if (ptr == nullptr) {
+        return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(nullptr);
+      }
+    } else {
+      TVM_FFI_ICHECK_NOTNULL(ptr);
     }
+    return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(
+        details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
+            const_cast<Object*>(static_cast<const Object*>(ptr))));
   } else {
-    TVM_FFI_ICHECK_NOTNULL(ptr);
+    static_assert(object_ref_contains_v<RefType, ObjectType>,
+                  "GetRef requires RefType to contain every ObjectType 
instance; specialize "
+                  "object_ref_contains_v for statically safe typed refs or use 
"
+                  "ObjectRef::as<RefType>() for runtime-dependent checks");
+    TVM_FFI_UNREACHABLE();
   }
-  return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(
-      details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
-          const_cast<Object*>(static_cast<const Object*>(ptr))));
 }
 
 /*!
diff --git a/include/tvm/ffi/container/array.h 
b/include/tvm/ffi/container/array.h
index 82380c0a..fd2c8233 100644
--- a/include/tvm/ffi/container/array.h
+++ b/include/tvm/ffi/container/array.h
@@ -241,7 +241,7 @@ class Array : public ObjectRef {
    * \param other The other array
    * \tparam U The value type of the other array
    */
-  template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
+  template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
   Array(Array<U>&& other)  // NOLINT(google-explicit-constructor)
       : ObjectRef(std::move(other.data_)) {}
   /*!
@@ -249,7 +249,7 @@ class Array : public ObjectRef {
    * \param other The other array
    * \tparam U The value type of the other array
    */
-  template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
+  template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
   Array(const Array<U>& other)  // NOLINT(google-explicit-constructor)
       : ObjectRef(other.data_) {}
 
@@ -274,7 +274,7 @@ class Array : public ObjectRef {
    * \param other The other array
    * \tparam U The value type of the other array
    */
-  template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
+  template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
   TVM_FFI_INLINE Array<T>& operator=(Array<U>&& other) {
     data_ = std::move(other.data_);
     return *this;
@@ -284,7 +284,7 @@ class Array : public ObjectRef {
    * \param other The other array
    * \tparam U The value type of the other array
    */
-  template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
+  template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
   TVM_FFI_INLINE Array<T>& operator=(const Array<U>& other) {
     data_ = other.data_;
     return *this;
@@ -895,10 +895,11 @@ struct TypeTraits<Array<T>> : public 
SeqTypeTraitsBase<TypeTraits<Array<T>>, Arr
   }
 };
 
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target Array storage subsumes source Array storage 
element-wise. */
 template <typename T, typename U>
-inline constexpr bool type_contains_v<Array<T>, Array<U>> = type_contains_v<T, 
U>;
-}  // namespace details
+inline constexpr bool type_subsumes_v<Array<T>, Array<U>> = type_subsumes_v<T, 
U>;
+/// \endcond
 
 }  // namespace ffi
 }  // namespace tvm
diff --git a/include/tvm/ffi/container/dict.h b/include/tvm/ffi/container/dict.h
index 1887154f..e75ea2cd 100644
--- a/include/tvm/ffi/container/dict.h
+++ b/include/tvm/ffi/container/dict.h
@@ -102,8 +102,7 @@ class Dict : public ObjectRef {
    * \tparam VU The mapped type of the other dict
    */
   template <typename KU, typename VU,
-            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
-                                        details::type_contains_v<V, VU>>>
+            typename = std::enable_if_t<type_subsumes_v<K, KU> && 
type_subsumes_v<V, VU>>>
   Dict(Dict<KU, VU>&& other)  // NOLINT(google-explicit-constructor)
       : ObjectRef(std::move(other.data_)) {}
 
@@ -114,8 +113,7 @@ class Dict : public ObjectRef {
    * \tparam VU The mapped type of the other dict
    */
   template <typename KU, typename VU,
-            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
-                                        details::type_contains_v<V, VU>>>
+            typename = std::enable_if_t<type_subsumes_v<K, KU> && 
type_subsumes_v<V, VU>>>
   // NOLINTNEXTLINE(google-explicit-constructor)
   Dict(const Dict<KU, VU>& other) : ObjectRef(other.data_) {}
 
@@ -144,8 +142,7 @@ class Dict : public ObjectRef {
    * \tparam VU The mapped type of the other dict
    */
   template <typename KU, typename VU,
-            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
-                                        details::type_contains_v<V, VU>>>
+            typename = std::enable_if_t<type_subsumes_v<K, KU> && 
type_subsumes_v<V, VU>>>
   Dict<K, V>& operator=(Dict<KU, VU>&& other) {
     data_ = std::move(other.data_);
     return *this;
@@ -158,8 +155,7 @@ class Dict : public ObjectRef {
    * \tparam VU The mapped type of the other dict
    */
   template <typename KU, typename VU,
-            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
-                                        details::type_contains_v<V, VU>>>
+            typename = std::enable_if_t<type_subsumes_v<K, KU> && 
type_subsumes_v<V, VU>>>
   Dict<K, V>& operator=(const Dict<KU, VU>& other) {
     data_ = other.data_;
     return *this;
@@ -365,11 +361,12 @@ struct TypeTraits<Dict<K, V>> : public 
MapTypeTraitsBase<TypeTraits<Dict<K, V>>,
   }
 };
 
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target Dict storage subsumes source Dict storage key- and 
value-wise. */
 template <typename K, typename V, typename KU, typename VU>
-inline constexpr bool type_contains_v<Dict<K, V>, Dict<KU, VU>> =
-    type_contains_v<K, KU> && type_contains_v<V, VU>;
-}  // namespace details
+inline constexpr bool type_subsumes_v<Dict<K, V>, Dict<KU, VU>> =
+    type_subsumes_v<K, KU> && type_subsumes_v<V, VU>;
+/// \endcond
 
 }  // namespace ffi
 }  // namespace tvm
diff --git a/include/tvm/ffi/container/list.h b/include/tvm/ffi/container/list.h
index f2292ce9..be2bda81 100644
--- a/include/tvm/ffi/container/list.h
+++ b/include/tvm/ffi/container/list.h
@@ -159,7 +159,7 @@ class List : public ObjectRef {
    * \brief Constructor from another list
    * \tparam U The value type of the other list
    */
-  template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
+  template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
   List(List<U>&& other)  // NOLINT(google-explicit-constructor)
       : ObjectRef(std::move(other.data_)) {}
 
@@ -167,7 +167,7 @@ class List : public ObjectRef {
    * \brief Constructor from another list
    * \tparam U The value type of the other list
    */
-  template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
+  template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
   List(const List<U>& other)  // NOLINT(google-explicit-constructor)
       : ObjectRef(other.data_) {}
 
@@ -194,7 +194,7 @@ class List : public ObjectRef {
    * \param other The other list.
    * \tparam U The value type of the other list.
    */
-  template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
+  template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
   TVM_FFI_INLINE List<T>& operator=(List<U>&& other) {
     data_ = std::move(other.data_);
     return *this;
@@ -205,7 +205,7 @@ class List : public ObjectRef {
    * \param other The other list.
    * \tparam U The value type of the other list.
    */
-  template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
+  template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
   TVM_FFI_INLINE List<T>& operator=(const List<U>& other) {
     data_ = other.data_;
     return *this;
@@ -518,10 +518,11 @@ struct TypeTraits<List<T>> : public 
SeqTypeTraitsBase<TypeTraits<List<T>>, List<
   }
 };
 
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target List storage subsumes source List storage 
element-wise. */
 template <typename T, typename U>
-inline constexpr bool type_contains_v<List<T>, List<U>> = type_contains_v<T, 
U>;
-}  // namespace details
+inline constexpr bool type_subsumes_v<List<T>, List<U>> = type_subsumes_v<T, 
U>;
+/// \endcond
 
 }  // namespace ffi
 }  // namespace tvm
diff --git a/include/tvm/ffi/container/map.h b/include/tvm/ffi/container/map.h
index a9003e20..10013be9 100644
--- a/include/tvm/ffi/container/map.h
+++ b/include/tvm/ffi/container/map.h
@@ -98,8 +98,7 @@ class Map : public ObjectRef {
    * \tparam VU The mapped type of the other map
    */
   template <typename KU, typename VU,
-            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
-                                        details::type_contains_v<V, VU>>>
+            typename = std::enable_if_t<type_subsumes_v<K, KU> && 
type_subsumes_v<V, VU>>>
   Map(Map<KU, VU>&& other)  // NOLINT(google-explicit-constructor)
       : ObjectRef(std::move(other.data_)) {}
 
@@ -110,8 +109,7 @@ class Map : public ObjectRef {
    * \tparam VU The mapped type of the other map
    */
   template <typename KU, typename VU,
-            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
-                                        details::type_contains_v<V, VU>>>
+            typename = std::enable_if_t<type_subsumes_v<K, KU> && 
type_subsumes_v<V, VU>>>
   Map(const Map<KU, VU>& other) : ObjectRef(other.data_) {}  // 
NOLINT(google-explicit-constructor)
 
   /*!
@@ -139,8 +137,7 @@ class Map : public ObjectRef {
    * \tparam VU The mapped type of the other map
    */
   template <typename KU, typename VU,
-            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
-                                        details::type_contains_v<V, VU>>>
+            typename = std::enable_if_t<type_subsumes_v<K, KU> && 
type_subsumes_v<V, VU>>>
   Map<K, V>& operator=(Map<KU, VU>&& other) {
     data_ = std::move(other.data_);
     return *this;
@@ -153,8 +150,7 @@ class Map : public ObjectRef {
    * \tparam VU The mapped type of the other map
    */
   template <typename KU, typename VU,
-            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
-                                        details::type_contains_v<V, VU>>>
+            typename = std::enable_if_t<type_subsumes_v<K, KU> && 
type_subsumes_v<V, VU>>>
   Map<K, V>& operator=(const Map<KU, VU>& other) {
     data_ = other.data_;
     return *this;
@@ -380,11 +376,12 @@ struct TypeTraits<Map<K, V>> : public 
MapTypeTraitsBase<TypeTraits<Map<K, V>>, M
   }
 };
 
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target Map storage subsumes source Map storage key- and 
value-wise. */
 template <typename K, typename V, typename KU, typename VU>
-inline constexpr bool type_contains_v<Map<K, V>, Map<KU, VU>> =
-    type_contains_v<K, KU> && type_contains_v<V, VU>;
-}  // namespace details
+inline constexpr bool type_subsumes_v<Map<K, V>, Map<KU, VU>> =
+    type_subsumes_v<K, KU> && type_subsumes_v<V, VU>;
+/// \endcond
 
 }  // namespace ffi
 }  // namespace tvm
diff --git a/include/tvm/ffi/container/tuple.h 
b/include/tvm/ffi/container/tuple.h
index 090f99f0..2d73f836 100644
--- a/include/tvm/ffi/container/tuple.h
+++ b/include/tvm/ffi/container/tuple.h
@@ -64,7 +64,7 @@ class Tuple : public ObjectRef {
    * \tparam The enable_if_t type
    */
   template <typename... UTypes,
-            typename = std::enable_if_t<(details::type_contains_v<Types, 
UTypes> && ...), int>>
+            typename = std::enable_if_t<(type_subsumes_v<Types, UTypes> && 
...), int>>
   Tuple(const Tuple<UTypes...>& other) : ObjectRef(other) {}  // 
NOLINT(google-explicit-constructor)
 
   /*!
@@ -74,7 +74,7 @@ class Tuple : public ObjectRef {
    * \tparam The enable_if_t type
    */
   template <typename... UTypes,
-            typename = std::enable_if_t<(details::type_contains_v<Types, 
UTypes> && ...), int>>
+            typename = std::enable_if_t<(type_subsumes_v<Types, UTypes> && 
...), int>>
   Tuple(Tuple<UTypes...>&& other)  // NOLINT(google-explicit-constructor)
       : ObjectRef(std::move(other)) {}
 
@@ -116,7 +116,7 @@ class Tuple : public ObjectRef {
    * \tparam The enable_if_t type
    */
   template <typename... UTypes,
-            typename = std::enable_if_t<(details::type_contains_v<Types, 
UTypes> && ...)>>
+            typename = std::enable_if_t<(type_subsumes_v<Types, UTypes> && 
...)>>
   TVM_FFI_INLINE Tuple& operator=(const Tuple<UTypes...>& other) {
     data_ = other.data_;
     return *this;
@@ -129,7 +129,7 @@ class Tuple : public ObjectRef {
    * \tparam The enable_if_t type
    */
   template <typename... UTypes,
-            typename = std::enable_if_t<(details::type_contains_v<Types, 
UTypes> && ...)>>
+            typename = std::enable_if_t<(type_subsumes_v<Types, UTypes> && 
...)>>
   TVM_FFI_INLINE Tuple& operator=(Tuple<UTypes...>&& other) {
     data_ = std::move(other.data_);
     return *this;
@@ -338,10 +338,13 @@ struct TypeTraits<Tuple<Types...>> : public 
ObjectRefTypeTraitsBase<Tuple<Types.
   }
 };
 
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target and source Tuple storage have equal arity and 
subsume element-wise. */
 template <typename... T, typename... U>
-inline constexpr bool type_contains_v<Tuple<T...>, Tuple<U...>> = 
(type_contains_v<T, U> && ...);
-}  // namespace details
+inline constexpr bool
+    type_subsumes_v<Tuple<T...>, Tuple<U...>, std::enable_if_t<sizeof...(T) == 
sizeof...(U)>> =
+        (type_subsumes_v<T, U> && ...);
+/// \endcond
 
 /// \cond Doxygen_Suppress
 
diff --git a/include/tvm/ffi/container/variant.h 
b/include/tvm/ffi/container/variant.h
index 9082e0b2..8ad8097f 100644
--- a/include/tvm/ffi/container/variant.h
+++ b/include/tvm/ffi/container/variant.h
@@ -114,7 +114,7 @@ class Variant : public 
details::VariantBase<details::all_object_ref_v<V...>> {
    * \brief Helper utility to check if the type can be contained in the variant
    */
   template <typename T>
-  static constexpr bool variant_contains_v = (details::type_contains_v<V, T> 
|| ...);
+  static constexpr bool variant_contains_v = (type_subsumes_v<V, T> || ...);
   /* \brief Helper utility for SFINAE if the type is part of the variant */
   template <typename T>
   using enable_if_variant_contains_t = std::enable_if_t<variant_contains_v<T>>;
@@ -305,10 +305,11 @@ TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const 
Variant<V...>& a,
   return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual();
 }
 
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether Variant storage subsumes a source type through one 
alternative. */
 template <typename... V, typename T>
-inline constexpr bool type_contains_v<Variant<V...>, T> = (type_contains_v<V, 
T> || ...);
-}  // namespace details
+inline constexpr bool type_subsumes_v<Variant<V...>, T> = (type_subsumes_v<V, 
T> || ...);
+/// \endcond
 }  // namespace ffi
 }  // namespace tvm
 #endif  // TVM_FFI_CONTAINER_VARIANT_H_
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 2e8a8f40..4868ab14 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -881,6 +881,25 @@ class ObjectRef {
   friend struct tvm::ffi::details::ObjectUnsafe;
 };
 
+/*!
+ * \brief Whether RefType contains every ObjectType instance.
+ *
+ * The containing reference type is first and the contained object type is
+ * second. The template is enabled only when RefType is an object reference and
+ * ObjectType is an object node. The default is true exactly when RefType has 
an
+ * exact container type and ObjectType derives from it. Direct specializations
+ * can provide the proof for non-exact reference types.
+ *
+ * \tparam RefType The object reference type.
+ * \tparam ObjectType The object node type.
+ */
+template <typename RefType, typename ObjectType,
+          typename = std::enable_if_t<std::is_base_of_v<ObjectRef, RefType> &&
+                                      std::is_base_of_v<Object, ObjectType>>>
+inline constexpr bool object_ref_contains_v =
+    RefType::_type_container_is_exact &&
+    std::is_base_of_v<typename RefType::ContainerType, ObjectType>;
+
 // forward delcare variant
 template <typename... V>
 class Variant;
diff --git a/include/tvm/ffi/type_traits.h b/include/tvm/ffi/type_traits.h
index 2810e3fe..1b75392a 100644
--- a/include/tvm/ffi/type_traits.h
+++ b/include/tvm/ffi/type_traits.h
@@ -118,22 +118,25 @@ inline std::string TypeIndexToTypeKey(int32_t type_index) 
{
   return std::string(type_info->type_key.data, type_info->type_key.size);
 }
 
-namespace details {
 /*!
- * \brief Check whether `Derived` can reuse `Base` storage directly.
+ * \brief Whether TargetType subsumes SourceType for direct storage reuse.
  *
- * \tparam Base The base type.
- * \tparam Derived The derived type.
- * \return True if Derived's storage can be used as Base's storage, false 
otherwise.
+ * The target type is first and the source type is second. The result is true
+ * exactly when every SourceType value can reuse TargetType storage without
+ * conversion.
+ *
+ * \tparam TargetType The target storage type.
+ * \tparam SourceType The source value type.
  */
-template <typename Base, typename Derived>
-inline constexpr bool type_contains_v =
-    std::is_base_of_v<Base, Derived> || std::is_same_v<Base, Derived>;
-
-// Special case for Any, which can store any compatible value directly.
-template <typename Derived>
-inline constexpr bool type_contains_v<Any, Derived> = true;
-}  // namespace details
+template <typename TargetType, typename SourceType, typename = void>
+inline constexpr bool type_subsumes_v =
+    std::is_base_of_v<TargetType, SourceType> || std::is_same_v<TargetType, 
SourceType>;
+
+/// \cond Doxygen_Suppress
+/*! \brief Whether Any subsumes SourceType for direct storage reuse. */
+template <typename SourceType>
+inline constexpr bool type_subsumes_v<Any, SourceType> = true;
+/// \endcond
 
 /*!
  * \brief TypeTraits that specifies the conversion behavior from/to FFI Any.
diff --git a/tests/cpp/test_array.cc b/tests/cpp/test_array.cc
index b7d1fa3d..31c13d43 100644
--- a/tests/cpp/test_array.cc
+++ b/tests/cpp/test_array.cc
@@ -289,8 +289,9 @@ TEST(Array, Upcast) {
   Array<Array<Any>> a3 = a2;
   Array<Array<Any>> a4 = a2;
 
-  static_assert(details::type_contains_v<Array<Any>, Array<int>>);
-  static_assert(details::type_contains_v<Any, Array<float>>);
+  static_assert(type_subsumes_v<Array<Any>, Array<int>>);
+  static_assert(!type_subsumes_v<Array<int>, Array<Any>>);
+  static_assert(type_subsumes_v<Any, Array<float>>);
 }
 
 TEST(Array, Contains) {
diff --git a/tests/cpp/test_map.cc b/tests/cpp/test_map.cc
index 2f323348..b91cb537 100644
--- a/tests/cpp/test_map.cc
+++ b/tests/cpp/test_map.cc
@@ -257,7 +257,7 @@ TEST(Map, Upcast) {
   Map<Any, Any> m1 = m0;
   EXPECT_EQ(m1[1].cast<int>(), 2);
   EXPECT_EQ(m1[3].cast<int>(), 4);
-  static_assert(details::type_contains_v<Map<Any, Any>, Map<String, int>>);
+  static_assert(type_subsumes_v<Map<Any, Any>, Map<String, int>>);
 
   Map<String, Array<int>> m2 = {{"x", {1}}, {"y", {2}}};
   Map<String, Array<Any>> m3 = m2;
diff --git a/tests/cpp/test_object.cc b/tests/cpp/test_object.cc
index 255087a7..25a8d8da 100644
--- a/tests/cpp/test_object.cc
+++ b/tests/cpp/test_object.cc
@@ -17,6 +17,7 @@
  * under the License.
  */
 #include <gtest/gtest.h>
+#include <tvm/ffi/cast.h>
 #include <tvm/ffi/container/array.h>
 #include <tvm/ffi/container/dict.h>
 #include <tvm/ffi/container/list.h>
@@ -54,6 +55,11 @@ struct TypeTraits<testing::TIntOrFloatRef>
   }
 };
 
+template <typename ObjectType>
+inline constexpr bool object_ref_contains_v<testing::TIntOrFloatRef, 
ObjectType> =
+    std::is_base_of_v<testing::TIntObj, ObjectType> ||
+    std::is_base_of_v<testing::TFloatObj, ObjectType>;
+
 }  // namespace ffi
 }  // namespace tvm
 
@@ -62,6 +68,13 @@ namespace {
 using namespace tvm::ffi;
 using namespace tvm::ffi::testing;
 
+template <typename RefType, typename ObjectType, typename = void>
+inline constexpr bool object_ref_contains_is_enabled_v = false;
+
+template <typename RefType, typename ObjectType>
+inline constexpr bool object_ref_contains_is_enabled_v<
+    RefType, ObjectType, std::void_t<decltype(object_ref_contains_v<RefType, 
ObjectType>)>> = true;
+
 static_assert(ObjectRef::_type_container_is_exact);
 static_assert(TNumber::_type_container_is_exact);
 static_assert(TInt::_type_container_is_exact);
@@ -74,6 +87,21 @@ static_assert(!Dict<TInt, TFloat>::_type_container_is_exact);
 static_assert(!Tuple<TInt, TFloat>::_type_container_is_exact);
 static_assert(!Variant<TInt, TFloat>::_type_container_is_exact);
 
+static_assert(object_ref_contains_v<TNumber, TIntObj>);
+static_assert(object_ref_contains_v<TInt, TIntObj>);
+static_assert(object_ref_contains_v<Optional<TInt>, TIntObj>);
+static_assert(!object_ref_contains_v<TInt, TFloatObj>);
+static_assert(!object_ref_contains_v<Array<TInt>, ArrayObj>);
+static_assert(object_ref_contains_v<TIntOrFloatRef, TIntObj>);
+static_assert(object_ref_contains_v<TIntOrFloatRef, TFloatObj>);
+static_assert(!object_ref_contains_v<TIntOrFloatRef, TNumberObj>);
+static_assert(object_ref_contains_is_enabled_v<TInt, TIntObj>);
+static_assert(object_ref_contains_is_enabled_v<TIntOrFloatRef, TIntObj>);
+static_assert(!object_ref_contains_is_enabled_v<int, TIntObj>);
+static_assert(!object_ref_contains_is_enabled_v<TIntObj, TIntObj>);
+static_assert(!object_ref_contains_is_enabled_v<TInt, int>);
+static_assert(!object_ref_contains_is_enabled_v<TIntOrFloatRef, int>);
+
 template <typename T>
 class CRTPObject : public Object {
  public:
@@ -202,6 +230,14 @@ TEST(ObjectRef, AsUsesTypeTraitsCheckAnyStrict) {
   EXPECT_NE((*float_like).as<TFloatObj>(), nullptr);  // 
NOLINT(bugprone-unchecked-optional-access)
 }
 
+TEST(ObjectRef, GetRefUsesObjectRefContainment) {
+  ObjectPtr<TIntObj> int_object = make_object<TIntObj>(10);
+  TIntOrFloatRef int_or_float = GetRef<TIntOrFloatRef>(int_object.get());
+
+  ASSERT_NE(int_or_float.as<TIntObj>(), nullptr);
+  EXPECT_EQ(int_or_float.as<TIntObj>()->value, 10);
+}
+
 TEST(ObjectRef, AsOrThrow) {
   ObjectRef a = TInt(10);
   ObjectRef b = TFloat(20);
diff --git a/tests/cpp/test_tuple.cc b/tests/cpp/test_tuple.cc
index 89f79c23..9aeb57c8 100644
--- a/tests/cpp/test_tuple.cc
+++ b/tests/cpp/test_tuple.cc
@@ -138,9 +138,11 @@ TEST(Tuple, Upcast) {
   Tuple<Any, Any> t1 = t0;
   EXPECT_EQ(t1.get<0>().cast<int>(), 1);
   EXPECT_EQ(t1.get<1>().cast<float>(), 2.0f);
-  static_assert(details::type_contains_v<Tuple<Any, Any>, Tuple<int, float>>);
-  static_assert(details::type_contains_v<Tuple<Any, float>, Tuple<int, 
float>>);
-  static_assert(details::type_contains_v<Tuple<TNumber, float>, Tuple<TInt, 
float>>);
+  static_assert(type_subsumes_v<Tuple<Any, Any>, Tuple<int, float>>);
+  static_assert(type_subsumes_v<Tuple<Any, float>, Tuple<int, float>>);
+  static_assert(type_subsumes_v<Tuple<TNumber, float>, Tuple<TInt, float>>);
+  static_assert(!type_subsumes_v<Tuple<Any>, Tuple<int, float>>);
+  static_assert(!type_subsumes_v<Tuple<Any, Any>, Tuple<int>>);
 }
 
 TEST(Tuple, ArrayIterForwarding) {
diff --git a/tests/cpp/test_variant.cc b/tests/cpp/test_variant.cc
index 939e4637..40a55769 100644
--- a/tests/cpp/test_variant.cc
+++ b/tests/cpp/test_variant.cc
@@ -130,7 +130,7 @@ TEST(Variant, FromTyped) {
 
 TEST(Variant, Upcast) {
   Array<int> a0 = {1, 2, 3};
-  static_assert(details::type_contains_v<Array<Variant<int, float>>, 
Array<int>>);
+  static_assert(type_subsumes_v<Array<Variant<int, float>>, Array<int>>);
   Array<Variant<int, float>> a1 = a0;
   EXPECT_EQ(a1[0].get<int>(), 1);
 }

Reply via email to