This is an automated email from the ASF dual-hosted git repository.
MasterJH5574 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 84ee1a07 [CORE] Add type subsumption and object-ref containment
relations (#647)
84ee1a07 is described below
commit 84ee1a07f85645886d2eebc9ddd0ddf9488fd38a
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 28 18:51:06 2026 -0400
[CORE] Add type subsumption and object-ref containment relations (#647)
---
include/tvm/ffi/cast.h | 31 +++++++++++++++++--------------
include/tvm/ffi/container/array.h | 15 ++++++++-------
include/tvm/ffi/container/dict.h | 21 +++++++++------------
include/tvm/ffi/container/list.h | 15 ++++++++-------
include/tvm/ffi/container/map.h | 21 +++++++++------------
include/tvm/ffi/container/tuple.h | 17 ++++++++++-------
include/tvm/ffi/container/variant.h | 9 +++++----
include/tvm/ffi/object.h | 19 +++++++++++++++++++
include/tvm/ffi/type_traits.h | 29 ++++++++++++++++-------------
tests/cpp/test_array.cc | 5 +++--
tests/cpp/test_map.cc | 2 +-
tests/cpp/test_object.cc | 36 ++++++++++++++++++++++++++++++++++++
tests/cpp/test_tuple.cc | 8 +++++---
tests/cpp/test_variant.cc | 2 +-
14 files changed, 147 insertions(+), 83 deletions(-)
diff --git a/include/tvm/ffi/cast.h b/include/tvm/ffi/cast.h
index 655f1891..44228fe2 100644
--- a/include/tvm/ffi/cast.h
+++ b/include/tvm/ffi/cast.h
@@ -27,6 +27,8 @@
#include <tvm/ffi/object.h>
#include <tvm/ffi/optional.h>
+#include <type_traits>
+
namespace tvm {
namespace ffi {
@@ -44,23 +46,24 @@ namespace ffi {
*/
template <typename RefType, typename ObjectType>
inline RefType GetRef(const ObjectType* ptr) {
- using ContainerType = typename RefType::ContainerType;
- static_assert(RefType::_type_container_is_exact,
- "GetRef requires RefType::ContainerType to exactly describe
all objects the ref "
- "can hold; use ObjectRef::as<RefType>() for richer
TypeTraits-based refs");
- static_assert(std::is_base_of_v<ContainerType, ObjectType>,
- "Can only cast to the ref of same container type");
-
- if constexpr (is_optional_type_v<RefType> || RefType::_type_is_nullable) {
- if (ptr == nullptr) {
- return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(nullptr);
+ if constexpr (object_ref_contains_v<RefType, ObjectType>) {
+ if constexpr (is_optional_type_v<RefType> || RefType::_type_is_nullable) {
+ if (ptr == nullptr) {
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(nullptr);
+ }
+ } else {
+ TVM_FFI_ICHECK_NOTNULL(ptr);
}
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(
+ details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
+ const_cast<Object*>(static_cast<const Object*>(ptr))));
} else {
- TVM_FFI_ICHECK_NOTNULL(ptr);
+ static_assert(object_ref_contains_v<RefType, ObjectType>,
+ "GetRef requires RefType to contain every ObjectType
instance; specialize "
+ "object_ref_contains_v for statically safe typed refs or use
"
+ "ObjectRef::as<RefType>() for runtime-dependent checks");
+ TVM_FFI_UNREACHABLE();
}
- return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(
- details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
- const_cast<Object*>(static_cast<const Object*>(ptr))));
}
/*!
diff --git a/include/tvm/ffi/container/array.h
b/include/tvm/ffi/container/array.h
index 82380c0a..fd2c8233 100644
--- a/include/tvm/ffi/container/array.h
+++ b/include/tvm/ffi/container/array.h
@@ -241,7 +241,7 @@ class Array : public ObjectRef {
* \param other The other array
* \tparam U The value type of the other array
*/
- template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
Array(Array<U>&& other) // NOLINT(google-explicit-constructor)
: ObjectRef(std::move(other.data_)) {}
/*!
@@ -249,7 +249,7 @@ class Array : public ObjectRef {
* \param other The other array
* \tparam U The value type of the other array
*/
- template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
Array(const Array<U>& other) // NOLINT(google-explicit-constructor)
: ObjectRef(other.data_) {}
@@ -274,7 +274,7 @@ class Array : public ObjectRef {
* \param other The other array
* \tparam U The value type of the other array
*/
- template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
TVM_FFI_INLINE Array<T>& operator=(Array<U>&& other) {
data_ = std::move(other.data_);
return *this;
@@ -284,7 +284,7 @@ class Array : public ObjectRef {
* \param other The other array
* \tparam U The value type of the other array
*/
- template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
TVM_FFI_INLINE Array<T>& operator=(const Array<U>& other) {
data_ = other.data_;
return *this;
@@ -895,10 +895,11 @@ struct TypeTraits<Array<T>> : public
SeqTypeTraitsBase<TypeTraits<Array<T>>, Arr
}
};
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target Array storage subsumes source Array storage
element-wise. */
template <typename T, typename U>
-inline constexpr bool type_contains_v<Array<T>, Array<U>> = type_contains_v<T,
U>;
-} // namespace details
+inline constexpr bool type_subsumes_v<Array<T>, Array<U>> = type_subsumes_v<T,
U>;
+/// \endcond
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ffi/container/dict.h b/include/tvm/ffi/container/dict.h
index 1887154f..e75ea2cd 100644
--- a/include/tvm/ffi/container/dict.h
+++ b/include/tvm/ffi/container/dict.h
@@ -102,8 +102,7 @@ class Dict : public ObjectRef {
* \tparam VU The mapped type of the other dict
*/
template <typename KU, typename VU,
- typename = std::enable_if_t<details::type_contains_v<K, KU> &&
- details::type_contains_v<V, VU>>>
+ typename = std::enable_if_t<type_subsumes_v<K, KU> &&
type_subsumes_v<V, VU>>>
Dict(Dict<KU, VU>&& other) // NOLINT(google-explicit-constructor)
: ObjectRef(std::move(other.data_)) {}
@@ -114,8 +113,7 @@ class Dict : public ObjectRef {
* \tparam VU The mapped type of the other dict
*/
template <typename KU, typename VU,
- typename = std::enable_if_t<details::type_contains_v<K, KU> &&
- details::type_contains_v<V, VU>>>
+ typename = std::enable_if_t<type_subsumes_v<K, KU> &&
type_subsumes_v<V, VU>>>
// NOLINTNEXTLINE(google-explicit-constructor)
Dict(const Dict<KU, VU>& other) : ObjectRef(other.data_) {}
@@ -144,8 +142,7 @@ class Dict : public ObjectRef {
* \tparam VU The mapped type of the other dict
*/
template <typename KU, typename VU,
- typename = std::enable_if_t<details::type_contains_v<K, KU> &&
- details::type_contains_v<V, VU>>>
+ typename = std::enable_if_t<type_subsumes_v<K, KU> &&
type_subsumes_v<V, VU>>>
Dict<K, V>& operator=(Dict<KU, VU>&& other) {
data_ = std::move(other.data_);
return *this;
@@ -158,8 +155,7 @@ class Dict : public ObjectRef {
* \tparam VU The mapped type of the other dict
*/
template <typename KU, typename VU,
- typename = std::enable_if_t<details::type_contains_v<K, KU> &&
- details::type_contains_v<V, VU>>>
+ typename = std::enable_if_t<type_subsumes_v<K, KU> &&
type_subsumes_v<V, VU>>>
Dict<K, V>& operator=(const Dict<KU, VU>& other) {
data_ = other.data_;
return *this;
@@ -365,11 +361,12 @@ struct TypeTraits<Dict<K, V>> : public
MapTypeTraitsBase<TypeTraits<Dict<K, V>>,
}
};
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target Dict storage subsumes source Dict storage key- and
value-wise. */
template <typename K, typename V, typename KU, typename VU>
-inline constexpr bool type_contains_v<Dict<K, V>, Dict<KU, VU>> =
- type_contains_v<K, KU> && type_contains_v<V, VU>;
-} // namespace details
+inline constexpr bool type_subsumes_v<Dict<K, V>, Dict<KU, VU>> =
+ type_subsumes_v<K, KU> && type_subsumes_v<V, VU>;
+/// \endcond
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ffi/container/list.h b/include/tvm/ffi/container/list.h
index f2292ce9..be2bda81 100644
--- a/include/tvm/ffi/container/list.h
+++ b/include/tvm/ffi/container/list.h
@@ -159,7 +159,7 @@ class List : public ObjectRef {
* \brief Constructor from another list
* \tparam U The value type of the other list
*/
- template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
List(List<U>&& other) // NOLINT(google-explicit-constructor)
: ObjectRef(std::move(other.data_)) {}
@@ -167,7 +167,7 @@ class List : public ObjectRef {
* \brief Constructor from another list
* \tparam U The value type of the other list
*/
- template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
List(const List<U>& other) // NOLINT(google-explicit-constructor)
: ObjectRef(other.data_) {}
@@ -194,7 +194,7 @@ class List : public ObjectRef {
* \param other The other list.
* \tparam U The value type of the other list.
*/
- template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
TVM_FFI_INLINE List<T>& operator=(List<U>&& other) {
data_ = std::move(other.data_);
return *this;
@@ -205,7 +205,7 @@ class List : public ObjectRef {
* \param other The other list.
* \tparam U The value type of the other list.
*/
- template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ template <typename U, typename = std::enable_if_t<type_subsumes_v<T, U>>>
TVM_FFI_INLINE List<T>& operator=(const List<U>& other) {
data_ = other.data_;
return *this;
@@ -518,10 +518,11 @@ struct TypeTraits<List<T>> : public
SeqTypeTraitsBase<TypeTraits<List<T>>, List<
}
};
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target List storage subsumes source List storage
element-wise. */
template <typename T, typename U>
-inline constexpr bool type_contains_v<List<T>, List<U>> = type_contains_v<T,
U>;
-} // namespace details
+inline constexpr bool type_subsumes_v<List<T>, List<U>> = type_subsumes_v<T,
U>;
+/// \endcond
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ffi/container/map.h b/include/tvm/ffi/container/map.h
index a9003e20..10013be9 100644
--- a/include/tvm/ffi/container/map.h
+++ b/include/tvm/ffi/container/map.h
@@ -98,8 +98,7 @@ class Map : public ObjectRef {
* \tparam VU The mapped type of the other map
*/
template <typename KU, typename VU,
- typename = std::enable_if_t<details::type_contains_v<K, KU> &&
- details::type_contains_v<V, VU>>>
+ typename = std::enable_if_t<type_subsumes_v<K, KU> &&
type_subsumes_v<V, VU>>>
Map(Map<KU, VU>&& other) // NOLINT(google-explicit-constructor)
: ObjectRef(std::move(other.data_)) {}
@@ -110,8 +109,7 @@ class Map : public ObjectRef {
* \tparam VU The mapped type of the other map
*/
template <typename KU, typename VU,
- typename = std::enable_if_t<details::type_contains_v<K, KU> &&
- details::type_contains_v<V, VU>>>
+ typename = std::enable_if_t<type_subsumes_v<K, KU> &&
type_subsumes_v<V, VU>>>
Map(const Map<KU, VU>& other) : ObjectRef(other.data_) {} //
NOLINT(google-explicit-constructor)
/*!
@@ -139,8 +137,7 @@ class Map : public ObjectRef {
* \tparam VU The mapped type of the other map
*/
template <typename KU, typename VU,
- typename = std::enable_if_t<details::type_contains_v<K, KU> &&
- details::type_contains_v<V, VU>>>
+ typename = std::enable_if_t<type_subsumes_v<K, KU> &&
type_subsumes_v<V, VU>>>
Map<K, V>& operator=(Map<KU, VU>&& other) {
data_ = std::move(other.data_);
return *this;
@@ -153,8 +150,7 @@ class Map : public ObjectRef {
* \tparam VU The mapped type of the other map
*/
template <typename KU, typename VU,
- typename = std::enable_if_t<details::type_contains_v<K, KU> &&
- details::type_contains_v<V, VU>>>
+ typename = std::enable_if_t<type_subsumes_v<K, KU> &&
type_subsumes_v<V, VU>>>
Map<K, V>& operator=(const Map<KU, VU>& other) {
data_ = other.data_;
return *this;
@@ -380,11 +376,12 @@ struct TypeTraits<Map<K, V>> : public
MapTypeTraitsBase<TypeTraits<Map<K, V>>, M
}
};
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target Map storage subsumes source Map storage key- and
value-wise. */
template <typename K, typename V, typename KU, typename VU>
-inline constexpr bool type_contains_v<Map<K, V>, Map<KU, VU>> =
- type_contains_v<K, KU> && type_contains_v<V, VU>;
-} // namespace details
+inline constexpr bool type_subsumes_v<Map<K, V>, Map<KU, VU>> =
+ type_subsumes_v<K, KU> && type_subsumes_v<V, VU>;
+/// \endcond
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ffi/container/tuple.h
b/include/tvm/ffi/container/tuple.h
index 090f99f0..2d73f836 100644
--- a/include/tvm/ffi/container/tuple.h
+++ b/include/tvm/ffi/container/tuple.h
@@ -64,7 +64,7 @@ class Tuple : public ObjectRef {
* \tparam The enable_if_t type
*/
template <typename... UTypes,
- typename = std::enable_if_t<(details::type_contains_v<Types,
UTypes> && ...), int>>
+ typename = std::enable_if_t<(type_subsumes_v<Types, UTypes> &&
...), int>>
Tuple(const Tuple<UTypes...>& other) : ObjectRef(other) {} //
NOLINT(google-explicit-constructor)
/*!
@@ -74,7 +74,7 @@ class Tuple : public ObjectRef {
* \tparam The enable_if_t type
*/
template <typename... UTypes,
- typename = std::enable_if_t<(details::type_contains_v<Types,
UTypes> && ...), int>>
+ typename = std::enable_if_t<(type_subsumes_v<Types, UTypes> &&
...), int>>
Tuple(Tuple<UTypes...>&& other) // NOLINT(google-explicit-constructor)
: ObjectRef(std::move(other)) {}
@@ -116,7 +116,7 @@ class Tuple : public ObjectRef {
* \tparam The enable_if_t type
*/
template <typename... UTypes,
- typename = std::enable_if_t<(details::type_contains_v<Types,
UTypes> && ...)>>
+ typename = std::enable_if_t<(type_subsumes_v<Types, UTypes> &&
...)>>
TVM_FFI_INLINE Tuple& operator=(const Tuple<UTypes...>& other) {
data_ = other.data_;
return *this;
@@ -129,7 +129,7 @@ class Tuple : public ObjectRef {
* \tparam The enable_if_t type
*/
template <typename... UTypes,
- typename = std::enable_if_t<(details::type_contains_v<Types,
UTypes> && ...)>>
+ typename = std::enable_if_t<(type_subsumes_v<Types, UTypes> &&
...)>>
TVM_FFI_INLINE Tuple& operator=(Tuple<UTypes...>&& other) {
data_ = std::move(other.data_);
return *this;
@@ -338,10 +338,13 @@ struct TypeTraits<Tuple<Types...>> : public
ObjectRefTypeTraitsBase<Tuple<Types.
}
};
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether target and source Tuple storage have equal arity and
subsume element-wise. */
template <typename... T, typename... U>
-inline constexpr bool type_contains_v<Tuple<T...>, Tuple<U...>> =
(type_contains_v<T, U> && ...);
-} // namespace details
+inline constexpr bool
+ type_subsumes_v<Tuple<T...>, Tuple<U...>, std::enable_if_t<sizeof...(T) ==
sizeof...(U)>> =
+ (type_subsumes_v<T, U> && ...);
+/// \endcond
/// \cond Doxygen_Suppress
diff --git a/include/tvm/ffi/container/variant.h
b/include/tvm/ffi/container/variant.h
index 9082e0b2..8ad8097f 100644
--- a/include/tvm/ffi/container/variant.h
+++ b/include/tvm/ffi/container/variant.h
@@ -114,7 +114,7 @@ class Variant : public
details::VariantBase<details::all_object_ref_v<V...>> {
* \brief Helper utility to check if the type can be contained in the variant
*/
template <typename T>
- static constexpr bool variant_contains_v = (details::type_contains_v<V, T>
|| ...);
+ static constexpr bool variant_contains_v = (type_subsumes_v<V, T> || ...);
/* \brief Helper utility for SFINAE if the type is part of the variant */
template <typename T>
using enable_if_variant_contains_t = std::enable_if_t<variant_contains_v<T>>;
@@ -305,10 +305,11 @@ TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const
Variant<V...>& a,
return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual();
}
-namespace details {
+/// \cond Doxygen_Suppress
+/*! \brief Whether Variant storage subsumes a source type through one
alternative. */
template <typename... V, typename T>
-inline constexpr bool type_contains_v<Variant<V...>, T> = (type_contains_v<V,
T> || ...);
-} // namespace details
+inline constexpr bool type_subsumes_v<Variant<V...>, T> = (type_subsumes_v<V,
T> || ...);
+/// \endcond
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_CONTAINER_VARIANT_H_
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 2e8a8f40..4868ab14 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -881,6 +881,25 @@ class ObjectRef {
friend struct tvm::ffi::details::ObjectUnsafe;
};
+/*!
+ * \brief Whether RefType contains every ObjectType instance.
+ *
+ * The containing reference type is first and the contained object type is
+ * second. The template is enabled only when RefType is an object reference and
+ * ObjectType is an object node. The default is true exactly when RefType has
an
+ * exact container type and ObjectType derives from it. Direct specializations
+ * can provide the proof for non-exact reference types.
+ *
+ * \tparam RefType The object reference type.
+ * \tparam ObjectType The object node type.
+ */
+template <typename RefType, typename ObjectType,
+ typename = std::enable_if_t<std::is_base_of_v<ObjectRef, RefType> &&
+ std::is_base_of_v<Object, ObjectType>>>
+inline constexpr bool object_ref_contains_v =
+ RefType::_type_container_is_exact &&
+ std::is_base_of_v<typename RefType::ContainerType, ObjectType>;
+
// forward delcare variant
template <typename... V>
class Variant;
diff --git a/include/tvm/ffi/type_traits.h b/include/tvm/ffi/type_traits.h
index 2810e3fe..1b75392a 100644
--- a/include/tvm/ffi/type_traits.h
+++ b/include/tvm/ffi/type_traits.h
@@ -118,22 +118,25 @@ inline std::string TypeIndexToTypeKey(int32_t type_index)
{
return std::string(type_info->type_key.data, type_info->type_key.size);
}
-namespace details {
/*!
- * \brief Check whether `Derived` can reuse `Base` storage directly.
+ * \brief Whether TargetType subsumes SourceType for direct storage reuse.
*
- * \tparam Base The base type.
- * \tparam Derived The derived type.
- * \return True if Derived's storage can be used as Base's storage, false
otherwise.
+ * The target type is first and the source type is second. The result is true
+ * exactly when every SourceType value can reuse TargetType storage without
+ * conversion.
+ *
+ * \tparam TargetType The target storage type.
+ * \tparam SourceType The source value type.
*/
-template <typename Base, typename Derived>
-inline constexpr bool type_contains_v =
- std::is_base_of_v<Base, Derived> || std::is_same_v<Base, Derived>;
-
-// Special case for Any, which can store any compatible value directly.
-template <typename Derived>
-inline constexpr bool type_contains_v<Any, Derived> = true;
-} // namespace details
+template <typename TargetType, typename SourceType, typename = void>
+inline constexpr bool type_subsumes_v =
+ std::is_base_of_v<TargetType, SourceType> || std::is_same_v<TargetType,
SourceType>;
+
+/// \cond Doxygen_Suppress
+/*! \brief Whether Any subsumes SourceType for direct storage reuse. */
+template <typename SourceType>
+inline constexpr bool type_subsumes_v<Any, SourceType> = true;
+/// \endcond
/*!
* \brief TypeTraits that specifies the conversion behavior from/to FFI Any.
diff --git a/tests/cpp/test_array.cc b/tests/cpp/test_array.cc
index b7d1fa3d..31c13d43 100644
--- a/tests/cpp/test_array.cc
+++ b/tests/cpp/test_array.cc
@@ -289,8 +289,9 @@ TEST(Array, Upcast) {
Array<Array<Any>> a3 = a2;
Array<Array<Any>> a4 = a2;
- static_assert(details::type_contains_v<Array<Any>, Array<int>>);
- static_assert(details::type_contains_v<Any, Array<float>>);
+ static_assert(type_subsumes_v<Array<Any>, Array<int>>);
+ static_assert(!type_subsumes_v<Array<int>, Array<Any>>);
+ static_assert(type_subsumes_v<Any, Array<float>>);
}
TEST(Array, Contains) {
diff --git a/tests/cpp/test_map.cc b/tests/cpp/test_map.cc
index 2f323348..b91cb537 100644
--- a/tests/cpp/test_map.cc
+++ b/tests/cpp/test_map.cc
@@ -257,7 +257,7 @@ TEST(Map, Upcast) {
Map<Any, Any> m1 = m0;
EXPECT_EQ(m1[1].cast<int>(), 2);
EXPECT_EQ(m1[3].cast<int>(), 4);
- static_assert(details::type_contains_v<Map<Any, Any>, Map<String, int>>);
+ static_assert(type_subsumes_v<Map<Any, Any>, Map<String, int>>);
Map<String, Array<int>> m2 = {{"x", {1}}, {"y", {2}}};
Map<String, Array<Any>> m3 = m2;
diff --git a/tests/cpp/test_object.cc b/tests/cpp/test_object.cc
index 255087a7..25a8d8da 100644
--- a/tests/cpp/test_object.cc
+++ b/tests/cpp/test_object.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <gtest/gtest.h>
+#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
@@ -54,6 +55,11 @@ struct TypeTraits<testing::TIntOrFloatRef>
}
};
+template <typename ObjectType>
+inline constexpr bool object_ref_contains_v<testing::TIntOrFloatRef,
ObjectType> =
+ std::is_base_of_v<testing::TIntObj, ObjectType> ||
+ std::is_base_of_v<testing::TFloatObj, ObjectType>;
+
} // namespace ffi
} // namespace tvm
@@ -62,6 +68,13 @@ namespace {
using namespace tvm::ffi;
using namespace tvm::ffi::testing;
+template <typename RefType, typename ObjectType, typename = void>
+inline constexpr bool object_ref_contains_is_enabled_v = false;
+
+template <typename RefType, typename ObjectType>
+inline constexpr bool object_ref_contains_is_enabled_v<
+ RefType, ObjectType, std::void_t<decltype(object_ref_contains_v<RefType,
ObjectType>)>> = true;
+
static_assert(ObjectRef::_type_container_is_exact);
static_assert(TNumber::_type_container_is_exact);
static_assert(TInt::_type_container_is_exact);
@@ -74,6 +87,21 @@ static_assert(!Dict<TInt, TFloat>::_type_container_is_exact);
static_assert(!Tuple<TInt, TFloat>::_type_container_is_exact);
static_assert(!Variant<TInt, TFloat>::_type_container_is_exact);
+static_assert(object_ref_contains_v<TNumber, TIntObj>);
+static_assert(object_ref_contains_v<TInt, TIntObj>);
+static_assert(object_ref_contains_v<Optional<TInt>, TIntObj>);
+static_assert(!object_ref_contains_v<TInt, TFloatObj>);
+static_assert(!object_ref_contains_v<Array<TInt>, ArrayObj>);
+static_assert(object_ref_contains_v<TIntOrFloatRef, TIntObj>);
+static_assert(object_ref_contains_v<TIntOrFloatRef, TFloatObj>);
+static_assert(!object_ref_contains_v<TIntOrFloatRef, TNumberObj>);
+static_assert(object_ref_contains_is_enabled_v<TInt, TIntObj>);
+static_assert(object_ref_contains_is_enabled_v<TIntOrFloatRef, TIntObj>);
+static_assert(!object_ref_contains_is_enabled_v<int, TIntObj>);
+static_assert(!object_ref_contains_is_enabled_v<TIntObj, TIntObj>);
+static_assert(!object_ref_contains_is_enabled_v<TInt, int>);
+static_assert(!object_ref_contains_is_enabled_v<TIntOrFloatRef, int>);
+
template <typename T>
class CRTPObject : public Object {
public:
@@ -202,6 +230,14 @@ TEST(ObjectRef, AsUsesTypeTraitsCheckAnyStrict) {
EXPECT_NE((*float_like).as<TFloatObj>(), nullptr); //
NOLINT(bugprone-unchecked-optional-access)
}
+TEST(ObjectRef, GetRefUsesObjectRefContainment) {
+ ObjectPtr<TIntObj> int_object = make_object<TIntObj>(10);
+ TIntOrFloatRef int_or_float = GetRef<TIntOrFloatRef>(int_object.get());
+
+ ASSERT_NE(int_or_float.as<TIntObj>(), nullptr);
+ EXPECT_EQ(int_or_float.as<TIntObj>()->value, 10);
+}
+
TEST(ObjectRef, AsOrThrow) {
ObjectRef a = TInt(10);
ObjectRef b = TFloat(20);
diff --git a/tests/cpp/test_tuple.cc b/tests/cpp/test_tuple.cc
index 89f79c23..9aeb57c8 100644
--- a/tests/cpp/test_tuple.cc
+++ b/tests/cpp/test_tuple.cc
@@ -138,9 +138,11 @@ TEST(Tuple, Upcast) {
Tuple<Any, Any> t1 = t0;
EXPECT_EQ(t1.get<0>().cast<int>(), 1);
EXPECT_EQ(t1.get<1>().cast<float>(), 2.0f);
- static_assert(details::type_contains_v<Tuple<Any, Any>, Tuple<int, float>>);
- static_assert(details::type_contains_v<Tuple<Any, float>, Tuple<int,
float>>);
- static_assert(details::type_contains_v<Tuple<TNumber, float>, Tuple<TInt,
float>>);
+ static_assert(type_subsumes_v<Tuple<Any, Any>, Tuple<int, float>>);
+ static_assert(type_subsumes_v<Tuple<Any, float>, Tuple<int, float>>);
+ static_assert(type_subsumes_v<Tuple<TNumber, float>, Tuple<TInt, float>>);
+ static_assert(!type_subsumes_v<Tuple<Any>, Tuple<int, float>>);
+ static_assert(!type_subsumes_v<Tuple<Any, Any>, Tuple<int>>);
}
TEST(Tuple, ArrayIterForwarding) {
diff --git a/tests/cpp/test_variant.cc b/tests/cpp/test_variant.cc
index 939e4637..40a55769 100644
--- a/tests/cpp/test_variant.cc
+++ b/tests/cpp/test_variant.cc
@@ -130,7 +130,7 @@ TEST(Variant, FromTyped) {
TEST(Variant, Upcast) {
Array<int> a0 = {1, 2, 3};
- static_assert(details::type_contains_v<Array<Variant<int, float>>,
Array<int>>);
+ static_assert(type_subsumes_v<Array<Variant<int, float>>, Array<int>>);
Array<Variant<int, float>> a1 = a0;
EXPECT_EQ(a1[0].get<int>(), 1);
}