This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9aa5bf7ad8 [FFI][REFACTOR] Update to distinguish as and cast (#17979)
9aa5bf7ad8 is described below
commit 9aa5bf7ad8defbd7c4ddd933abdc2896513d1d3a
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed May 14 19:52:30 2025 -0700
[FFI][REFACTOR] Update to distinguish as and cast (#17979)
This PR updates the Any system to distinguish as and cast
- as function will run strict check and won't do any type conversion
- try_cast/cast will try to run the type conversion
We also updated the type traits to be consistent with the naming
---
ffi/include/tvm/ffi/any.h | 136 ++++++++++++++----
ffi/include/tvm/ffi/container/array.h | 38 +++---
ffi/include/tvm/ffi/container/map.h | 37 ++---
ffi/include/tvm/ffi/container/tuple.h | 27 ++--
ffi/include/tvm/ffi/container/variant.h | 18 +--
ffi/include/tvm/ffi/dtype.h | 8 +-
ffi/include/tvm/ffi/function.h | 10 +-
ffi/include/tvm/ffi/function_details.h | 2 +-
ffi/include/tvm/ffi/rvalue_ref.h | 9 +-
ffi/include/tvm/ffi/string.h | 7 +-
ffi/include/tvm/ffi/type_traits.h | 152 +++++++++++----------
ffi/src/ffi/ndarray.cc | 2 +-
ffi/tests/cpp/test_any.cc | 33 +++++
ffi/tests/cpp/test_dtype.cc | 4 +-
ffi/tests/cpp/test_string.cc | 31 +++--
ffi/tests/cpp/test_variant.cc | 1 -
include/tvm/ir/attrs.h | 10 +-
include/tvm/ir/transform.h | 4 +-
include/tvm/runtime/data_type.h | 13 +-
include/tvm/topi/utils.h | 2 +-
src/arith/analyzer.cc | 2 +-
src/meta_schedule/database/database_utils.cc | 6 +-
.../postproc/rewrite_parallel_vectorize_unroll.cc | 8 +-
src/meta_schedule/utils.h | 6 +-
src/node/serialization.cc | 4 +-
src/runtime/c_runtime_api.cc | 2 +-
src/runtime/contrib/json/json_runtime.h | 2 +-
src/runtime/relax_vm/builtin.cc | 8 +-
src/runtime/relax_vm/ndarray_cache_support.cc | 2 +-
src/runtime/relax_vm/vm.cc | 2 +-
src/runtime/rpc/rpc_channel.cc | 2 +-
src/script/ir_builder/tir/ir.cc | 2 +-
src/support/ffi_testing.cc | 2 +-
src/target/target.cc | 22 +--
src/target/target_kind.cc | 2 +-
src/te/operation/create_primfunc.cc | 2 +-
src/tir/ir/function.cc | 2 -
src/tir/ir/stmt.cc | 5 +-
src/tir/ir/utils.cc | 68 ---------
src/tir/ir/utils.h | 47 -------
src/tir/op/op.cc | 4 +-
src/tir/schedule/concrete_schedule.cc | 8 +-
src/tir/schedule/instruction_traits.h | 4 +-
src/tir/schedule/primitive/block_annotate.cc | 2 +-
src/tir/schedule/trace.cc | 4 +-
src/tir/transforms/inject_permuted_layout.cc | 2 +-
src/tir/transforms/lower_opaque_block.cc | 4 +-
src/topi/transform.cc | 2 +-
48 files changed, 384 insertions(+), 386 deletions(-)
diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 4ec72d6846..7897d62898 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -99,14 +99,41 @@ class AnyView {
return *this;
}
+ /*!
+ * \brief Try to see if we can reinterpret the AnyView to as T object.
+ *
+ * \tparam T The type to cast to.
+ * \return The casted value, or std::nullopt if the cast is not possible.
+ * \note This function won't try run type conversion (use try_cast for that
purpose).
+ */
template <typename T, typename =
std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE std::optional<T> as() const {
- return TypeTraits<T>::TryConvertFromAnyView(&data_);
+ if (TypeTraits<T>::CheckAnyStrict(&data_)) {
+ return TypeTraits<T>::CopyFromAnyViewAfterCheck(&data_);
+ } else {
+ return std::optional<T>(std::nullopt);
+ }
+ }
+ /*
+ * \brief Shortcut of as Object to cast to a const pointer when T is an
Object.
+ *
+ * \tparam T The object type.
+ * \return The requested pointer, returns nullptr if type mismatches.
+ */
+ template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object,
T>>>
+ TVM_FFI_INLINE const T* as() const {
+ return this->as<const T*>().value_or(nullptr);
}
+ /**
+ * \brief Cast to a type T.
+ *
+ * \tparam T The type to cast to.
+ * \return The casted value, or throws an exception if the cast is not
possible.
+ */
template <typename T, typename =
std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE T cast() const {
- std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_);
+ std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
if (!opt.has_value()) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_)
<< "` to `"
@@ -115,16 +142,17 @@ class AnyView {
return *std::move(opt);
}
- /*
- * \brief Shortcut of as Object to cast to a const pointer when T is an
Object.
+ /*!
+ * \brief Try to cast to a type T, return std::nullopt if the cast is not
possible.
*
- * \tparam T The object type.
- * \return The requested pointer, returns nullptr if type mismatches.
+ * \tparam T The type to cast to.
+ * \return The casted value, or std::nullopt if the cast is not possible.
*/
- template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object,
T>>>
- TVM_FFI_INLINE const T* as() const {
- return this->as<const T*>().value_or(nullptr);
+ template <typename T, typename =
std::enable_if_t<TypeTraits<T>::convert_enabled>>
+ TVM_FFI_INLINE std::optional<T> try_cast() const {
+ return TypeTraits<T>::TryCastFromAnyView(&data_);
}
+
// comparison with nullptr
TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept {
return data_.type_index == TypeIndex::kTVMFFINone;
@@ -269,13 +297,45 @@ class Any {
return *this;
}
+ /**
+ * \brief Try to reinterpret the Any as a type T, return std::nullopt if it
is not possible.
+ *
+ * \tparam T The type to cast to.
+ * \return The casted value, or std::nullopt if the cast is not possible.
+ * \note This function won't try to run type conversion (use try_cast for
that purpose).
+ */
+ template <typename T,
+ typename = std::enable_if_t<TypeTraits<T>::storage_enabled ||
std::is_same_v<T, Any>>>
+ TVM_FFI_INLINE std::optional<T> as() && {
+ if constexpr (std::is_same_v<T, Any>) {
+ return std::move(*this);
+ } else {
+ if (TypeTraits<T>::CheckAnyStrict(&data_)) {
+ return TypeTraits<T>::MoveFromAnyAfterCheck(&data_);
+ } else {
+ return std::optional<T>(std::nullopt);
+ }
+ }
+ }
+
+ /**
+ * \brief Try to reinterpret the Any as a type T, return std::nullopt if it
is not possible.
+ *
+ * \tparam T The type to cast to.
+ * \return The casted value, or std::nullopt if the cast is not possible.
+ * \note This function won't try to run type conversion (use try_cast for
that purpose).
+ */
template <typename T,
typename = std::enable_if_t<TypeTraits<T>::convert_enabled ||
std::is_same_v<T, Any>>>
- TVM_FFI_INLINE std::optional<T> as() const {
+ TVM_FFI_INLINE std::optional<T> as() const& {
if constexpr (std::is_same_v<T, Any>) {
return *this;
} else {
- return TypeTraits<T>::TryConvertFromAnyView(&data_);
+ if (TypeTraits<T>::CheckAnyStrict(&data_)) {
+ return TypeTraits<T>::CopyFromAnyViewAfterCheck(&data_);
+ } else {
+ return std::optional<T>(std::nullopt);
+ }
}
}
@@ -286,13 +346,18 @@ class Any {
* \return The requested pointer, returns nullptr if type mismatches.
*/
template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object,
T>>>
- TVM_FFI_INLINE const T* as() const {
+ TVM_FFI_INLINE const T* as() const& {
return this->as<const T*>().value_or(nullptr);
}
+ /**
+ * \brief Cast to a type T, throw an exception if the cast is not possible.
+ *
+ * \tparam T The type to cast to.
+ */
template <typename T, typename =
std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE T cast() const& {
- std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_);
+ std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
if (!opt.has_value()) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_)
<< "` to `"
@@ -301,13 +366,18 @@ class Any {
return *std::move(opt);
}
+ /**
+ * \brief Cast to a type T, throw an exception if the cast is not possible.
+ *
+ * \tparam T The type to cast to.
+ */
template <typename T, typename =
std::enable_if_t<TypeTraits<T>::storage_enabled>>
TVM_FFI_INLINE T cast() && {
- if (TypeTraits<T>::CheckAnyStorage(&data_)) {
- return TypeTraits<T>::MoveFromAnyStorageAfterCheck(&data_);
+ if (TypeTraits<T>::CheckAnyStrict(&data_)) {
+ return TypeTraits<T>::MoveFromAnyAfterCheck(&data_);
}
// slow path, try to do fallback convert
- std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_);
+ std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
if (!opt.has_value()) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_)
<< "` to `"
@@ -316,6 +386,22 @@ class Any {
return *std::move(opt);
}
+ /**
+ * \brief Try to cast to a type T.
+ *
+ * \tparam T The type to cast to.
+ * \return The casted value, or std::nullopt if the cast is not possible.
+ * \note use STL name since it to be more consistent with cast API.
+ */
+ template <typename T,
+ typename = std::enable_if_t<TypeTraits<T>::convert_enabled ||
std::is_same_v<T, Any>>>
+ TVM_FFI_INLINE std::optional<T> try_cast() const {
+ if constexpr (std::is_same_v<T, Any>) {
+ return *this;
+ } else {
+ return TypeTraits<T>::TryCastFromAnyView(&data_);
+ }
+ }
/*
* \brief Check if the two Any are same type and value in shallow comparison.
* \param other The other Any
@@ -412,23 +498,23 @@ struct AnyUnsafe : public ObjectUnsafe {
}
template <typename T>
- static TVM_FFI_INLINE bool CheckAnyStorage(const Any& ref) {
- return TypeTraits<T>::CheckAnyStorage(&(ref.data_));
+ static TVM_FFI_INLINE bool CheckAnyStrict(const Any& ref) {
+ return TypeTraits<T>::CheckAnyStrict(&(ref.data_));
}
template <typename T>
- static TVM_FFI_INLINE T CopyFromAnyStorageAfterCheck(const Any& ref) {
+ static TVM_FFI_INLINE T CopyFromAnyViewAfterCheck(const Any& ref) {
if constexpr (!std::is_same_v<T, Any>) {
- return TypeTraits<T>::CopyFromAnyStorageAfterCheck(&(ref.data_));
+ return TypeTraits<T>::CopyFromAnyViewAfterCheck(&(ref.data_));
} else {
return ref;
}
}
template <typename T>
- static TVM_FFI_INLINE T MoveFromAnyStorageAfterCheck(Any&& ref) {
+ static TVM_FFI_INLINE T MoveFromAnyAfterCheck(Any&& ref) {
if constexpr (!std::is_same_v<T, Any>) {
- return TypeTraits<T>::MoveFromAnyStorageAfterCheck(&(ref.data_));
+ return TypeTraits<T>::MoveFromAnyAfterCheck(&(ref.data_));
} else {
return std::move(ref);
}
@@ -461,7 +547,7 @@ struct AnyHash {
if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
const BytesObjBase* src_str =
- details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const
BytesObjBase*>(src);
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
BytesObjBase*>(src);
return details::StableHashBytes(src_str->data, src_str->size);
} else {
return src.data_.v_uint64;
@@ -487,9 +573,9 @@ struct AnyEqual {
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
const BytesObjBase* lhs_str =
- details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const
BytesObjBase*>(lhs);
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
BytesObjBase*>(lhs);
const BytesObjBase* rhs_str =
- details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const
BytesObjBase*>(rhs);
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
BytesObjBase*>(rhs);
return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size,
rhs_str->size) == 0;
}
return false;
diff --git a/ffi/include/tvm/ffi/container/array.h
b/ffi/include/tvm/ffi/container/array.h
index ce00436043..30402d9ae6 100644
--- a/ffi/include/tvm/ffi/container/array.h
+++ b/ffi/include/tvm/ffi/container/array.h
@@ -386,9 +386,7 @@ class Array : public ObjectRef {
// iterators
struct ValueConverter {
using ResultType = T;
- static T convert(const Any& n) {
- return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(n);
- }
+ static T convert(const Any& n) { return
details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(n); }
};
using iterator = details::IterAdapter<ValueConverter, const Any*>;
@@ -427,7 +425,7 @@ class Array : public ObjectRef {
if (i < 0 || i >= p->size_) {
TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size "
<< p->size_;
}
- return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*(p->begin() +
i));
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->begin() + i));
}
/*! \return The size of the array */
@@ -451,7 +449,7 @@ class Array : public ObjectRef {
if (p == nullptr || p->size_ == 0) {
TVM_FFI_THROW(IndexError) << "cannot index a empty array";
}
- return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*(p->begin()));
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->begin()));
}
/*! \return The last element of the array */
@@ -460,7 +458,7 @@ class Array : public ObjectRef {
if (p == nullptr || p->size_ == 0) {
TVM_FFI_THROW(IndexError) << "cannot index a empty array";
}
- return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*(p->end() -
1));
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->end() - 1));
}
public:
@@ -835,7 +833,7 @@ class Array : public ObjectRef {
// no other shared copies of the array.
auto arr = static_cast<ArrayObj*>(data.get());
for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
- T value = details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*it);
+ T value = details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*it);
// reset the original value to nullptr, to ensure unique ownership
it->reset();
T mapped = fmap(std::move(value));
@@ -860,7 +858,7 @@ class Array : public ObjectRef {
// `T`.
bool all_identical = true;
for (; it != arr->end(); it++) {
- U mapped =
fmap(details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*it));
+ U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*it));
if (!(*it).same_as(mapped)) {
// At least one mapped element is different than the
// original. Therefore, prepare the output array,
@@ -914,7 +912,7 @@ class Array : public ObjectRef {
// so we can either start or resume the iteration from that point,
// with no further checks on the result.
for (; it != arr->end(); it++) {
- U mapped =
fmap(details::AnyUnsafe::CopyFromAnyStorageAfterCheck<T>(*it));
+ U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*it));
output->SetItem(it - arr->begin(), std::move(mapped));
}
@@ -952,7 +950,7 @@ inline constexpr bool use_default_type_traits_v<Array<T>> =
false;
template <typename T>
struct TypeTraits<Array<T>> : public ObjectRefTypeTraitsBase<Array<T>> {
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;
- using ObjectRefTypeTraitsBase<Array<T>>::CopyFromAnyStorageAfterCheck;
+ using ObjectRefTypeTraitsBase<Array<T>>::CopyFromAnyViewAfterCheck;
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIArray) {
@@ -962,10 +960,10 @@ struct TypeTraits<Array<T>> : public
ObjectRefTypeTraitsBase<Array<T>> {
const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
for (size_t i = 0; i < n->size(); i++) {
const Any& any_v = (*n)[i];
- // CheckAnyStorage is cheaper than as<T>
- if (details::AnyUnsafe::CheckAnyStorage<T>(any_v)) continue;
+ // CheckAnyStrict is cheaper than try_cast<T>
+ if (details::AnyUnsafe::CheckAnyStrict<T>(any_v)) continue;
// try see if p is convertible to T
- if (any_v.as<T>()) continue;
+ if (any_v.try_cast<T>()) continue;
// now report the accurate mismatch information
return "Array[index " + std::to_string(i) + ": " +
details::AnyUnsafe::GetMismatchTypeInfo<T>(any_v) + "]";
@@ -975,7 +973,7 @@ struct TypeTraits<Array<T>> : public
ObjectRefTypeTraitsBase<Array<T>> {
TVM_FFI_UNREACHABLE();
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIArray) return false;
if constexpr (std::is_same_v<T, Any>) {
return true;
@@ -983,13 +981,13 @@ struct TypeTraits<Array<T>> : public
ObjectRefTypeTraitsBase<Array<T>> {
const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
for (size_t i = 0; i < n->size(); i++) {
const Any& any_v = (*n)[i];
- if (!details::AnyUnsafe::CheckAnyStorage<T>(any_v)) return false;
+ if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
}
return true;
}
}
- static TVM_FFI_INLINE std::optional<Array<T>> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<Array<T>> TryCastFromAnyView(const
TVMFFIAny* src) {
// try to run conversion.
if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt;
if constexpr (!std::is_same_v<T, Any>) {
@@ -997,20 +995,20 @@ struct TypeTraits<Array<T>> : public
ObjectRefTypeTraitsBase<Array<T>> {
bool storage_check = [&]() {
for (size_t i = 0; i < n->size(); i++) {
const Any& any_v = (*n)[i];
- if (!details::AnyUnsafe::CheckAnyStorage<T>(any_v)) return false;
+ if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
}
return true;
}();
// fast path, if storage check passes, we can return the array directly.
if (storage_check) {
- return CopyFromAnyStorageAfterCheck(src);
+ return CopyFromAnyViewAfterCheck(src);
}
// slow path, try to run a conversion to Array<T>
Array<T> result;
result.reserve(n->size());
for (size_t i = 0; i < n->size(); i++) {
const Any& any_v = (*n)[i];
- if (auto opt_v = any_v.as<T>()) {
+ if (auto opt_v = any_v.try_cast<T>()) {
result.push_back(*std::move(opt_v));
} else {
return std::nullopt;
@@ -1018,7 +1016,7 @@ struct TypeTraits<Array<T>> : public
ObjectRefTypeTraitsBase<Array<T>> {
}
return result;
} else {
- return CopyFromAnyStorageAfterCheck(src);
+ return CopyFromAnyViewAfterCheck(src);
}
}
diff --git a/ffi/include/tvm/ffi/container/map.h
b/ffi/include/tvm/ffi/container/map.h
index d2dd1a54d7..a738c1e229 100644
--- a/ffi/include/tvm/ffi/container/map.h
+++ b/ffi/include/tvm/ffi/container/map.h
@@ -1354,7 +1354,7 @@ class Map : public ObjectRef {
* \return the corresonding element.
*/
const V at(const K& key) const {
- return
details::AnyUnsafe::CopyFromAnyStorageAfterCheck<V>(GetMapObj()->at(key));
+ return
details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(GetMapObj()->at(key));
}
/*!
* \brief Read element from map.
@@ -1402,7 +1402,7 @@ class Map : public ObjectRef {
if (iter == GetMapObj()->end()) {
return std::nullopt;
}
- return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<V>(iter->second);
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(iter->second);
}
void erase(const K& key) { CopyOnWrite()->erase(key); }
@@ -1445,8 +1445,8 @@ class Map : public ObjectRef {
/*! \brief De-reference iterators */
reference operator*() const {
auto& kv = *itr;
- return
std::make_pair(details::AnyUnsafe::CopyFromAnyStorageAfterCheck<K>(kv.first),
-
details::AnyUnsafe::CopyFromAnyStorageAfterCheck<V>(kv.second));
+ return
std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck<K>(kv.first),
+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(kv.second));
}
/*! \brief Prefix self increment, e.g. ++iter */
iterator& operator++() {
@@ -1513,7 +1513,7 @@ inline constexpr bool use_default_type_traits_v<Map<K,
V>> = false;
template <typename K, typename V>
struct TypeTraits<Map<K, V>> : public ObjectRefTypeTraitsBase<Map<K, V>> {
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap;
- using ObjectRefTypeTraitsBase<Map<K, V>>::CopyFromAnyStorageAfterCheck;
+ using ObjectRefTypeTraitsBase<Map<K, V>>::CopyFromAnyViewAfterCheck;
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIMap) {
@@ -1523,14 +1523,15 @@ struct TypeTraits<Map<K, V>> : public
ObjectRefTypeTraitsBase<Map<K, V>> {
const MapObj* n = reinterpret_cast<const MapObj*>(src->v_obj);
for (const auto& kv : *n) {
if constexpr (!std::is_same_v<K, Any>) {
- if (!details::AnyUnsafe::CheckAnyStorage<K>(kv.first) &&
!kv.first.as<K>().has_value()) {
+ if (!details::AnyUnsafe::CheckAnyStrict<K>(kv.first) &&
+ !kv.first.try_cast<K>().has_value()) {
return "Map[some key is " +
details::AnyUnsafe::GetMismatchTypeInfo<K>(kv.first) +
", V]";
}
}
if constexpr (!std::is_same_v<V, Any>) {
- if (!details::AnyUnsafe::CheckAnyStorage<V>(kv.second) &&
- !kv.second.as<V>().has_value()) {
+ if (!details::AnyUnsafe::CheckAnyStrict<V>(kv.second) &&
+ !kv.second.try_cast<V>().has_value()) {
return "Map[K, some value is " +
details::AnyUnsafe::GetMismatchTypeInfo<V>(kv.second) +
"]";
}
@@ -1541,7 +1542,7 @@ struct TypeTraits<Map<K, V>> : public
ObjectRefTypeTraitsBase<Map<K, V>> {
TVM_FFI_UNREACHABLE();
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIMap) return false;
if constexpr (std::is_same_v<K, Any> && std::is_same_v<V, Any>) {
return true;
@@ -1549,44 +1550,44 @@ struct TypeTraits<Map<K, V>> : public
ObjectRefTypeTraitsBase<Map<K, V>> {
const MapObj* n = reinterpret_cast<const MapObj*>(src->v_obj);
for (const auto& kv : *n) {
if constexpr (!std::is_same_v<K, Any>) {
- if (!details::AnyUnsafe::CheckAnyStorage<K>(kv.first)) return false;
+ if (!details::AnyUnsafe::CheckAnyStrict<K>(kv.first)) return false;
}
if constexpr (!std::is_same_v<V, Any>) {
- if (!details::AnyUnsafe::CheckAnyStorage<V>(kv.second)) return false;
+ if (!details::AnyUnsafe::CheckAnyStrict<V>(kv.second)) return false;
}
}
return true;
}
}
- static TVM_FFI_INLINE std::optional<Map<K, V>> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<Map<K, V>> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIMap) return std::nullopt;
if constexpr (!std::is_same_v<K, Any> || !std::is_same_v<V, Any>) {
const MapObj* n = reinterpret_cast<const MapObj*>(src->v_obj);
bool storage_check = [&]() {
for (const auto& kv : *n) {
if constexpr (!std::is_same_v<K, Any>) {
- if (!details::AnyUnsafe::CheckAnyStorage<K>(kv.first)) return
false;
+ if (!details::AnyUnsafe::CheckAnyStrict<K>(kv.first)) return false;
}
if constexpr (!std::is_same_v<V, Any>) {
- if (!details::AnyUnsafe::CheckAnyStorage<V>(kv.second)) return
false;
+ if (!details::AnyUnsafe::CheckAnyStrict<V>(kv.second)) return
false;
}
}
return true;
}();
// fast path, if storage check passes, we can return the array directly.
- if (storage_check) return CopyFromAnyStorageAfterCheck(src);
+ if (storage_check) return CopyFromAnyViewAfterCheck(src);
// slow path, we need to create a new map and convert to the target type.
Map<K, V> ret;
for (const auto& kv : *n) {
- auto k = kv.first.as<K>();
- auto v = kv.second.as<V>();
+ auto k = kv.first.try_cast<K>();
+ auto v = kv.second.try_cast<V>();
if (!k.has_value() || !v.has_value()) return std::nullopt;
ret.Set(*std::move(k), *std::move(v));
}
return ret;
} else {
- return CopyFromAnyStorageAfterCheck(src);
+ return CopyFromAnyViewAfterCheck(src);
}
}
diff --git a/ffi/include/tvm/ffi/container/tuple.h
b/ffi/include/tvm/ffi/container/tuple.h
index 641d006a94..9fba4b0e06 100644
--- a/ffi/include/tvm/ffi/container/tuple.h
+++ b/ffi/include/tvm/ffi/container/tuple.h
@@ -98,7 +98,7 @@ class Tuple : public ObjectRef {
static_assert(I < sizeof...(Types), "Tuple index out of bounds");
using ReturnType = std::tuple_element_t<I, std::tuple<Types...>>;
const Any* ptr = GetArrayObj()->begin() + I;
- return details::AnyUnsafe::CopyFromAnyStorageAfterCheck<ReturnType>(*ptr);
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<ReturnType>(*ptr);
}
/*!
@@ -179,7 +179,7 @@ inline constexpr bool
use_default_type_traits_v<Tuple<Types...>> = false;
template <typename... Types>
struct TypeTraits<Tuple<Types...>> : public
ObjectRefTypeTraitsBase<Tuple<Types...>> {
- using ObjectRefTypeTraitsBase<Tuple<Types...>>::CopyFromAnyStorageAfterCheck;
+ using ObjectRefTypeTraitsBase<Tuple<Types...>>::CopyFromAnyViewAfterCheck;
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIArray) {
@@ -196,7 +196,7 @@ struct TypeTraits<Tuple<Types...>> : public
ObjectRefTypeTraitsBase<Tuple<Types.
static TVM_FFI_INLINE std::string GetMismatchTypeInfoHelper(const Any* arr) {
if constexpr (!std::is_same_v<T, Any>) {
const Any& any_v = arr[I];
- if (!details::AnyUnsafe::CheckAnyStorage<T>(any_v) &&
!(any_v.as<T>().has_value())) {
+ if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v) &&
!(any_v.try_cast<T>().has_value())) {
// now report the accurate mismatch information
return "Array[index " + std::to_string(I) + ": " +
details::AnyUnsafe::GetMismatchTypeInfo<T>(any_v) + "]";
@@ -209,39 +209,38 @@ struct TypeTraits<Tuple<Types...>> : public
ObjectRefTypeTraitsBase<Tuple<Types.
TVM_FFI_UNREACHABLE();
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIArray) return false;
const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
if (n->size() != sizeof...(Types)) return false;
const TVMFFIAny* ffi_any_arr = reinterpret_cast<const
TVMFFIAny*>(n->begin());
- return CheckAnyStorageHelper<0, Types...>(ffi_any_arr);
+ return CheckAnyStrictHelper<0, Types...>(ffi_any_arr);
}
template <size_t I, typename T, typename... Rest>
- static TVM_FFI_INLINE bool CheckAnyStorageHelper(const TVMFFIAny* src_arr) {
+ static TVM_FFI_INLINE bool CheckAnyStrictHelper(const TVMFFIAny* src_arr) {
if constexpr (!std::is_same_v<T, Any>) {
- if (!TypeTraits<T>::CheckAnyStorage(src_arr + I)) {
+ if (!TypeTraits<T>::CheckAnyStrict(src_arr + I)) {
return false;
}
}
if constexpr (sizeof...(Rest) > 0) {
- return CheckAnyStorageHelper<I + 1, Rest...>(src_arr);
+ return CheckAnyStrictHelper<I + 1, Rest...>(src_arr);
}
return true;
}
- static TVM_FFI_INLINE std::optional<Tuple<Types...>> TryConvertFromAnyView(
- const TVMFFIAny* src //
+ static TVM_FFI_INLINE std::optional<Tuple<Types...>>
TryCastFromAnyView(const TVMFFIAny* src //
) {
if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt;
const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
if (n->size() != sizeof...(Types)) return std::nullopt;
// fast path, storage is already in the right type
- if (CheckAnyStorage(src)) {
- return CopyFromAnyStorageAfterCheck(src);
+ if (CheckAnyStrict(src)) {
+ return CopyFromAnyViewAfterCheck(src);
}
// slow path, try to convert to each type to match the tuple storage need.
- Array<Any> arr = TypeTraits<Array<Any>>::CopyFromAnyStorageAfterCheck(src);
+ Array<Any> arr = TypeTraits<Array<Any>>::CopyFromAnyViewAfterCheck(src);
Any* ptr = arr.CopyOnWrite()->MutableBegin();
if (TryConvertElements<0, Types...>(ptr)) {
return
Tuple<Types...>(details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(arr));
@@ -252,7 +251,7 @@ struct TypeTraits<Tuple<Types...>> : public
ObjectRefTypeTraitsBase<Tuple<Types.
template <size_t I, typename T, typename... Rest>
static TVM_FFI_INLINE bool TryConvertElements(Any* arr) {
if constexpr (!std::is_same_v<T, Any>) {
- if (auto opt_convert = arr[I].as<T>()) {
+ if (auto opt_convert = arr[I].try_cast<T>()) {
arr[I] = *std::move(opt_convert);
} else {
return false;
diff --git a/ffi/include/tvm/ffi/container/variant.h
b/ffi/include/tvm/ffi/container/variant.h
index c2b0688900..373d0aaa70 100644
--- a/ffi/include/tvm/ffi/container/variant.h
+++ b/ffi/include/tvm/ffi/container/variant.h
@@ -70,7 +70,7 @@ class VariantBase<true> : public ObjectRef {
explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {}
explicit VariantBase(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
explicit VariantBase(Any other)
- :
ObjectRef(details::AnyUnsafe::MoveFromAnyStorageAfterCheck<ObjectRef>(std::move(other)))
{}
+ :
ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck<ObjectRef>(std::move(other)))
{}
TVM_FFI_INLINE void SetData(ObjectPtr<Object> other) { data_ =
std::move(other); }
@@ -201,22 +201,22 @@ struct TypeTraits<Variant<V...>> : public TypeTraitsBase {
return TypeTraitsBase::GetMismatchTypeInfo(src);
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
- return (TypeTraits<V>::CheckAnyStorage(src) || ...);
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
+ return (TypeTraits<V>::CheckAnyStrict(src) || ...);
}
- static TVM_FFI_INLINE Variant<V...> CopyFromAnyStorageAfterCheck(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE Variant<V...> CopyFromAnyViewAfterCheck(const
TVMFFIAny* src) {
return Variant<V...>(Any(AnyView::CopyFromTVMFFIAny(*src)));
}
- static TVM_FFI_INLINE Variant<V...> MoveFromAnyStorageAfterCheck(TVMFFIAny*
src) {
+ static TVM_FFI_INLINE Variant<V...> MoveFromAnyAfterCheck(TVMFFIAny* src) {
return
Variant<V...>(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src)));
}
- static TVM_FFI_INLINE std::optional<Variant<V...>>
TryConvertFromAnyView(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<Variant<V...>> TryCastFromAnyView(const
TVMFFIAny* src) {
// fast path, storage is already in the right type
- if (CheckAnyStorage(src)) {
- return CopyFromAnyStorageAfterCheck(src);
+ if (CheckAnyStrict(src)) {
+ return CopyFromAnyViewAfterCheck(src);
}
// More expensive path, try to convert to each type, in order of
declaration
return TryVariantTypes<V...>(src);
@@ -224,7 +224,7 @@ struct TypeTraits<Variant<V...>> : public TypeTraitsBase {
template <typename VariantType, typename... Rest>
static TVM_FFI_INLINE std::optional<Variant<V...>> TryVariantTypes(const
TVMFFIAny* src) {
- if (auto opt_convert =
TypeTraits<VariantType>::TryConvertFromAnyView(src)) {
+ if (auto opt_convert = TypeTraits<VariantType>::TryCastFromAnyView(src)) {
return Variant<V...>(*std::move(opt_convert));
}
if constexpr (sizeof...(Rest) > 0) {
diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h
index a1a6b58afa..954c77fdfa 100644
--- a/ffi/include/tvm/ffi/dtype.h
+++ b/ffi/include/tvm/ffi/dtype.h
@@ -140,20 +140,20 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase {
result->v_dtype = src;
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
return src->type_index == TypeIndex::kTVMFFIDataType;
}
- static TVM_FFI_INLINE DLDataType CopyFromAnyStorageAfterCheck(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
return src->v_dtype;
}
- static TVM_FFI_INLINE std::optional<DLDataType> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<DLDataType> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIDataType) {
return src->v_dtype;
}
// enable string to dtype auto conversion
- if (auto opt_str = TypeTraits<std::string>::TryConvertFromAnyView(src)) {
+ if (auto opt_str = TypeTraits<std::string>::TryCastFromAnyView(src)) {
return StringToDLDataType(*opt_str);
}
return std::nullopt;
diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h
index 2bf6b960d7..5162df6d83 100644
--- a/ffi/include/tvm/ffi/function.h
+++ b/ffi/include/tvm/ffi/function.h
@@ -735,17 +735,17 @@ struct TypeTraits<TypedFunction<FType>> : public
TypeTraitsBase {
TypeTraits<Function>::MoveToAny(std::move(src.packed()), result);
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
return src->type_index == TypeIndex::kTVMFFIFunction;
}
- static TVM_FFI_INLINE TypedFunction<FType>
CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) {
- return
TypedFunction<FType>(TypeTraits<Function>::CopyFromAnyStorageAfterCheck(src));
+ static TVM_FFI_INLINE TypedFunction<FType> CopyFromAnyViewAfterCheck(const
TVMFFIAny* src) {
+ return
TypedFunction<FType>(TypeTraits<Function>::CopyFromAnyViewAfterCheck(src));
}
- static TVM_FFI_INLINE std::optional<TypedFunction<FType>>
TryConvertFromAnyView(
+ static TVM_FFI_INLINE std::optional<TypedFunction<FType>> TryCastFromAnyView(
const TVMFFIAny* src) {
- std::optional<Function> opt =
TypeTraits<Function>::TryConvertFromAnyView(src);
+ std::optional<Function> opt =
TypeTraits<Function>::TryCastFromAnyView(src);
if (opt.has_value()) {
return TypedFunction<FType>(*std::move(opt));
} else {
diff --git a/ffi/include/tvm/ffi/function_details.h
b/ffi/include/tvm/ffi/function_details.h
index 3e7f9be140..6391c4ebba 100644
--- a/ffi/include/tvm/ffi/function_details.h
+++ b/ffi/include/tvm/ffi/function_details.h
@@ -137,7 +137,7 @@ class ArgValueWithContext {
} else if constexpr (std::is_same_v<TypeWithoutCR, Any>) {
return Any(args_[arg_index_]);
} else {
- std::optional<TypeWithoutCR> opt = args_[arg_index_].as<TypeWithoutCR>();
+ std::optional<TypeWithoutCR> opt =
args_[arg_index_].try_cast<TypeWithoutCR>();
if (!opt.has_value()) {
TVMFFIAny any_data = args_[arg_index_].CopyToTVMFFIAny();
TVM_FFI_THROW(TypeError) << "Mismatched type on argument #" <<
arg_index_
diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h
index b300436fee..3b939d2df5 100644
--- a/ffi/include/tvm/ffi/rvalue_ref.h
+++ b/ffi/include/tvm/ffi/rvalue_ref.h
@@ -114,8 +114,7 @@ struct TypeTraits<RValueRef<TObjRef>> : public
TypeTraitsBase {
}
}
- static TVM_FFI_INLINE std::optional<RValueRef<TObjRef>>
TryConvertFromAnyView(
- const TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<RValueRef<TObjRef>>
TryCastFromAnyView(const TVMFFIAny* src) {
// first try rvalue conversion
if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) {
ObjectPtr<Object>* rvalue_ref =
reinterpret_cast<ObjectPtr<Object>*>(src->v_ptr);
@@ -123,10 +122,10 @@ struct TypeTraits<RValueRef<TObjRef>> : public
TypeTraitsBase {
tmp_any.type_index = rvalue_ref->get()->type_index();
tmp_any.v_obj = reinterpret_cast<TVMFFIObject*>(rvalue_ref->get());
// fast path, storage type matches, direct move the rvalue ref
- if (TypeTraits<TObjRef>::CheckAnyStorage(&tmp_any)) {
+ if (TypeTraits<TObjRef>::CheckAnyStrict(&tmp_any)) {
return RValueRef<TObjRef>(TObjRef(std::move(*rvalue_ref)));
}
- if (std::optional<TObjRef> opt =
TypeTraits<TObjRef>::TryConvertFromAnyView(&tmp_any)) {
+ if (std::optional<TObjRef> opt =
TypeTraits<TObjRef>::TryCastFromAnyView(&tmp_any)) {
// object type does not match up, we need to try to convert the object
// in this case we do not move the original rvalue ref since
conversion creates a copy
return RValueRef<TObjRef>(*std::move(opt));
@@ -134,7 +133,7 @@ struct TypeTraits<RValueRef<TObjRef>> : public
TypeTraitsBase {
return std::nullopt;
}
// try lvalue conversion
- if (std::optional<TObjRef> opt =
TypeTraits<TObjRef>::TryConvertFromAnyView(src)) {
+ if (std::optional<TObjRef> opt =
TypeTraits<TObjRef>::TryCastFromAnyView(src)) {
return RValueRef<TObjRef>(*std::move(opt));
} else {
return std::nullopt;
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index 44491f800b..c3eceff905 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -430,8 +430,8 @@ struct TypeTraits<const char*> : public TypeTraitsBase {
// when we need to move to any, convert to owned object first
ObjectRefTypeTraitsBase<String>::MoveToAny(String(src), result);
}
- // Do not allow const char* in a container, so we do not need CheckAnyStorage
- static TVM_FFI_INLINE std::optional<const char*> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ // Do not allow const char* in a container, so we do not need CheckAnyStrict
+ static TVM_FFI_INLINE std::optional<const char*> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIRawStr) {
return static_cast<const char*>(src->v_c_str);
}
@@ -458,8 +458,7 @@ struct TypeTraits<TVMFFIByteArray*> : public TypeTraitsBase
{
ObjectRefTypeTraitsBase<Bytes>::MoveToAny(Bytes(*src), result);
}
- static TVM_FFI_INLINE std::optional<TVMFFIByteArray*> TryConvertFromAnyView(
- const TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<TVMFFIByteArray*>
TryCastFromAnyView(const TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) {
return static_cast<TVMFFIByteArray*>(src->v_ptr);
}
diff --git a/ffi/include/tvm/ffi/type_traits.h
b/ffi/include/tvm/ffi/type_traits.h
index d350aea82a..02c9a90edc 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -43,25 +43,25 @@ namespace ffi {
*
* - CopyToAnyView: Convert a value T to AnyView
* - MoveToAny: Move a value to Any
- * - CheckAnyStorage: Check if a Any stores a result of MoveToAny of current T.
- * - CopyFromAnyStorageAfterCheck: Copy a value T from Any storage after we
pass CheckAnyStorage.
- * - MoveFromAnyStorageAfterCheck: Move a value T from Any storage after we
pass CheckAnyStorage.
- * - TryConvertFromAnyView: Convert a AnyView to a T, we may apply type
conversion.
- * - GetMismatchTypeInfo: Get the type key of a type when
TryConvertFromAnyView fails.
+ * - CheckAnyStrict: Check if a Any stores a result of CopyToAnyView of
current T.
+ * - CopyFromAnyViewAfterCheck: Copy a value T from Any view after we pass
CheckAnyStrict.
+ * - MoveFromAnyAfterCheck: Move a value T from Any storage after we pass
CheckAnyStrict.
+ * - TryCastFromAnyView: Convert a AnyView to a T, we may apply type
conversion.
+ * - GetMismatchTypeInfo: Get the type key of a type when TryCastFromAnyView
fails.
* - TypeStr: Get the type key of a type
*
- * It is possible that CheckAnyStorage is false but TryConvertFromAnyView
still works.
+ * It is possible that CheckAnyStrict is false but TryCastFromAnyView still
works.
*
- * For example, when Any x stores int, TypeTraits<float>::CheckAnyStorage(x)
will be false,
- * but TypeTraits<float>::TryConvertFromAnyView(x) will return a corresponding
float value
+ * For example, when Any x stores int, TypeTraits<float>::CheckAnyStrict(x)
will be false,
+ * but TypeTraits<float>::TryCastFromAnyView(x) will return a corresponding
float value
* via type conversion.
*
- * CheckAnyStorage is mainly used in recursive container such as Array<T> to
+ * CheckAnyStrict is mainly used in recursive container such as Array<T> to
* decide if a new Array needed to be created via recursive conversion,
* or we can use the current container as is when converting to Array<T>.
*
* A container array: Array<T> satisfies the following invariant:
- * - `all(TypeTraits<T>::CheckAnyStorage(x) for x in the array)`.
+ * - `all(TypeTraits<T>::CheckAnyStrict(x) for x in the array)`.
*/
template <typename, typename = void>
struct TypeTraits {
@@ -85,7 +85,7 @@ struct TypeTraitsBase {
static constexpr bool convert_enabled = true;
static constexpr bool storage_enabled = true;
// get mismatched type when result mismatches the trait.
- // this function is called after TryConvertFromAnyView fails
+ // this function is called after TryCastFromAnyView fails
// to get more detailed type information in runtime
// especially when the error involves nested container type
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny*
source) {
@@ -132,17 +132,17 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase
{
result->v_int64 = 0;
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
return src->type_index == TypeIndex::kTVMFFINone;
}
- static TVM_FFI_INLINE std::nullptr_t CopyFromAnyStorageAfterCheck(const
TVMFFIAny*) {
+ static TVM_FFI_INLINE std::nullptr_t CopyFromAnyViewAfterCheck(const
TVMFFIAny*) {
return nullptr;
}
- static TVM_FFI_INLINE std::nullptr_t
MoveFromAnyStorageAfterCheck(TVMFFIAny*) { return nullptr; }
+ static TVM_FFI_INLINE std::nullptr_t MoveFromAnyAfterCheck(TVMFFIAny*) {
return nullptr; }
- static TVM_FFI_INLINE std::optional<std::nullptr_t>
TryConvertFromAnyView(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<std::nullptr_t> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFINone) {
return nullptr;
}
@@ -179,20 +179,20 @@ struct TypeTraits<StrictBool> : public TypeTraitsBase {
CopyToAnyView(src, result);
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
return src->type_index == TypeIndex::kTVMFFIBool;
}
- static TVM_FFI_INLINE StrictBool CopyFromAnyStorageAfterCheck(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE StrictBool CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
return static_cast<bool>(src->v_int64);
}
- static TVM_FFI_INLINE StrictBool MoveFromAnyStorageAfterCheck(TVMFFIAny*
src) {
+ static TVM_FFI_INLINE StrictBool MoveFromAnyAfterCheck(TVMFFIAny* src) {
// POD type, we can just copy the value
- return CopyFromAnyStorageAfterCheck(src);
+ return CopyFromAnyViewAfterCheck(src);
}
- static TVM_FFI_INLINE std::optional<StrictBool> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<StrictBool> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIBool) {
return StrictBool(static_cast<bool>(src->v_int64));
}
@@ -214,20 +214,20 @@ struct TypeTraits<bool> : public TypeTraitsBase {
static TVM_FFI_INLINE void MoveToAny(bool src, TVMFFIAny* result) {
CopyToAnyView(src, result); }
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
return src->type_index == TypeIndex::kTVMFFIBool;
}
- static TVM_FFI_INLINE bool CopyFromAnyStorageAfterCheck(const TVMFFIAny*
src) {
+ static TVM_FFI_INLINE bool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
return static_cast<bool>(src->v_int64);
}
- static TVM_FFI_INLINE bool MoveFromAnyStorageAfterCheck(TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool MoveFromAnyAfterCheck(TVMFFIAny* src) {
// POD type, we can just copy the value
- return CopyFromAnyStorageAfterCheck(src);
+ return CopyFromAnyViewAfterCheck(src);
}
- static TVM_FFI_INLINE std::optional<bool> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<bool> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index ==
TypeIndex::kTVMFFIBool) {
return static_cast<bool>(src->v_int64);
}
@@ -249,21 +249,21 @@ struct TypeTraits<Int,
std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
static TVM_FFI_INLINE void MoveToAny(Int src, TVMFFIAny* result) {
CopyToAnyView(src, result); }
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
- // NOTE: CheckAnyStorage is always strict and should be consistent with
MoveToAny
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
+ // NOTE: CheckAnyStrict is always strict and should be consistent with
MoveToAny
return src->type_index == TypeIndex::kTVMFFIInt;
}
- static TVM_FFI_INLINE Int CopyFromAnyStorageAfterCheck(const TVMFFIAny* src)
{
+ static TVM_FFI_INLINE Int CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
return static_cast<Int>(src->v_int64);
}
- static TVM_FFI_INLINE Int MoveFromAnyStorageAfterCheck(TVMFFIAny* src) {
+ static TVM_FFI_INLINE Int MoveFromAnyAfterCheck(TVMFFIAny* src) {
// POD type, we can just copy the value
- return CopyFromAnyStorageAfterCheck(src);
+ return CopyFromAnyViewAfterCheck(src);
}
- static TVM_FFI_INLINE std::optional<Int> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<Int> TryCastFromAnyView(const TVMFFIAny*
src) {
if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index ==
TypeIndex::kTVMFFIBool) {
return Int(src->v_int64);
}
@@ -286,21 +286,21 @@ struct TypeTraits<Float,
std::enable_if_t<std::is_floating_point_v<Float>>>
static TVM_FFI_INLINE void MoveToAny(Float src, TVMFFIAny* result) {
CopyToAnyView(src, result); }
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
- // NOTE: CheckAnyStorage is always strict and should be consistent with
MoveToAny
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
+ // NOTE: CheckAnyStrict is always strict and should be consistent with
MoveToAny
return src->type_index == TypeIndex::kTVMFFIFloat;
}
- static TVM_FFI_INLINE Float CopyFromAnyStorageAfterCheck(const TVMFFIAny*
src) {
+ static TVM_FFI_INLINE Float CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
return static_cast<Float>(src->v_float64);
}
- static TVM_FFI_INLINE Float MoveFromAnyStorageAfterCheck(TVMFFIAny* src) {
+ static TVM_FFI_INLINE Float MoveFromAnyAfterCheck(TVMFFIAny* src) {
// POD type, we can just copy the value
- return CopyFromAnyStorageAfterCheck(src);
+ return CopyFromAnyViewAfterCheck(src);
}
- static TVM_FFI_INLINE std::optional<Float> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<Float> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIFloat) {
return Float(src->v_float64);
} else if (src->type_index == TypeIndex::kTVMFFIInt ||
@@ -326,21 +326,19 @@ struct TypeTraits<void*> : public TypeTraitsBase {
static TVM_FFI_INLINE void MoveToAny(void* src, TVMFFIAny* result) {
CopyToAnyView(src, result); }
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
- // NOTE: CheckAnyStorage is always strict and should be consistent with
MoveToAny
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
+ // NOTE: CheckAnyStrict is always strict and should be consistent with
MoveToAny
return src->type_index == TypeIndex::kTVMFFIOpaquePtr;
}
- static TVM_FFI_INLINE void* CopyFromAnyStorageAfterCheck(const TVMFFIAny*
src) {
- return src->v_ptr;
- }
+ static TVM_FFI_INLINE void* CopyFromAnyViewAfterCheck(const TVMFFIAny* src)
{ return src->v_ptr; }
- static TVM_FFI_INLINE void* MoveFromAnyStorageAfterCheck(TVMFFIAny* src) {
+ static TVM_FFI_INLINE void* MoveFromAnyAfterCheck(TVMFFIAny* src) {
// POD type, we can just copy the value
- return CopyFromAnyStorageAfterCheck(src);
+ return CopyFromAnyViewAfterCheck(src);
}
- static TVM_FFI_INLINE std::optional<void*> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<void*> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) {
return static_cast<void*>(src->v_ptr);
}
@@ -368,20 +366,20 @@ struct TypeTraits<DLDevice> : public TypeTraitsBase {
result->v_device = src;
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
return src->type_index == TypeIndex::kTVMFFIDevice;
}
- static TVM_FFI_INLINE DLDevice CopyFromAnyStorageAfterCheck(const TVMFFIAny*
src) {
+ static TVM_FFI_INLINE DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
return src->v_device;
}
- static TVM_FFI_INLINE DLDevice MoveFromAnyStorageAfterCheck(TVMFFIAny* src) {
+ static TVM_FFI_INLINE DLDevice MoveFromAnyAfterCheck(TVMFFIAny* src) {
// POD type, we can just copy the value
- return CopyFromAnyStorageAfterCheck(src);
+ return CopyFromAnyViewAfterCheck(src);
}
- static TVM_FFI_INLINE std::optional<DLDevice> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<DLDevice> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIDevice) {
return src->v_device;
}
@@ -404,12 +402,20 @@ struct TypeTraits<DLTensor*> : public TypeTraitsBase {
result->v_ptr = src;
}
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
+ return src->type_index == TypeIndex::kTVMFFIDLTensorPtr;
+ }
+
+ static TVM_FFI_INLINE DLTensor* CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
+ return static_cast<DLTensor*>(src->v_ptr);
+ }
+
static TVM_FFI_INLINE void MoveToAny(DLTensor*, TVMFFIAny*) {
TVM_FFI_THROW(RuntimeError)
<< "DLTensor* cannot be held in Any as it does not retain ownership,
use NDArray instead";
}
- static TVM_FFI_INLINE std::optional<DLTensor*> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<DLTensor*> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) {
return static_cast<DLTensor*>(src->v_ptr);
} else if (src->type_index == TypeIndex::kTVMFFINDArray) {
@@ -458,7 +464,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase {
result->v_obj = obj_ptr;
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
if constexpr (TObjRef::_type_is_nullable) {
if (src->type_index == TypeIndex::kTVMFFINone) return true;
}
@@ -466,7 +472,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase {
details::IsObjectInstance<ContainerType>(src->type_index));
}
- static TVM_FFI_INLINE TObjRef CopyFromAnyStorageAfterCheck(const TVMFFIAny*
src) {
+ static TVM_FFI_INLINE TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
if constexpr (TObjRef::_type_is_nullable) {
if (src->type_index == TypeIndex::kTVMFFINone) {
return TObjRef(ObjectPtr<Object>(nullptr));
@@ -475,7 +481,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase {
return
TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj));
}
- static TVM_FFI_INLINE TObjRef MoveFromAnyStorageAfterCheck(TVMFFIAny* src) {
+ static TVM_FFI_INLINE TObjRef MoveFromAnyAfterCheck(TVMFFIAny* src) {
if constexpr (TObjRef::_type_is_nullable) {
if (src->type_index == TypeIndex::kTVMFFINone) {
return TObjRef(ObjectPtr<Object>(nullptr));
@@ -488,7 +494,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase {
return TObjRef(std::move(obj_ptr));
}
- static TVM_FFI_INLINE std::optional<TObjRef> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<TObjRef> TryCastFromAnyView(const
TVMFFIAny* src) {
if constexpr (TObjRef::_type_is_nullable) {
if (src->type_index == TypeIndex::kTVMFFINone) {
return TObjRef(ObjectPtr<Object>(nullptr));
@@ -525,7 +531,7 @@ struct FallbackOnlyTraitsBase : public TypeTraitsBase {
// disable container for FallbackOnlyTraitsBase
static constexpr bool storage_enabled = false;
- static TVM_FFI_INLINE std::optional<T> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<T> TryCastFromAnyView(const TVMFFIAny*
src) {
return TryFallbackTypes<FallbackTypes...>(src);
}
@@ -534,7 +540,7 @@ struct FallbackOnlyTraitsBase : public TypeTraitsBase {
static_assert(!std::is_same_v<bool, FallbackType>,
"Using bool as FallbackType can cause bug because int will
be detected as bool, "
"use tvm::ffi::StrictBool instead");
- if (auto opt_fallback =
TypeTraits<FallbackType>::TryConvertFromAnyView(src)) {
+ if (auto opt_fallback = TypeTraits<FallbackType>::TryCastFromAnyView(src))
{
return TypeTraits<T>::ConvertFallbackValue(*std::move(opt_fallback));
}
if constexpr (sizeof...(Rest) > 0) {
@@ -557,11 +563,11 @@ struct FallbackOnlyTraitsBase : public TypeTraitsBase {
*/
template <typename ObjectRefType, typename... FallbackTypes>
struct ObjectRefWithFallbackTraitsBase : public
ObjectRefTypeTraitsBase<ObjectRefType> {
- static TVM_FFI_INLINE std::optional<ObjectRefType>
TryConvertFromAnyView(const TVMFFIAny* src) {
- if (auto opt_obj =
ObjectRefTypeTraitsBase<ObjectRefType>::TryConvertFromAnyView(src)) {
+ static TVM_FFI_INLINE std::optional<ObjectRefType> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (auto opt_obj =
ObjectRefTypeTraitsBase<ObjectRefType>::TryCastFromAnyView(src)) {
return opt_obj.value();
}
- // apply fallback types in TryConvertFromAnyView
+ // apply fallback types in TryCastFromAnyView
return TryFallbackTypes<FallbackTypes...>(src);
}
@@ -570,7 +576,7 @@ struct ObjectRefWithFallbackTraitsBase : public
ObjectRefTypeTraitsBase<ObjectRe
static_assert(!std::is_same_v<bool, FallbackType>,
"Using bool as FallbackType can cause bug because int will
be detected as bool, "
"use tvm::ffi::StrictBool instead");
- if (auto opt_fallback =
TypeTraits<FallbackType>::TryConvertFromAnyView(src)) {
+ if (auto opt_fallback = TypeTraits<FallbackType>::TryCastFromAnyView(src))
{
return
TypeTraits<ObjectRefType>::ConvertFallbackValue(*std::move(opt_fallback));
}
if constexpr (sizeof...(Rest) > 0) {
@@ -601,17 +607,17 @@ struct TypeTraits<const TObject*,
std::enable_if_t<std::is_base_of_v<Object, TOb
details::ObjectUnsafe::IncRefObjectHandle(result->v_obj);
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
return src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin &&
details::IsObjectInstance<TObject>(src->type_index);
}
- static TVM_FFI_INLINE const TObject* CopyFromAnyStorageAfterCheck(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE const TObject* CopyFromAnyViewAfterCheck(const
TVMFFIAny* src) {
return details::ObjectUnsafe::RawObjectPtrFromUnowned<TObject>(src->v_obj);
}
- static TVM_FFI_INLINE std::optional<const TObject*>
TryConvertFromAnyView(const TVMFFIAny* src) {
- if (CheckAnyStorage(src)) return CopyFromAnyStorageAfterCheck(src);
+ static TVM_FFI_INLINE std::optional<const TObject*> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (CheckAnyStrict(src)) return CopyFromAnyViewAfterCheck(src);
return std::nullopt;
}
@@ -639,28 +645,28 @@ struct TypeTraits<Optional<T>> : public TypeTraitsBase {
}
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFINone) return true;
- return TypeTraits<T>::CheckAnyStorage(src);
+ return TypeTraits<T>::CheckAnyStrict(src);
}
- static TVM_FFI_INLINE Optional<T> CopyFromAnyStorageAfterCheck(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE Optional<T> CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
if (src->type_index == TypeIndex::kTVMFFINone) {
return Optional<T>(std::nullopt);
}
- return TypeTraits<T>::CopyFromAnyStorageAfterCheck(src);
+ return TypeTraits<T>::CopyFromAnyViewAfterCheck(src);
}
- static TVM_FFI_INLINE Optional<T> MoveFromAnyStorageAfterCheck(TVMFFIAny*
src) {
+ static TVM_FFI_INLINE Optional<T> MoveFromAnyAfterCheck(TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFINone) {
return Optional<T>(std::nullopt);
}
- return TypeTraits<T>::MoveFromAnyStorageAfterCheck(src);
+ return TypeTraits<T>::MoveFromAnyAfterCheck(src);
}
- static TVM_FFI_INLINE std::optional<Optional<T>> TryConvertFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<Optional<T>> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFINone) return
Optional<T>(std::nullopt);
- if (std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(src)) {
+ if (std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(src)) {
return Optional<T>(*std::move(opt));
} else {
// important to be explicit here
diff --git a/ffi/src/ffi/ndarray.cc b/ffi/src/ffi/ndarray.cc
index d4c1470566..f3c48c8ad5 100644
--- a/ffi/src/ffi/ndarray.cc
+++ b/ffi/src/ffi/ndarray.cc
@@ -32,7 +32,7 @@
TVM_FFI_REGISTER_GLOBAL("ffi.Shape").set_body_packed([](ffi::PackedArgs args, An
int64_t* mutable_data;
ObjectPtr<ShapeObj> shape = details::MakeEmptyShape(args.size(),
&mutable_data);
for (int i = 0; i < args.size(); ++i) {
- if (auto opt_int = args[i].as<int64_t>()) {
+ if (auto opt_int = args[i].try_cast<int64_t>()) {
mutable_data[i] = *opt_int;
} else {
TVM_FFI_THROW(ValueError) << "Expect shape to take list of int
arguments";
diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc
index d84cc64ae4..3ad81cd118 100644
--- a/ffi/tests/cpp/test_any.cc
+++ b/ffi/tests/cpp/test_any.cc
@@ -332,6 +332,39 @@ TEST(Any, ObjectRefWithFallbackTraits) {
EXPECT_EQ(v9->value, 0);
}
+TEST(Any, CastVsAs) {
+ AnyView view0 = 1;
+ // as only runs strict check
+ auto opt_v0 = view0.as<int64_t>();
+ EXPECT_TRUE(opt_v0.has_value());
+ EXPECT_EQ(opt_v0.value(), 1);
+
+ auto opt_v1 = view0.as<bool>();
+ EXPECT_TRUE(!opt_v1.has_value());
+ auto opt_v2 = view0.as<double>();
+ EXPECT_TRUE(!opt_v2.has_value());
+
+ // try_cast will try run the conversion.
+ auto opt_v3 = view0.try_cast<bool>();
+ EXPECT_TRUE(opt_v3.has_value());
+ EXPECT_EQ(opt_v3.value(), 1);
+ auto opt_v4 = view0.try_cast<double>();
+ EXPECT_TRUE(opt_v4.has_value());
+ EXPECT_EQ(opt_v4.value(), 1);
+
+ Any any1 = true;
+ auto opt_v5 = any1.as<bool>();
+ EXPECT_TRUE(opt_v5.has_value());
+ EXPECT_EQ(opt_v5.value(), 1);
+
+ auto opt_v6 = any1.try_cast<int>();
+ EXPECT_TRUE(opt_v6.has_value());
+ EXPECT_EQ(opt_v6.value(), 1);
+
+ auto opt_v7 = any1.try_cast<double>();
+ EXPECT_TRUE(opt_v7.has_value());
+}
+
TEST(Any, ObjectMove) {
Any any1 = TPrimExpr("float32", 3.14);
auto v0 = std::move(any1).cast<TPrimExpr>();
diff --git a/ffi/tests/cpp/test_dtype.cc b/ffi/tests/cpp/test_dtype.cc
index e31df8761d..620f729a66 100644
--- a/ffi/tests/cpp/test_dtype.cc
+++ b/ffi/tests/cpp/test_dtype.cc
@@ -114,14 +114,14 @@ TEST(DataType, AnyConversion) {
TEST(DataType, AnyConversionWithString) {
AnyView view0 = "float32";
- Optional<DLDataType> opt_v0 = view0.as<DLDataType>();
+ Optional<DLDataType> opt_v0 = view0.try_cast<DLDataType>();
DLDataType dtype_v0 = opt_v0.value();
EXPECT_EQ(dtype_v0.code, kDLFloat);
EXPECT_EQ(dtype_v0.bits, 32);
EXPECT_EQ(dtype_v0.lanes, 1);
Any any = String("bfloat16x2");
- Optional<DLDataType> opt_v1 = any.as<DLDataType>();
+ Optional<DLDataType> opt_v1 = any.try_cast<DLDataType>();
EXPECT_EQ(opt_v1.value().code, kDLBfloat);
EXPECT_EQ(opt_v1.value().bits, 16);
EXPECT_EQ(opt_v1.value().lanes, 2);
diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc
index 847ed6f955..a74102a953 100644
--- a/ffi/tests/cpp/test_string.cc
+++ b/ffi/tests/cpp/test_string.cc
@@ -275,22 +275,23 @@ TEST(String, Any) {
Any b = view;
EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr);
EXPECT_EQ(b.as<String>().value(), "hello");
- EXPECT_EQ(b.as<std::string>().value(), "hello");
+ EXPECT_TRUE(b.as<String>().has_value());
+ EXPECT_EQ(b.try_cast<std::string>().value(), "hello");
std::string s_world = "world";
view = s_world;
- EXPECT_EQ(view.as<std::string>().value(), "world");
+ EXPECT_EQ(view.try_cast<std::string>().value(), "world");
String s{"hello"};
Any a = s;
EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFIStr);
EXPECT_EQ(a.as<String>().value(), "hello");
- EXPECT_EQ(a.as<std::string>().value(), "hello");
+ EXPECT_EQ(a.try_cast<std::string>().value(), "hello");
Any c = "helloworld";
EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFIStr);
EXPECT_EQ(c.as<String>().value(), "helloworld");
- EXPECT_EQ(c.as<std::string>().value(), "helloworld");
+ EXPECT_EQ(c.try_cast<std::string>().value(), "helloworld");
}
TEST(String, Bytes) {
@@ -312,52 +313,52 @@ TEST(String, BytesAny) {
AnyView view = &arr;
EXPECT_EQ(view.type_index(), TypeIndex::kTVMFFIByteArrayPtr);
- EXPECT_EQ(view.as<Bytes>().value().operator std::string(), s);
+ EXPECT_EQ(view.try_cast<Bytes>().value().operator std::string(), s);
Any b = view;
EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIBytes);
- EXPECT_EQ(b.as<Bytes>().value().operator std::string(), s);
- EXPECT_EQ(b.as<std::string>().value(), s);
+ EXPECT_EQ(b.try_cast<Bytes>().value().operator std::string(), s);
+ EXPECT_EQ(b.cast<std::string>(), s);
}
TEST(String, StdString) {
std::string s1 = "test_string";
AnyView view1 = s1;
EXPECT_EQ(view1.type_index(), TypeIndex::kTVMFFIRawStr);
- EXPECT_EQ(view1.as<std::string>().value(), s1);
+ EXPECT_EQ(view1.try_cast<std::string>().value(), s1);
TVMFFIByteArray arr1{s1.data(), static_cast<size_t>(s1.size())};
AnyView view2 = &arr1;
EXPECT_EQ(view2.type_index(), TypeIndex::kTVMFFIByteArrayPtr);
- EXPECT_EQ(view2.as<std::string>().value(), s1);
+ EXPECT_EQ(view2.try_cast<std::string>().value(), s1);
Bytes bytes1 = s1;
AnyView view3 = bytes1;
EXPECT_EQ(view3.type_index(), TypeIndex::kTVMFFIBytes);
- EXPECT_EQ(view3.as<std::string>().value(), s1);
+ EXPECT_EQ(view3.try_cast<std::string>().value(), s1);
String string1 = s1;
AnyView view4 = string1;
EXPECT_EQ(view4.type_index(), TypeIndex::kTVMFFIStr);
- EXPECT_EQ(view4.as<std::string>().value(), s1);
+ EXPECT_EQ(view4.try_cast<std::string>().value(), s1);
// Test with Any
Any any1 = s1;
EXPECT_EQ(any1.type_index(), TypeIndex::kTVMFFIStr);
- EXPECT_EQ(any1.as<std::string>().value(), s1);
+ EXPECT_EQ(any1.try_cast<std::string>().value(), s1);
Any any2 = &arr1;
EXPECT_EQ(any2.type_index(), TypeIndex::kTVMFFIBytes);
- EXPECT_EQ(any2.as<std::string>().value(), s1);
+ EXPECT_EQ(any2.try_cast<std::string>().value(), s1);
Any any3 = bytes1;
EXPECT_EQ(any3.type_index(), TypeIndex::kTVMFFIBytes);
- EXPECT_EQ(any3.as<std::string>().value(), s1);
+ EXPECT_EQ(any3.try_cast<std::string>().value(), s1);
Any any4 = string1;
EXPECT_EQ(any4.type_index(), TypeIndex::kTVMFFIStr);
- EXPECT_EQ(any4.as<std::string>().value(), s1);
+ EXPECT_EQ(any4.try_cast<std::string>().value(), s1);
}
TEST(String, CAPIAccessor) {
diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc
index 17a1129087..451913c992 100644
--- a/ffi/tests/cpp/test_variant.cc
+++ b/ffi/tests/cpp/test_variant.cc
@@ -32,7 +32,6 @@ using namespace tvm::ffi::testing;
TEST(Variant, Basic) {
Variant<int, float> v1 = 1;
EXPECT_EQ(v1.get<int>(), 1);
- EXPECT_EQ(v1.as<float>().value(), 1.0f);
Variant<int, float> v2 = 2.0f;
EXPECT_EQ(v2.get<float>(), 2.0f);
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index 510c5325f3..dc8e99fceb 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -604,7 +604,7 @@ inline void SetValue(T* ptr, const ffi::AnyView& val) {
template <typename T>
inline void SetIntValue(T* ptr, const ffi::AnyView& val) {
- if (auto opt_int = val.as<int64_t>()) {
+ if (auto opt_int = val.try_cast<int64_t>()) {
*ptr = static_cast<T>(opt_int.value());
} else {
IntImm expr = val.cast<IntImm>();
@@ -620,16 +620,12 @@ inline void SetValue<DataType>(DataType* ptr, const
ffi::AnyView& val) {
template <>
inline void SetValue<std::string>(std::string* ptr, const ffi::AnyView& val) {
- if (auto opt_str = val.as<std::string>()) {
- *ptr = opt_str.value();
- } else {
- LOG(FATAL) << "Expect str";
- }
+ *ptr = val.cast<std::string>();
}
template <>
inline void SetValue<double>(double* ptr, const ffi::AnyView& val) {
- if (auto opt_double = val.as<double>()) {
+ if (auto opt_double = val.try_cast<double>()) {
*ptr = opt_double.value();
} else {
ObjectRef expr = val.cast<ObjectRef>();
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 0cc9383bac..8562fbaa8f 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -273,10 +273,10 @@ class PassContext : public ObjectRef {
auto type_key = ffi::TypeIndexToTypeKey(tindex);
auto legalization = [=](ffi::Any value) -> ffi::Any {
- if (auto opt_map = value.as<Map<String, ffi::Any>>()) {
+ if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
return reflection->CreateObject(type_key, opt_map.value());
} else {
- auto opt_val = value.as<ValueType>();
+ auto opt_val = value.try_cast<ValueType>();
if (!opt_val.has_value()) {
TVM_FFI_THROW(AttributeError)
<< "Expect config " << key << " to have type " << type_key << ",
but instead get "
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 9418a0c902..b5b6ad19d2 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -375,21 +375,20 @@ struct TypeTraits<runtime::DataType> : public
TypeTraitsBase {
result->v_dtype = src;
}
- static TVM_FFI_INLINE std::optional<runtime::DataType> TryConvertFromAnyView(
- const TVMFFIAny* src) {
- auto opt_dtype = TypeTraits<DLDataType>::TryConvertFromAnyView(src);
+ static TVM_FFI_INLINE std::optional<runtime::DataType>
TryCastFromAnyView(const TVMFFIAny* src) {
+ auto opt_dtype = TypeTraits<DLDataType>::TryCastFromAnyView(src);
if (opt_dtype) {
return runtime::DataType(opt_dtype.value());
}
return std::nullopt;
}
- static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) {
- return TypeTraits<DLDataType>::CheckAnyStorage(src);
+ static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
+ return TypeTraits<DLDataType>::CheckAnyStrict(src);
}
- static TVM_FFI_INLINE runtime::DataType CopyFromAnyStorageAfterCheck(const
TVMFFIAny* src) {
- return
runtime::DataType(TypeTraits<DLDataType>::CopyFromAnyStorageAfterCheck(src));
+ static TVM_FFI_INLINE runtime::DataType CopyFromAnyViewAfterCheck(const
TVMFFIAny* src) {
+ return
runtime::DataType(TypeTraits<DLDataType>::CopyFromAnyViewAfterCheck(src));
}
static TVM_FFI_INLINE std::string TypeStr() { return
ffi::StaticTypeKey::kTVMFFIDataType; }
diff --git a/include/tvm/topi/utils.h b/include/tvm/topi/utils.h
index 368674c0b6..23ac27d134 100644
--- a/include/tvm/topi/utils.h
+++ b/include/tvm/topi/utils.h
@@ -37,7 +37,7 @@ inline Optional<Array<Integer>> ArrayOrInt(AnyView arg) {
if (arg == nullptr) {
return std::nullopt;
}
- if (auto opt_int = arg.as<int>()) {
+ if (auto opt_int = arg.try_cast<int>()) {
Array<Integer> result;
result.push_back(opt_int.value());
return result;
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 10c7676282..8dcb67190a 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -320,7 +320,7 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
});
} else if (name == "bind") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
- if (auto opt_range = args[1].as<Range>()) {
+ if (auto opt_range = args[1].try_cast<Range>()) {
self->Bind(args[0].cast<Var>(), opt_range.value());
} else {
self->Bind(args[0].cast<Var>(), args[1].cast<PrimExpr>());
diff --git a/src/meta_schedule/database/database_utils.cc
b/src/meta_schedule/database/database_utils.cc
index 75253a3680..1f39688272 100644
--- a/src/meta_schedule/database/database_utils.cc
+++ b/src/meta_schedule/database/database_utils.cc
@@ -29,7 +29,7 @@ namespace meta_schedule {
void JSONDumps(Any json_obj, std::ostringstream& os) {
if (json_obj == nullptr) {
os << "null";
- } else if (auto opt_int_imm = json_obj.as<IntImm>()) {
+ } else if (auto opt_int_imm = json_obj.try_cast<IntImm>()) {
IntImm int_imm = *std::move(opt_int_imm);
if (int_imm->dtype == DataType::Bool()) {
if (int_imm->value) {
@@ -40,7 +40,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) {
} else {
os << int_imm->value;
}
- } else if (auto opt_float_imm = json_obj.as<FloatImm>()) {
+ } else if (auto opt_float_imm = json_obj.try_cast<FloatImm>()) {
FloatImm float_imm = *std::move(opt_float_imm);
os << std::setprecision(20) << float_imm->value;
} else if (const auto* str = json_obj.as<ffi::StringObj>()) {
@@ -60,7 +60,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) {
std::vector<std::pair<String, ffi::Any>> key_values;
key_values.reserve(n);
for (const auto& kv : *dict) {
- if (auto key = kv.first.as<String>()) {
+ if (auto key = kv.first.try_cast<String>()) {
key_values.emplace_back(key.value(), kv.second);
} else {
LOG(FATAL) << "TypeError: Only string keys are supported in JSON
dumps, but got: "
diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index ad60f477b3..c43c8aac12 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -106,22 +106,22 @@ bool ParseAnnotation(const Block& block,
ParsedAnnotation* parsed) {
for (const auto& ann : block->annotations) {
if (ann.first == attr::meta_schedule_parallel) {
found = true;
- if (auto opt_int_imm = ann.second.as<IntImm>()) {
+ if (auto opt_int_imm = ann.second.try_cast<IntImm>()) {
parsed->max_parallel_extent = (*opt_int_imm)->value;
}
} else if (ann.first == attr::meta_schedule_vectorize) {
found = true;
- if (auto opt_int_imm = ann.second.as<IntImm>()) {
+ if (auto opt_int_imm = ann.second.try_cast<IntImm>()) {
parsed->max_vectorize_extent = (*opt_int_imm)->value;
}
} else if (ann.first == attr::meta_schedule_unroll_explicit) {
found = true;
- if (auto opt_int_imm = ann.second.as<IntImm>()) {
+ if (auto opt_int_imm = ann.second.try_cast<IntImm>()) {
parsed->unroll_explicit = (*opt_int_imm)->value;
}
} else if (ann.first == attr::meta_schedule_unroll_implicit) {
found = true;
- if (auto opt_int_imm = ann.second.as<IntImm>()) {
+ if (auto opt_int_imm = ann.second.try_cast<IntImm>()) {
parsed->unroll_implicit = (*opt_int_imm)->value;
}
}
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index e689edb425..21483d3b98 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -424,9 +424,9 @@ inline Array<FloatImm> AsFloatArray(const ObjectRef& obj) {
results.reserve(arr->size());
for (Any val : *arr) {
auto float_value = [&]() -> FloatImm {
- if (auto opt_int_imm = val.as<IntImm>()) {
+ if (auto opt_int_imm = val.try_cast<IntImm>()) {
return FloatImm(DataType::Float(32), (*opt_int_imm)->value);
- } else if (auto opt_float_imm = val.as<FloatImm>()) {
+ } else if (auto opt_float_imm = val.try_cast<FloatImm>()) {
return *std::move(opt_float_imm);
} else {
LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: "
<< val.GetTypeKey();
@@ -451,7 +451,7 @@ inline Array<Integer> AsIntArray(const ObjectRef& obj) {
results.reserve(arr->size());
for (Any val : *arr) {
auto int_value = [&]() -> int64_t {
- if (auto opt_int_imm = val.as<IntImm>()) {
+ if (auto opt_int_imm = val.try_cast<IntImm>()) {
return (*opt_int_imm)->value;
} else {
LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " <<
val.GetTypeKey();
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index 698483d068..7f74012c0d 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -119,7 +119,7 @@ class NodeIndexer : public AttrVisitor {
} else if (auto opt_map = node.as<const ffi::MapObj*>()) {
const ffi::MapObj* n = opt_map.value();
bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) {
- return v.first.template as<const ffi::StringObj*>().has_value();
+ return v.first.template as<const ffi::StringObj*>();
});
if (is_str_map) {
for (const auto& kv : *n) {
@@ -280,7 +280,7 @@ class JSONAttrGetter : public AttrVisitor {
} else if (auto opt_map = node.as<const ffi::MapObj*>()) {
const ffi::MapObj* n = opt_map.value();
bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) {
- return v.first.template as<const ffi::StringObj*>().has_value();
+ return v.first.template as<const ffi::StringObj*>();
});
if (is_str_map) {
for (const auto& kv : *n) {
diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc
index e7ecd0802c..751617914e 100644
--- a/src/runtime/c_runtime_api.cc
+++ b/src/runtime/c_runtime_api.cc
@@ -733,7 +733,7 @@ int TVMCbArgToReturn(TVMValue* value, int* code) {
API_BEGIN();
AnyView arg = LegacyTVMArgValueToAnyView(*value, *code);
Any rv;
- if (auto opt_rv = arg.as<tvm::ffi::RValueRef<tvm::ffi::ObjectRef>>()) {
+ if (auto opt_rv = arg.try_cast<tvm::ffi::RValueRef<tvm::ffi::ObjectRef>>()) {
rv = *std::move(*std::move(opt_rv));
} else {
rv = arg;
diff --git a/src/runtime/contrib/json/json_runtime.h
b/src/runtime/contrib/json/json_runtime.h
index 3f42e109f8..025e85263e 100644
--- a/src/runtime/contrib/json/json_runtime.h
+++ b/src/runtime/contrib/json/json_runtime.h
@@ -124,7 +124,7 @@ class JSONRuntimeBase : public ModuleNode {
// Bind argument tensors to data entries.
this->SetInputOutputBuffers(args);
- if (auto opt_str = rv->as<String>()) {
+ if (auto opt_str = rv->try_cast<String>()) {
String purpose = std::move(opt_str.value());
if ("debug_dump" == purpose) {
*rv = this->DebugDump();
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 399109131a..23aee70ff6 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -123,7 +123,7 @@ void MatchShape(ffi::PackedArgs args, Any* rv) {
} else {
input_shape = args[0].cast<ffi::Shape>();
}
- auto heap = args[1].as<DLTensor*>();
+ auto heap = args[1].try_cast<DLTensor*>();
int64_t* heap_data = heap.has_value() ? static_cast<int64_t*>((*heap)->data)
: nullptr;
int64_t size = args[2].cast<int64_t>();
const int64_t kBeginCode = 3;
@@ -192,7 +192,7 @@
TVM_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue);
*/
void MakeShape(ffi::PackedArgs args, Any* rv) {
// NOTE: heap can be nullptr
- auto heap = args[0].as<DLTensor*>();
+ auto heap = args[0].try_cast<DLTensor*>();
int64_t* heap_data = heap.has_value() ? static_cast<int64_t*>((*heap)->data)
: nullptr;
int64_t size = args[1].cast<int64_t>();
const int64_t kBeginCode = 2;
@@ -235,7 +235,7 @@ void CheckTensorInfo(ffi::PackedArgs args, Any* rv) {
err_ctx = args[3].cast<Optional<String>>();
}
- auto opt_ptr = arg.as<DLTensor*>();
+ auto opt_ptr = arg.try_cast<DLTensor*>();
CHECK(opt_ptr.has_value()) << "TypeError: " << err_ctx.value_or("") << "
expect a Tensor but get "
<< arg.GetTypeKey();
@@ -422,7 +422,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.to_device")
* \return Bool
*/
bool ReadIfCond(AnyView cond) {
- if (auto opt_int = cond.as<bool>()) {
+ if (auto opt_int = cond.try_cast<bool>()) {
return opt_int.value();
}
NDArray arr = cond.cast<tvm::runtime::NDArray>();
diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc
b/src/runtime/relax_vm/ndarray_cache_support.cc
index 427af3cdeb..ef60f5b870 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.cc
+++ b/src/runtime/relax_vm/ndarray_cache_support.cc
@@ -364,7 +364,7 @@
TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name_unpacked")
Array<String> names;
names.reserve(args.size());
for (int i = 0; i < args.size(); ++i) {
- if (!args[i].as<String>()) {
+ if (!args[i].try_cast<String>()) {
LOG(FATAL) << "ValueError: Expect string as input, but get " <<
args[i].GetTypeKey()
<< " at " << i;
}
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 67bb19fb9b..dcaa1a3726 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -788,7 +788,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame,
Instruction instr) {
}
int ret_kind = static_cast<int>(VMInstrumentReturnKind::kNoOp);
instrument_.CallPacked(call_args.data(), call_args.size(), &rv);
- if (auto opt_int = rv.as<int64_t>()) {
+ if (auto opt_int = rv.try_cast<int64_t>()) {
ret_kind = opt_int.value();
}
if (ret_kind != static_cast<int>(VMInstrumentReturnKind::kSkipRun)) {
diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc
index a9e897f1a3..d757eb718a 100644
--- a/src/runtime/rpc/rpc_channel.cc
+++ b/src/runtime/rpc/rpc_channel.cc
@@ -41,7 +41,7 @@ size_t CallbackChannel::Send(const void* data, size_t size) {
size_t CallbackChannel::Recv(void* data, size_t size) {
Any ret = frecv_(size);
- auto opt_bytes = ret.as<ffi::Bytes>();
+ auto opt_bytes = ret.try_cast<ffi::Bytes>();
CHECK(opt_bytes.has_value()) << "CallbackChannel::Recv";
ffi::Bytes bytes = std::move(opt_bytes.value());
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index fe2e3d7d97..b4f94b2b89 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -486,7 +486,7 @@ AllocateConstFrame AllocateConst(tvm::runtime::NDArray
data, DataType dtype,
AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) {
// convert POD value to PrimExpr
if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
- node = node.as<PrimExpr>().value();
+ node = node.cast<PrimExpr>();
}
ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
n->node = node.cast<ObjectRef>();
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index baf374065d..f35c00b4fb 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -255,7 +255,7 @@ class TestingEventLogger {
};
TVM_REGISTER_GLOBAL("testing.record_event").set_body_packed([](ffi::PackedArgs
args, ffi::Any* rv) {
- if (args.size() != 0 && args[0].as<String>()) {
+ if (args.size() != 0 && args[0].try_cast<String>()) {
TestingEventLogger::ThreadLocal()->Record(args[0].cast<String>());
} else {
TestingEventLogger::ThreadLocal()->Record("X");
diff --git a/src/target/target.cc b/src/target/target.cc
index c90244654a..6885cf0cff 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -111,7 +111,7 @@ static std::vector<String> DeduplicateKeys(const
std::vector<String>& keys) {
template <class T>
static T ObjTypeCheck(const Any& obj, const std::string& expected_type) {
- auto opt = obj.as<T>();
+ auto opt = obj.try_cast<T>();
if (!opt.has_value()) {
TVM_FFI_THROW(TypeError) << "Expects type \"" << expected_type << "\", but
gets \""
<< obj.GetTypeKey() << "\" for object: " << obj;
@@ -426,7 +426,7 @@ Any TargetInternal::ParseType(const Any& obj, const
TargetKindNode::ValueTypeInf
// Parsing target
if (auto opt = obj.as<Target>()) {
return opt.value();
- } else if (auto str = obj.as<String>()) {
+ } else if (auto str = obj.try_cast<String>()) {
return Target(TargetInternal::FromString(str.value()));
} else if (const auto* ptr = obj.as<ffi::MapObj>()) {
for (const auto& kv : *ptr) {
@@ -487,9 +487,9 @@ Any TargetInternal::ParseType(const Any& obj, const
TargetKindNode::ValueTypeInf
std::string TargetInternal::StringifyAtomicType(const Any& obj) {
if (obj.type_index() == ffi::TypeIndex::kTVMFFIBool) {
- return std::to_string(obj.as<bool>().value());
+ return std::to_string(obj.cast<bool>());
} else if (obj.type_index() == ffi::TypeIndex::kTVMFFIInt) {
- return std::to_string(obj.as<int64_t>().value());
+ return std::to_string(obj.cast<int64_t>());
} else if (auto opt_str = obj.as<String>()) {
std::string s = opt_str.value();
auto u = Uninterpret(s);
@@ -761,9 +761,9 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs
args, ffi::Any* rv) {
const auto& arg = args[0];
if (auto opt_target = arg.as<Target>()) {
*rv = Target(opt_target.value());
- } else if (auto opt_str = arg.as<String>()) {
+ } else if (auto opt_str = arg.try_cast<String>()) {
*rv = Target(opt_str.value());
- } else if (auto opt_map = arg.as<Map<String, ffi::Any>>()) {
+ } else if (auto opt_map = arg.try_cast<Map<String, ffi::Any>>()) {
*rv = Target(opt_map.value());
} else {
LOG(FATAL) << "TypeError: Cannot create target with type: " <<
args[0].GetTypeKey();
@@ -850,7 +850,7 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String,
ffi::Any> config) {
// parse 'kind'
if (config.count(kKind)) {
- if (auto kind = config[kKind].as<String>()) {
+ if (auto kind = config[kKind].try_cast<String>()) {
target->kind = GetTargetKind(kind.value());
ICHECK(!(target->kind->preprocessor != nullptr &&
target->kind->target_parser != nullptr))
<< "Cannot use both set_attrs_preprocessor and set_target_parser";
@@ -875,7 +875,7 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String,
ffi::Any> config) {
}
// parse "tag"
if (config.count(kTag)) {
- if (auto tag = config[kTag].as<String>()) {
+ if (auto tag = config[kTag].try_cast<String>()) {
target->tag = tag.value();
config.erase(kTag);
} else {
@@ -893,7 +893,7 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String,
ffi::Any> config) {
// user provided keys
if (const auto* cfg_keys = config[kKeys].as<ffi::ArrayObj>()) {
for (const Any& e : *cfg_keys) {
- if (auto key = e.as<String>()) {
+ if (auto key = e.try_cast<String>()) {
keys.push_back(key.value());
} else {
TVM_FFI_THROW(TypeError) << "Expect 'keys' to be an array of
strings, but it "
@@ -907,7 +907,7 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String,
ffi::Any> config) {
}
// add device name
if (config.count(kDeviceName)) {
- if (auto device = config.at(kDeviceName).as<String>()) {
+ if (auto device = config.at(kDeviceName).try_cast<String>()) {
keys.push_back(device.value());
}
}
@@ -945,7 +945,7 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String,
ffi::Any> config) {
// If requested, query attributes from the device. User-specified
// parameters take precedence over queried parameters.
if (attrs.count("from_device")) {
- int device_id = attrs.at("from_device").as<int64_t>().value();
+ int device_id = attrs.at("from_device").cast<int64_t>();
attrs.erase("from_device");
auto device_params = QueryDevice(device_id, target.get());
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index ff84bbd665..6f206b567c 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -145,7 +145,7 @@ void CheckOrSetAttr(Map<String, ffi::Any>* attrs, const
String& name, const Stri
if (iter == attrs->end()) {
attrs->Set(name, value);
} else {
- auto str = (*iter).second.as<String>();
+ auto str = (*iter).second.try_cast<String>();
ICHECK(str && str.value() == value) << "ValueError: Expects \"" << name <<
"\" to be \""
<< value << "\", but gets: " <<
(*iter).second;
}
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index 0355c8fd9b..05fd5ae64a 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -299,7 +299,7 @@ Map<String, ffi::Any> GenerateBlockAnnotations(const
te::ComputeOp& compute_op,
CreateFuncInfo* info) {
Map<String, ffi::Any> annotations;
auto mutate_attr = [&info](const ffi::Any& value) -> ffi::Any {
- if (auto tensor_value = value.as<te::Tensor>()) {
+ if (auto tensor_value = value.try_cast<te::Tensor>()) {
return info->tensor2buffers.at(tensor_value.value());
} else {
return value;
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index bbf77b7301..5f85b74c8d 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -27,8 +27,6 @@
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
-#include "utils.h"
-
namespace tvm {
namespace tir {
namespace {
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 0a7817c02a..6376d1edfd 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -27,7 +27,6 @@
#include <tvm/tir/stmt.h>
#include "buffer_common.h"
-#include "utils.h"
namespace tvm {
namespace tir {
@@ -75,9 +74,9 @@ TVM_REGISTER_GLOBAL("tir.AttrStmt")
.set_body_typed([](Any node, String attr_key, PrimExpr value, Stmt body,
Span span) {
// when node is a POD data type like int or bool, first convert to
primexpr.
if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
- return AttrStmt(node.as<PrimExpr>().value(), attr_key, value, body,
span);
+ return AttrStmt(node.cast<PrimExpr>(), attr_key, value, body, span);
}
- return AttrStmt(node.as<ObjectRef>().value(), attr_key, value, body,
span);
+ return AttrStmt(node.cast<ObjectRef>(), attr_key, value, body, span);
});
TVM_REGISTER_NODE_TYPE(AttrStmtNode);
diff --git a/src/tir/ir/utils.cc b/src/tir/ir/utils.cc
deleted file mode 100644
index 65495bd8dd..0000000000
--- a/src/tir/ir/utils.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/tir/ir/utils.cc
- * \brief Utilities for manipulating TIR
- */
-#include "utils.h"
-
-#include <tvm/ir/attrs.h>
-
-namespace tvm {
-namespace tir {
-
-ffi::Any NormalizeAttributeObject(ffi::Any obj) {
- if (obj.type_index() == ffi::TypeIndex::kTVMFFIBool) {
- return Bool(obj.cast<bool>());
- } else if (auto opt_int = obj.as<int>()) {
- return Integer(opt_int.value());
- } else if (auto opt_float = obj.as<double>()) {
- return FloatImm(DataType::Float(32), opt_float.value());
- } else if (auto opt_array = obj.as<Array<ffi::Any>>()) {
- return opt_array.value().Map(NormalizeAttributeObject);
- } else if (auto opt_map = obj.as<Map<ffi::Any, ffi::Any>>()) {
- Map<ffi::Any, ffi::Any> new_map;
- bool is_same = true;
-
- for (const auto& [key, obj] : opt_map.value()) {
- ObjectRef new_obj =
NormalizeAttributeObject(obj.cast<ObjectRef>()).cast<ObjectRef>();
- is_same = is_same && obj.same_as(new_obj);
- new_map.Set(key, new_obj);
- }
-
- if (is_same) {
- return obj;
- } else {
- return new_map;
- }
- } else if (auto dict_attrs = obj.as<DictAttrs::ContainerType>()) {
- auto new_attrs = Downcast<Map<String,
ffi::Any>>(NormalizeAttributeObject(dict_attrs->dict));
- if (new_attrs.same_as(dict_attrs->dict)) {
- return GetRef<DictAttrs>(dict_attrs);
- } else {
- return DictAttrs(new_attrs);
- }
- } else {
- return obj;
- }
-}
-
-} // namespace tir
-} // namespace tvm
diff --git a/src/tir/ir/utils.h b/src/tir/ir/utils.h
deleted file mode 100644
index c19a850a70..0000000000
--- a/src/tir/ir/utils.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tir/ir/utils.h
- * \brief Utilities for manipulating TIR
- */
-#ifndef TVM_TIR_IR_UTILS_H_
-#define TVM_TIR_IR_UTILS_H_
-
-#include <tvm/tir/expr.h>
-
-namespace tvm {
-namespace tir {
-
-/* \brief Normalize an ObjectRef held
- *
- * Where possible, the IR should be normalized contain IR types. For
- * example, holding a `tir::IntImm` instead of a `runtime::Int`. In
- * attributes, this is not always possible, as attributes may refer to
- * non-IR objects.
- *
- * \param obj The attribute object to be normalized
- *
- * \returns The normalized attribute
- */
-ffi::Any NormalizeAttributeObject(ffi::Any obj);
-
-} // namespace tir
-} // namespace tvm
-#endif // TVM_TIR_IR_UTILS_H_
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 838af436a6..9752c052a1 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -1072,9 +1072,9 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace")
// expose basic functions to node namespace
TVM_REGISTER_GLOBAL("node._const").set_body_packed([](ffi::PackedArgs args,
ffi::Any* ret) {
- if (auto opt = args[0].as<int64_t>()) {
+ if (auto opt = args[0].try_cast<int64_t>()) {
*ret = tir::make_const(args[1].cast<DataType>(), *opt,
args[2].cast<Span>());
- } else if (auto opt = args[0].as<double>()) {
+ } else if (auto opt = args[0].try_cast<double>()) {
*ret = tir::make_const(args[1].cast<DataType>(), *opt,
args[2].cast<Span>());
} else {
LOG(FATAL) << "First argument to tvm.tir.const must be int, float, or
bool, "
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 30644c5fd1..edaccb51d6 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -913,7 +913,7 @@ void ConcreteScheduleNode::Tensorize(const BlockRV&
block_rv, const String& intr
/******** Schedule: Annotation ********/
Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) {
- if (auto opt_str = ann_val.as<ffi::String>()) {
+ if (auto opt_str = ann_val.try_cast<ffi::String>()) {
return *std::move(opt_str);
}
@@ -921,10 +921,10 @@ Any
ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) {
return ann_val;
}
// prefer to return int/float literals for annotations
- if (auto opt_intimm = ann_val.as<IntImm>()) {
+ if (auto opt_intimm = ann_val.try_cast<IntImm>()) {
return (*std::move(opt_intimm))->value;
}
- if (auto opt_floatimm = ann_val.as<FloatImm>()) {
+ if (auto opt_floatimm = ann_val.try_cast<FloatImm>()) {
return (*std::move(opt_floatimm))->value;
}
@@ -956,7 +956,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const
ffi::Any& ann_val) {
auto value = CheckAndGetAnnotationValue(it->second);
if (const StringImmNode* imm = key.as<StringImmNode>()) {
result.Set(imm->value, value);
- } else if (auto opt_str = key.as<ffi::String>()) {
+ } else if (auto opt_str = key.try_cast<ffi::String>()) {
result.Set(opt_str.value(), value);
} else {
LOG(FATAL) << "TypeError: annotation dict key expect to be String or
StringImm";
diff --git a/src/tir/schedule/instruction_traits.h
b/src/tir/schedule/instruction_traits.h
index b22615c4e2..cbd5185ff8 100644
--- a/src/tir/schedule/instruction_traits.h
+++ b/src/tir/schedule/instruction_traits.h
@@ -420,9 +420,9 @@ inline void PythonAPICall::AsPythonString(const Any& obj,
std::ostream& os) {
os << "None";
} else if (const auto* str = obj.as<ffi::StringObj>()) {
os << str->data;
- } else if (const auto opt_int_imm = obj.as<IntImm>()) {
+ } else if (const auto opt_int_imm = obj.try_cast<IntImm>()) {
os << (*opt_int_imm)->value;
- } else if (const auto opt_float_imm = obj.as<FloatImm>()) {
+ } else if (const auto opt_float_imm = obj.try_cast<FloatImm>()) {
os.precision(17);
os << (*opt_float_imm)->value;
} else if (const auto* array = obj.as<ffi::ArrayObj>()) {
diff --git a/src/tir/schedule/primitive/block_annotate.cc
b/src/tir/schedule/primitive/block_annotate.cc
index 40725871cf..0e2a055d7a 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -173,7 +173,7 @@ class StorageAlignInvalidAnnotationError : public
ScheduleError {
private:
static bool IsValidAnnotation(const Block& block, const Any& anno_value) {
- return anno_value.as<ffi::Array<ffi::Tuple<int, int, int,
int>>>().has_value();
+ return anno_value.try_cast<ffi::Array<ffi::Tuple<int, int, int,
int>>>().has_value();
}
IRModule mod_;
diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc
index f6fa76da78..883bd65ce3 100644
--- a/src/tir/schedule/trace.cc
+++ b/src/tir/schedule/trace.cc
@@ -78,7 +78,7 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs,
auto it = rv_map.find(input.as<Object>());
ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't
exist: " << input;
result.push_back(GetRef<ObjectRef>(it->second));
- } else if (auto expr = input.as<PrimExpr>()) { // RV: Expr
+ } else if (auto expr = input.try_cast<PrimExpr>()) { // RV: Expr
result.push_back(Substitute(expr.value(), f_subst_with_rv_map));
} else if (auto index_map = input.as<IndexMap>()) {
result.push_back(Substitute(index_map.value(), f_subst_with_rv_map));
@@ -400,7 +400,7 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule
sch) {
try {
const ffi::ArrayObj* arr = decision_entry.as<ffi::ArrayObj>();
ICHECK(arr && arr->size() == 2);
- auto arr0 = arr->at(0).as<IntImm>();
+ auto arr0 = arr->at(0).try_cast<IntImm>();
ICHECK(arr0);
index = arr0.value()->value;
decision = arr->at(1);
diff --git a/src/tir/transforms/inject_permuted_layout.cc
b/src/tir/transforms/inject_permuted_layout.cc
index a5f29f624c..8a1f4b1ff5 100644
--- a/src/tir/transforms/inject_permuted_layout.cc
+++ b/src/tir/transforms/inject_permuted_layout.cc
@@ -108,7 +108,7 @@ class PermutedLayoutInjector : private
IRMutatorWithAnalyzer {
return GetRef<String>(node) != "";
} else if (auto* node = annotation.as<IntImmNode>()) {
return node->value != 0;
- } else if (auto opt_val = annotation.as<int64_t>()) {
+ } else if (auto opt_val = annotation.try_cast<int64_t>()) {
return *opt_val != 0;
} else {
LOG(FATAL) << "Invalid permuted layout annotation: " << annotation;
diff --git a/src/tir/transforms/lower_opaque_block.cc
b/src/tir/transforms/lower_opaque_block.cc
index 067e7d1b05..9939ac9dec 100644
--- a/src/tir/transforms/lower_opaque_block.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -150,9 +150,9 @@ class OpaqueBlockLower : public StmtExprMutator {
PrimExpr ConvertAttrValue(const String& key, const Any& obj) {
if (obj == nullptr) {
return PrimExpr();
- } else if (auto expr = obj.as<PrimExpr>()) {
+ } else if (auto expr = obj.try_cast<PrimExpr>()) {
return expr.value();
- } else if (auto str = obj.as<String>()) {
+ } else if (auto str = obj.try_cast<String>()) {
return std::move(StringImm(str.value()));
} else {
LOG(FATAL) << "Illegal attribute of key " << key << ", value type " <<
obj.GetTypeKey()
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index cf86242d84..450aded945 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -86,7 +86,7 @@
TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body_packed([](ffi::PackedArgs args
});
TVM_REGISTER_GLOBAL("topi.split").set_body_packed([](ffi::PackedArgs args,
ffi::Any* rv) {
- if (args[1].as<int>()) {
+ if (args[1].try_cast<int>()) {
*rv = split_n_sections(args[0].cast<te::Tensor>(), args[1].cast<int>(),
args[2].cast<int>());
} else {
*rv = split_indices_array(args[0].cast<te::Tensor>(),
args[1].cast<Array<Integer>>(),