This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 695ac175b919240816c99783d2dff567839a5536 Author: tqchen <[email protected]> AuthorDate: Sun Dec 29 07:38:59 2024 +0800 [FFI] bool, device, tensor, raw str --- ffi/include/tvm/ffi/c_api.h | 2 +- ffi/include/tvm/ffi/function_details.h | 7 +- ffi/include/tvm/ffi/reflection.h | 24 ++-- ffi/include/tvm/ffi/string.h | 26 +++++ ffi/include/tvm/ffi/type_traits.h | 196 +++++++++++++++++++++++++++++++-- ffi/src/ffi/object.cc | 2 +- ffi/tests/cpp/test_any.cc | 129 ++++++++++++++++++++++ ffi/tests/cpp/test_function.cc | 4 + ffi/tests/cpp/test_reflection.cc | 6 +- 9 files changed, 365 insertions(+), 31 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 18c44e40da..5a626d03c0 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -358,7 +358,7 @@ TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInf * \brief Register type method information for rutnime reflection. * \param type_index The type index * \param info The method info to be registered. - * \return 0 when success, nonzero when failure happens + * \return 0 when success, nonzero when failure happens */ TVM_FFI_DLL int TVMFFIRegisterTypeMethod(int32_t type_index, const TVMFFIMethodInfo* info); diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h index 0f6701e509..90d842d000 100644 --- a/ffi/include/tvm/ffi/function_details.h +++ b/ffi/include/tvm/ffi/function_details.h @@ -161,7 +161,7 @@ class MovableArgValueWithContext { template <typename Type> TVM_FFI_INLINE operator Type() { using TypeWithoutCR = std::remove_const_t<std::remove_reference_t<Type>>; - std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]); + std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]); if (opt.has_value()) { return std::move(*opt); } @@ -210,9 +210,8 @@ struct unpack_call_dispatcher<R, 0, index, F> { template <int index, typename F> struct unpack_call_dispatcher<void, 0, index, F> { template <typename... Args> - TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature , - const F& , int32_t , const AnyView* , Any* , - Args&&... unpacked_args) { + TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const F&, int32_t, + const AnyView*, Any*, Args&&... unpacked_args) { f(std::forward<Args>(unpacked_args)...); } }; diff --git a/ffi/include/tvm/ffi/reflection.h b/ffi/include/tvm/ffi/reflection.h index 51573d1bc2..3e6a1ecc37 100644 --- a/ffi/include/tvm/ffi/reflection.h +++ b/ffi/include/tvm/ffi/reflection.h @@ -52,7 +52,7 @@ struct Type2FieldStaticTypeIndex<T, std::enable_if_t<TypeTraits<T>::enabled>> { * \returns The byteoffset */ template <typename Class, typename T> -inline int64_t GetFieldByteOffset(T Class::*field_ptr) { +inline int64_t GetFieldByteOffset(T Class::* field_ptr) { return reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr)); } @@ -61,13 +61,13 @@ class ReflectionDef { explicit ReflectionDef(int32_t type_index) : type_index_(type_index) {} template <typename Class, typename T> - ReflectionDef& def_readonly(const char* name, T Class::*field_ptr) { + ReflectionDef& def_readonly(const char* name, T Class::* field_ptr) { RegisterField(name, field_ptr, true); return *this; } template <typename Class, typename T> - ReflectionDef& def_readwrite(const char* name, T Class::*field_ptr) { + ReflectionDef& def_readwrite(const char* name, T Class::* field_ptr) { RegisterField(name, field_ptr, false); return *this; } @@ -76,7 +76,7 @@ class ReflectionDef { private: template <typename Class, typename T> - void RegisterField(const char* name, T Class::*field_ptr, bool readonly) { + void RegisterField(const char* name, T Class::* field_ptr, bool readonly) { TVMFFIFieldInfo info; info.name = name; info.field_static_type_index = Type2FieldStaticTypeIndex<T>::value; @@ -126,23 +126,19 @@ inline const TVMFFIFieldInfo* GetReflectionFieldInfo(const char* type_key, const */ class ReflectionFieldGetter { public: - explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) { - } + explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} - Any operator()(const Object* obj_ptr) const { + Any operator()(const Object* obj_ptr) const { Any result; const void* addr = reinterpret_cast<const char*>(obj_ptr) + field_info_->byte_offset; - TVM_FFI_CHECK_SAFE_CALL(field_info_->getter(const_cast<void*>(addr), reinterpret_cast<TVMFFIAny*>(&result))); + TVM_FFI_CHECK_SAFE_CALL( + field_info_->getter(const_cast<void*>(addr), reinterpret_cast<TVMFFIAny*>(&result))); return result; } - Any operator()(const ObjectPtr<Object>& obj_ptr) const { - return operator()(obj_ptr.get()); - } + Any operator()(const ObjectPtr<Object>& obj_ptr) const { return operator()(obj_ptr.get()); } - Any operator()(const ObjectRef& obj) const { - return operator()(obj.get()); - } + Any operator()(const ObjectRef& obj) const { return operator()(obj.get()); } private: const TVMFFIFieldInfo* field_info_; diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index 7cbed2f7e8..debd361e7d 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -28,6 +28,7 @@ #include <tvm/ffi/error.h> #include <tvm/ffi/memory.h> #include <tvm/ffi/object.h> +#include <tvm/ffi/type_traits.h> #include <cstddef> #include <cstring> @@ -305,6 +306,30 @@ class String : public ObjectRef { friend struct AnyEqual; }; +template <> +inline constexpr bool use_default_type_traits_v<String> = false; + +// specialize to enable implicit conversion from const char* +template <> +struct TypeTraits<String> : public ObjectRefTypeTraitsBase<String> { + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIRawStr) return true; + return ObjectRefTypeTraitsBase<String>::CheckAnyView(src); + } + + static TVM_FFI_INLINE String CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIRawStr) { + return String(src->v_c_str); + } + return ObjectRefTypeTraitsBase<String>::CopyFromAnyViewAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional<String> TryCopyFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIRawStr) return String(src->v_c_str); + return ObjectRefTypeTraitsBase<String>::TryCopyFromAnyView(src); + } +}; + inline String operator+(const String& lhs, const String& rhs) { size_t lhs_size = lhs.size(); size_t rhs_size = rhs.size(); @@ -421,6 +446,7 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, s return 0; } } + } // namespace ffi } // namespace tvm diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index cf88eb7319..69b3738be0 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -44,6 +44,8 @@ inline std::string TypeIndex2TypeKey(int32_t type_index) { switch (type_index) { case TypeIndex::kTVMFFINone: return "None"; + case TypeIndex::kTVMFFIBool: + return "bool"; case TypeIndex::kTVMFFIInt: return "int"; case TypeIndex::kTVMFFIFloat: @@ -154,14 +156,14 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeT static TVM_FFI_INLINE void MoveToAny(Int src, TVMFFIAny* result) { CopyToAnyView(src, result); } static TVM_FFI_INLINE std::optional<Int> TryCopyFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIInt) { + if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { return std::make_optional<Int>(src->v_int64); } return std::nullopt; } static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIInt; + return src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool; } static TVM_FFI_INLINE int CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { @@ -171,6 +173,36 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeT static TVM_FFI_INLINE std::string TypeStr() { return "int"; } }; +// Bool type, allow implicit casting from int +template <> +struct TypeTraits<bool> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; + + static TVM_FFI_INLINE void CopyToAnyView(const bool& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIBool; + result->v_int64 = static_cast<int64_t>(src); + } + + static TVM_FFI_INLINE void MoveToAny(bool src, TVMFFIAny* result) { CopyToAnyView(src, result); } + + static TVM_FFI_INLINE std::optional<bool> TryCopyFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { + return std::make_optional<bool>(static_cast<bool>(src->v_int64)); + } + return std::nullopt; + } + + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool; + } + + static TVM_FFI_INLINE bool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return static_cast<bool>(src->v_int64); + } + + static TVM_FFI_INLINE std::string TypeStr() { return "bool"; } +}; + // Float POD values template <typename Float> struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> @@ -187,14 +219,16 @@ struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> static TVM_FFI_INLINE std::optional<Float> TryCopyFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIFloat) { return std::make_optional<Float>(src->v_float64); - } else if (src->type_index == TypeIndex::kTVMFFIInt) { + } else if (src->type_index == TypeIndex::kTVMFFIInt || + src->type_index == TypeIndex::kTVMFFIBool) { return std::make_optional<Float>(src->v_int64); } return std::nullopt; } static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIFloat || src->type_index == TypeIndex::kTVMFFIInt; + return src->type_index == TypeIndex::kTVMFFIFloat || src->type_index == TypeIndex::kTVMFFIInt || + src->type_index == TypeIndex::kTVMFFIBool; } static TVM_FFI_INLINE Float CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { @@ -244,11 +278,154 @@ struct TypeTraits<void*> : public TypeTraitsBase { static TVM_FFI_INLINE std::string TypeStr() { return "void*"; } }; +// DataType +template <> +struct TypeTraits<DLDataType> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType; + + static TVM_FFI_INLINE void CopyToAnyView(const DLDataType& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDataType; + result->v_dtype = src; + } + + static TVM_FFI_INLINE void MoveToAny(DLDataType src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDataType; + result->v_dtype = src; + } + + static TVM_FFI_INLINE std::optional<DLDataType> TryCopyFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIDataType) { + return src->v_dtype; + } + return std::nullopt; + } + + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIDataType; + } + + static TVM_FFI_INLINE DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return src->v_dtype; + } + + static TVM_FFI_INLINE std::string TypeStr() { return "DataType"; } +}; + +// Device +template <> +struct TypeTraits<DLDevice> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDevice; + + static TVM_FFI_INLINE void CopyToAnyView(const DLDevice& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDevice; + result->v_device = src; + } + + static TVM_FFI_INLINE void MoveToAny(DLDevice src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDevice; + result->v_device = src; + } + + static TVM_FFI_INLINE std::optional<DLDevice> TryCopyFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIDevice) { + return src->v_device; + } + return std::nullopt; + } + + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIDevice; + } + + static TVM_FFI_INLINE DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return src->v_device; + } + + static TVM_FFI_INLINE std::string TypeStr() { return "Device"; } +}; + +// DLTensor*, requirement: not nullable, do not retain ownership +template <> +struct TypeTraits<DLTensor*> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDLTensorPtr; + + static TVM_FFI_INLINE void CopyToAnyView(DLTensor* src, TVMFFIAny* result) { + TVM_FFI_ICHECK_NOTNULL(src); + result->type_index = TypeIndex::kTVMFFIDLTensorPtr; + result->v_ptr = src; + } + + static TVM_FFI_INLINE void MoveToAny(DLTensor* src, TVMFFIAny* result) { + TVM_FFI_THROW(RuntimeError) + << "DLTensor* cannot be held in Any as it does not retain ownership, use NDArray instead"; + } + + static TVM_FFI_INLINE std::optional<DLTensor*> TryCopyFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { + return static_cast<DLTensor*>(src->v_ptr); + } + return std::nullopt; + } + + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIDLTensorPtr; + } + + static TVM_FFI_INLINE DLTensor* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return static_cast<DLTensor*>(src->v_ptr); + } + + static TVM_FFI_INLINE std::string TypeStr() { return "DLTensor*"; } +}; + +// const char*, requirement: not nullable, do not retain ownership +template <> +struct TypeTraits<const char*> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr; + + static TVM_FFI_INLINE void CopyToAnyView(const char* src, TVMFFIAny* result) { + TVM_FFI_ICHECK_NOTNULL(src); + result->type_index = TypeIndex::kTVMFFIRawStr; + result->v_c_str = src; + } + + static TVM_FFI_INLINE void MoveToAny(const char* src, TVMFFIAny* result) { + TVM_FFI_THROW(RuntimeError) + << "const char* cannot be held in Any as it does not retain ownership, use NDArray instead"; + } + + static TVM_FFI_INLINE std::optional<const char*> TryCopyFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { + return static_cast<const char*>(src->v_c_str); + } + return std::nullopt; + } + + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIRawStr; + } + + static TVM_FFI_INLINE const char* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return static_cast<const char*>(src->v_ptr); + } + + static TVM_FFI_INLINE std::string TypeStr() { return "const char*"; } +}; + +template <int N> +struct TypeTraits<char[N]> : public TypeTraitsBase { + // NOTE: only enable implicit conversion into AnyView + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr; + + static TVM_FFI_INLINE void CopyToAnyView(const char src[N], TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIRawStr; + result->v_c_str = src; + } +}; + // Traits for ObjectRef template <typename TObjRef> -struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef> && - use_default_type_traits_v<TObjRef>>> - : public TypeTraitsBase { +struct ObjectRefTypeTraitsBase : public TypeTraitsBase { static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIObject; using ContainerType = typename TObjRef::ContainerType; @@ -292,6 +469,11 @@ struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef static TVM_FFI_INLINE std::string TypeStr() { return ContainerType::_type_key; } }; +template <typename TObjRef> +struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef> && + use_default_type_traits_v<TObjRef>>> + : public ObjectRefTypeTraitsBase<TObjRef> {}; + // Traits for ObjectPtr template <typename T> struct TypeTraits<ObjectPtr<T>> : public TypeTraitsBase { diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index a0fe306e78..1cf37ac448 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -172,7 +172,7 @@ class TypeTable { return it->second; } - Entry* GetTypeEntry(int32_t type_index) { + Entry* GetTypeEntry(int32_t type_index) { Entry* entry = nullptr; if (type_index >= 0 && static_cast<size_t>(type_index) < type_table_.size()) { entry = type_table_[type_index].get(); diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc index 2cb3e8b95b..1fd312ecd2 100644 --- a/ffi/tests/cpp/test_any.cc +++ b/ffi/tests/cpp/test_any.cc @@ -60,6 +60,37 @@ TEST(Any, Int) { EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2); } +TEST(Any, bool) { + AnyView view0; + std::optional<bool> opt_v0 = view0.TryAs<bool>(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] bool v0 = view0; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `bool`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + AnyView view1 = true; + EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool); + EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); + + int32_t int_v1 = view1; + EXPECT_EQ(int_v1, 1); + + bool v1 = false; + view0 = v1; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool); + EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 0); +} + TEST(Any, Float) { AnyView view0; EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); @@ -94,6 +125,104 @@ TEST(Any, Float) { EXPECT_EQ(view0.CopyToTVMFFIAny().v_float64, 2); } +TEST(Any, DataType) { + AnyView view0; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + std::optional<DLDataType> opt_v0 = view0.TryAs<DLDataType>(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] DLDataType v0 = view0; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `DataType`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + DLDataType dtype{kDLFloat, 32, 1}; + + AnyView view1_dtype = dtype; + DLDataType dtype_v1 = view1_dtype; + EXPECT_EQ(dtype_v1.code, kDLFloat); + EXPECT_EQ(dtype_v1.bits, 32); + EXPECT_EQ(dtype_v1.lanes, 1); + + Any view2 = DLDataType{kDLInt, 16, 2}; + TVMFFIAny ffi_v2; + view2.MoveToTVMFFIAny(&ffi_v2); + EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDataType); + EXPECT_EQ(ffi_v2.v_dtype.code, kDLInt); + EXPECT_EQ(ffi_v2.v_dtype.bits, 16); + EXPECT_EQ(ffi_v2.v_dtype.lanes, 2); +} + +TEST(Any, Device) { + AnyView view0; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + std::optional<DLDevice> opt_v0 = view0.TryAs<DLDevice>(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] DLDevice v0 = view0; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `Device`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + DLDevice device{kDLCUDA, 1}; + + AnyView view1_device = device; + DLDevice dtype_v1 = view1_device; + EXPECT_EQ(dtype_v1.device_type, kDLCUDA); + EXPECT_EQ(dtype_v1.device_id, 1); + + Any view2 = DLDevice{kDLCPU, 0}; + TVMFFIAny ffi_v2; + view2.MoveToTVMFFIAny(&ffi_v2); + EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDevice); + EXPECT_EQ(ffi_v2.v_device.device_type, kDLCPU); + EXPECT_EQ(ffi_v2.v_device.device_id, 0); +} + +TEST(Any, DLTensor) { + AnyView view0; + + std::optional<DLTensor*> opt_v0 = view0.TryAs<DLTensor*>(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] DLTensor* v0 = view0; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `DLTensor*`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + DLTensor dltensor; + + AnyView view1_dl = &dltensor; + DLTensor* dl_v1 = view1_dl; + EXPECT_EQ(dl_v1, &dltensor); +} + TEST(Any, Object) { AnyView view0; EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc index 367e613665..ff098d70c9 100644 --- a/ffi/tests/cpp/test_function.cc +++ b/ffi/tests/cpp/test_function.cc @@ -112,6 +112,10 @@ TEST(Func, FromUnpacked) { } }, ::tvm::ffi::Error); + + Function fconcact = + Function::FromUnpacked([](const String& a, const String& b) -> String { return a + b; }); + EXPECT_EQ(fconcact("abc", "def").operator String(), "abcdef"); } TEST(Func, Global) { diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc index c7901a4ca3..9de0015009 100644 --- a/ffi/tests/cpp/test_reflection.cc +++ b/ffi/tests/cpp/test_reflection.cc @@ -20,6 +20,7 @@ #include <gtest/gtest.h> #include <tvm/ffi/object.h> #include <tvm/ffi/reflection.h> + #include "./testing_object.h" namespace { @@ -38,12 +39,9 @@ TEST(Reflection, GetFieldByteOffset) { EXPECT_EQ(details::GetFieldByteOffset(&A::y), 12); } - TEST(Reflection, FieldGetter) { ObjectRef a = TInt(10); - details::ReflectionFieldGetter getter( - details::GetReflectionFieldInfo("test.Int", "value") - ); + details::ReflectionFieldGetter getter(details::GetReflectionFieldInfo("test.Int", "value")); EXPECT_EQ(getter(a).operator int(), 10); } } // namespace
