This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch small-str-v1 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit d74c4334dd50c9ea4927e3984cfa38f4add2851f Author: tqchen <[email protected]> AuthorDate: Sat Aug 2 12:35:43 2025 -0400 [FFI] Bring up SmallBytes along with SmallStr --- ffi/include/tvm/ffi/any.h | 19 ++- ffi/include/tvm/ffi/c_api.h | 4 +- ffi/include/tvm/ffi/object.h | 1 + ffi/include/tvm/ffi/optional.h | 39 ++++--- ffi/include/tvm/ffi/string.h | 169 +++++++++++++++------------ ffi/src/ffi/extra/structural_equal.cc | 15 ++- ffi/src/ffi/object.cc | 1 + ffi/tests/cpp/test_any.cc | 7 ++ ffi/tests/cpp/test_optional.cc | 12 ++ ffi/tests/cpp/test_string.cc | 23 +++- jvm/native/src/main/native/jni_helper_func.h | 6 +- python/tvm/ffi/cython/base.pxi | 3 +- python/tvm/ffi/cython/dtype.pxi | 2 +- python/tvm/ffi/cython/function.pxi | 11 +- src/node/repr_printer.cc | 1 + src/node/serialization.cc | 15 ++- src/relax/backend/contrib/cublas/codegen.cc | 2 +- src/relax/backend/contrib/cudnn/codegen.cc | 2 +- src/relax/backend/contrib/hipblas/codegen.cc | 2 +- src/relax/backend/contrib/nnapi/codegen.cc | 2 +- src/runtime/rpc/rpc_module.cc | 5 +- src/script/ir_builder/tir/ir.cc | 3 +- src/support/ffi_testing.cc | 2 +- src/tir/ir/stmt.cc | 3 +- src/tir/schedule/concrete_schedule.cc | 3 +- src/tir/schedule/instruction.cc | 3 +- src/tir/schedule/trace.cc | 9 +- web/src/ctypes.ts | 2 + 28 files changed, 239 insertions(+), 127 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 51d68303c2..55eff8802a 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -554,6 +554,10 @@ struct AnyHash { // so heap allocated string and on stack string will have the same hash return details::StableHashCombine(TypeIndex::kTVMFFIStr, details::StableHashSmallStrBytes(&src.data_)); + } else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) { + // use byte the same type key as bytes + return details::StableHashCombine(TypeIndex::kTVMFFIBytes, + details::StableHashSmallStrBytes(&src.data_)); } else if (src.data_.type_index == TypeIndex::kTVMFFIStr || src.data_.type_index == TypeIndex::kTVMFFIBytes) { const details::BytesObjBase* src_str = @@ -598,7 +602,8 @@ struct AnyEqual { return false; } else { // type_index mismatch, if index is not string, return false - if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr) { + if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr && + lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index != kTVMFFIBytes) { return false; } // small string and normal string comparison @@ -614,6 +619,18 @@ struct AnyEqual { return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data, lhs.data_.small_str_len, rhs_str->size); } + if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index == kTVMFFISmallBytes) { + const details::BytesObjBase* lhs_bytes = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs); + return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes, lhs_bytes->size, + rhs.data_.small_str_len); + } + if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index == kTVMFFIBytes) { + const details::BytesObjBase* rhs_bytes = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs); + return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data, lhs.data_.small_str_len, + rhs_bytes->size); + } return false; } } diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index bb25dfed0d..11080a21f0 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -105,6 +105,8 @@ typedef enum { kTVMFFIObjectRValueRef = 10, /*! \brief Small string on stack */ kTVMFFISmallStr = 11, + /*! \brief Small bytes on stack */ + kTVMFFISmallBytes = 12, /*! \brief Start of statically defined objects. */ kTVMFFIStaticObjectBegin = 64, /*! @@ -917,7 +919,7 @@ inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) { * \param obj The object handle. * \return The content of the small string in bytearray format. */ -inline TVMFFIByteArray TVMFFISmallStrGetContentByteArray(const TVMFFIAny* value) { +inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) { return TVMFFIByteArray{value->v_bytes, static_cast<size_t>(value->small_str_len)}; } diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 74977e0216..4b7b56209a 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -61,6 +61,7 @@ struct StaticTypeKey { static constexpr const char* kTVMFFIArray = "ffi.Array"; static constexpr const char* kTVMFFIMap = "ffi.Map"; static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr"; + static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes"; }; /*! diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h index 9c9706f9b0..a52f64e483 100644 --- a/ffi/include/tvm/ffi/optional.h +++ b/ffi/include/tvm/ffi/optional.h @@ -54,7 +54,8 @@ inline constexpr bool use_ptr_based_optional_v = // Specialization for non-ObjectRef types. // simply fallback to std::optional template <typename T> -class Optional<T, std::enable_if_t<!use_ptr_based_optional_v<T> && !std::is_same_v<T, String>>> { +class Optional<T, std::enable_if_t<!use_ptr_based_optional_v<T> && !std::is_same_v<T, String> && + !std::is_same_v<T, Bytes>>> { public: // default constructors. Optional() = default; @@ -140,39 +141,39 @@ class Optional<T, std::enable_if_t<!use_ptr_based_optional_v<T> && !std::is_same }; // Specialization for String type, use nullptr to indicate nullopt -template <> -class Optional<String, void> { +template <typename T> +class Optional<T, std::enable_if_t<std::is_same_v<T, String> || std::is_same_v<T, Bytes>>> { public: // default constructors. Optional() = default; - Optional(const Optional<String>& other) : data_(other.data_) {} - Optional(Optional<String>&& other) : data_(std::move(other.data_)) {} + Optional(const Optional<T>& other) : data_(other.data_) {} + Optional(Optional<T>&& other) : data_(std::move(other.data_)) {} Optional(std::nullopt_t) {} // NOLINT(*) // normal value handling. - Optional(String other) // NOLINT(*) + Optional(T other) // NOLINT(*) : data_(std::move(other)) {} - TVM_FFI_INLINE Optional<String>& operator=(const Optional<String>& other) { + TVM_FFI_INLINE Optional<T>& operator=(const Optional<T>& other) { data_ = other.data_; return *this; } - TVM_FFI_INLINE Optional<String>& operator=(Optional<String>&& other) { + TVM_FFI_INLINE Optional<T>& operator=(Optional<T>&& other) { data_ = std::move(other.data_); return *this; } - TVM_FFI_INLINE Optional<String>& operator=(String other) { + TVM_FFI_INLINE Optional<T>& operator=(T other) { data_ = std::move(other); return *this; } - TVM_FFI_INLINE Optional<String>& operator=(std::nullopt_t) { - String(details::BytesBaseCell(std::nullopt)).swap(data_); + TVM_FFI_INLINE Optional<T>& operator=(std::nullopt_t) { + T(details::BytesBaseCell(std::nullopt)).swap(data_); return *this; } - TVM_FFI_INLINE const String& value() const& { + TVM_FFI_INLINE const T& value() const& { if (data_.data_ == std::nullopt) { TVM_FFI_THROW(RuntimeError) << "Back optional access"; } @@ -186,8 +187,8 @@ class Optional<String, void> { return std::move(data_); } - template <typename U = String> - TVM_FFI_INLINE String value_or(U&& default_value) const { + template <typename U = T> + TVM_FFI_INLINE T value_or(U&& default_value) const { if (data_.data_ == std::nullopt) { return std::forward<U>(default_value); } @@ -198,7 +199,7 @@ class Optional<String, void> { TVM_FFI_INLINE bool has_value() const noexcept { return data_.data_ != std::nullopt; } - TVM_FFI_INLINE bool operator==(const Optional<String>& other) const { + TVM_FFI_INLINE bool operator==(const Optional<T>& other) const { if (data_.data_ == std::nullopt) { return other.data_.data_ == std::nullopt; } @@ -208,7 +209,7 @@ class Optional<String, void> { return data_ == other.data_; } - TVM_FFI_INLINE bool operator!=(const Optional<String>& other) const { return !(*this == other); } + TVM_FFI_INLINE bool operator!=(const Optional<T>& other) const { return !(*this == other); } template <typename U> TVM_FFI_INLINE bool operator==(const U& other) const { @@ -238,17 +239,17 @@ class Optional<String, void> { * \return the xvalue reference to the stored value. * \note only use this function after checking has_value() */ - TVM_FFI_INLINE String&& operator*() && noexcept { return std::move(data_); } + TVM_FFI_INLINE T&& operator*() && noexcept { return std::move(data_); } /*! * \brief Direct access to the value. * \return the const reference to the stored value. * \note only use this function after checking has_value() */ - TVM_FFI_INLINE const String& operator*() const& noexcept { return data_; } + TVM_FFI_INLINE const T& operator*() const& noexcept { return data_; } private: // this is a private initializer - String data_{details::BytesBaseCell(std::nullopt)}; + T data_{details::BytesBaseCell(std::nullopt)}; }; // Specialization for ObjectRef types. diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index a679e68d75..78e0c0dd8f 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -195,16 +195,6 @@ class BytesBaseCell { data_.type_index = large_type_index; } - uint32_t AnyHash() const { - if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - return details::StableHashCombine(data_.type_index, details::StableHashSmallStrBytes(&data_)); - } else { - const TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr(data_.v_obj); - return details::StableHashCombine(data_.type_index, - details::StableHashBytes(bytes->data, bytes->size)); - } - } - /*! * \brief Create a new empty space for a string * \param size The size of the string @@ -218,6 +208,7 @@ class BytesBaseCell { size_t kMaxSmallBytesLen = sizeof(int64_t) - 1; // first zero the content, this is important for exception safety data_.type_index = small_type_index; + data_.zero_padding = 0; if (size <= kMaxSmallBytesLen) { // set up the size accordingly data_.small_str_len = static_cast<uint32_t>(size); @@ -228,6 +219,7 @@ class BytesBaseCell { char* dest_data = reinterpret_cast<char*>(ptr.get()) + sizeof(LargeObj); ptr->data = dest_data; ptr->size = size; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); // now reset the type index to str data_.type_index = large_type_index; @@ -272,29 +264,40 @@ class BytesBaseCell { /*! * \brief Managed reference of byte array. */ -class Bytes : public ObjectRef { +class Bytes { public: + /*! \brief default constructor */ + Bytes() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallBytes); } /*! * \brief constructor from size * * \param other a char array. */ - Bytes(const char* data, size_t size) // NOLINT(*) - : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(data, size)) {} + Bytes(const char* data, size_t size) { this->InitData(data, size); } /*! * \brief constructor from TVMFFIByteArray * * \param other a char array. */ - Bytes(TVMFFIByteArray bytes) // NOLINT(*) - : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(bytes.data, bytes.size)) {} + Bytes(TVMFFIByteArray bytes) { // NOLINT(*) + this->InitData(bytes.data, bytes.size); + } /*! * \brief constructor from std::string * * \param other a char array. */ - Bytes(std::string other) // NOLINT(*) - : ObjectRef(make_object<details::BytesObjStdImpl<details::BytesObj>>(std::move(other))) {} + Bytes(const std::string& other) { // NOLINT(*) + this->InitData(other.data(), other.size()); + } + /*! + * \brief constructor from std::string + * + * \param other a char array. + */ + Bytes(std::string&& other) { // NOLINT(*) + data_.InitFromStd<details::BytesObj>(std::move(other), TypeIndex::kTVMFFIBytes); + } /*! * \brief Swap this String with another string * \param other The other string @@ -314,21 +317,19 @@ class Bytes : public ObjectRef { * * \return size_t string length */ - size_t size() const { return get()->size; } + size_t size() const { return data_.size(); } /*! * \brief Return the data pointer * * \return const char* data pointer */ - const char* data() const { return get()->data; } + const char* data() const { return data_.data(); } /*! * \brief Convert String to an std::string object * * \return std::string */ - operator std::string() const { return std::string{get()->data, size()}; } - - TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef, details::BytesObj); + operator std::string() const { return std::string{data(), size()}; } /*! * \brief Compare two char sequence @@ -371,40 +372,30 @@ class Bytes : public ObjectRef { private: friend class String; + template <typename, typename> + friend struct TypeTraits; + template <typename, typename> + friend struct Optional; template <typename> friend struct std::hash; - - static uint64_t AnyHash(const Bytes& bytes) { - return details::StableHashCombine(TypeIndex::kTVMFFIBytes, - details::StableHashBytes(bytes.data(), bytes.size())); + // internal backing cell + details::BytesBaseCell data_; + // create a new String from TVMFFIAny, must keep private + explicit Bytes(details::BytesBaseCell data) : data_(data) {} + char* InitSpaceForSize(size_t size) { + return data_.InitSpaceForSize<details::BytesObj>(size, TypeIndex::kTVMFFISmallBytes, + TypeIndex::kTVMFFIBytes); + } + void InitData(const char* data, size_t size) { + char* dest_data = InitSpaceForSize(size); + std::memcpy(dest_data, data, size); + // mainly to be compat with string + dest_data[size] = '\0'; } }; /*! - * \brief Reference to string objects. - * - * \code - * - * // Example to create runtime String reference object from std::string - * std::string s = "hello world"; - * - * // You can create the reference from existing std::string - * String ref{std::move(s)}; - * - * // You can rebind the reference to another string. - * ref = std::string{"hello world2"}; - * - * // You can use the reference as hash map key - * std::unordered_map<String, int32_t> m; - * m[ref] = 1; - * - * // You can compare the reference object with other string objects - * assert(ref == "hello world", true); - * - * // You can convert the reference to std::string again - * string s2 = (string)ref; - * - * \endcode + * \brief String container class. */ class String { public: @@ -417,10 +408,10 @@ class String { */ String() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallStr); } // constructors from Any - String(const String& other) = default; - String(String&& other) = default; - String& operator=(const String& other) = default; - String& operator=(String&& other) = default; + String(const String& other) = default; // NOLINT(*) + String(String&& other) = default; // NOLINT(*) + String& operator=(const String& other) = default; // NOLINT(*) + String& operator=(String&& other) = default; // NOLINT(*) /*! * \brief Swap this String with another string @@ -602,7 +593,8 @@ class String { operator std::string() const { return std::string{data(), size()}; } private: - friend struct TypeTraits<String>; + template <typename, typename> + friend struct TypeTraits; template <typename, typename> friend class Optional; template <typename> @@ -620,7 +612,6 @@ class String { return data_.InitSpaceForSize<details::StringObj>(size, TypeIndex::kTVMFFISmallStr, TypeIndex::kTVMFFIStr); } - // create a new TVMFFIAny from the data and size void InitData(const char* data, size_t size) { char* dest_data = InitSpaceForSize(size); std::memcpy(dest_data, data, size); @@ -656,12 +647,6 @@ class String { #endif return ret; } - /*! - * \brief Hash the string same as AnyHash - * \param str The string to hash - * \return The hash value - */ - static uint64_t AnyHash(const String& str) { return str.data_.AnyHash(); } // Overload + operator friend String operator+(const String& lhs, const String& rhs); friend String operator+(const String& lhs, const std::string& rhs); @@ -675,6 +660,50 @@ TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) { return std::string_view(str.data, str.size); } +template <> +inline constexpr bool use_default_type_traits_v<Bytes> = false; + +// specialize to enable implicit conversion from TVMFFIByteArray* +template <> +struct TypeTraits<Bytes> : public TypeTraitsBase { + // bytes can be union type of small bytes and object, so keep it as any + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; + + TVM_FFI_INLINE static void CopyToAnyView(const Bytes& src, TVMFFIAny* result) { + *result = src.data_.CopyToTVMFFIAny(); + } + + TVM_FFI_INLINE static void MoveToAny(Bytes src, TVMFFIAny* result) { + src.data_.MoveToAny(result); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFISmallBytes || + src->type_index == TypeIndex::kTVMFFIBytes; + } + + TVM_FFI_INLINE static Bytes CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); + } + + TVM_FFI_INLINE static Bytes MoveFromAnyAfterCheck(TVMFFIAny* src) { + return Bytes(details::BytesBaseCell::MoveFromAny(src)); + } + + TVM_FFI_INLINE static std::optional<Bytes> TryCastFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { + return Bytes(*static_cast<TVMFFIByteArray*>(src->v_ptr)); + } + if (src->type_index == TypeIndex::kTVMFFISmallBytes || + src->type_index == TypeIndex::kTVMFFIBytes) { + return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return "bytes"; } +}; + template <> inline constexpr bool use_default_type_traits_v<String> = false; @@ -777,7 +806,7 @@ struct TypeTraits<TVMFFIByteArray*> : public TypeTraitsBase { } TVM_FFI_INLINE static void MoveToAny(TVMFFIByteArray* src, TVMFFIAny* result) { - ObjectRefTypeTraitsBase<Bytes>::MoveToAny(Bytes(*src), result); + TypeTraits<Bytes>::MoveToAny(Bytes(*src), result); } TVM_FFI_INLINE static std::optional<TVMFFIByteArray*> TryCastFromAnyView(const TVMFFIAny* src) { @@ -790,16 +819,6 @@ struct TypeTraits<TVMFFIByteArray*> : public TypeTraitsBase { TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIByteArrayPtr; } }; -template <> -inline constexpr bool use_default_type_traits_v<Bytes> = false; - -// specialize to enable implicit conversion from TVMFFIByteArray* -template <> -struct TypeTraits<Bytes> : public ObjectRefWithFallbackTraitsBase<Bytes, TVMFFIByteArray*> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBytes; - TVM_FFI_INLINE static Bytes ConvertFallbackValue(TVMFFIByteArray* src) { return Bytes(*src); } -}; - template <> inline constexpr bool use_default_type_traits_v<std::string> = false; @@ -973,14 +992,14 @@ namespace std { template <> struct hash<::tvm::ffi::Bytes> { std::size_t operator()(const ::tvm::ffi::Bytes& bytes) const { - return ::tvm::ffi::Bytes::AnyHash(bytes); + return std::hash<std::string_view>()(std::string_view(bytes.data(), bytes.size())); } }; template <> struct hash<::tvm::ffi::String> { std::size_t operator()(const ::tvm::ffi::String& str) const { - return ::tvm::ffi::String::AnyHash(str); + return std::hash<std::string_view>()(std::string_view(str.data(), str.size())); } }; } // namespace std diff --git a/ffi/src/ffi/extra/structural_equal.cc b/ffi/src/ffi/extra/structural_equal.cc index 90cc50ac5d..97ebbf4072 100644 --- a/ffi/src/ffi/extra/structural_equal.cc +++ b/ffi/src/ffi/extra/structural_equal.cc @@ -48,7 +48,8 @@ class StructEqualHandler { const TVMFFIAny* rhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(rhs); if (lhs_data->type_index != rhs_data->type_index) { // type_index mismatch, if index is not string, return false - if (lhs_data->type_index != kTVMFFIStr && lhs_data->type_index != kTVMFFISmallStr) { + if (lhs_data->type_index != kTVMFFIStr && lhs_data->type_index != kTVMFFISmallStr && + lhs_data->type_index != kTVMFFISmallBytes && lhs_data->type_index != kTVMFFIBytes) { return false; } // small string and normal string comparison @@ -64,6 +65,18 @@ class StructEqualHandler { return Bytes::memequal(lhs_data->v_bytes, rhs_str->data, lhs_data->small_str_len, rhs_str->size); } + if (lhs_data->type_index == kTVMFFIBytes && rhs_data->type_index == kTVMFFISmallBytes) { + const details::BytesObjBase* lhs_bytes = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs); + return Bytes::memequal(lhs_bytes->data, rhs_data->v_bytes, lhs_bytes->size, + rhs_data->small_str_len); + } + if (lhs_data->type_index == kTVMFFISmallBytes && rhs_data->type_index == kTVMFFIBytes) { + const details::BytesObjBase* rhs_bytes = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs); + return Bytes::memequal(lhs_data->v_bytes, rhs_bytes->data, lhs_data->small_str_len, + rhs_bytes->size); + } return false; } diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 20ad356f60..9948ceda6b 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -338,6 +338,7 @@ class TypeTable { ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef, TypeIndex::kTVMFFIObjectRValueRef); ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallStr, TypeIndex::kTVMFFISmallStr); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallBytes, TypeIndex::kTVMFFISmallBytes); // no need to reserve for object types as they will be registered } diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc index 1f393c42ab..d1f56e1a93 100644 --- a/ffi/tests/cpp/test_any.cc +++ b/ffi/tests/cpp/test_any.cc @@ -403,6 +403,13 @@ TEST(Any, AnyEqualHash) { EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr); EXPECT_TRUE(AnyEqual()(a, b)); EXPECT_EQ(AnyHash()(a), AnyHash()(b)); + + Any c = Bytes("a1", 2); + Any d = Bytes(std::string("a1")); + EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFISmallBytes); + EXPECT_EQ(d.type_index(), TypeIndex::kTVMFFIBytes); + EXPECT_TRUE(AnyEqual()(c, d)); + EXPECT_EQ(AnyHash()(c), AnyHash()(d)); } } // namespace diff --git a/ffi/tests/cpp/test_optional.cc b/ffi/tests/cpp/test_optional.cc index 7236863c12..eb114df8a3 100644 --- a/ffi/tests/cpp/test_optional.cc +++ b/ffi/tests/cpp/test_optional.cc @@ -187,4 +187,16 @@ TEST(Optional, String) { EXPECT_TRUE(opt_str != std::nullopt); static_assert(sizeof(Optional<String>) == sizeof(String)); } + +TEST(Optional, Bytes) { + Optional<Bytes> opt_bytes; + EXPECT_TRUE(!opt_bytes.has_value()); + EXPECT_EQ(opt_bytes.value_or(std::string("default")), "default"); + + opt_bytes = std::string("hello"); + EXPECT_TRUE(opt_bytes.has_value()); + EXPECT_EQ(opt_bytes.value().operator std::string(), "hello"); + EXPECT_TRUE(opt_bytes != std::nullopt); + static_assert(sizeof(Optional<Bytes>) == sizeof(Bytes)); +} } // namespace diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc index 54105d29ea..364f2f6540 100644 --- a/ffi/tests/cpp/test_string.cc +++ b/ffi/tests/cpp/test_string.cc @@ -318,6 +318,10 @@ TEST(String, Any) { } TEST(String, Bytes) { + Bytes b0; + EXPECT_EQ(b0.size(), 0); + EXPECT_EQ(b0.operator std::string(), ""); + // explicitly test zero element std::string s = {'\0', 'a', 'b', 'c'}; Bytes b = s; @@ -339,10 +343,17 @@ TEST(String, BytesAny) { EXPECT_EQ(view.try_cast<Bytes>().value().operator std::string(), s); Any b = view; - EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIBytes); + EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFISmallBytes); EXPECT_EQ(b.try_cast<Bytes>().value().operator std::string(), s); EXPECT_EQ(b.cast<std::string>(), s); + + std::string s2 = "hello long long long string"; + s2[0] = '\0'; + Any b2 = Bytes(s2); + EXPECT_EQ(b2.type_index(), TypeIndex::kTVMFFIBytes); + EXPECT_EQ(b2.try_cast<std::string>().value(), s2); + EXPECT_EQ(b2.cast<std::string>(), s2); } TEST(String, StdString) { @@ -407,4 +418,14 @@ TEST(String, BytesHash) { EXPECT_EQ(hash1, hash2); } +TEST(String, StdHash) { + String s1 = "a"; + String s2(std::string("a")); + EXPECT_EQ(std::hash<String>()(s1), std::hash<String>()(s2)); + + Bytes s3("a", 1); + Bytes s4(std::string("a")); + EXPECT_EQ(std::hash<Bytes>()(s3), std::hash<Bytes>()(s4)); +} + } // namespace diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index ab043028d3..5db3e279cf 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -224,12 +224,16 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) { return newNDArray(env, reinterpret_cast<jlong>(value.v_obj), false); } case TypeIndex::kTVMFFISmallStr: { - TVMFFIByteArray arr = TVMFFISmallStrGetContentByteArray(&value); + TVMFFIByteArray arr = TVMFFISmallBytesGetContentByteArray(&value); return newTVMValueString(env, &arr); } case TypeIndex::kTVMFFIStr: { return newTVMValueString(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); } + case TypeIndex::kTVMFFISmallBytes: { + TVMFFIByteArray arr = TVMFFISmallBytesGetContentByteArray(&value); + return newTVMValueBytes(env, &arr); + } case TypeIndex::kTVMFFIBytes: { jobject ret = newTVMValueBytes(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); TVMFFIObjectFree(value.v_obj); diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 70db76207d..00b76e68f7 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -41,6 +41,7 @@ cdef extern from "tvm/ffi/c_api.h": kTVMFFIByteArrayPtr = 9 kTVMFFIObjectRValueRef = 10 kTVMFFISmallStr = 11 + kTVMFFISmallBytes = 12 kTVMFFIStaticObjectBegin = 64 kTVMFFIObject = 64 kTVMFFIStr = 65 @@ -197,7 +198,7 @@ cdef extern from "tvm/ffi/c_api.h": int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle src, DLManagedTensorVersioned** out) nogil const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil - TVMFFIByteArray TVMFFISmallStrGetContentByteArray(const TVMFFIAny* value) nogil + TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) nogil TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi index b98c3e1107..279b17f8c8 100644 --- a/python/tvm/ffi/cython/dtype.pxi +++ b/python/tvm/ffi/cython/dtype.pxi @@ -98,7 +98,7 @@ cdef class DataType: CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &temp_any)) if temp_any.type_index == kTVMFFISmallStr: - bytes = TVMFFISmallStrGetContentByteArray(&temp_any) + bytes = TVMFFISmallBytesGetContentByteArray(&temp_any) res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) return res diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index e8e6987dd9..cbff3fecf1 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -26,10 +26,17 @@ except ImportError: cdef inline object make_ret_small_str(TVMFFIAny result): """convert small string to return value.""" cdef TVMFFIByteArray bytes - bytes = TVMFFISmallStrGetContentByteArray(&result) + bytes = TVMFFISmallBytesGetContentByteArray(&result) return py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) +cdef inline object make_ret_small_bytes(TVMFFIAny result): + """convert small bytes to return value.""" + cdef TVMFFIByteArray bytes + bytes = TVMFFISmallBytesGetContentByteArray(&result) + return PyBytes_FromStringAndSize(bytes.data, bytes.size) + + cdef inline object make_ret(TVMFFIAny result): """convert result to return value.""" # TODO: Implement @@ -50,6 +57,8 @@ cdef inline object make_ret(TVMFFIAny result): return result.v_float64 elif type_index == kTVMFFISmallStr: return make_ret_small_str(result) + elif type_index == kTVMFFISmallBytes: + return make_ret_small_bytes(result) elif type_index == kTVMFFIOpaquePtr: return ctypes_handle(result.v_ptr) elif type_index == kTVMFFIDataType: diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index ec62a3c95a..d3b62b5e87 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -84,6 +84,7 @@ void ReprPrinter::Print(const ffi::Any& node) { stream << '"' << support::StrEscape(str.data(), str.size()) << '"'; break; } + case ffi::TypeIndex::kTVMFFISmallBytes: case ffi::TypeIndex::kTVMFFIBytes: { ffi::Bytes bytes = node.cast<ffi::Bytes>(); stream << "b\"" << support::StrEscape(bytes.data(), bytes.size()) << '"'; diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 3d0175bcfa..0c3ca959a3 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -109,7 +109,8 @@ class NodeIndexer { } } else if (node.type_index() == ffi::TypeIndex::kTVMFFIStr || node.type_index() == ffi::TypeIndex::kTVMFFISmallStr || - node.type_index() == ffi::TypeIndex::kTVMFFIBytes) { + node.type_index() == ffi::TypeIndex::kTVMFFIBytes || + node.type_index() == ffi::TypeIndex::kTVMFFISmallBytes) { // skip content index for string and bytes } else if (auto opt_object = node.as<const Object*>()) { Object* n = const_cast<Object*>(opt_object.value()); @@ -254,6 +255,9 @@ class JSONAttrGetter { if (node_->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) { node_->type_key = ffi::StaticTypeKey::kTVMFFIStr; } + if (node_->type_key == ffi::StaticTypeKey::kTVMFFISmallBytes) { + node_->type_key = ffi::StaticTypeKey::kTVMFFIBytes; + } // populates the fields. node_->attrs.clear(); node_->data.clear(); @@ -398,7 +402,8 @@ class FieldDependencyFinder { } if (node.type_index() == ffi::TypeIndex::kTVMFFIStr || node.type_index() == ffi::TypeIndex::kTVMFFISmallStr || - node.type_index() == ffi::TypeIndex::kTVMFFIBytes) { + node.type_index() == ffi::TypeIndex::kTVMFFIBytes || + node.type_index() == ffi::TypeIndex::kTVMFFISmallBytes) { // skip indexing content of string and bytes return; } @@ -561,7 +566,8 @@ class JSONAttrSetter { } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr || jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) { return Any(String(jnode->repr_bytes)); - } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) { + } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes || + jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallBytes) { return Any(Bytes(jnode->repr_bytes)); } else { return ObjectRef(reflection->CreateInitObject(jnode->type_key, jnode->repr_bytes)); @@ -594,7 +600,8 @@ class JSONAttrSetter { *node = result; } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr || jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr || - jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) { + jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes || + jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallBytes) { // skip set attrs for string and bytes } else if (auto opt_object = node->as<const Object*>()) { Object* n = const_cast<Object*>(opt_object.value()); diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 41a4cb766a..3f132b024a 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -53,7 +53,7 @@ class CublasJSONSerializer : public JSONSerializer { ICHECK(fn.defined()) << "Expects the callee to be a function."; auto composite_opt = fn->GetAttr<String>(attr::kComposite); - ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); diff --git a/src/relax/backend/contrib/cudnn/codegen.cc b/src/relax/backend/contrib/cudnn/codegen.cc index 358f2d6604..b529c6f796 100644 --- a/src/relax/backend/contrib/cudnn/codegen.cc +++ b/src/relax/backend/contrib/cudnn/codegen.cc @@ -52,7 +52,7 @@ class cuDNNJSONSerializer : public JSONSerializer { ICHECK(fn.defined()) << "Expects the callee to be a function."; auto composite_opt = fn->GetAttr<String>(attr::kComposite); - ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); diff --git a/src/relax/backend/contrib/hipblas/codegen.cc b/src/relax/backend/contrib/hipblas/codegen.cc index d14d7aed57..761221c88b 100644 --- a/src/relax/backend/contrib/hipblas/codegen.cc +++ b/src/relax/backend/contrib/hipblas/codegen.cc @@ -52,7 +52,7 @@ class HipblasJSONSerializer : public JSONSerializer { ICHECK(fn.defined()) << "Expects the callee to be a function."; auto composite_opt = fn->GetAttr<String>(attr::kComposite); - ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc index ded7340b6f..c62523f539 100644 --- a/src/relax/backend/contrib/nnapi/codegen.cc +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -201,7 +201,7 @@ class NNAPIJSONSerializer : public JSONSerializer { ICHECK(fn.defined()) << "Expects the callee to be a function."; auto composite_opt = fn->GetAttr<String>(attr::kComposite); - ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 42450e3322..a693c671f3 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -96,7 +96,7 @@ class RPCWrappedFunc : public Object { } else if (args[i].type_index() == ffi::TypeIndex::kTVMFFISmallStr) { // we cannot cast here, since we need to make sure the space is alive const TVMFFIAny* any_view_ptr = reinterpret_cast<const TVMFFIAny*>(&args.data()[i]); - TVMFFIByteArray bytes = TVMFFISmallStrGetContentByteArray(any_view_ptr); + TVMFFIByteArray bytes = TVMFFISmallBytesGetContentByteArray(any_view_ptr); packed_args[i] = bytes.data; continue; } @@ -322,7 +322,8 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) nd_handle); } else if (type_index == ffi::TypeIndex::kTVMFFIBytes || type_index == ffi::TypeIndex::kTVMFFIStr || - type_index == ffi::TypeIndex::kTVMFFISmallStr) { + type_index == ffi::TypeIndex::kTVMFFISmallStr || + type_index == ffi::TypeIndex::kTVMFFISmallBytes) { ICHECK_EQ(args.size(), 2); *rv = args[1]; } else if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 78bccb829c..33a687f54b 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -520,8 +520,7 @@ AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) { // convert POD value to PrimExpr - if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && - node.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { + if (node.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { node = node.cast<PrimExpr>(); } ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>(); diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 737f27c7e9..a1b1272cde 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -212,7 +212,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } return arr; }) - .def("testing.AcceptsMapOfPrimExpr", [](Map<ObjectRef, PrimExpr> map) -> ObjectRef { + .def("testing.AcceptsMapOfPrimExpr", [](Map<Any, PrimExpr> map) -> ObjectRef { for (const auto& kv : map) { ObjectRef value = kv.second; CHECK(value->IsInstance<PrimExprNode>()) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 5a2b95844b..56fab07605 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -100,8 +100,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](Any node, String attr_key, PrimExpr value, Stmt body, Span span) { // when node is a POD data type like int or bool, first convert to // primexpr. - if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && - node.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { + if (node.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { return AttrStmt(node.cast<PrimExpr>(), attr_key, value, body, span); } return AttrStmt(node, attr_key, value, body, span); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index db175c77f2..6f7e682d6c 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -916,8 +916,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { if (auto opt_str = ann_val.try_cast<ffi::String>()) { return *std::move(opt_str); } - if (ann_val.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && - ann_val.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { + if (ann_val.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { return ann_val; } // prefer to return int/float literals for annotations diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index fdc0dd41c4..2f327354c9 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -74,8 +74,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) inputs.push_back(String('"' + (*opt_str).operator std::string() + '"')); } else if (obj.as<BlockRVNode>() || obj.as<LoopRVNode>()) { inputs.push_back(String("_")); - } else if (obj.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && - obj.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { + } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { inputs.push_back(obj); } else if (obj.as<IntImmNode>() || obj.as<FloatImmNode>()) { inputs.push_back(obj); diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index b1fb7881a6..61f24f980f 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -71,8 +71,7 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs, }; for (const Any& input : inputs) { - if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && - input.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { + if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { // directly put back POD type result.push_back(input); } else if (auto expr = input.as<ffi::String>()) { @@ -114,8 +113,7 @@ Array<Any> TranslateInputRVs( // string => "content" if (auto opt_str = input.as<ffi::String>()) { results.push_back(String('"' + (*opt_str).operator std::string() + '"')); - } else if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && - input.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { + } else if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { // directly put back POD type and not string results.push_back(input); } else if (input.as<BlockRVNode>() || // RV: block @@ -161,8 +159,7 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs, Array<Any> results; results.reserve(inputs.size()); for (const Any& input : inputs) { - if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && - input.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { + if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { // directly put back POD type results.push_back(input); continue; diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index 50e82e445c..41d848a228 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -74,6 +74,8 @@ export const enum TypeIndex { kTVMFFIObjectRValueRef = 10, /*! \brief Small string on stack */ kTVMFFISmallStr = 11, + /*! \brief Small bytes on stack */ + kTVMFFISmallBytes = 12, /*! \brief Start of statically defined objects. */ kTVMFFIStaticObjectBegin = 64, /*!
