This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new fbb2b5c5 [FFI] Add ref-qualified strict ObjectRef casts (#639)
fbb2b5c5 is described below
commit fbb2b5c5f6e4f5f962db2ae74c9c1fc4c412843e
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 21 06:49:40 2026 -0400
[FFI] Add ref-qualified strict ObjectRef casts (#639)
This PR refactors strict ObjectRef casting to use the Any TypeTraits
strict-check path and adds strict throwing APIs.
Main benefit:
- `Any` and `ObjectRef` now have consistent strict `as` / `as_or_throw`
APIs: both rely on the same `TypeTraits` strict-check semantics, both
return optional values for probe-style `as<T>()`, and both provide
throwing `as_or_throw<T>()` variants for required casts.
- ObjectRef casts can now support richer compatibility checks through
`TypeTraits<ObjectRefType>::CheckAnyStrict`, instead of being limited to
the target ref's canonical `ContainerType` instance check.
- `ObjectRefType::ContainerType` now has an explicit
`_type_container_is_exact` invariant. Ordinary refs inherit `true` from
`ObjectRef`, while richer parameterized refs such as `Array<T>`,
`Map<K,V>`, `Tuple<...>`, and object-backed `Variant<...>` opt out
because their accepted values are determined by `TypeTraits`, not only
by the backing container. `GetRef` is guarded to only work for
exact-container refs.
API behavior:
- `ObjectRef::as<ObjectRefType>() const&` returns
`std::optional<ObjectRefType>`. For a non-null source, it succeeds only
when `TypeTraits<ObjectRefType>::CheckAnyStrict` accepts the object
through a compact temporary `TVMFFIAny` view. On success, it returns a
ref that shares the original object pointer. On type mismatch, it
returns `std::nullopt`.
- `ObjectRef::as<ObjectRefType>() &&` has the same strict-check and
`std::nullopt` behavior, but moves the object pointer into the returned
ref on success and clears the source ref.
- For null `ObjectRef` sources, `ObjectRef::as<ObjectRefType>()` returns
a successful null ref when the target ref type is nullable, and returns
`std::nullopt` for non-nullable target refs. The null path is explicit
and is not treated as a globally cold failure path.
- `ObjectRef::as_or_throw<ObjectRefType>() const&` and `&&` are the
throwing forms of the same strict ObjectRef cast. They return
`ObjectRefType` on success, preserve/move the object pointer according
to the receiver qualifier, return a null ref for nullable null targets,
and throw `TypeError` for non-nullable null or type-mismatch cases.
- `Any::as_or_throw<T>() const&` and `&&` are strict reinterpretation
helpers for `Any`. They call the corresponding strict `as<T>()` path,
return `T` on success, and throw `TypeError` on mismatch. They do not
run fallback conversions; use `cast<T>()` when conversion is intended.
- Optional-returning `as` APIs intentionally do not use success/failure
prediction annotations because `std::nullopt` can be a normal probe
result. Throwing APIs mark only the final failure/throw path as cold,
and `ObjectRef::as_or_throw` marks the strict-check success path as
likely because mismatch throws.
Implementation notes:
- `ObjectRef::as<T>()` no longer relies on `T::ContainerType` as the
complete runtime compatibility test. It piggy-backs on Any `TypeTraits`
strict checks, so ref types with richer strict compatibility rules work
consistently with `Any`.
- The temporary `TVMFFIAny` setup is deliberately kept inline in the
non-null ObjectRef paths. This preserves explicit null behavior and lets
`as_or_throw` call `TypeTraits<T>::GetMismatchTypeInfo` for richer
diagnostics instead of reducing to `as<T>()` and losing the synthesized
Any view.
- The compact temporary Any view is expected to optimize away on hot
paths; GCC/Clang assembly probes confirmed direct object-header
loads/type-index comparisons for the checked ObjectRef paths.
Changes:
- Add ref-qualified `ObjectRef::as<T>()` and
`ObjectRef::as_or_throw<T>()` overloads for const and rvalue receivers.
- Add `Any::as_or_throw<T>()` const/rvalue helpers.
- Keep ObjectRef null handling explicit while using temporary
`TVMFFIAny` views for rich `TypeTraits` checks and mismatch messages.
- Add `_type_container_is_exact` and guard `GetRef` so it is only used
for refs whose `ContainerType` is an exact acceptance predicate.
- Add focused tests for const and move variants of `as<T>()` and
`as_or_throw<T>()`, plus compile-time coverage for exact/non-exact
container-ref flags.
- Update `TVM_FFI_UNSAFE_ASSUME` lowering to be more robust for GCC.
---
include/tvm/ffi/any.h | 115 +++++++++++++++++++++++++++++++++-
include/tvm/ffi/base_details.h | 14 ++---
include/tvm/ffi/cast.h | 3 +
include/tvm/ffi/container/array.h | 4 +-
include/tvm/ffi/container/dict.h | 4 +-
include/tvm/ffi/container/list.h | 4 +-
include/tvm/ffi/container/map.h | 4 +-
include/tvm/ffi/container/tuple.h | 4 +-
include/tvm/ffi/container/variant.h | 3 +
include/tvm/ffi/object.h | 93 ++++++++++++++++++++++++----
include/tvm/ffi/optional.h | 1 +
tests/cpp/test_any.cc | 50 +++++++++++++++
tests/cpp/test_object.cc | 119 ++++++++++++++++++++++++++++++++++++
13 files changed, 392 insertions(+), 26 deletions(-)
diff --git a/include/tvm/ffi/any.h b/include/tvm/ffi/any.h
index 5d754f7d..7c680ec1 100644
--- a/include/tvm/ffi/any.h
+++ b/include/tvm/ffi/any.h
@@ -141,7 +141,7 @@ class AnyView {
template <typename T, typename =
std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE T cast() const {
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
- if (!opt.has_value()) {
+ if (TVM_FFI_PREDICT_FALSE(!opt.has_value())) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_)
<< "` to `"
<< TypeTraits<T>::TypeStr() << "`";
@@ -361,6 +361,29 @@ class Any {
}
}
+ /**
+ * \brief Strictly reinterpret the Any as a type T or throw.
+ *
+ * \tparam T The type to cast to.
+ * \return The casted value.
+ * \note This function will not run fallback conversions.
+ */
+ template <typename T,
+ typename = std::enable_if_t<TypeTraits<T>::storage_enabled ||
std::is_same_v<T, Any>>>
+ TVM_FFI_INLINE T as_or_throw() && {
+ if constexpr (std::is_same_v<T, Any>) {
+ return std::move(*this);
+ } else {
+ std::optional<T> result = std::move(*this).template as<T>();
+ if (TVM_FFI_PREDICT_FALSE(!result.has_value())) {
+ TVM_FFI_THROW(TypeError) << "Cannot treat type `"
+ << TypeTraits<T>::GetMismatchTypeInfo(&data_)
<< "` as type `"
+ << TypeTraits<T>::TypeStr() << "`";
+ }
+ return *std::move(result);
+ }
+ }
+
/**
* \brief Try to reinterpret the Any as a type T, return std::nullopt if it
is not possible.
*
@@ -382,6 +405,29 @@ class Any {
}
}
+ /**
+ * \brief Strictly reinterpret the Any as a type T or throw.
+ *
+ * \tparam T The type to cast to.
+ * \return The casted value.
+ * \note This function will not run fallback conversions.
+ */
+ template <typename T,
+ typename = std::enable_if_t<TypeTraits<T>::convert_enabled ||
std::is_same_v<T, Any>>>
+ TVM_FFI_INLINE T as_or_throw() const& {
+ if constexpr (std::is_same_v<T, Any>) {
+ return *this;
+ } else {
+ std::optional<T> result = this->as<T>();
+ if (TVM_FFI_PREDICT_FALSE(!result.has_value())) {
+ TVM_FFI_THROW(TypeError) << "Cannot treat type `"
+ << TypeTraits<T>::GetMismatchTypeInfo(&data_)
<< "` as type `"
+ << TypeTraits<T>::TypeStr() << "`";
+ }
+ return *std::move(result);
+ }
+ }
+
/*!
* \brief Shortcut of as Object to cast to a const pointer when T is an
Object.
*
@@ -401,7 +447,7 @@ class Any {
template <typename T, typename =
std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE T cast() const& {
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
- if (!opt.has_value()) {
+ if (TVM_FFI_PREDICT_FALSE(!opt.has_value())) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_)
<< "` to `"
<< TypeTraits<T>::TypeStr() << "`";
@@ -421,7 +467,7 @@ class Any {
}
// slow path, try to do fallback convert
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
- if (!opt.has_value()) {
+ if (TVM_FFI_PREDICT_FALSE(!opt.has_value())) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_)
<< "` to `"
<< TypeTraits<T>::TypeStr() << "`";
@@ -824,6 +870,69 @@ struct AnyEqual {
}
};
+// Defer this definition until any.h so the throwing path can depend on
+// TVM_FFI_THROW(TypeError), while object.h stays below the error layer.
+//! \cond Doxygen_Suppress
+template <typename ObjectRefType, typename>
+TVM_FFI_INLINE ObjectRefType ObjectRef::as_or_throw() const& {
+ if (data_ != nullptr) {
+ // Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data
will optimize away.
+ TVMFFIAny any_data;
+ any_data.type_index = data_->type_index();
+ TVM_FFI_UNSAFE_ASSUME(any_data.type_index >=
TypeIndex::kTVMFFIStaticObjectBegin);
+ any_data.zero_padding = 0;
+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+ any_data.v_obj =
reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
+ if
(TVM_FFI_PREDICT_TRUE(TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data))) {
+ ObjectRefType result(UnsafeInit{});
+ result.data_ = data_;
+ return result;
+ } else {
+ TVM_FFI_THROW(TypeError) << "Cannot treat type `"
+ <<
TypeTraits<ObjectRefType>::GetMismatchTypeInfo(&any_data)
+ << "` as type `" <<
TypeTraits<ObjectRefType>::TypeStr() << "`";
+ }
+ } else {
+ if constexpr (ObjectRefType::_type_is_nullable) {
+ return ObjectRefType(UnsafeInit{});
+ } else {
+ TVM_FFI_THROW(TypeError) << "Cannot treat type `" <<
StaticTypeKey::kTVMFFINone
+ << "` as type `" <<
TypeTraits<ObjectRefType>::TypeStr() << "`";
+ }
+ }
+}
+
+template <typename ObjectRefType, typename>
+TVM_FFI_INLINE ObjectRefType ObjectRef::as_or_throw() && {
+ if (data_ != nullptr) {
+ // Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data
will optimize away.
+ TVMFFIAny any_data;
+ any_data.type_index = data_->type_index();
+ TVM_FFI_UNSAFE_ASSUME(any_data.type_index >=
TypeIndex::kTVMFFIStaticObjectBegin);
+ any_data.zero_padding = 0;
+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+ any_data.v_obj =
reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
+ if
(TVM_FFI_PREDICT_TRUE(TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data))) {
+ ObjectRefType result(UnsafeInit{});
+ result.data_ = std::move(data_);
+ data_ = nullptr;
+ return result;
+ } else {
+ TVM_FFI_THROW(TypeError) << "Cannot treat type `"
+ <<
TypeTraits<ObjectRefType>::GetMismatchTypeInfo(&any_data)
+ << "` as type `" <<
TypeTraits<ObjectRefType>::TypeStr() << "`";
+ }
+ } else {
+ if constexpr (ObjectRefType::_type_is_nullable) {
+ return ObjectRefType(UnsafeInit{});
+ } else {
+ TVM_FFI_THROW(TypeError) << "Cannot treat type `" <<
StaticTypeKey::kTVMFFINone
+ << "` as type `" <<
TypeTraits<ObjectRefType>::TypeStr() << "`";
+ }
+ }
+}
+//! \endcond
+
// Placed near the end because this specialization depends on error handling.
template <>
struct TypeTraits<uint64_t> : public TypeTraitsIntBase<uint64_t> {
diff --git a/include/tvm/ffi/base_details.h b/include/tvm/ffi/base_details.h
index 27431b2b..4c2226d8 100644
--- a/include/tvm/ffi/base_details.h
+++ b/include/tvm/ffi/base_details.h
@@ -125,13 +125,13 @@
#if defined(__clang__)
#define TVM_FFI_UNSAFE_ASSUME(cond) __builtin_assume(cond)
#elif defined(__GNUC__)
-// GCC 13+ supports __attribute__((assume(...))); fall back to the void-cast
-// no-op for older GCC where __builtin_assume is absent.
-#if __GNUC__ >= 13
-#define TVM_FFI_UNSAFE_ASSUME(cond) __attribute__((assume(cond)))
-#else
-#define TVM_FFI_UNSAFE_ASSUME(cond) static_cast<void>(0)
-#endif
+// GCC does not reliably propagate __attribute__((assume(...))) through the
+// returned-aggregate/helper flows used in TVM_FFI hot paths. Lower to an
+// unreachable edge instead so GCC 11/14 recover the intended codegen.
+#define TVM_FFI_UNSAFE_ASSUME(cond) \
+ do { \
+ if (!(cond)) __builtin_unreachable(); \
+ } while (0)
#elif defined(_MSC_VER)
#define TVM_FFI_UNSAFE_ASSUME(cond) __assume(cond)
#else
diff --git a/include/tvm/ffi/cast.h b/include/tvm/ffi/cast.h
index 66abd664..655f1891 100644
--- a/include/tvm/ffi/cast.h
+++ b/include/tvm/ffi/cast.h
@@ -45,6 +45,9 @@ namespace ffi {
template <typename RefType, typename ObjectType>
inline RefType GetRef(const ObjectType* ptr) {
using ContainerType = typename RefType::ContainerType;
+ static_assert(RefType::_type_container_is_exact,
+ "GetRef requires RefType::ContainerType to exactly describe
all objects the ref "
+ "can hold; use ObjectRef::as<RefType>() for richer
TypeTraits-based refs");
static_assert(std::is_base_of_v<ContainerType, ObjectType>,
"Can only cast to the ref of same container type");
diff --git a/include/tvm/ffi/container/array.h
b/include/tvm/ffi/container/array.h
index 523aee2d..82380c0a 100644
--- a/include/tvm/ffi/container/array.h
+++ b/include/tvm/ffi/container/array.h
@@ -668,8 +668,10 @@ class Array : public ObjectRef {
return static_cast<ArrayObj*>(data_.get());
}
- /*! \brief specify container node */
+ /// \cond Doxygen_Suppress
using ContainerType = ArrayObj;
+ static constexpr bool _type_container_is_exact = false;
+ /// \endcond
/*!
* \brief Agregate arguments into a single Array<T>
diff --git a/include/tvm/ffi/container/dict.h b/include/tvm/ffi/container/dict.h
index d28f39de..1887154f 100644
--- a/include/tvm/ffi/container/dict.h
+++ b/include/tvm/ffi/container/dict.h
@@ -266,8 +266,10 @@ class Dict : public ObjectRef {
}
}
- /*! \brief specify container node */
+ /// \cond Doxygen_Suppress
using ContainerType = DictObj;
+ static constexpr bool _type_container_is_exact = false;
+ /// \endcond
/// \cond Doxygen_Suppress
/*! \brief Iterator of the hash map */
diff --git a/include/tvm/ffi/container/list.h b/include/tvm/ffi/container/list.h
index 6e52be69..f2292ce9 100644
--- a/include/tvm/ffi/container/list.h
+++ b/include/tvm/ffi/container/list.h
@@ -460,8 +460,10 @@ class List : public ObjectRef {
}
}
- /*! \brief specify container node */
+ /// \cond Doxygen_Suppress
using ContainerType = ListObj;
+ static constexpr bool _type_container_is_exact = false;
+ /// \endcond
private:
/*!
diff --git a/include/tvm/ffi/container/map.h b/include/tvm/ffi/container/map.h
index 20153a06..a9003e20 100644
--- a/include/tvm/ffi/container/map.h
+++ b/include/tvm/ffi/container/map.h
@@ -272,8 +272,10 @@ class Map : public ObjectRef {
}
return GetMapObj();
}
- /*! \brief specify container node */
+ /// \cond Doxygen_Suppress
using ContainerType = MapObj;
+ static constexpr bool _type_container_is_exact = false;
+ /// \endcond
/// \cond Doxygen_Suppress
/*! \brief Iterator of the hash map */
diff --git a/include/tvm/ffi/container/tuple.h
b/include/tvm/ffi/container/tuple.h
index 79e402eb..090f99f0 100644
--- a/include/tvm/ffi/container/tuple.h
+++ b/include/tvm/ffi/container/tuple.h
@@ -190,8 +190,10 @@ class Tuple : public ObjectRef {
*ptr = T(std::forward<U>(item));
}
- /*! \brief specify container node */
+ /// \cond Doxygen_Suppress
using ContainerType = ArrayObj;
+ static constexpr bool _type_container_is_exact = false;
+ /// \endcond
private:
static ObjectPtr<ArrayObj> MakeDefaultTupleNode() {
diff --git a/include/tvm/ffi/container/variant.h
b/include/tvm/ffi/container/variant.h
index 08dc764d..9082e0b2 100644
--- a/include/tvm/ffi/container/variant.h
+++ b/include/tvm/ffi/container/variant.h
@@ -107,6 +107,9 @@ class Variant : public
details::VariantBase<details::all_object_ref_v<V...>> {
using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
static_assert(details::all_storage_enabled_v<V...>,
"All types used in Variant<...> must be compatible with Any");
+ /// \cond Doxygen_Suppress
+ static constexpr bool _type_container_is_exact = false;
+ /// \endcond
/*
* \brief Helper utility to check if the type can be contained in the variant
*/
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 2048cc5b..2e8a8f40 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -79,6 +79,7 @@ constexpr uint64_t kCombinedRefCountMaskUInt32 =
(static_cast<uint64_t>(1) << 32
*/
template <typename TargetType>
TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index);
+
} // namespace details
/*!
@@ -747,7 +748,7 @@ class ObjectRef {
* \return The pointer to the requested type.
*/
template <typename ObjectType, typename =
std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
- const ObjectType* as() const {
+ const ObjectType* as() const& {
if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
} else {
@@ -758,27 +759,95 @@ class ObjectRef {
/*!
* \brief Try to downcast the ObjectRef to Optional<T> of the requested type.
*
- * The function will return a std::nullopt if the cast or if the pointer is
nullptr.
+ * If the cast fails, returns std::nullopt. If this ObjectRef is null,
returns
+ * a null ref for nullable target refs and std::nullopt for non-nullable
targets.
*
- * \tparam ObjectRefType the target type, must be a subtype of ObjectRef'
+ * \tparam ObjectRefType the target type, must be a subtype of ObjectRef.
* \return The optional value of the requested type.
*/
template <typename ObjectRefType,
typename = std::enable_if_t<std::is_base_of_v<ObjectRef,
ObjectRefType>>>
- TVM_FFI_INLINE std::optional<ObjectRefType> as() const {
+ TVM_FFI_INLINE std::optional<ObjectRefType> as() const& {
if (data_ != nullptr) {
- if (data_->IsInstance<typename ObjectRefType::ContainerType>()) {
- ObjectRefType ref(UnsafeInit{});
- ref.data_ = data_;
- return ref;
- } else {
- return std::nullopt;
+ // Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data
will optimize away.
+ TVMFFIAny any_data;
+ any_data.type_index = data_->type_index();
+ TVM_FFI_UNSAFE_ASSUME(any_data.type_index >=
TypeIndex::kTVMFFIStaticObjectBegin);
+ any_data.zero_padding = 0;
+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+ any_data.v_obj =
reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
+ if (TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data)) {
+ ObjectRefType result(UnsafeInit{});
+ result.data_ = data_;
+ return result;
+ }
+ return std::nullopt;
+ }
+ if constexpr (ObjectRefType::_type_is_nullable) {
+ return ObjectRefType(UnsafeInit{});
+ }
+ return std::nullopt;
+ }
+
+ /*!
+ * \brief Try to move-downcast the ObjectRef to Optional<T> of the requested
type.
+ *
+ * If the cast succeeds, moves the internal object pointer to the returned
+ * ObjectRefType. If the cast fails, returns std::nullopt. If this ObjectRef
+ * is null, returns a null ref for nullable target refs and std::nullopt for
+ * non-nullable targets.
+ *
+ * \tparam ObjectRefType the target type, must be a subtype of ObjectRef.
+ * \return The optional value of the requested type.
+ */
+ template <typename ObjectRefType,
+ typename = std::enable_if_t<std::is_base_of_v<ObjectRef,
ObjectRefType>>>
+ TVM_FFI_INLINE std::optional<ObjectRefType> as() && {
+ if (data_ != nullptr) {
+ // Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data
will optimize away.
+ TVMFFIAny any_data;
+ any_data.type_index = data_->type_index();
+ TVM_FFI_UNSAFE_ASSUME(any_data.type_index >=
TypeIndex::kTVMFFIStaticObjectBegin);
+ any_data.zero_padding = 0;
+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+ any_data.v_obj =
reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
+ if (TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data)) {
+ ObjectRefType result(UnsafeInit{});
+ result.data_ = std::move(data_);
+ data_ = nullptr;
+ return result;
}
- } else {
return std::nullopt;
}
+ if constexpr (ObjectRefType::_type_is_nullable) {
+ return ObjectRefType(UnsafeInit{});
+ }
+ return std::nullopt;
}
+ /*!
+ * \brief Downcast the ObjectRef to the requested type or throw.
+ *
+ * \tparam ObjectRefType the target type, must be a subtype of ObjectRef
+ * \return The requested value.
+ */
+ template <typename ObjectRefType,
+ typename = std::enable_if_t<std::is_base_of_v<ObjectRef,
ObjectRefType>>>
+ TVM_FFI_INLINE ObjectRefType as_or_throw() const&;
+
+ /*!
+ * \brief Move-downcast the ObjectRef to the requested type or throw.
+ *
+ * If the cast succeeds, moves the internal object pointer to the returned
+ * ObjectRefType.
+ *
+ * \tparam ObjectRefType the target type, must be a subtype of ObjectRef
+ * \return The requested value.
+ */
+ template <typename ObjectRefType,
+ typename = std::enable_if_t<std::is_base_of_v<ObjectRef,
ObjectRefType>>>
+ TVM_FFI_INLINE ObjectRefType as_or_throw() &&;
+
/*!
* \brief Get the type index of the ObjectRef
* \return The type index of the ObjectRef
@@ -797,6 +866,8 @@ class ObjectRef {
/*! \brief type indicate the container type. */
using ContainerType = Object;
+ /*! \brief Whether ContainerType exactly describes the objects this ref can
hold. */
+ static constexpr bool _type_container_is_exact = true;
/*! \brief Whether the reference can point to nullptr */
static constexpr bool _type_is_nullable = true;
diff --git a/include/tvm/ffi/optional.h b/include/tvm/ffi/optional.h
index b28f7105..eb63fafb 100644
--- a/include/tvm/ffi/optional.h
+++ b/include/tvm/ffi/optional.h
@@ -261,6 +261,7 @@ template <typename T>
class Optional<T, std::enable_if_t<use_ptr_based_optional_v<T>>> : public
ObjectRef {
public:
using ContainerType = typename T::ContainerType;
+ static constexpr bool _type_container_is_exact = T::_type_container_is_exact;
Optional() = default;
// NOLINTBEGIN(google-explicit-constructor)
Optional(const Optional<T>& other) : ObjectRef(other) {}
diff --git a/tests/cpp/test_any.cc b/tests/cpp/test_any.cc
index bc8d36d6..698a22f4 100644
--- a/tests/cpp/test_any.cc
+++ b/tests/cpp/test_any.cc
@@ -370,6 +370,56 @@ TEST(Any, ObjectRefWithFallbackTraits) {
EXPECT_EQ(v9->value, 0);
}
+TEST(Any, AsOrThrow) {
+ Any any_int = 1;
+ EXPECT_EQ(any_int.as_or_throw<int>(), 1);
+ EXPECT_EQ(std::move(any_int).as_or_throw<int>(), 1);
+
+ Any any_obj = TInt(11);
+ EXPECT_EQ(any_obj.as_or_throw<const TIntObj*>()->value, 11);
+ EXPECT_EQ(any_obj.as_or_throw<TInt>()->value, 11);
+
+ const Any const_any_obj = TInt(12);
+ auto const_as_obj = const_any_obj.as<TInt>();
+ ASSERT_TRUE(const_as_obj.has_value()) << "Expected const Any as<TInt>() to
succeed";
+ EXPECT_EQ((*const_as_obj).get()->value, 12); //
NOLINT(bugprone-unchecked-optional-access)
+ EXPECT_EQ(const_any_obj.as_or_throw<TInt>()->value, 12);
+
+ auto moved_as_obj = Any(TInt(13)).as<TInt>();
+ ASSERT_TRUE(moved_as_obj.has_value()) << "Expected rvalue Any as<TInt>() to
succeed";
+ EXPECT_EQ((*moved_as_obj).get()->value, 13); //
NOLINT(bugprone-unchecked-optional-access)
+ EXPECT_EQ(Any(TInt(13)).as_or_throw<TInt>()->value, 13);
+
+ Any any_float = 1;
+ EXPECT_THROW(
+ {
+ try {
+ [[maybe_unused]] auto value = any_float.as_or_throw<double>();
+ } catch (const Error& error) {
+ EXPECT_EQ(error.kind(), "TypeError");
+ std::string what = error.what();
+ EXPECT_NE(what.find("Cannot treat type `int` as type `float`"),
std::string::npos);
+ throw;
+ }
+ },
+ ::tvm::ffi::Error);
+
+ Any any_number = TFloat(2.5);
+ EXPECT_THROW(
+ {
+ try {
+ [[maybe_unused]] auto value = any_number.as_or_throw<TInt>();
+ } catch (const Error& error) {
+ EXPECT_EQ(error.kind(), "TypeError");
+ std::string what = error.what();
+ EXPECT_NE(what.find("Cannot treat type `test.Float` as type
`test.Int`"),
+ std::string::npos);
+ throw;
+ }
+ },
+ ::tvm::ffi::Error);
+}
+
TEST(Any, CastVsAs) {
AnyView view0 = 1;
// as only runs strict check
diff --git a/tests/cpp/test_object.cc b/tests/cpp/test_object.cc
index fc621e77..255087a7 100644
--- a/tests/cpp/test_object.cc
+++ b/tests/cpp/test_object.cc
@@ -17,16 +17,63 @@
* under the License.
*/
#include <gtest/gtest.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
+#include <tvm/ffi/container/list.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/tuple.h>
+#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
+#include <tvm/ffi/optional.h>
#include "./testing_object.h"
+namespace tvm {
+namespace ffi {
+namespace testing {
+
+class TIntOrFloatRef : public ObjectRef {
+ public:
+ TIntOrFloatRef() = default;
+ explicit TIntOrFloatRef(UnsafeInit tag) : ObjectRef(tag) {}
+
+ static constexpr bool _type_is_nullable = true;
+ static constexpr bool _type_container_is_exact = false;
+ using ContainerType = Object;
+};
+
+} // namespace testing
+
+template <>
+struct TypeTraits<testing::TIntOrFloatRef>
+ : public ObjectRefTypeTraitsBase<testing::TIntOrFloatRef> {
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ return TypeTraits<testing::TInt>::CheckAnyStrict(src) ||
+ TypeTraits<testing::TFloat>::CheckAnyStrict(src);
+ }
+};
+
+} // namespace ffi
+} // namespace tvm
+
namespace {
using namespace tvm::ffi;
using namespace tvm::ffi::testing;
+static_assert(ObjectRef::_type_container_is_exact);
+static_assert(TNumber::_type_container_is_exact);
+static_assert(TInt::_type_container_is_exact);
+static_assert(Optional<TInt>::_type_container_is_exact);
+static_assert(!TIntOrFloatRef::_type_container_is_exact);
+static_assert(!Array<TInt>::_type_container_is_exact);
+static_assert(!List<TInt>::_type_container_is_exact);
+static_assert(!Map<TInt, TFloat>::_type_container_is_exact);
+static_assert(!Dict<TInt, TFloat>::_type_container_is_exact);
+static_assert(!Tuple<TInt, TFloat>::_type_container_is_exact);
+static_assert(!Variant<TInt, TFloat>::_type_container_is_exact);
+
template <typename T>
class CRTPObject : public Object {
public:
@@ -133,11 +180,83 @@ TEST(ObjectRef, as) {
EXPECT_TRUE(c.as<TIntObj>() == nullptr);
EXPECT_TRUE(c.as<TFloatObj>() == nullptr);
EXPECT_TRUE(c.as<TNumberObj>() == nullptr);
+ auto null_number = c.as<TNumber>();
+ ASSERT_TRUE(null_number.has_value()) << "Expected nullable null ObjectRef
cast to succeed";
+ EXPECT_TRUE(!(*null_number).defined()); //
NOLINT(bugprone-unchecked-optional-access)
+ EXPECT_TRUE(!c.as<TInt>().has_value());
EXPECT_EQ(a.as<TIntObj>()->value, 10);
EXPECT_EQ(b.as<TFloatObj>()->value, 20);
}
+TEST(ObjectRef, AsUsesTypeTraitsCheckAnyStrict) {
+ ObjectRef a = TInt(10);
+ ObjectRef b = TFloat(20);
+
+ auto int_like = a.as<TIntOrFloatRef>();
+ ASSERT_TRUE(int_like.has_value()) << "Expected TIntOrFloatRef cast from TInt
to succeed";
+ EXPECT_TRUE((*int_like).as<TIntObj>() != nullptr); //
NOLINT(bugprone-unchecked-optional-access)
+
+ auto float_like = b.as<TIntOrFloatRef>();
+ ASSERT_TRUE(float_like.has_value()) << "Expected TIntOrFloatRef cast from
TFloat to succeed";
+ EXPECT_NE((*float_like).as<TFloatObj>(), nullptr); //
NOLINT(bugprone-unchecked-optional-access)
+}
+
+TEST(ObjectRef, AsOrThrow) {
+ ObjectRef a = TInt(10);
+ ObjectRef b = TFloat(20);
+ ObjectRef c(nullptr);
+ const ObjectRef const_a = TInt(30);
+ ObjectRef movable_as = TInt(40);
+ ObjectRef movable_as_or_throw = TInt(50);
+
+ EXPECT_EQ(a.as<TIntObj>()->value, 10);
+ EXPECT_EQ(a.as_or_throw<TInt>()->value, 10);
+ EXPECT_EQ(b.as<TFloatObj>()->value, 20);
+ EXPECT_TRUE(!c.as_or_throw<TNumber>().defined());
+ auto const_as = const_a.as<TInt>();
+ ASSERT_TRUE(const_as.has_value()) << "Expected const ObjectRef as<TInt>() to
succeed";
+ EXPECT_EQ((*const_as).get()->value, 30); //
NOLINT(bugprone-unchecked-optional-access)
+ EXPECT_EQ(const_a.as_or_throw<TInt>()->value, 30);
+
+ auto moved_as = std::move(movable_as).as<TInt>();
+ ASSERT_TRUE(moved_as.has_value()) << "Expected rvalue ObjectRef as<TInt>()
to succeed";
+ EXPECT_EQ((*moved_as).get()->value, 40); //
NOLINT(bugprone-unchecked-optional-access)
+ // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
+ EXPECT_FALSE(movable_as.defined());
+
+ EXPECT_EQ(std::move(movable_as_or_throw).as_or_throw<TInt>()->value, 50);
+ // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
+ EXPECT_FALSE(movable_as_or_throw.defined());
+
+ EXPECT_THROW(
+ {
+ try {
+ [[maybe_unused]] auto value = a.as_or_throw<TFloat>();
+ } catch (const Error& error) {
+ EXPECT_EQ(error.kind(), "TypeError");
+ std::string what = error.what();
+ EXPECT_NE(what.find("Cannot treat type `test.Int` as type
`test.Float`"),
+ std::string::npos);
+ throw;
+ }
+ },
+ ::tvm::ffi::Error);
+
+ EXPECT_THROW(
+ {
+ try {
+ [[maybe_unused]] auto value = c.as_or_throw<TInt>();
+ } catch (const Error& error) {
+ EXPECT_EQ(error.kind(), "TypeError");
+ std::string what = error.what();
+ EXPECT_NE(what.find("Cannot treat type `None` as type `test.Int`"),
std::string::npos);
+ throw;
+ }
+ },
+ ::tvm::ffi::Error);
+}
+
TEST(ObjectRef, UnsafeInit) {
ObjectRef a(UnsafeInit{});
EXPECT_TRUE(a.get() == nullptr);