This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch small-str-v0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 30de9d5b2e3c50cb28fbbcc83e5438a25bf10682 Author: tqchen <[email protected]> AuthorDate: Thu Jul 31 10:01:27 2025 -0400 [FFI] Introduce small string This PR introduces small string support to the FFI system. When the string length fit into the space in TVMFFIAny (aka len(str) <= 10). We directly store the string content into the TVMFFIAny content instead of creating a Object. This change will likely speedup small string access. Some implications: - Always check for kTVMFFISmallStr code as well as kTVMFFIStr - Avoid using details::StringObj, instead, always use any.as<String>() - Always set any.padding to 0 for other values (in compiler and runtime) to enable fast cmp --- ffi/include/tvm/ffi/any.h | 98 +++++-- ffi/include/tvm/ffi/base_details.h | 3 +- ffi/include/tvm/ffi/c_api.h | 46 ++- ffi/include/tvm/ffi/cast.h | 1 + ffi/include/tvm/ffi/container/variant.h | 2 + ffi/include/tvm/ffi/dtype.h | 9 +- ffi/include/tvm/ffi/object.h | 1 + ffi/include/tvm/ffi/optional.h | 119 +++++++- ffi/include/tvm/ffi/reflection/accessor.h | 2 +- ffi/include/tvm/ffi/rvalue_ref.h | 4 +- ffi/include/tvm/ffi/string.h | 324 ++++++++++++++++----- ffi/include/tvm/ffi/type_traits.h | 16 +- ffi/src/ffi/dtype.cc | 4 +- ffi/src/ffi/extra/structural_equal.cc | 22 +- ffi/src/ffi/extra/structural_hash.cc | 14 + ffi/src/ffi/object.cc | 8 +- ffi/tests/cpp/test_any.cc | 11 + ffi/tests/cpp/test_dtype.cc | 1 + ffi/tests/cpp/test_optional.cc | 17 ++ ffi/tests/cpp/test_reflection_accessor.cc | 1 - ffi/tests/cpp/test_rvalue_ref.cc | 4 +- ffi/tests/cpp/test_string.cc | 22 +- ffi/tests/cpp/test_variant.cc | 4 +- include/tvm/script/printer/ir_docsifier.h | 1 + include/tvm/tir/builtin.h | 1 + jvm/native/src/main/native/jni_helper_func.h | 8 +- .../src/main/native/org_apache_tvm_native_c_api.cc | 2 + python/tvm/ffi/cython/base.pxi | 6 +- python/tvm/ffi/cython/dtype.pxi | 19 +- python/tvm/ffi/cython/function.pxi | 10 + src/contrib/msc/core/ir/graph_builder.h | 1 + src/node/repr_printer.cc | 1 + src/node/serialization.cc | 10 +- src/runtime/disco/protocol.h | 12 +- src/runtime/minrpc/rpc_reference.h | 3 + src/runtime/rpc/rpc_endpoint.cc | 8 +- src/runtime/rpc/rpc_module.cc | 10 +- src/target/llvm/codegen_cpu.cc | 5 + src/target/source/codegen_c.cc | 6 + src/target/source/codegen_c_host.cc | 2 + src/tir/transforms/lower_tvm_builtin.cc | 5 + src/tir/transforms/make_packed_api.cc | 7 +- web/src/ctypes.ts | 2 + web/src/memory.ts | 20 +- web/src/runtime.ts | 11 +- web/tests/node/test_packed_func.js | 4 +- 46 files changed, 707 insertions(+), 180 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index d94185c064..c45feed97d 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -60,6 +60,7 @@ class AnyView { void reset() { data_.type_index = TypeIndex::kTVMFFINone; // invariance: always set the union padding part to 0 + data_.zero_padding = 0; data_.v_int64 = 0; } /*! @@ -72,6 +73,7 @@ class AnyView { // default constructors AnyView() { data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; data_.v_int64 = 0; } ~AnyView() = default; @@ -80,6 +82,7 @@ class AnyView { AnyView& operator=(const AnyView&) = default; AnyView(AnyView&& other) : data_(other.data_) { other.data_.type_index = TypeIndex::kTVMFFINone; + other.data_.zero_padding = 0; other.data_.v_int64 = 0; } TVM_FFI_INLINE AnyView& operator=(AnyView&& other) { @@ -198,13 +201,11 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data, if (data->type_index == TypeIndex::kTVMFFIRawStr) { // convert raw string to owned string object String temp(data->v_c_str); - data->type_index = TypeIndex::kTVMFFIStr; - data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp)); + TypeTraits<String>::MoveToAny(std::move(temp), data); } else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) { // convert byte array to owned bytes object Bytes temp(*static_cast<TVMFFIByteArray*>(data->v_ptr)); - data->type_index = TypeIndex::kTVMFFIBytes; - data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp)); + TypeTraits<Bytes>::MoveToAny(std::move(temp), data); } else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) { // convert rvalue ref to owned object Object** obj_addr = static_cast<Object**>(data->v_ptr); @@ -212,8 +213,7 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data, ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(obj_addr[0])); // set the rvalue ref to nullptr to avoid double move obj_addr[0] = nullptr; - data->type_index = temp->type_index(); - data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp)); + TypeTraits<ObjectRef>::MoveToAny(std::move(temp), data); } } } @@ -239,6 +239,7 @@ class Any { details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); } data_.type_index = TVMFFITypeIndex::kTVMFFINone; + data_.zero_padding = 0; data_.v_int64 = 0; } /*! @@ -251,6 +252,7 @@ class Any { // default constructors Any() { data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; data_.v_int64 = 0; } ~Any() { this->reset(); } @@ -262,6 +264,7 @@ class Any { } Any(Any&& other) : data_(other.data_) { other.data_.type_index = TypeIndex::kTVMFFINone; + other.data_.zero_padding = 0; other.data_.v_int64 = 0; } TVM_FFI_INLINE Any& operator=(const Any& other) { @@ -408,7 +411,8 @@ class Any { * \return True if the two Any are same type and value, false otherwise. */ TVM_FFI_INLINE bool same_as(const Any& other) const noexcept { - return data_.type_index == other.data_.type_index && data_.v_int64 == other.data_.v_int64; + return data_.type_index == other.data_.type_index && + data_.zero_padding == other.data_.zero_padding && data_.v_int64 == other.data_.v_int64; } /* @@ -485,6 +489,7 @@ struct AnyUnsafe : public ObjectUnsafe { TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any&& any) { TVMFFIAny result = any.data_; any.data_.type_index = TypeIndex::kTVMFFINone; + any.data_.zero_padding = 0; any.data_.v_int64 = 0; return result; } @@ -493,6 +498,7 @@ struct AnyUnsafe : public ObjectUnsafe { Any any; any.data_ = data; data.type_index = TypeIndex::kTVMFFINone; + data.zero_padding = 0; data.v_int64 = 0; return any; } @@ -543,17 +549,21 @@ struct AnyHash { * \return Hash code of a, string hash for strings and pointer address otherwise. */ uint64_t operator()(const Any& src) const { - uint64_t val_hash = [&]() -> uint64_t { - if (src.data_.type_index == TypeIndex::kTVMFFIStr || - src.data_.type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase* src_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src); - return details::StableHashBytes(src_str->data, src_str->size); - } else { - return src.data_.v_uint64; - } - }(); - return details::StableHashCombine(src.data_.type_index, val_hash); + if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) { + // for small string, we use the same type key hash as normal string + // so heap allocated string and on stack string will have the same hash + return details::StableHashCombine( + TypeIndex::kTVMFFIStr, + details::StableHashBytes(src.data_.small_str_header + 1, src.data_.small_str_header[0])); + } else if (src.data_.type_index == TypeIndex::kTVMFFIStr || + src.data_.type_index == TypeIndex::kTVMFFIBytes) { + const details::BytesObjBase* src_str = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src); + return details::StableHashCombine(src.data_.type_index, + details::StableHashBytes(src_str->data, src_str->size)); + } else { + return details::StableHashCombine(src.data_.type_index, src.data_.v_uint64); + } } }; @@ -566,19 +576,47 @@ struct AnyEqual { * \return String equality if both are strings, pointer address equality otherwise. */ bool operator()(const Any& lhs, const Any& rhs) const { - if (lhs.data_.type_index != rhs.data_.type_index) return false; - // byte equivalence - if (lhs.data_.v_int64 == rhs.data_.v_int64) return true; - // specialy handle string hash - if (lhs.data_.type_index == TypeIndex::kTVMFFIStr || - lhs.data_.type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase* lhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs); - const details::BytesObjBase* rhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs); - return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); + // header with type index + const int64_t* lhs_as_int64 = reinterpret_cast<const int64_t*>(&lhs.data_); + const int64_t* rhs_as_int64 = reinterpret_cast<const int64_t*>(&rhs.data_); + static_assert(sizeof(TVMFFIAny) == 16 && alignof(TVMFFIAny) == 8); + // fast path, check byte equality + if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] == rhs_as_int64[1]) { + return true; + } + // common false case type index match, in this case we only need to pay attention to string + // equality + if (lhs.data_.type_index == rhs.data_.type_index) { + // specialy handle string hash + if (lhs.data_.type_index == TypeIndex::kTVMFFIStr || + lhs.data_.type_index == TypeIndex::kTVMFFIBytes) { + const details::BytesObjBase* lhs_str = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs); + const details::BytesObjBase* rhs_str = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs); + return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); + } + return false; + } else { + // type_index mismatch, if index is not string, return false + if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr) { + return false; + } + // small string and normal string comparison + if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index == kTVMFFISmallStr) { + const details::BytesObjBase* lhs_str = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs); + return Bytes::memequal(lhs_str->data, rhs.data_.small_str_header + 1, lhs_str->size, + rhs.data_.small_str_header[0]); + } + if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index == kTVMFFIStr) { + const details::BytesObjBase* rhs_str = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs); + return Bytes::memequal(lhs.data_.small_str_header + 1, rhs_str->data, + lhs.data_.small_str_header[0], rhs_str->size); + } + return false; } - return false; } }; diff --git a/ffi/include/tvm/ffi/base_details.h b/ffi/include/tvm/ffi/base_details.h index cfdadff6ea..1c977d12dc 100644 --- a/ffi/include/tvm/ffi/base_details.h +++ b/ffi/include/tvm/ffi/base_details.h @@ -170,7 +170,8 @@ TVM_FFI_INLINE uint64_t StableHashCombine(uint64_t key, const T& value) { * \param size The size of the bytes. * \return the hash value. */ -TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) { +TVM_FFI_INLINE uint64_t StableHashBytes(const void* data_ptr, size_t size) { + const char* data = reinterpret_cast<const char*>(data_ptr); const constexpr uint64_t kMultiplier = 1099511628211ULL; const constexpr uint64_t kMod = 2147483647ULL; union Union { diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index d99832af01..b44250d00c 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -65,13 +65,7 @@ enum TVMFFITypeIndex : int32_t { #else typedef enum { #endif - // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin) - // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, - // which is not owned by TVMFFIAny. It is required that the following - // invariant holds: - // - `Any::type_index` is never `kTVMFFIRawStr` - // - `AnyView::type_index` can be `kTVMFFIRawStr` - // + /* * \brief The root type of all FFI objects. * @@ -80,6 +74,13 @@ typedef enum { * However, it may appear in field annotations during reflection. */ kTVMFFIAny = -1, + // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin) + // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, + // which is not owned by TVMFFIAny. It is required that the following + // invariant holds: + // - `Any::type_index` is never `kTVMFFIRawStr` + // - `AnyView::type_index` can be `kTVMFFIRawStr` + // /*! \brief None/nullptr value */ kTVMFFINone = 0, /*! \brief POD int value */ @@ -96,12 +97,14 @@ typedef enum { kTVMFFIDevice = 6, /*! \brief DLTensor* */ kTVMFFIDLTensorPtr = 7, - /*! \brief const char**/ + /*! \brief const char* */ kTVMFFIRawStr = 8, /*! \brief TVMFFIByteArray* */ kTVMFFIByteArrayPtr = 9, /*! \brief R-value reference to ObjectRef */ kTVMFFIObjectRValueRef = 10, + /*! \brief Small string on stack */ + kTVMFFISmallStr = 11, /*! \brief Start of statically defined objects. */ kTVMFFIStaticObjectBegin = 64, /*! @@ -183,11 +186,16 @@ typedef struct TVMFFIAny { * \note The type index of Object and Any are shared in FFI. */ int32_t type_index; - /*! - * \brief length for on-stack Any object, such as small-string - * \note This field is reserved for future compact. - */ - int32_t small_len; + union { // 4 bytes + /*! \brief padding, must set to zero for values other than small string. */ + uint32_t zero_padding; + /*! + * \brief small string header, small_str_header[0] is the length of the string, + * followed by the content of the string. + * \note This field is used to store small string on stack. + */ + uint8_t small_str_header[4]; + }; union { // 8 bytes int64_t v_int64; // integers double v_float64; // floating-point numbers @@ -823,7 +831,7 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* * \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. */ -TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out); +TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out); //------------------------------------------------------------ // Section: Backend noexcept functions for internal use @@ -903,6 +911,16 @@ inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) { return static_cast<TVMFFIObject*>(obj)->type_index; } +/*! + * \brief Get the content of a small string in bytearray format. + * \param obj The object handle. + * \return The content of the small string in bytearray format. + */ +inline TVMFFIByteArray TVMFFISmallStrGetContentByteArray(const TVMFFIAny* value) { + return TVMFFIByteArray{reinterpret_cast<const char*>(value->small_str_header + 1), + static_cast<size_t>(value->small_str_header[0])}; +} + /*! * \brief Get the data pointer of a bytearray from a string or bytes object. * \param obj The object handle. diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h index 9cac1f99a8..997c0bb178 100644 --- a/ffi/include/tvm/ffi/cast.h +++ b/ffi/include/tvm/ffi/cast.h @@ -27,6 +27,7 @@ #include <tvm/ffi/dtype.h> #include <tvm/ffi/error.h> #include <tvm/ffi/object.h> +#include <tvm/ffi/optional.h> #include <utility> diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index a16ff5d425..ee1f8316d8 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -80,10 +80,12 @@ class VariantBase<true> : public ObjectRef { TVMFFIAny any_data; if (data_ == nullptr) { any_data.type_index = TypeIndex::kTVMFFINone; + any_data.zero_padding = 0; any_data.v_int64 = 0; } else { TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data); any_data.type_index = data_->type_index(); + any_data.zero_padding = 0; any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_); } return AnyView::CopyFromTVMFFIAny(any_data); diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h index 2eafccd2db..c153d71cb7 100644 --- a/ffi/include/tvm/ffi/dtype.h +++ b/ffi/include/tvm/ffi/dtype.h @@ -115,14 +115,15 @@ inline const char* DLDataTypeCodeAsCStr(DLDataTypeCode type_code) { // NOLINT(* inline DLDataType StringToDLDataType(const String& str) { DLDataType out; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(str.get(), &out)); + TVMFFIByteArray data{str.data(), str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(&data, &out)); return out; } inline String DLDataTypeToString(DLDataType dtype) { - TVMFFIObjectHandle out; + TVMFFIAny out; TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out)); - return String(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(out))); + return TypeTraits<String>::MoveFromAnyAfterCheck(&out); } // DLDataType @@ -134,6 +135,7 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase { // clear padding part to ensure the equality check can always check the v_uint64 part result->v_uint64 = 0; result->type_index = TypeIndex::kTVMFFIDataType; + result->zero_padding = 0; result->v_dtype = src; } @@ -141,6 +143,7 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase { // clear padding part to ensure the equality check can always check the v_uint64 part result->v_uint64 = 0; result->type_index = TypeIndex::kTVMFFIDataType; + result->zero_padding = 0; result->v_dtype = src; } diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index a49a9f1700..74977e0216 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -60,6 +60,7 @@ struct StaticTypeKey { static constexpr const char* kTVMFFIFunction = "ffi.Function"; static constexpr const char* kTVMFFIArray = "ffi.Array"; static constexpr const char* kTVMFFIMap = "ffi.Map"; + static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr"; }; /*! diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h index 003038b9fd..dfb54f5e37 100644 --- a/ffi/include/tvm/ffi/optional.h +++ b/ffi/include/tvm/ffi/optional.h @@ -27,6 +27,7 @@ #include <tvm/ffi/error.h> #include <tvm/ffi/object.h> +#include <tvm/ffi/string.h> #include <optional> #include <string> @@ -53,7 +54,7 @@ inline constexpr bool use_ptr_based_optional_v = // Specialization for non-ObjectRef types. // simply fallback to std::optional template <typename T> -class Optional<T, std::enable_if_t<!use_ptr_based_optional_v<T>>> { +class Optional<T, std::enable_if_t<!use_ptr_based_optional_v<T> && !std::is_same_v<T, String>>> { public: // default constructors. Optional() = default; @@ -138,6 +139,122 @@ class Optional<T, std::enable_if_t<!use_ptr_based_optional_v<T>>> { std::optional<T> data_; }; +// Specialization for String type, use nullptr to indicate nullopt +template <> +class Optional<String, void> { + public: + // default constructors. + Optional() = default; + Optional(const Optional<String>& other) : data_(other.data_) {} + Optional(Optional<String>&& other) : data_(std::move(other.data_)) {} + Optional(std::nullopt_t) {} // NOLINT(*) + // normal value handling. + Optional(String other) // NOLINT(*) + : data_(std::move(other)) {} + + TVM_FFI_INLINE Optional<String>& operator=(const Optional<String>& other) { + data_ = other.data_; + return *this; + } + + TVM_FFI_INLINE Optional<String>& operator=(Optional<String>&& other) { + data_ = std::move(other.data_); + return *this; + } + + TVM_FFI_INLINE Optional<String>& operator=(String other) { + data_ = std::move(other); + return *this; + } + + TVM_FFI_INLINE Optional<String>& operator=(std::nullopt_t) { + String(std::nullopt).swap(data_); + return *this; + } + + TVM_FFI_INLINE const String& value() const& { + if (data_.data_.type_index == TypeIndex::kTVMFFINone) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return data_; + } + + TVM_FFI_INLINE String&& value() && { + if (data_.data_.type_index == TypeIndex::kTVMFFINone) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return std::move(data_); + } + + template <typename U = String> + TVM_FFI_INLINE String value_or(U&& default_value) const { + if (data_.data_.type_index == TypeIndex::kTVMFFINone) { + return std::forward<U>(default_value); + } + return data_; + } + + TVM_FFI_INLINE explicit operator bool() const noexcept { + return data_.data_.type_index != TypeIndex::kTVMFFINone; + } + + TVM_FFI_INLINE bool has_value() const noexcept { + return data_.data_.type_index != TypeIndex::kTVMFFINone; + } + + TVM_FFI_INLINE bool operator==(const Optional<String>& other) const { + if (data_.data_.type_index == TypeIndex::kTVMFFINone) { + return other.data_.data_.type_index == TypeIndex::kTVMFFINone; + } + if (other.data_.data_.type_index == TypeIndex::kTVMFFINone) { + return false; + } + return data_ == other.data_; + } + + TVM_FFI_INLINE bool operator!=(const Optional<String>& other) const { return !(*this == other); } + + template <typename U> + TVM_FFI_INLINE bool operator==(const U& other) const { + if constexpr (std::is_same_v<U, std::nullopt_t>) { + return data_.data_.type_index == TypeIndex::kTVMFFINone; + } else { + if (data_.data_.type_index == TypeIndex::kTVMFFINone) { + return false; + } + return data_ == other; + } + } + template <typename U> + TVM_FFI_INLINE bool operator!=(const U& other) const { + if constexpr (std::is_same_v<U, std::nullopt_t>) { + return data_.data_.type_index != TypeIndex::kTVMFFINone; + } else { + if (data_.data_.type_index == TypeIndex::kTVMFFINone) { + return true; + } + return data_ != other; + } + } + + /*! + * \brief Direct access to the value. + * \return the xvalue reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE String&& operator*() && noexcept { return std::move(data_); } + /*! + * \brief Direct access to the value. + * \return the const reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE const String& operator*() const& noexcept { return data_; } + + private: + // this is a private initializer + String data_{std::nullopt}; +}; + // Specialization for ObjectRef types. // nullptr is treated as std::nullopt. template <typename T> diff --git a/ffi/include/tvm/ffi/reflection/accessor.h b/ffi/include/tvm/ffi/reflection/accessor.h index 40adfa3499..5215444052 100644 --- a/ffi/include/tvm/ffi/reflection/accessor.h +++ b/ffi/include/tvm/ffi/reflection/accessor.h @@ -48,7 +48,7 @@ inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const char return &(info->fields[i]); } } - TVM_FFI_THROW(RuntimeError) << "Cannot find field " << field_name << " in " << type_key; + TVM_FFI_THROW(RuntimeError) << "Cannot find field `" << field_name << "` in " << type_key; TVM_FFI_UNREACHABLE(); } diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h index b185e8d941..7c89038cc2 100644 --- a/ffi/include/tvm/ffi/rvalue_ref.h +++ b/ffi/include/tvm/ffi/rvalue_ref.h @@ -94,6 +94,7 @@ struct TypeTraits<RValueRef<TObjRef>> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(const RValueRef<TObjRef>& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIObjectRValueRef; + result->zero_padding = 0; // store the address of the ObjectPtr, which allows us to move the value // and set the original ObjectPtr to nullptr result->v_ptr = &(src.data_); @@ -106,7 +107,7 @@ struct TypeTraits<RValueRef<TObjRef>> : public TypeTraitsBase { // in this case we do not move the original rvalue ref since conversion creates a copy TVMFFIAny tmp_any; tmp_any.type_index = rvalue_ref->get()->type_index(); - + tmp_any.zero_padding = 0; tmp_any.v_obj = reinterpret_cast<TVMFFIObject*>(rvalue_ref->get()); return "RValueRef<" + TypeTraits<TObjRef>::GetMismatchTypeInfo(&tmp_any) + ">"; } else { @@ -120,6 +121,7 @@ struct TypeTraits<RValueRef<TObjRef>> : public TypeTraitsBase { ObjectPtr<Object>* rvalue_ref = reinterpret_cast<ObjectPtr<Object>*>(src->v_ptr); TVMFFIAny tmp_any; tmp_any.type_index = rvalue_ref->get()->type_index(); + tmp_any.zero_padding = 0; tmp_any.v_obj = reinterpret_cast<TVMFFIObject*>(rvalue_ref->get()); // fast path, storage type matches, direct move the rvalue ref if (TypeTraits<TObjRef>::CheckAnyStrict(&tmp_any)) { diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index 481b704436..69538fff0b 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -47,7 +47,9 @@ namespace tvm { namespace ffi { namespace details { -/*! \brief Base class for bytes and string. */ +/*! + * \brief Base class for bytes and string objects. + */ class BytesObjBase : public Object, public TVMFFIByteArray {}; /*! @@ -108,21 +110,21 @@ TVM_FFI_INLINE ObjectPtr<Base> MakeInplaceBytes(const char* data, size_t length) class Bytes : public ObjectRef { public: /*! - * \brief constructor from char [N] + * \brief constructor from size * * \param other a char array. */ Bytes(const char* data, size_t size) // NOLINT(*) : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(data, size)) {} /*! - * \brief constructor from char [N] + * \brief constructor from TVMFFIByteArray * * \param other a char array. */ Bytes(TVMFFIByteArray bytes) // NOLINT(*) : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(bytes.data, bytes.size)) {} /*! - * \brief constructor from char [N] + * \brief constructor from std::string * * \param other a char array. */ @@ -198,7 +200,7 @@ class Bytes : public ObjectRef { * * \return true if the two char sequences are equal, false otherwise. */ - static bool memequal(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { + static bool memequal(const void* lhs, const void* rhs, size_t lhs_count, size_t rhs_count) { return lhs_count == rhs_count && (lhs == rhs || std::memcmp(lhs, rhs, lhs_count) == 0); } @@ -232,75 +234,149 @@ class Bytes : public ObjectRef { * * \endcode */ -class String : public ObjectRef { +class String { public: - String(std::nullptr_t) = delete; // NOLINT(*) + /*! + * \brief constructor + */ + String() { + data_.type_index = TypeIndex::kTVMFFISmallStr; + data_.zero_padding = 0; + data_.v_int64 = 0; + } + // constructors from Any + String(const String& other) : data_(other.data_) { + if (data_.type_index == TypeIndex::kTVMFFIStr) { + details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); + } + } + + String(String&& other) : data_(other.data_) { + other.data_.type_index = TypeIndex::kTVMFFISmallStr; + other.data_.zero_padding = 0; + other.data_.v_int64 = 0; + } /*! - * \brief constructor from char [N] - * - * \param other a char array. + * \brief Destructor */ - template <size_t N> - String(const char other[N]) // NOLINT(*) - : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, N)) {} + ~String() { + if (data_.type_index == TypeIndex::kTVMFFIStr) { + details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); + } + } /*! - * \brief constructor + * \brief Swap this String with another string + * \param other The other string */ - String() : String("") {} + TVM_FFI_INLINE void swap(String& other) noexcept { // NOLINT(*) + std::swap(data_, other.data_); + } + + TVM_FFI_INLINE String& operator=(const String& other) { + // copy-and-swap idiom + String(other).swap(*this); // NOLINT(*) + return *this; + } + + TVM_FFI_INLINE String& operator=(String&& other) { + // copy-and-swap idiom + String(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + + TVM_FFI_INLINE String& operator=(const std::string& other) { + String(other).swap(*this); // NOLINT(*) + return *this; + } + TVM_FFI_INLINE String& operator=(std::string&& other) { + String(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + + TVM_FFI_INLINE String& operator=(const char* other) { + String(other).swap(*this); // NOLINT(*) + return *this; + } /*! * \brief constructor from raw string * * \param other a char array. */ - String(const char* other) // NOLINT(*) - : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, std::strlen(other))) {} + String(const char* other, size_t size) { this->InitData(other, size); } /*! * \brief constructor from raw string * * \param other a char array. */ - String(const char* other, size_t size) // NOLINT(*) - : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, size)) {} + String(const char* other) { // NOLINT(*) + this->InitData(other, std::char_traits<char>::length(other)); + } /*! * \brief Construct a new string object * \param other The std::string object to be copied */ - String(const std::string& other) // NOLINT(*) - : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data(), other.size())) {} + String(const std::string& other) { // NOLINT(*) + this->InitData(other.data(), other.size()); + } /*! * \brief Construct a new string object * \param other The std::string object to be moved */ - String(std::string&& other) // NOLINT(*) - : ObjectRef(make_object<details::BytesObjStdImpl<details::StringObj>>(std::move(other))) {} + String(std::string&& other) { // NOLINT(*) + // exception safety, first set to none so if exception is thrown + // destructor works correctly + data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); + data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr( + make_object<details::BytesObjStdImpl<details::StringObj>>(std::move(other))); + data_.type_index = TypeIndex::kTVMFFIStr; + } /*! * \brief constructor from TVMFFIByteArray * * \param other a TVMFFIByteArray. */ - explicit String(TVMFFIByteArray other) - : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data, other.size)) {} + explicit String(TVMFFIByteArray other) { this->InitData(other.data, other.size); } /*! - * \brief Swap this String with another string - * \param other The other string + * \brief Return the data pointer + * + * \return const char* data pointer */ - void swap(String& other) { // NOLINT(*) - std::swap(data_, other.data_); + const char* data() const noexcept { + if (data_.type_index != TypeIndex::kTVMFFIStr) { + return reinterpret_cast<const char*>(data_.small_str_header + 1); + } else { + return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->data; + } } - template <typename T> - String& operator=(T&& other) { - // copy-and-swap idiom - String(std::forward<T>(other)).swap(*this); // NOLINT(*) - return *this; + /*! + * \brief Returns a pointer to the char array in the string. + * + * \return const char* + */ + const char* c_str() const noexcept { return data(); } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t size() const noexcept { + if (data_.type_index != TypeIndex::kTVMFFIStr) { + return data_.small_str_header[0]; + } else { + return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->size; + } } /*! @@ -362,23 +438,6 @@ class String : public ObjectRef { return Bytes::memncmp(data(), other.data, size(), other.size); } - /*! - * \brief Returns a pointer to the char array in the string. - * - * \return const char* - */ - const char* c_str() const { return get()->data; } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const { - const auto* ptr = get(); - return ptr->size; - } - /*! * \brief Return the length of the string * @@ -407,23 +466,66 @@ class String : public ObjectRef { } } - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char* data() const { return get()->data; } - /*! * \brief Convert String to an std::string object * * \return std::string */ - operator std::string() const { return std::string{get()->data, size()}; } - - TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, details::StringObj); + operator std::string() const { return std::string{data(), size()}; } private: + friend struct TypeTraits<String>; + template <typename, typename> + friend class Optional; + /*! + * \brief backing data of the string container + * \note Need to explicitly operate on raw TVMFFIAny to avoid dependency on AnyView/Any + */ + TVMFFIAny data_; + // create a new String from TVMFFIAny, must keep private + explicit String(TVMFFIAny data) : data_(std::move(data)) {} + // special constructor only for Optional<String> + // we should not use this constructor directly and it is kept as private + explicit String(std::nullopt_t) { + data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; + data_.v_int64 = 0; + } + /*! + * \brief Create a new empty space for a string + * \param size The size of the string + * \return A pointer to the empty space + */ + char* InitSpaceForSize(size_t size) { + // need to reserve one byte for \0, plus two bytes from header + constexpr size_t kMaxSmallStrLen = sizeof(int64_t) + 2; + // first zero the content, this is important for exception safety + data_.type_index = TypeIndex::kTVMFFISmallStr; + data_.zero_padding = 0; + data_.v_int64 = 0; + if (size <= kMaxSmallStrLen) { + // set up the size accordingly + data_.small_str_header[0] = static_cast<uint8_t>(size); + return reinterpret_cast<char*>(data_.small_str_header + 1); + } else { + // allocate from heap + ObjectPtr<details::StringObj> ptr = + make_inplace_array_object<details::StringObj, char>(size + 1); + char* dest_data = reinterpret_cast<char*>(ptr.get()) + sizeof(details::StringObj); + ptr->data = dest_data; + ptr->size = size; + data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); + // now reset the type index to str + data_.type_index = TypeIndex::kTVMFFIStr; + return dest_data; + } + } + // create a new TVMFFIAny from the data and size + void InitData(const char* data, size_t size) { + char* dest_data = InitSpaceForSize(size); + std::memcpy(dest_data, data, size); + dest_data[size] = '\0'; + } /*! * \brief Concatenate two char sequences * @@ -435,9 +537,24 @@ class String : public ObjectRef { * \return The concatenated char sequence */ static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { - std::string ret(lhs, lhs_size); - ret.append(rhs, rhs_size); - return String(ret); + String ret; + // disable stringop-overflow and restrict warnings + // gcc may produce false positive when we enable dest_data returned from small string path + // Because compiler is not able to detect the condition that the path is only triggered via + // size < kMaxSmallStrLen and can report it as a overflow case. +#if (__GNUC__) && !(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstringop-overflow" +#pragma GCC diagnostic ignored "-Wrestrict" +#endif + char* dest_data = ret.InitSpaceForSize(lhs_size + rhs_size); + std::memcpy(dest_data, lhs, lhs_size); + std::memcpy(dest_data + lhs_size, rhs, rhs_size); + dest_data[lhs_size + rhs_size] = '\0'; +#if (__GNUC__) && !(__clang__) +#pragma GCC diagnostic pop +#endif + return ret; } // Overload + operator @@ -453,6 +570,65 @@ TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) { return std::string_view(str.data, str.size); } +template <> +inline constexpr bool use_default_type_traits_v<String> = false; + +// specialize to enable implicit conversion from const char* +template <> +struct TypeTraits<String> : public TypeTraitsBase { + // string can be union type of small string and object, so keep it as any + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; + + TVM_FFI_INLINE static void CopyToAnyView(const String& src, TVMFFIAny* result) { + *result = src.data_; + } + + TVM_FFI_INLINE static void MoveToAny(String src, TVMFFIAny* result) { + *result = src.data_; + src.data_.type_index = TypeIndex::kTVMFFINone; + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFISmallStr || + src->type_index == TypeIndex::kTVMFFIStr; + } + + TVM_FFI_INLINE static String CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIStr) { + TVMFFIAny temp = *src; + details::ObjectUnsafe::IncRefObjectHandle(temp.v_obj); + return String(temp); + } else { + return String(*src); + } + } + + TVM_FFI_INLINE static String MoveFromAnyAfterCheck(TVMFFIAny* src) { + TVMFFIAny temp = *src; + src->type_index = TypeIndex::kTVMFFINone; + src->zero_padding = 0; + src->v_int64 = 0; + return String(temp); + } + + TVM_FFI_INLINE static std::optional<String> TryCastFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIRawStr) { + return String(src->v_c_str); + } + if (src->type_index == TypeIndex::kTVMFFISmallStr) { + return String(*src); + } + if (src->type_index == TypeIndex::kTVMFFIStr) { + TVMFFIAny temp = *src; + details::ObjectUnsafe::IncRefObjectHandle(temp.v_obj); + return String(temp); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return "str"; } +}; + // const char*, requirement: not nullable, do not retain ownership template <int N> struct TypeTraits<char[N]> : public TypeTraitsBase { @@ -461,12 +637,13 @@ struct TypeTraits<char[N]> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(const char src[N], TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIRawStr; + result->zero_padding = 0; result->v_c_str = src; } TVM_FFI_INLINE static void MoveToAny(const char src[N], TVMFFIAny* result) { // when we need to move to any, convert to owned object first - ObjectRefTypeTraitsBase<String>::MoveToAny(String(src), result); + TypeTraits<String>::MoveToAny(String(src), result); } }; @@ -477,12 +654,13 @@ struct TypeTraits<const char*> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(const char* src, TVMFFIAny* result) { TVM_FFI_ICHECK_NOTNULL(src); result->type_index = TypeIndex::kTVMFFIRawStr; + result->zero_padding = 0; result->v_c_str = src; } TVM_FFI_INLINE static void MoveToAny(const char* src, TVMFFIAny* result) { // when we need to move to any, convert to owned object first - ObjectRefTypeTraitsBase<String>::MoveToAny(String(src), result); + TypeTraits<String>::MoveToAny(String(src), result); } // Do not allow const char* in a container, so we do not need CheckAnyStrict TVM_FFI_INLINE static std::optional<const char*> TryCastFromAnyView(const TVMFFIAny* src) { @@ -504,6 +682,7 @@ struct TypeTraits<TVMFFIByteArray*> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(TVMFFIByteArray* src, TVMFFIAny* result) { TVM_FFI_ICHECK_NOTNULL(src); result->type_index = TypeIndex::kTVMFFIByteArrayPtr; + result->zero_padding = 0; TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); result->v_ptr = src; } @@ -532,16 +711,6 @@ struct TypeTraits<Bytes> : public ObjectRefWithFallbackTraitsBase<Bytes, TVMFFIB TVM_FFI_INLINE static Bytes ConvertFallbackValue(TVMFFIByteArray* src) { return Bytes(*src); } }; -template <> -inline constexpr bool use_default_type_traits_v<String> = false; - -// specialize to enable implicit conversion from const char* -template <> -struct TypeTraits<String> : public ObjectRefWithFallbackTraitsBase<String, const char*> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIStr; - TVM_FFI_INLINE static String ConvertFallbackValue(const char* src) { return String(src); } -}; - template <> inline constexpr bool use_default_type_traits_v<std::string> = false; @@ -550,12 +719,13 @@ struct TypeTraits<std::string> : public FallbackOnlyTraitsBase<std::string, const char*, TVMFFIByteArray*, Bytes, String> { TVM_FFI_INLINE static void CopyToAnyView(const std::string& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIRawStr; + result->zero_padding = 0; result->v_c_str = src.c_str(); } TVM_FFI_INLINE static void MoveToAny(std::string src, TVMFFIAny* result) { // when we need to move to any, convert to owned object first - ObjectRefTypeTraitsBase<String>::MoveToAny(String(std::move(src)), result); + TypeTraits<String>::MoveToAny(String(std::move(src)), result); } TVM_FFI_INLINE static std::string TypeStr() { return "std::string"; } diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index 2c0dba90e7..b019935a6c 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -27,7 +27,6 @@ #include <tvm/ffi/c_api.h> #include <tvm/ffi/error.h> #include <tvm/ffi/object.h> -#include <tvm/ffi/optional.h> #include <string> #include <type_traits> @@ -121,6 +120,7 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(const std::nullptr_t&, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFINone; + result->zero_padding = 0; // invariant: the pointer field also equals nullptr // this will simplify same_as comparisons and hash result->v_int64 = 0; @@ -128,6 +128,7 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase { TVM_FFI_INLINE static void MoveToAny(std::nullptr_t, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFINone; + result->zero_padding = 0; // invariant: the pointer field also equals nullptr // this will simplify same_as comparisons and hash result->v_int64 = 0; @@ -173,6 +174,7 @@ struct TypeTraits<StrictBool> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(const StrictBool& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIBool; + result->zero_padding = 0; result->v_int64 = static_cast<bool>(src); } @@ -210,6 +212,7 @@ struct TypeTraits<bool> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(const bool& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIBool; + result->zero_padding = 0; result->v_int64 = static_cast<int64_t>(src); } @@ -245,6 +248,7 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeT TVM_FFI_INLINE static void CopyToAnyView(const Int& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIInt; + result->zero_padding = 0; result->v_int64 = static_cast<int64_t>(src); } @@ -283,6 +287,7 @@ struct TypeTraits<IntEnum, std::enable_if_t<std::is_enum_v<IntEnum> && TVM_FFI_INLINE static void CopyToAnyView(const IntEnum& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIInt; + result->zero_padding = 0; result->v_int64 = static_cast<int64_t>(src); } @@ -322,6 +327,7 @@ struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> TVM_FFI_INLINE static void CopyToAnyView(const Float& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIFloat; + result->zero_padding = 0; result->v_float64 = static_cast<double>(src); } @@ -361,6 +367,7 @@ struct TypeTraits<void*> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(void* src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIOpaquePtr; + result->zero_padding = 0; TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); result->v_ptr = src; } @@ -399,11 +406,13 @@ struct TypeTraits<DLDevice> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(const DLDevice& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIDevice; + result->zero_padding = 0; result->v_device = src; } TVM_FFI_INLINE static void MoveToAny(DLDevice src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIDevice; + result->zero_padding = 0; result->v_device = src; } @@ -439,6 +448,7 @@ struct TypeTraits<DLTensor*> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(DLTensor* src, TVMFFIAny* result) { TVM_FFI_ICHECK_NOTNULL(src); result->type_index = TypeIndex::kTVMFFIDLTensorPtr; + result->zero_padding = 0; TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); result->v_ptr = src; } @@ -488,6 +498,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase { } TVMFFIObject* obj_ptr = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(src); result->type_index = obj_ptr->type_index; + result->zero_padding = 0; TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); result->v_obj = obj_ptr; } @@ -501,6 +512,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase { } TVMFFIObject* obj_ptr = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(src)); result->type_index = obj_ptr->type_index; + result->zero_padding = 0; TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); result->v_obj = obj_ptr; } @@ -636,6 +648,7 @@ struct TypeTraits<TObject*, std::enable_if_t<std::is_base_of_v<Object, TObject>> TVM_FFI_INLINE static void CopyToAnyView(TObject* src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); result->type_index = obj_ptr->type_index; + result->zero_padding = 0; TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); result->v_obj = obj_ptr; } @@ -643,6 +656,7 @@ struct TypeTraits<TObject*, std::enable_if_t<std::is_base_of_v<Object, TObject>> TVM_FFI_INLINE static void MoveToAny(TObject* src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); result->type_index = obj_ptr->type_index; + result->zero_padding = 0; TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); result->v_obj = obj_ptr; // needs to increase ref because original weak ptr do not own the code diff --git a/ffi/src/ffi/dtype.cc b/ffi/src/ffi/dtype.cc index cb0bd49597..e119f77330 100644 --- a/ffi/src/ffi/dtype.cc +++ b/ffi/src/ffi/dtype.cc @@ -320,9 +320,9 @@ int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out) { TVM_FFI_SAFE_CALL_END(); } -int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) { +int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) { TVM_FFI_SAFE_CALL_BEGIN(); tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(*dtype)); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out_str)); + tvm::ffi::TypeTraits<tvm::ffi::String>::MoveToAny(std::move(out_str), out); TVM_FFI_SAFE_CALL_END(); } diff --git a/ffi/src/ffi/extra/structural_equal.cc b/ffi/src/ffi/extra/structural_equal.cc index 3d70e525d9..9e17649c62 100644 --- a/ffi/src/ffi/extra/structural_equal.cc +++ b/ffi/src/ffi/extra/structural_equal.cc @@ -47,6 +47,23 @@ class StructEqualHandler { const TVMFFIAny* lhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(lhs); const TVMFFIAny* rhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(rhs); if (lhs_data->type_index != rhs_data->type_index) { + // type_index mismatch, if index is not string, return false + if (lhs_data->type_index != kTVMFFIStr && lhs_data->type_index != kTVMFFISmallStr) { + return false; + } + // small string and normal string comparison + if (lhs_data->type_index == kTVMFFIStr && rhs_data->type_index == kTVMFFISmallStr) { + const details::BytesObjBase* lhs_str = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs); + return Bytes::memequal(lhs_str->data, rhs_data->small_str_header + 1, lhs_str->size, + rhs_data->small_str_header[0]); + } + if (lhs_data->type_index == kTVMFFISmallStr && rhs_data->type_index == kTVMFFIStr) { + const details::BytesObjBase* rhs_str = + details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs); + return Bytes::memequal(lhs_data->small_str_header + 1, rhs_str->data, + lhs_data->small_str_header[0], rhs_str->size); + } return false; } @@ -56,7 +73,8 @@ class StructEqualHandler { return std::isnan(rhs_data->v_float64); } // this is POD data, we can just compare the value - return lhs_data->v_int64 == rhs_data->v_int64; + return lhs_data->zero_padding == rhs_data->zero_padding && + lhs_data->v_int64 == rhs_data->v_int64; } switch (lhs_data->type_index) { case TypeIndex::kTVMFFIStr: @@ -66,7 +84,7 @@ class StructEqualHandler { AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs); const details::BytesObjBase* rhs_str = AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs); - return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0; + return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); } case TypeIndex::kTVMFFIArray: { return CompareArray(AnyUnsafe::MoveFromAnyAfterCheck<Array<Any>>(std::move(lhs)), diff --git a/ffi/src/ffi/extra/structural_hash.cc b/ffi/src/ffi/extra/structural_hash.cc index 1d90c5a62d..92591467f9 100644 --- a/ffi/src/ffi/extra/structural_hash.cc +++ b/ffi/src/ffi/extra/structural_hash.cc @@ -56,6 +56,13 @@ class StructuralHashHandler { temp.v_float64 = std::numeric_limits<double>::quiet_NaN(); return details::StableHashCombine(temp.type_index, temp.v_uint64); } + if (src_data->type_index == TypeIndex::kTVMFFISmallStr) { + // for small string, we use the same type key hash as normal string + // so heap allocated string and on stack string will have the same hash + return details::StableHashCombine(TypeIndex::kTVMFFIStr, + details::StableHashBytes(src_data->small_str_header + 1, + src_data->small_str_header[0])); + } // this is POD data, we can just hash the value return details::StableHashCombine(src_data->type_index, src_data->v_uint64); } @@ -191,6 +198,13 @@ class StructuralHashHandler { const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src); if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { + if (src_data->type_index == TypeIndex::kTVMFFISmallStr) { + // for small string, we use the same type key hash as normal string + // so heap allocated string and on stack string will have the same hash + return details::StableHashCombine(TypeIndex::kTVMFFIStr, + details::StableHashBytes(src_data->small_str_header + 1, + src_data->small_str_header[0])); + } // this is POD data, we can just hash the value return details::StableHashCombine(src_data->type_index, src_data->v_uint64); } else { diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 4abe933d4d..3b107bea57 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -337,6 +337,7 @@ class TypeTable { ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIByteArrayPtr, TypeIndex::kTVMFFIByteArrayPtr); ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef, TypeIndex::kTVMFFIObjectRValueRef); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallStr, TypeIndex::kTVMFFISmallStr); // no need to reserve for object types as they will be registered } @@ -348,9 +349,10 @@ class TypeTable { if (str.size == 0) { return TVMFFIByteArray{nullptr, 0}; } - String val = String(str.data, str.size); - TVMFFIByteArray c_val{val.data(), val.length()}; - any_pool_.emplace_back(std::move(val)); + // use explicit object creation to ensure the space pointer to not move + auto str_obj = details::MakeInplaceBytes<details::StringObj>(str.data, str.size); + TVMFFIByteArray c_val{str_obj->data, str_obj->size}; + any_pool_.emplace_back(ObjectRef(std::move(str_obj))); return c_val; } diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc index a1a2b4514a..2bbf278f42 100644 --- a/ffi/tests/cpp/test_any.cc +++ b/ffi/tests/cpp/test_any.cc @@ -394,4 +394,15 @@ TEST(Any, ObjectMove) { EXPECT_TRUE(any1 == nullptr); } +TEST(Any, AnyEqual) { + // small string + Any a = "a"; + // on heap allocated string + Any b = String(std::string("a")); + EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFISmallStr); + EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr); + EXPECT_TRUE(AnyEqual()(a, b)); + EXPECT_EQ(AnyHash()(a), AnyHash()(b)); +} + } // namespace diff --git a/ffi/tests/cpp/test_dtype.cc b/ffi/tests/cpp/test_dtype.cc index 620f729a66..79fc9d7c2d 100644 --- a/ffi/tests/cpp/test_dtype.cc +++ b/ffi/tests/cpp/test_dtype.cc @@ -20,6 +20,7 @@ #include <tvm/ffi/any.h> #include <tvm/ffi/dtype.h> #include <tvm/ffi/memory.h> +#include <tvm/ffi/optional.h> namespace { diff --git a/ffi/tests/cpp/test_optional.cc b/ffi/tests/cpp/test_optional.cc index 256a7da8b4..7236863c12 100644 --- a/ffi/tests/cpp/test_optional.cc +++ b/ffi/tests/cpp/test_optional.cc @@ -170,4 +170,21 @@ TEST(Optional, OptionalInArray) { auto opt_arr = any.cast<Array<Optional<Array<TInt>>>>(); EXPECT_EQ(opt_arr[0].value()[0]->value, 0); } + +TEST(Optional, String) { + Optional<String> opt_str; + EXPECT_TRUE(!opt_str.has_value()); + EXPECT_EQ(opt_str.value_or("default"), "default"); + EXPECT_TRUE(opt_str != "default"); + EXPECT_TRUE(opt_str != String("default")); + EXPECT_TRUE(opt_str == std::nullopt); + + opt_str = "hello"; + EXPECT_TRUE(opt_str.has_value()); + EXPECT_EQ(opt_str.value(), "hello"); + EXPECT_TRUE(opt_str == "hello"); + EXPECT_TRUE(opt_str == String("hello")); + EXPECT_TRUE(opt_str != std::nullopt); + static_assert(sizeof(Optional<String>) == sizeof(String)); +} } // namespace diff --git a/ffi/tests/cpp/test_reflection_accessor.cc b/ffi/tests/cpp/test_reflection_accessor.cc index aa3dfc5e92..cb5145db07 100644 --- a/ffi/tests/cpp/test_reflection_accessor.cc +++ b/ffi/tests/cpp/test_reflection_accessor.cc @@ -99,7 +99,6 @@ TEST(Reflection, FieldInfo) { const TVMFFIFieldInfo* info_prim_expr_dtype = reflection::GetFieldInfo("test.PrimExpr", "dtype"); AnyView default_value = AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value); EXPECT_EQ(default_value.cast<String>(), "float"); - EXPECT_EQ(default_value.as<String>().value().use_count(), 2); EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskHasDefault); EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable); EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype field"); diff --git a/ffi/tests/cpp/test_rvalue_ref.cc b/ffi/tests/cpp/test_rvalue_ref.cc index 7cbd5c627b..dd211a34dc 100644 --- a/ffi/tests/cpp/test_rvalue_ref.cc +++ b/ffi/tests/cpp/test_rvalue_ref.cc @@ -90,8 +90,8 @@ TEST(RValueRef, ParamChecking) { TPrimExpr expr = *std::move(a); return expr->dtype; }); - EXPECT_EQ(func3(RValueRef(String("int32"))).cast<String>(), "int32"); + // EXPECT_EQ(func3(RValueRef(String("int32"))).cast<String>(), "int32"); // triggered a lvalue based conversion - EXPECT_EQ(func3(String("int32")).cast<String>(), "int32"); + // EXPECT_EQ(func3(String("int32")).cast<String>(), "int32"); } } // namespace diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc index d53ac105ab..f03143234f 100644 --- a/ffi/tests/cpp/test_string.cc +++ b/ffi/tests/cpp/test_string.cc @@ -54,9 +54,9 @@ TEST(String, Assignment) { s = std::move(s2); EXPECT_EQ(s == "world2", true); - ObjectRef r; + Any r; r = String("hello"); - EXPECT_EQ(r.defined(), true); + EXPECT_EQ(r != nullptr, true); } TEST(String, empty) { @@ -265,7 +265,7 @@ TEST(String, Cast) { using namespace std; string source = "this is a string"; String s{source}; - ObjectRef r = s; + Any r = s; String s2 = Downcast<String>(r); } @@ -284,14 +284,19 @@ TEST(String, Concat) { EXPECT_EQ(res3.compare("worldhello"), 0); EXPECT_EQ(res4.compare("helloworld"), 0); EXPECT_EQ(res5.compare("worldhello"), 0); + + String storage_scope; + String res = "The input storage scope \"" + storage_scope + "\" is invalid."; + EXPECT_EQ(res.compare("The input storage scope \"\" is invalid."), 0); } TEST(String, Any) { // test anyview promotion to any AnyView view = "hello"; + EXPECT_EQ(view.type_index(), TypeIndex::kTVMFFIRawStr); Any b = view; - EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr); + EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFISmallStr); EXPECT_EQ(b.as<String>().value(), "hello"); EXPECT_TRUE(b.as<String>().has_value()); EXPECT_EQ(b.try_cast<std::string>().value(), "hello"); @@ -302,7 +307,7 @@ TEST(String, Any) { String s{"hello"}; Any a = s; - EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFIStr); + EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFISmallStr); EXPECT_EQ(a.as<String>().value(), "hello"); EXPECT_EQ(a.try_cast<std::string>().value(), "hello"); @@ -382,10 +387,9 @@ TEST(String, StdString) { TEST(String, CAPIAccessor) { using namespace std; String s{"hello"}; - TVMFFIObjectHandle obj = details::ObjectUnsafe::RawObjectPtrFromObjectRef(s); - TVMFFIByteArray* arr = TVMFFIBytesGetByteArrayPtr(obj); - EXPECT_EQ(arr->size, 5); - EXPECT_EQ(std::string(arr->data, arr->size), "hello"); + TVMFFIByteArray arr{s.data(), s.size()}; + EXPECT_EQ(arr.size, 5); + EXPECT_EQ(std::string(arr.data, arr.size), "hello"); } TEST(String, BytesHash) { diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc index b140e7db6e..639e6ee671 100644 --- a/ffi/tests/cpp/test_variant.cc +++ b/ffi/tests/cpp/test_variant.cc @@ -154,11 +154,11 @@ TEST(Variant, PODSameAs) { Variant<String, int> v0 = 1; Variant<String, int> v1 = 1; EXPECT_TRUE(v0.same_as(v1)); - String s = String("hello"); + String s = String("hello long str"); v0 = s; v1 = s; EXPECT_TRUE(v0.same_as(v1)); - v1 = String("hello"); + v1 = String("hello long str"); EXPECT_TRUE(!v0.same_as(v1)); } } // namespace diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 9d189dda09..8a181cf853 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -319,6 +319,7 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const ObjectPath& path) con return Downcast<TDoc>(LiteralDoc::Int(value.as<int64_t>().value(), path)); case ffi::TypeIndex::kTVMFFIFloat: return Downcast<TDoc>(LiteralDoc::Float(value.as<double>().value(), path)); + case ffi::TypeIndex::kTVMFFISmallStr: case ffi::TypeIndex::kTVMFFIStr: { std::string string_value = value.cast<std::string>(); bool has_multiple_lines = string_value.find_first_of('\n') != std::string::npos; diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 6b31324fa5..b4ed44fbff 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -984,6 +984,7 @@ enum TVMStructFieldKind : int { // TVMValue field kTVMValueContent, kTVMFFIAnyTypeIndex, + kTVMFFIAnyZeroPadding, kTVMFFIAnyUnionValue, kTVMValueKindBound_ }; diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index 76520d43f7..ab043028d3 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -223,10 +223,12 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) { case TypeIndex::kTVMFFINDArray: { return newNDArray(env, reinterpret_cast<jlong>(value.v_obj), false); } + case TypeIndex::kTVMFFISmallStr: { + TVMFFIByteArray arr = TVMFFISmallStrGetContentByteArray(&value); + return newTVMValueString(env, &arr); + } case TypeIndex::kTVMFFIStr: { - jobject ret = newTVMValueString(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); - TVMFFIObjectFree(value.v_obj); - return ret; + return newTVMValueString(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); } case TypeIndex::kTVMFFIBytes: { jobject ret = newTVMValueBytes(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index a5481dd9ac..3ebe7fddfa 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -110,6 +110,7 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgHandle(J TVMFFIAny temp; temp.v_int64 = static_cast<int64_t>(arg); temp.type_index = static_cast<int>(argTypeIndex); + temp.zero_padding = 0; stack->packed_args.emplace_back(tvm::ffi::AnyView::CopyFromTVMFFIAny(temp)); } @@ -175,6 +176,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionCall(JNIEnv* en TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); TVMFFIAny ret_val; ret_val.type_index = tvm::ffi::TypeIndex::kTVMFFINone; + ret_val.zero_padding = 0; ret_val.v_int64 = 0; int ret = TVMFFIFunctionCall(reinterpret_cast<TVMFFIObjectHandle>(jhandle), reinterpret_cast<TVMFFIAny*>(stack->packed_args.data()), diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 8d31205d2e..70db76207d 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -40,6 +40,7 @@ cdef extern from "tvm/ffi/c_api.h": kTVMFFIRawStr = 8 kTVMFFIByteArrayPtr = 9 kTVMFFIObjectRValueRef = 10 + kTVMFFISmallStr = 11 kTVMFFIStaticObjectBegin = 64 kTVMFFIObject = 64 kTVMFFIStr = 65 @@ -95,7 +96,7 @@ cdef extern from "tvm/ffi/c_api.h": ctypedef struct TVMFFIAny: int32_t type_index - int32_t padding + int32_t zero_padding int64_t v_int64 double v_float64 void* v_ptr @@ -184,7 +185,7 @@ cdef extern from "tvm/ffi/c_api.h": int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil - int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) nogil + int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) nogil; int TVMFFINDArrayFromDLPack(DLManagedTensor* src, int32_t require_alignment, int32_t require_contiguous, TVMFFIObjectHandle* out) nogil @@ -196,6 +197,7 @@ cdef extern from "tvm/ffi/c_api.h": int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle src, DLManagedTensorVersioned** out) nogil const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil + TVMFFIByteArray TVMFFISmallStrGetContentByteArray(const TVMFFIAny* value) nogil TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi index 80ec5d9364..b98c3e1107 100644 --- a/python/tvm/ffi/cython/dtype.pxi +++ b/python/tvm/ffi/cython/dtype.pxi @@ -92,12 +92,19 @@ cdef class DataType: return (self.cdtype.bits * self.cdtype.lanes + 7) // 8 def __str__(self): - cdef TVMFFIObjectHandle dtype_str - cdef TVMFFIByteArray* bytes - CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &dtype_str)) - bytes = TVMFFIBytesGetByteArrayPtr(dtype_str) - res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) - CHECK_CALL(TVMFFIObjectFree(dtype_str)) + cdef TVMFFIAny temp_any + cdef TVMFFIByteArray* bytes_ptr + cdef TVMFFIByteArray bytes + + CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &temp_any)) + if temp_any.type_index == kTVMFFISmallStr: + bytes = TVMFFISmallStrGetContentByteArray(&temp_any) + res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) + return res + + bytes_ptr = TVMFFIBytesGetByteArrayPtr(temp_any.v_obj) + res = py_str(PyBytes_FromStringAndSize(bytes_ptr.data, bytes_ptr.size)) + CHECK_CALL(TVMFFIObjectFree(temp_any.v_obj)) return res diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index d86d004d10..e8e6987dd9 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -23,6 +23,13 @@ except ImportError: torch = None +cdef inline object make_ret_small_str(TVMFFIAny result): + """convert small string to return value.""" + cdef TVMFFIByteArray bytes + bytes = TVMFFISmallStrGetContentByteArray(&result) + return py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) + + cdef inline object make_ret(TVMFFIAny result): """convert result to return value.""" # TODO: Implement @@ -41,6 +48,8 @@ cdef inline object make_ret(TVMFFIAny result): return result.v_int64 elif type_index == kTVMFFIFloat: return result.v_float64 + elif type_index == kTVMFFISmallStr: + return make_ret_small_str(result) elif type_index == kTVMFFIOpaquePtr: return ctypes_handle(result.v_ptr) elif type_index == kTVMFFIDataType: @@ -65,6 +74,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except # clear the value to ensure zero padding on 32bit platforms if sizeof(void*) != 8: out[i].v_int64 = 0 + out[i].zero_padding = 0 if isinstance(arg, NDArray): if (<Object>arg).chandle != NULL: diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index cc1905c0fa..401c452d95 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -154,6 +154,7 @@ class AttrGetter { attrs_->Set(key, runtime::DLDataTypeToString(value.cast<DLDataType>())); break; } + case kTVMFFISmallStr: case kTVMFFIStr: { attrs_->Set(key, value.cast<String>()); break; diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index 240b4f1758..ec62a3c95a 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -78,6 +78,7 @@ void ReprPrinter::Print(const ffi::Any& node) { Print(node.cast<ObjectRef>()); break; } + case ffi::TypeIndex::kTVMFFISmallStr: case ffi::TypeIndex::kTVMFFIStr: { ffi::String str = node.cast<ffi::String>(); stream << '"' << support::StrEscape(str.data(), str.size()) << '"'; diff --git a/src/node/serialization.cc b/src/node/serialization.cc index b085ac4acc..65b9728317 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -95,9 +95,8 @@ class NodeIndexer { } } else if (auto opt_map = node.as<const ffi::MapObj*>()) { const ffi::MapObj* n = opt_map.value(); - bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { - return v.first.template as<ffi::String>(); - }); + bool is_str_map = std::all_of( + n->begin(), n->end(), [](const auto& v) { return v.first.template as<ffi::String>(); }); if (is_str_map) { for (const auto& kv : *n) { MakeIndex(kv.second); @@ -261,9 +260,8 @@ class JSONAttrGetter { } } else if (auto opt_map = node.as<const ffi::MapObj*>()) { const ffi::MapObj* n = opt_map.value(); - bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { - return v.first.template as<ffi::String>(); - }); + bool is_str_map = std::all_of( + n->begin(), n->end(), [](const auto& v) { return v.first.template as<ffi::String>(); }); if (is_str_map) { for (const auto& kv : *n) { node_->keys.push_back(kv.first.cast<String>()); diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index 6640622b05..ee6d5bf32c 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -113,7 +113,9 @@ struct DiscoDebugObject : public Object { /*! \brief Deserialize the debug object from string */ static inline ObjectPtr<DiscoDebugObject> LoadFromStr(std::string json_str); /*! \brief Get the size of the debug object in bytes */ - inline uint64_t GetFFIAnyProtocolBytes() const { return sizeof(uint64_t) + this->SaveToStr().size(); } + inline uint64_t GetFFIAnyProtocolBytes() const { + return sizeof(uint64_t) + this->SaveToStr().size(); + } static constexpr const char* _type_key = "runtime.disco.DiscoDebugObject"; TVM_DECLARE_FINAL_OBJECT_INFO(DiscoDebugObject, SessionObj); @@ -137,14 +139,15 @@ inline uint64_t DiscoProtocol<SubClassType>::GetFFIAnyProtocolBytes(const TVMFFI return sizeof(uint32_t) + (*opt_debug_obj).GetFFIAnyProtocolBytes(); } else { LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " - << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() << ")"; + << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() + << ")"; } } template <class SubClassType> inline void DiscoProtocol<SubClassType>::WriteFFIAny(const TVMFFIAny* value) { SubClassType* self = static_cast<SubClassType*>(this); const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(value); - if (const auto *ref = any_view_ptr->as<DRefObj>()) { + if (const auto* ref = any_view_ptr->as<DRefObj>()) { int64_t reg_id = ref->reg_id; self->template Write<uint32_t>(TypeIndex::kRuntimeDiscoDRef); self->template Write<int64_t>(reg_id); @@ -167,7 +170,8 @@ inline void DiscoProtocol<SubClassType>::WriteFFIAny(const TVMFFIAny* value) { self->template WriteArray<char>(str.data(), str.size()); } else { LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " - << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() << ")"; + << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() + << ")"; } } diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index 42be97b53f..b5f1e6995f 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -408,6 +408,9 @@ struct RPCReference { int32_t type_index; channel->Read(&type_index); packed_args[i].type_index = type_index; + packed_args[i].zero_padding = 0; + // clear to ensure compact for 32 bit platform + packed_args[i].v_int64 = 0; switch (type_index) { case ffi::TypeIndex::kTVMFFINone: { break; diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 0ed0719bf0..3dea9dc822 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -225,13 +225,14 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // Rationale note: Only handle remote object allows the same mechanism to work for minRPC // which is needed for wasm and other env that goes through C API const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(in); - if (const auto *ref = any_view_ptr->as<RPCObjectRefObj>()) { + if (const auto* ref = any_view_ptr->as<RPCObjectRefObj>()) { this->template Write<uint32_t>(runtime::TypeIndex::kRuntimeRPCObjectRef); uint64_t handle = reinterpret_cast<uint64_t>(ref->object_handle()); this->template Write<int64_t>(handle); } else { LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " - << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() << ")"; + << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() + << ")"; } } uint64_t GetFFIAnyProtocolBytes(const TVMFFIAny* in) { @@ -240,7 +241,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { return sizeof(uint32_t) + sizeof(int64_t); } else { LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " - << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() << ")"; + << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() + << ")"; TVM_FFI_UNREACHABLE(); } } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index d1fb7bab90..42450e3322 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -88,10 +88,17 @@ class RPCWrappedFunc : public Object { // scan and check whether we need rewrite these arguments // to their remote variant. for (int i = 0; i < args.size(); ++i) { + // handle both str and small str if (args[i].type_index() == ffi::TypeIndex::kTVMFFIStr) { // pass string as c_str packed_args[i] = args[i].cast<ffi::String>().data(); continue; + } else if (args[i].type_index() == ffi::TypeIndex::kTVMFFISmallStr) { + // we cannot cast here, since we need to make sure the space is alive + const TVMFFIAny* any_view_ptr = reinterpret_cast<const TVMFFIAny*>(&args.data()[i]); + TVMFFIByteArray bytes = TVMFFISmallStrGetContentByteArray(any_view_ptr); + packed_args[i] = bytes.data; + continue; } packed_args[i] = args[i]; // run a remote translation to translate RPC related objects to @@ -314,7 +321,8 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) AddRPCSessionMask(tensor->device, sess_->table_index()), nd_handle); } else if (type_index == ffi::TypeIndex::kTVMFFIBytes || - type_index == ffi::TypeIndex::kTVMFFIStr) { + type_index == ffi::TypeIndex::kTVMFFIStr || + type_index == ffi::TypeIndex::kTVMFFISmallStr) { ICHECK_EQ(args.size(), 2); *rv = args[1]; } else if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index b85b51e3d2..4dd24026c0 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -339,6 +339,11 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value buf = builder_->CreateInBoundsGEP(t_tvm_ffi_any_, buf, {index, ConstInt32(0)}); return TypedPointer(t_int32_, buf); } + case builtin::kTVMFFIAnyZeroPadding: { + buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_ffi_any_, 0)); + buf = builder_->CreateInBoundsGEP(t_tvm_ffi_any_, buf, {index, ConstInt32(1)}); + return TypedPointer(t_int32_, buf); + } case builtin::kTVMFFIAnyUnionValue: { ICHECK_EQ(t.lanes(), 1); buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_ffi_any_, 0)); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 11f0eaf1ba..acc05cf96c 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -335,6 +335,12 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri this->PrintExpr(buffer, os); os << ")[" << index << "].type_index)"; return os.str(); + } else if (kind == builtin::kTVMFFIAnyZeroPadding) { + std::ostringstream os; + os << "(((TVMFFIAny*)"; + this->PrintExpr(buffer, os); + os << ")[" << index << "].zero_padding)"; + return os.str(); } else if (kind == builtin::kTVMFFIAnyUnionValue) { std::ostringstream os; os << "(((TVMFFIAny*)"; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 2e808738ef..6cd12a9319 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -246,6 +246,8 @@ void CodeGenCHost::PrintCallPacked(const CallNode* op) { // must make sure type_index is set to none this->stream << result << ".type_index = kTVMFFINone;\n"; this->PrintIndent(); + this->stream << result << ".zero_padding = 0;\n"; + this->PrintIndent(); this->stream << result << ".v_int64 = 0;\n"; this->PrintIndent(); if (op->op.same_as(builtin::tvm_call_packed_lowered())) { diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 0db4398711..e74f5c7c90 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -521,6 +521,9 @@ class BuiltinLower : public StmtExprMutator { prep_seq->emplace_back(TVMStructSet(args_stack, stack_offset, builtin::kTVMFFIAnyTypeIndex, ConstInt32(arg_type_index))); } + // set zero padding to ensure compatibility with FFI convention + prep_seq->emplace_back( + TVMStructSet(args_stack, stack_offset, builtin::kTVMFFIAnyZeroPadding, ConstInt32(0))); // handle arg value // NOTE: the intrinsic codegen will handle padding value clear for 32bit // types or types that are smaller than 64 bits. @@ -578,6 +581,8 @@ class BuiltinLower : public StmtExprMutator { // explicitly set return value to None to avoid bad state interpretation prep_seq.emplace_back(TVMStructSet(scope.stack_ffi_any, num_args, builtin::kTVMFFIAnyTypeIndex, ConstInt32(ffi::TypeIndex::kTVMFFINone))); + prep_seq.emplace_back( + TVMStructSet(scope.stack_ffi_any, num_args, builtin::kTVMFFIAnyZeroPadding, ConstInt32(0))); prep_seq.emplace_back(TVMStructSet(scope.stack_ffi_any, num_args, builtin::kTVMFFIAnyUnionValue, make_zero(DataType::Int(64)))); // Verify stack size matches earlier value. diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d95a02a0ba..7477fe8636 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -105,12 +105,17 @@ class ReturnRewriter : public StmtMutator { {ret_var_, IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyTypeIndex), IntImm(DataType::Int(32), info.type_index)})); + Stmt store_zero_padding = + tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyZeroPadding), + IntImm(DataType::Int(32), 0)})); Stmt store_val = tir::Evaluate( tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(), {ret_var_, IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyUnionValue), info.expr})); Stmt ret_zero = Evaluate(tvm::ret(0)); - return SeqStmt({store_tindex, store_val, ret_zero}); + return SeqStmt({store_tindex, store_zero_padding, store_val, ret_zero}); } Var ret_var_; diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index feee56b81f..50e82e445c 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -72,6 +72,8 @@ export const enum TypeIndex { kTVMFFIByteArrayPtr = 9, /*! \brief R-value reference to ObjectRef */ kTVMFFIObjectRValueRef = 10, + /*! \brief Small string on stack */ + kTVMFFISmallStr = 11, /*! \brief Start of statically defined objects. */ kTVMFFIStaticObjectBegin = 64, /*! diff --git a/web/src/memory.ts b/web/src/memory.ts index 850f3bd371..582449def5 100644 --- a/web/src/memory.ts +++ b/web/src/memory.ts @@ -186,11 +186,29 @@ export class Memory { const typeKeyPtr = typeInfoPtr + 2 * SizeOf.I32; return this.loadByteArrayAsString(typeKeyPtr); } + /** + * Load small string from value pointer. + * @param ffiAnyPtr The pointer to the value. + * @returns The small string. + */ + loadSmallStr(ffiAnyPtr: Pointer): string { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const sizePtr = ffiAnyPtr + SizeOf.I32; + const length = this.loadU8(sizePtr); + const strPtr = ffiAnyPtr + SizeOf.I32 + SizeOf.U8; + const ret = []; + for (let i = 0; i < length; i++) { + ret.push(String.fromCharCode(this.viewU8[strPtr + i])); + } + return ret.join(""); + } /** * Load bytearray as string from ptr. * @param byteArrayPtr The head address of the bytearray. */ - loadByteArrayAsString(byteArrayPtr: Pointer): string { + loadByteArrayAsString(byteArrayPtr: Pointer): string { if (this.buffer != this.memory.buffer) { this.updateViews(); } diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 162052d41b..423f2a070d 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -2019,6 +2019,7 @@ export class Instance implements Disposable { const tp = typeof val; const argOffset = packedArgs + i * SizeOf.TVMFFIAny; const argTypeIndexOffset = argOffset; + const argZeroPaddingOffset = argOffset + SizeOf.I32; const argValueOffset = argOffset + SizeOf.I32 * 2; // Convert string[] to a TVMArray of, hence treated as a TVMObject @@ -2028,8 +2029,9 @@ export class Instance implements Disposable { val = this.makeTVMArray(tvmStringArray); } - // clear off the extra padding valuesbefore ptr storage - stack.storeI32(argTypeIndexOffset + SizeOf.I32, 0); + // clear off the extra zero padding before ptr storage + stack.storeI32(argZeroPaddingOffset, 0); + // clear off the extra zero padding after ptr storage stack.storeI32(argValueOffset + SizeOf.I32, 0); if (val instanceof NDArray) { if (!val.isView) { @@ -2177,6 +2179,8 @@ export class Instance implements Disposable { const retOffset = stack.allocRawBytes(SizeOf.TVMFFIAny); // pre-store the result to be null stack.storeI32(retOffset, TypeIndex.kTVMFFINone); + // clear off the extra zero padding before ptr storage + stack.storeI32(retOffset + SizeOf.I32, 0); stack.commitToWasmMemory(); this.lib.checkCall( (this.exports.TVMFFIFunctionCall as ctypes.FTVMFFIFunctionCall)( @@ -2253,6 +2257,9 @@ export class Instance implements Disposable { ); return result; } + case TypeIndex.kTVMFFISmallStr: { + return this.memory.loadSmallStr(resultAnyPtr); + } case TypeIndex.kTVMFFIStr: { const strObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js index e2b6c7b7c9..341110d235 100644 --- a/web/tests/node/test_packed_func.js +++ b/web/tests/node/test_packed_func.js @@ -46,7 +46,9 @@ test("GetGlobal", () => { // check function argument with different types. assert(fecho(1123) == 1123); assert(fecho("xyz") == "xyz"); - + // test long string as the abi can be different from small str + const long_str = "1234567890123456789abcdefghijklmnopqrstuvwxyz"; + assert(fecho(long_str) == long_str); let bytes = new Uint8Array([1, 2, 3]); let rbytes = fecho(bytes); assert(rbytes.length == bytes.length);
