This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new fbb2b5c5 [FFI] Add ref-qualified strict ObjectRef casts (#639)
fbb2b5c5 is described below

commit fbb2b5c5f6e4f5f962db2ae74c9c1fc4c412843e
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 21 06:49:40 2026 -0400

    [FFI] Add ref-qualified strict ObjectRef casts (#639)
    
    This PR refactors strict ObjectRef casting to use the Any TypeTraits
    strict-check path and adds strict throwing APIs.
    
    Main benefit:
    - `Any` and `ObjectRef` now have consistent strict `as` / `as_or_throw`
    APIs: both rely on the same `TypeTraits` strict-check semantics, both
    return optional values for probe-style `as<T>()`, and both provide
    throwing `as_or_throw<T>()` variants for required casts.
    - ObjectRef casts can now support richer compatibility checks through
    `TypeTraits<ObjectRefType>::CheckAnyStrict`, instead of being limited to
    the target ref's canonical `ContainerType` instance check.
    - `ObjectRefType::ContainerType` now has an explicit
    `_type_container_is_exact` invariant. Ordinary refs inherit `true` from
    `ObjectRef`, while richer parameterized refs such as `Array<T>`,
    `Map<K,V>`, `Tuple<...>`, and object-backed `Variant<...>` opt out
    because their accepted values are determined by `TypeTraits`, not only
    by the backing container. `GetRef` is guarded to only work for
    exact-container refs.
    
    API behavior:
    - `ObjectRef::as<ObjectRefType>() const&` returns
    `std::optional<ObjectRefType>`. For a non-null source, it succeeds only
    when `TypeTraits<ObjectRefType>::CheckAnyStrict` accepts the object
    through a compact temporary `TVMFFIAny` view. On success, it returns a
    ref that shares the original object pointer. On type mismatch, it
    returns `std::nullopt`.
    - `ObjectRef::as<ObjectRefType>() &&` has the same strict-check and
    `std::nullopt` behavior, but moves the object pointer into the returned
    ref on success and clears the source ref.
    - For null `ObjectRef` sources, `ObjectRef::as<ObjectRefType>()` returns
    a successful null ref when the target ref type is nullable, and returns
    `std::nullopt` for non-nullable target refs. The null path is explicit
    and is not treated as a globally cold failure path.
    - `ObjectRef::as_or_throw<ObjectRefType>() const&` and `&&` are the
    throwing forms of the same strict ObjectRef cast. They return
    `ObjectRefType` on success, preserve/move the object pointer according
    to the receiver qualifier, return a null ref for nullable null targets,
    and throw `TypeError` for non-nullable null or type-mismatch cases.
    - `Any::as_or_throw<T>() const&` and `&&` are strict reinterpretation
    helpers for `Any`. They call the corresponding strict `as<T>()` path,
    return `T` on success, and throw `TypeError` on mismatch. They do not
    run fallback conversions; use `cast<T>()` when conversion is intended.
    - Optional-returning `as` APIs intentionally do not use success/failure
    prediction annotations because `std::nullopt` can be a normal probe
    result. Throwing APIs mark only the final failure/throw path as cold,
    and `ObjectRef::as_or_throw` marks the strict-check success path as
    likely because mismatch throws.
    
    Implementation notes:
    - `ObjectRef::as<T>()` no longer relies on `T::ContainerType` as the
    complete runtime compatibility test. It piggy-backs on Any `TypeTraits`
    strict checks, so ref types with richer strict compatibility rules work
    consistently with `Any`.
    - The temporary `TVMFFIAny` setup is deliberately kept inline in the
    non-null ObjectRef paths. This preserves explicit null behavior and lets
    `as_or_throw` call `TypeTraits<T>::GetMismatchTypeInfo` for richer
    diagnostics instead of reducing to `as<T>()` and losing the synthesized
    Any view.
    - The compact temporary Any view is expected to optimize away on hot
    paths; GCC/Clang assembly probes confirmed direct object-header
    loads/type-index comparisons for the checked ObjectRef paths.
    
    Changes:
    - Add ref-qualified `ObjectRef::as<T>()` and
    `ObjectRef::as_or_throw<T>()` overloads for const and rvalue receivers.
    - Add `Any::as_or_throw<T>()` const/rvalue helpers.
    - Keep ObjectRef null handling explicit while using temporary
    `TVMFFIAny` views for rich `TypeTraits` checks and mismatch messages.
    - Add `_type_container_is_exact` and guard `GetRef` so it is only used
    for refs whose `ContainerType` is an exact acceptance predicate.
    - Add focused tests for const and move variants of `as<T>()` and
    `as_or_throw<T>()`, plus compile-time coverage for exact/non-exact
    container-ref flags.
    - Update `TVM_FFI_UNSAFE_ASSUME` lowering to be more robust for GCC.
---
 include/tvm/ffi/any.h               | 115 +++++++++++++++++++++++++++++++++-
 include/tvm/ffi/base_details.h      |  14 ++---
 include/tvm/ffi/cast.h              |   3 +
 include/tvm/ffi/container/array.h   |   4 +-
 include/tvm/ffi/container/dict.h    |   4 +-
 include/tvm/ffi/container/list.h    |   4 +-
 include/tvm/ffi/container/map.h     |   4 +-
 include/tvm/ffi/container/tuple.h   |   4 +-
 include/tvm/ffi/container/variant.h |   3 +
 include/tvm/ffi/object.h            |  93 ++++++++++++++++++++++++----
 include/tvm/ffi/optional.h          |   1 +
 tests/cpp/test_any.cc               |  50 +++++++++++++++
 tests/cpp/test_object.cc            | 119 ++++++++++++++++++++++++++++++++++++
 13 files changed, 392 insertions(+), 26 deletions(-)

diff --git a/include/tvm/ffi/any.h b/include/tvm/ffi/any.h
index 5d754f7d..7c680ec1 100644
--- a/include/tvm/ffi/any.h
+++ b/include/tvm/ffi/any.h
@@ -141,7 +141,7 @@ class AnyView {
   template <typename T, typename = 
std::enable_if_t<TypeTraits<T>::convert_enabled>>
   TVM_FFI_INLINE T cast() const {
     std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
-    if (!opt.has_value()) {
+    if (TVM_FFI_PREDICT_FALSE(!opt.has_value())) {
       TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
                                << TypeTraits<T>::GetMismatchTypeInfo(&data_) 
<< "` to `"
                                << TypeTraits<T>::TypeStr() << "`";
@@ -361,6 +361,29 @@ class Any {
     }
   }
 
+  /**
+   * \brief Strictly reinterpret the Any as a type T or throw.
+   *
+   * \tparam T The type to cast to.
+   * \return The casted value.
+   * \note This function will not run fallback conversions.
+   */
+  template <typename T,
+            typename = std::enable_if_t<TypeTraits<T>::storage_enabled || 
std::is_same_v<T, Any>>>
+  TVM_FFI_INLINE T as_or_throw() && {
+    if constexpr (std::is_same_v<T, Any>) {
+      return std::move(*this);
+    } else {
+      std::optional<T> result = std::move(*this).template as<T>();
+      if (TVM_FFI_PREDICT_FALSE(!result.has_value())) {
+        TVM_FFI_THROW(TypeError) << "Cannot treat type `"
+                                 << TypeTraits<T>::GetMismatchTypeInfo(&data_) 
<< "` as type `"
+                                 << TypeTraits<T>::TypeStr() << "`";
+      }
+      return *std::move(result);
+    }
+  }
+
   /**
    * \brief Try to reinterpret the Any as a type T, return std::nullopt if it 
is not possible.
    *
@@ -382,6 +405,29 @@ class Any {
     }
   }
 
+  /**
+   * \brief Strictly reinterpret the Any as a type T or throw.
+   *
+   * \tparam T The type to cast to.
+   * \return The casted value.
+   * \note This function will not run fallback conversions.
+   */
+  template <typename T,
+            typename = std::enable_if_t<TypeTraits<T>::convert_enabled || 
std::is_same_v<T, Any>>>
+  TVM_FFI_INLINE T as_or_throw() const& {
+    if constexpr (std::is_same_v<T, Any>) {
+      return *this;
+    } else {
+      std::optional<T> result = this->as<T>();
+      if (TVM_FFI_PREDICT_FALSE(!result.has_value())) {
+        TVM_FFI_THROW(TypeError) << "Cannot treat type `"
+                                 << TypeTraits<T>::GetMismatchTypeInfo(&data_) 
<< "` as type `"
+                                 << TypeTraits<T>::TypeStr() << "`";
+      }
+      return *std::move(result);
+    }
+  }
+
   /*!
    * \brief Shortcut of as Object to cast to a const pointer when T is an 
Object.
    *
@@ -401,7 +447,7 @@ class Any {
   template <typename T, typename = 
std::enable_if_t<TypeTraits<T>::convert_enabled>>
   TVM_FFI_INLINE T cast() const& {
     std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
-    if (!opt.has_value()) {
+    if (TVM_FFI_PREDICT_FALSE(!opt.has_value())) {
       TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
                                << TypeTraits<T>::GetMismatchTypeInfo(&data_) 
<< "` to `"
                                << TypeTraits<T>::TypeStr() << "`";
@@ -421,7 +467,7 @@ class Any {
     }
     // slow path, try to do fallback convert
     std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
-    if (!opt.has_value()) {
+    if (TVM_FFI_PREDICT_FALSE(!opt.has_value())) {
       TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
                                << TypeTraits<T>::GetMismatchTypeInfo(&data_) 
<< "` to `"
                                << TypeTraits<T>::TypeStr() << "`";
@@ -824,6 +870,69 @@ struct AnyEqual {
   }
 };
 
+// Defer this definition until any.h so the throwing path can depend on
+// TVM_FFI_THROW(TypeError), while object.h stays below the error layer.
+//! \cond Doxygen_Suppress
+template <typename ObjectRefType, typename>
+TVM_FFI_INLINE ObjectRefType ObjectRef::as_or_throw() const& {
+  if (data_ != nullptr) {
+    // Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data 
will optimize away.
+    TVMFFIAny any_data;
+    any_data.type_index = data_->type_index();
+    TVM_FFI_UNSAFE_ASSUME(any_data.type_index >= 
TypeIndex::kTVMFFIStaticObjectBegin);
+    any_data.zero_padding = 0;
+    TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+    any_data.v_obj = 
reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
+    if 
(TVM_FFI_PREDICT_TRUE(TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data))) {
+      ObjectRefType result(UnsafeInit{});
+      result.data_ = data_;
+      return result;
+    } else {
+      TVM_FFI_THROW(TypeError) << "Cannot treat type `"
+                               << 
TypeTraits<ObjectRefType>::GetMismatchTypeInfo(&any_data)
+                               << "` as type `" << 
TypeTraits<ObjectRefType>::TypeStr() << "`";
+    }
+  } else {
+    if constexpr (ObjectRefType::_type_is_nullable) {
+      return ObjectRefType(UnsafeInit{});
+    } else {
+      TVM_FFI_THROW(TypeError) << "Cannot treat type `" << 
StaticTypeKey::kTVMFFINone
+                               << "` as type `" << 
TypeTraits<ObjectRefType>::TypeStr() << "`";
+    }
+  }
+}
+
+template <typename ObjectRefType, typename>
+TVM_FFI_INLINE ObjectRefType ObjectRef::as_or_throw() && {
+  if (data_ != nullptr) {
+    // Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data 
will optimize away.
+    TVMFFIAny any_data;
+    any_data.type_index = data_->type_index();
+    TVM_FFI_UNSAFE_ASSUME(any_data.type_index >= 
TypeIndex::kTVMFFIStaticObjectBegin);
+    any_data.zero_padding = 0;
+    TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+    any_data.v_obj = 
reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
+    if 
(TVM_FFI_PREDICT_TRUE(TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data))) {
+      ObjectRefType result(UnsafeInit{});
+      result.data_ = std::move(data_);
+      data_ = nullptr;
+      return result;
+    } else {
+      TVM_FFI_THROW(TypeError) << "Cannot treat type `"
+                               << 
TypeTraits<ObjectRefType>::GetMismatchTypeInfo(&any_data)
+                               << "` as type `" << 
TypeTraits<ObjectRefType>::TypeStr() << "`";
+    }
+  } else {
+    if constexpr (ObjectRefType::_type_is_nullable) {
+      return ObjectRefType(UnsafeInit{});
+    } else {
+      TVM_FFI_THROW(TypeError) << "Cannot treat type `" << 
StaticTypeKey::kTVMFFINone
+                               << "` as type `" << 
TypeTraits<ObjectRefType>::TypeStr() << "`";
+    }
+  }
+}
+//! \endcond
+
 // Placed near the end because this specialization depends on error handling.
 template <>
 struct TypeTraits<uint64_t> : public TypeTraitsIntBase<uint64_t> {
diff --git a/include/tvm/ffi/base_details.h b/include/tvm/ffi/base_details.h
index 27431b2b..4c2226d8 100644
--- a/include/tvm/ffi/base_details.h
+++ b/include/tvm/ffi/base_details.h
@@ -125,13 +125,13 @@
 #if defined(__clang__)
 #define TVM_FFI_UNSAFE_ASSUME(cond) __builtin_assume(cond)
 #elif defined(__GNUC__)
-// GCC 13+ supports __attribute__((assume(...))); fall back to the void-cast
-// no-op for older GCC where __builtin_assume is absent.
-#if __GNUC__ >= 13
-#define TVM_FFI_UNSAFE_ASSUME(cond) __attribute__((assume(cond)))
-#else
-#define TVM_FFI_UNSAFE_ASSUME(cond) static_cast<void>(0)
-#endif
+// GCC does not reliably propagate __attribute__((assume(...))) through the
+// returned-aggregate/helper flows used in TVM_FFI hot paths. Lower to an
+// unreachable edge instead so GCC 11/14 recover the intended codegen.
+#define TVM_FFI_UNSAFE_ASSUME(cond)       \
+  do {                                    \
+    if (!(cond)) __builtin_unreachable(); \
+  } while (0)
 #elif defined(_MSC_VER)
 #define TVM_FFI_UNSAFE_ASSUME(cond) __assume(cond)
 #else
diff --git a/include/tvm/ffi/cast.h b/include/tvm/ffi/cast.h
index 66abd664..655f1891 100644
--- a/include/tvm/ffi/cast.h
+++ b/include/tvm/ffi/cast.h
@@ -45,6 +45,9 @@ namespace ffi {
 template <typename RefType, typename ObjectType>
 inline RefType GetRef(const ObjectType* ptr) {
   using ContainerType = typename RefType::ContainerType;
+  static_assert(RefType::_type_container_is_exact,
+                "GetRef requires RefType::ContainerType to exactly describe 
all objects the ref "
+                "can hold; use ObjectRef::as<RefType>() for richer 
TypeTraits-based refs");
   static_assert(std::is_base_of_v<ContainerType, ObjectType>,
                 "Can only cast to the ref of same container type");
 
diff --git a/include/tvm/ffi/container/array.h 
b/include/tvm/ffi/container/array.h
index 523aee2d..82380c0a 100644
--- a/include/tvm/ffi/container/array.h
+++ b/include/tvm/ffi/container/array.h
@@ -668,8 +668,10 @@ class Array : public ObjectRef {
     return static_cast<ArrayObj*>(data_.get());
   }
 
-  /*! \brief specify container node */
+  /// \cond Doxygen_Suppress
   using ContainerType = ArrayObj;
+  static constexpr bool _type_container_is_exact = false;
+  /// \endcond
 
   /*!
    * \brief Agregate arguments into a single Array<T>
diff --git a/include/tvm/ffi/container/dict.h b/include/tvm/ffi/container/dict.h
index d28f39de..1887154f 100644
--- a/include/tvm/ffi/container/dict.h
+++ b/include/tvm/ffi/container/dict.h
@@ -266,8 +266,10 @@ class Dict : public ObjectRef {
     }
   }
 
-  /*! \brief specify container node */
+  /// \cond Doxygen_Suppress
   using ContainerType = DictObj;
+  static constexpr bool _type_container_is_exact = false;
+  /// \endcond
 
   /// \cond Doxygen_Suppress
   /*! \brief Iterator of the hash map */
diff --git a/include/tvm/ffi/container/list.h b/include/tvm/ffi/container/list.h
index 6e52be69..f2292ce9 100644
--- a/include/tvm/ffi/container/list.h
+++ b/include/tvm/ffi/container/list.h
@@ -460,8 +460,10 @@ class List : public ObjectRef {
     }
   }
 
-  /*! \brief specify container node */
+  /// \cond Doxygen_Suppress
   using ContainerType = ListObj;
+  static constexpr bool _type_container_is_exact = false;
+  /// \endcond
 
  private:
   /*!
diff --git a/include/tvm/ffi/container/map.h b/include/tvm/ffi/container/map.h
index 20153a06..a9003e20 100644
--- a/include/tvm/ffi/container/map.h
+++ b/include/tvm/ffi/container/map.h
@@ -272,8 +272,10 @@ class Map : public ObjectRef {
     }
     return GetMapObj();
   }
-  /*! \brief specify container node */
+  /// \cond Doxygen_Suppress
   using ContainerType = MapObj;
+  static constexpr bool _type_container_is_exact = false;
+  /// \endcond
 
   /// \cond Doxygen_Suppress
   /*! \brief Iterator of the hash map */
diff --git a/include/tvm/ffi/container/tuple.h 
b/include/tvm/ffi/container/tuple.h
index 79e402eb..090f99f0 100644
--- a/include/tvm/ffi/container/tuple.h
+++ b/include/tvm/ffi/container/tuple.h
@@ -190,8 +190,10 @@ class Tuple : public ObjectRef {
     *ptr = T(std::forward<U>(item));
   }
 
-  /*! \brief specify container node */
+  /// \cond Doxygen_Suppress
   using ContainerType = ArrayObj;
+  static constexpr bool _type_container_is_exact = false;
+  /// \endcond
 
  private:
   static ObjectPtr<ArrayObj> MakeDefaultTupleNode() {
diff --git a/include/tvm/ffi/container/variant.h 
b/include/tvm/ffi/container/variant.h
index 08dc764d..9082e0b2 100644
--- a/include/tvm/ffi/container/variant.h
+++ b/include/tvm/ffi/container/variant.h
@@ -107,6 +107,9 @@ class Variant : public 
details::VariantBase<details::all_object_ref_v<V...>> {
   using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
   static_assert(details::all_storage_enabled_v<V...>,
                 "All types used in Variant<...> must be compatible with Any");
+  /// \cond Doxygen_Suppress
+  static constexpr bool _type_container_is_exact = false;
+  /// \endcond
   /*
    * \brief Helper utility to check if the type can be contained in the variant
    */
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 2048cc5b..2e8a8f40 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -79,6 +79,7 @@ constexpr uint64_t kCombinedRefCountMaskUInt32 = 
(static_cast<uint64_t>(1) << 32
  */
 template <typename TargetType>
 TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index);
+
 }  // namespace details
 
 /*!
@@ -747,7 +748,7 @@ class ObjectRef {
    * \return The pointer to the requested type.
    */
   template <typename ObjectType, typename = 
std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
-  const ObjectType* as() const {
+  const ObjectType* as() const& {
     if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
       return static_cast<ObjectType*>(data_.get());
     } else {
@@ -758,27 +759,95 @@ class ObjectRef {
   /*!
    * \brief Try to downcast the ObjectRef to Optional<T> of the requested type.
    *
-   *  The function will return a std::nullopt if the cast or if the pointer is 
nullptr.
+   * If the cast fails, returns std::nullopt. If this ObjectRef is null, 
returns
+   * a null ref for nullable target refs and std::nullopt for non-nullable 
targets.
    *
-   * \tparam ObjectRefType the target type, must be a subtype of ObjectRef'
+   * \tparam ObjectRefType the target type, must be a subtype of ObjectRef.
    * \return The optional value of the requested type.
    */
   template <typename ObjectRefType,
             typename = std::enable_if_t<std::is_base_of_v<ObjectRef, 
ObjectRefType>>>
-  TVM_FFI_INLINE std::optional<ObjectRefType> as() const {
+  TVM_FFI_INLINE std::optional<ObjectRefType> as() const& {
     if (data_ != nullptr) {
-      if (data_->IsInstance<typename ObjectRefType::ContainerType>()) {
-        ObjectRefType ref(UnsafeInit{});
-        ref.data_ = data_;
-        return ref;
-      } else {
-        return std::nullopt;
+      // Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data 
will optimize away.
+      TVMFFIAny any_data;
+      any_data.type_index = data_->type_index();
+      TVM_FFI_UNSAFE_ASSUME(any_data.type_index >= 
TypeIndex::kTVMFFIStaticObjectBegin);
+      any_data.zero_padding = 0;
+      TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+      any_data.v_obj = 
reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
+      if (TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data)) {
+        ObjectRefType result(UnsafeInit{});
+        result.data_ = data_;
+        return result;
+      }
+      return std::nullopt;
+    }
+    if constexpr (ObjectRefType::_type_is_nullable) {
+      return ObjectRefType(UnsafeInit{});
+    }
+    return std::nullopt;
+  }
+
+  /*!
+   * \brief Try to move-downcast the ObjectRef to Optional<T> of the requested 
type.
+   *
+   * If the cast succeeds, moves the internal object pointer to the returned
+   * ObjectRefType. If the cast fails, returns std::nullopt. If this ObjectRef
+   * is null, returns a null ref for nullable target refs and std::nullopt for
+   * non-nullable targets.
+   *
+   * \tparam ObjectRefType the target type, must be a subtype of ObjectRef.
+   * \return The optional value of the requested type.
+   */
+  template <typename ObjectRefType,
+            typename = std::enable_if_t<std::is_base_of_v<ObjectRef, 
ObjectRefType>>>
+  TVM_FFI_INLINE std::optional<ObjectRefType> as() && {
+    if (data_ != nullptr) {
+      // Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data 
will optimize away.
+      TVMFFIAny any_data;
+      any_data.type_index = data_->type_index();
+      TVM_FFI_UNSAFE_ASSUME(any_data.type_index >= 
TypeIndex::kTVMFFIStaticObjectBegin);
+      any_data.zero_padding = 0;
+      TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
+      any_data.v_obj = 
reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
+      if (TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data)) {
+        ObjectRefType result(UnsafeInit{});
+        result.data_ = std::move(data_);
+        data_ = nullptr;
+        return result;
       }
-    } else {
       return std::nullopt;
     }
+    if constexpr (ObjectRefType::_type_is_nullable) {
+      return ObjectRefType(UnsafeInit{});
+    }
+    return std::nullopt;
   }
 
+  /*!
+   * \brief Downcast the ObjectRef to the requested type or throw.
+   *
+   * \tparam ObjectRefType the target type, must be a subtype of ObjectRef
+   * \return The requested value.
+   */
+  template <typename ObjectRefType,
+            typename = std::enable_if_t<std::is_base_of_v<ObjectRef, 
ObjectRefType>>>
+  TVM_FFI_INLINE ObjectRefType as_or_throw() const&;
+
+  /*!
+   * \brief Move-downcast the ObjectRef to the requested type or throw.
+   *
+   * If the cast succeeds, moves the internal object pointer to the returned
+   * ObjectRefType.
+   *
+   * \tparam ObjectRefType the target type, must be a subtype of ObjectRef
+   * \return The requested value.
+   */
+  template <typename ObjectRefType,
+            typename = std::enable_if_t<std::is_base_of_v<ObjectRef, 
ObjectRefType>>>
+  TVM_FFI_INLINE ObjectRefType as_or_throw() &&;
+
   /*!
    * \brief Get the type index of the ObjectRef
    * \return The type index of the ObjectRef
@@ -797,6 +866,8 @@ class ObjectRef {
 
   /*! \brief type indicate the container type. */
   using ContainerType = Object;
+  /*! \brief Whether ContainerType exactly describes the objects this ref can 
hold. */
+  static constexpr bool _type_container_is_exact = true;
   /*! \brief Whether the reference can point to nullptr */
   static constexpr bool _type_is_nullable = true;
 
diff --git a/include/tvm/ffi/optional.h b/include/tvm/ffi/optional.h
index b28f7105..eb63fafb 100644
--- a/include/tvm/ffi/optional.h
+++ b/include/tvm/ffi/optional.h
@@ -261,6 +261,7 @@ template <typename T>
 class Optional<T, std::enable_if_t<use_ptr_based_optional_v<T>>> : public 
ObjectRef {
  public:
   using ContainerType = typename T::ContainerType;
+  static constexpr bool _type_container_is_exact = T::_type_container_is_exact;
   Optional() = default;
   // NOLINTBEGIN(google-explicit-constructor)
   Optional(const Optional<T>& other) : ObjectRef(other) {}
diff --git a/tests/cpp/test_any.cc b/tests/cpp/test_any.cc
index bc8d36d6..698a22f4 100644
--- a/tests/cpp/test_any.cc
+++ b/tests/cpp/test_any.cc
@@ -370,6 +370,56 @@ TEST(Any, ObjectRefWithFallbackTraits) {
   EXPECT_EQ(v9->value, 0);
 }
 
+TEST(Any, AsOrThrow) {
+  Any any_int = 1;
+  EXPECT_EQ(any_int.as_or_throw<int>(), 1);
+  EXPECT_EQ(std::move(any_int).as_or_throw<int>(), 1);
+
+  Any any_obj = TInt(11);
+  EXPECT_EQ(any_obj.as_or_throw<const TIntObj*>()->value, 11);
+  EXPECT_EQ(any_obj.as_or_throw<TInt>()->value, 11);
+
+  const Any const_any_obj = TInt(12);
+  auto const_as_obj = const_any_obj.as<TInt>();
+  ASSERT_TRUE(const_as_obj.has_value()) << "Expected const Any as<TInt>() to 
succeed";
+  EXPECT_EQ((*const_as_obj).get()->value, 12);  // 
NOLINT(bugprone-unchecked-optional-access)
+  EXPECT_EQ(const_any_obj.as_or_throw<TInt>()->value, 12);
+
+  auto moved_as_obj = Any(TInt(13)).as<TInt>();
+  ASSERT_TRUE(moved_as_obj.has_value()) << "Expected rvalue Any as<TInt>() to 
succeed";
+  EXPECT_EQ((*moved_as_obj).get()->value, 13);  // 
NOLINT(bugprone-unchecked-optional-access)
+  EXPECT_EQ(Any(TInt(13)).as_or_throw<TInt>()->value, 13);
+
+  Any any_float = 1;
+  EXPECT_THROW(
+      {
+        try {
+          [[maybe_unused]] auto value = any_float.as_or_throw<double>();
+        } catch (const Error& error) {
+          EXPECT_EQ(error.kind(), "TypeError");
+          std::string what = error.what();
+          EXPECT_NE(what.find("Cannot treat type `int` as type `float`"), 
std::string::npos);
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+
+  Any any_number = TFloat(2.5);
+  EXPECT_THROW(
+      {
+        try {
+          [[maybe_unused]] auto value = any_number.as_or_throw<TInt>();
+        } catch (const Error& error) {
+          EXPECT_EQ(error.kind(), "TypeError");
+          std::string what = error.what();
+          EXPECT_NE(what.find("Cannot treat type `test.Float` as type 
`test.Int`"),
+                    std::string::npos);
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+}
+
 TEST(Any, CastVsAs) {
   AnyView view0 = 1;
   // as only runs strict check
diff --git a/tests/cpp/test_object.cc b/tests/cpp/test_object.cc
index fc621e77..255087a7 100644
--- a/tests/cpp/test_object.cc
+++ b/tests/cpp/test_object.cc
@@ -17,16 +17,63 @@
  * under the License.
  */
 #include <gtest/gtest.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
+#include <tvm/ffi/container/list.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/tuple.h>
+#include <tvm/ffi/container/variant.h>
 #include <tvm/ffi/memory.h>
 #include <tvm/ffi/object.h>
+#include <tvm/ffi/optional.h>
 
 #include "./testing_object.h"
 
+namespace tvm {
+namespace ffi {
+namespace testing {
+
+class TIntOrFloatRef : public ObjectRef {
+ public:
+  TIntOrFloatRef() = default;
+  explicit TIntOrFloatRef(UnsafeInit tag) : ObjectRef(tag) {}
+
+  static constexpr bool _type_is_nullable = true;
+  static constexpr bool _type_container_is_exact = false;
+  using ContainerType = Object;
+};
+
+}  // namespace testing
+
+template <>
+struct TypeTraits<testing::TIntOrFloatRef>
+    : public ObjectRefTypeTraitsBase<testing::TIntOrFloatRef> {
+  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+    return TypeTraits<testing::TInt>::CheckAnyStrict(src) ||
+           TypeTraits<testing::TFloat>::CheckAnyStrict(src);
+  }
+};
+
+}  // namespace ffi
+}  // namespace tvm
+
 namespace {
 
 using namespace tvm::ffi;
 using namespace tvm::ffi::testing;
 
+static_assert(ObjectRef::_type_container_is_exact);
+static_assert(TNumber::_type_container_is_exact);
+static_assert(TInt::_type_container_is_exact);
+static_assert(Optional<TInt>::_type_container_is_exact);
+static_assert(!TIntOrFloatRef::_type_container_is_exact);
+static_assert(!Array<TInt>::_type_container_is_exact);
+static_assert(!List<TInt>::_type_container_is_exact);
+static_assert(!Map<TInt, TFloat>::_type_container_is_exact);
+static_assert(!Dict<TInt, TFloat>::_type_container_is_exact);
+static_assert(!Tuple<TInt, TFloat>::_type_container_is_exact);
+static_assert(!Variant<TInt, TFloat>::_type_container_is_exact);
+
 template <typename T>
 class CRTPObject : public Object {
  public:
@@ -133,11 +180,83 @@ TEST(ObjectRef, as) {
   EXPECT_TRUE(c.as<TIntObj>() == nullptr);
   EXPECT_TRUE(c.as<TFloatObj>() == nullptr);
   EXPECT_TRUE(c.as<TNumberObj>() == nullptr);
+  auto null_number = c.as<TNumber>();
+  ASSERT_TRUE(null_number.has_value()) << "Expected nullable null ObjectRef 
cast to succeed";
+  EXPECT_TRUE(!(*null_number).defined());  // 
NOLINT(bugprone-unchecked-optional-access)
+  EXPECT_TRUE(!c.as<TInt>().has_value());
 
   EXPECT_EQ(a.as<TIntObj>()->value, 10);
   EXPECT_EQ(b.as<TFloatObj>()->value, 20);
 }
 
+TEST(ObjectRef, AsUsesTypeTraitsCheckAnyStrict) {
+  ObjectRef a = TInt(10);
+  ObjectRef b = TFloat(20);
+
+  auto int_like = a.as<TIntOrFloatRef>();
+  ASSERT_TRUE(int_like.has_value()) << "Expected TIntOrFloatRef cast from TInt 
to succeed";
+  EXPECT_TRUE((*int_like).as<TIntObj>() != nullptr);  // 
NOLINT(bugprone-unchecked-optional-access)
+
+  auto float_like = b.as<TIntOrFloatRef>();
+  ASSERT_TRUE(float_like.has_value()) << "Expected TIntOrFloatRef cast from 
TFloat to succeed";
+  EXPECT_NE((*float_like).as<TFloatObj>(), nullptr);  // 
NOLINT(bugprone-unchecked-optional-access)
+}
+
+TEST(ObjectRef, AsOrThrow) {
+  ObjectRef a = TInt(10);
+  ObjectRef b = TFloat(20);
+  ObjectRef c(nullptr);
+  const ObjectRef const_a = TInt(30);
+  ObjectRef movable_as = TInt(40);
+  ObjectRef movable_as_or_throw = TInt(50);
+
+  EXPECT_EQ(a.as<TIntObj>()->value, 10);
+  EXPECT_EQ(a.as_or_throw<TInt>()->value, 10);
+  EXPECT_EQ(b.as<TFloatObj>()->value, 20);
+  EXPECT_TRUE(!c.as_or_throw<TNumber>().defined());
+  auto const_as = const_a.as<TInt>();
+  ASSERT_TRUE(const_as.has_value()) << "Expected const ObjectRef as<TInt>() to 
succeed";
+  EXPECT_EQ((*const_as).get()->value, 30);  // 
NOLINT(bugprone-unchecked-optional-access)
+  EXPECT_EQ(const_a.as_or_throw<TInt>()->value, 30);
+
+  auto moved_as = std::move(movable_as).as<TInt>();
+  ASSERT_TRUE(moved_as.has_value()) << "Expected rvalue ObjectRef as<TInt>() 
to succeed";
+  EXPECT_EQ((*moved_as).get()->value, 40);  // 
NOLINT(bugprone-unchecked-optional-access)
+  // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
+  EXPECT_FALSE(movable_as.defined());
+
+  EXPECT_EQ(std::move(movable_as_or_throw).as_or_throw<TInt>()->value, 50);
+  // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
+  EXPECT_FALSE(movable_as_or_throw.defined());
+
+  EXPECT_THROW(
+      {
+        try {
+          [[maybe_unused]] auto value = a.as_or_throw<TFloat>();
+        } catch (const Error& error) {
+          EXPECT_EQ(error.kind(), "TypeError");
+          std::string what = error.what();
+          EXPECT_NE(what.find("Cannot treat type `test.Int` as type 
`test.Float`"),
+                    std::string::npos);
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+
+  EXPECT_THROW(
+      {
+        try {
+          [[maybe_unused]] auto value = c.as_or_throw<TInt>();
+        } catch (const Error& error) {
+          EXPECT_EQ(error.kind(), "TypeError");
+          std::string what = error.what();
+          EXPECT_NE(what.find("Cannot treat type `None` as type `test.Int`"), 
std::string::npos);
+          throw;
+        }
+      },
+      ::tvm::ffi::Error);
+}
+
 TEST(ObjectRef, UnsafeInit) {
   ObjectRef a(UnsafeInit{});
   EXPECT_TRUE(a.get() == nullptr);

Reply via email to