This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fe36bb9062 [FFI] Introduce small string/bytes (#18185)
fe36bb9062 is described below
commit fe36bb9062136d9b29c98fd0be8d9ecc729ac5b2
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Aug 4 08:52:12 2025 -0400
[FFI] Introduce small string/bytes (#18185)
---
ffi/include/tvm/ffi/any.h | 114 +++--
ffi/include/tvm/ffi/base_details.h | 17 +-
ffi/include/tvm/ffi/c_api.h | 48 +-
ffi/include/tvm/ffi/cast.h | 1 +
ffi/include/tvm/ffi/container/variant.h | 2 +
ffi/include/tvm/ffi/dtype.h | 9 +-
ffi/include/tvm/ffi/object.h | 2 +
ffi/include/tvm/ffi/optional.h | 116 ++++-
ffi/include/tvm/ffi/reflection/accessor.h | 2 +-
ffi/include/tvm/ffi/rvalue_ref.h | 4 +-
ffi/include/tvm/ffi/string.h | 567 +++++++++++++++------
ffi/include/tvm/ffi/type_traits.h | 16 +-
ffi/src/ffi/dtype.cc | 4 +-
ffi/src/ffi/extra/structural_equal.cc | 35 +-
ffi/src/ffi/extra/structural_hash.cc | 13 +
ffi/src/ffi/object.cc | 26 +-
ffi/tests/cpp/test_any.cc | 18 +
ffi/tests/cpp/test_dtype.cc | 1 +
ffi/tests/cpp/test_optional.cc | 29 ++
ffi/tests/cpp/test_reflection_accessor.cc | 1 -
ffi/tests/cpp/test_rvalue_ref.cc | 4 +-
ffi/tests/cpp/test_string.cc | 51 +-
ffi/tests/cpp/test_variant.cc | 4 +-
include/tvm/relax/exec_builder.h | 2 +-
include/tvm/relax/transform.h | 4 +-
include/tvm/script/ir_builder/tir/frame.h | 2 +-
include/tvm/script/printer/ir_docsifier.h | 1 +
include/tvm/tir/builtin.h | 1 +
jvm/native/src/main/native/jni_helper_func.h | 12 +-
.../src/main/native/org_apache_tvm_native_c_api.cc | 2 +
python/tvm/ffi/cython/base.pxi | 7 +-
python/tvm/ffi/cython/dtype.pxi | 19 +-
python/tvm/ffi/cython/function.pxi | 19 +
src/contrib/msc/core/ir/graph_builder.h | 1 +
src/contrib/msc/core/printer/cpp_printer.cc | 7 +-
src/contrib/msc/core/printer/python_printer.cc | 2 +-
src/meta_schedule/mutator/mutate_tile_size.cc | 5 +-
src/node/repr_printer.cc | 2 +
src/node/serialization.cc | 57 ++-
src/relax/backend/contrib/clml/codegen.cc | 6 +-
.../backend/contrib/codegen_json/codegen_json.h | 7 +-
src/relax/backend/contrib/cublas/codegen.cc | 2 +-
src/relax/backend/contrib/cudnn/codegen.cc | 2 +-
src/relax/backend/contrib/cutlass/codegen.cc | 2 +-
src/relax/backend/contrib/dnnl/codegen.cc | 2 +-
src/relax/backend/contrib/hipblas/codegen.cc | 2 +-
src/relax/backend/contrib/nnapi/codegen.cc | 2 +-
src/relax/backend/vm/exec_builder.cc | 23 +-
src/relax/transform/bind_params.cc | 8 +-
src/relax/transform/bind_symbolic_vars.cc | 13 +-
src/runtime/minrpc/rpc_reference.h | 3 +
src/runtime/profiling.cc | 4 +-
src/runtime/rpc/rpc_module.cc | 11 +-
src/script/ir_builder/tir/ir.cc | 7 +-
.../printer/doc_printer/python_doc_printer.cc | 2 +-
src/support/ffi_testing.cc | 2 +-
src/support/utils.h | 11 +-
src/target/llvm/codegen_cpu.cc | 5 +
src/target/source/codegen_c.cc | 6 +
src/target/source/codegen_c_host.cc | 2 +
src/tir/ir/stmt.cc | 4 +-
src/tir/schedule/concrete_schedule.cc | 3 +-
src/tir/schedule/instruction.cc | 2 +-
src/tir/schedule/trace.cc | 14 +-
src/tir/transforms/lower_tvm_builtin.cc | 5 +
src/tir/transforms/make_packed_api.cc | 7 +-
.../test_tir_transform_lower_tvm_builtin.py | 12 +-
.../test_tir_transform_make_packed_api.py | 19 +-
web/src/ctypes.ts | 4 +
web/src/memory.ts | 53 +-
web/src/runtime.ts | 14 +-
web/tests/node/test_packed_func.js | 14 +-
72 files changed, 1138 insertions(+), 362 deletions(-)
diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index d94185c064..55eff8802a 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -60,6 +60,7 @@ class AnyView {
void reset() {
data_.type_index = TypeIndex::kTVMFFINone;
// invariance: always set the union padding part to 0
+ data_.zero_padding = 0;
data_.v_int64 = 0;
}
/*!
@@ -72,6 +73,7 @@ class AnyView {
// default constructors
AnyView() {
data_.type_index = TypeIndex::kTVMFFINone;
+ data_.zero_padding = 0;
data_.v_int64 = 0;
}
~AnyView() = default;
@@ -80,6 +82,7 @@ class AnyView {
AnyView& operator=(const AnyView&) = default;
AnyView(AnyView&& other) : data_(other.data_) {
other.data_.type_index = TypeIndex::kTVMFFINone;
+ other.data_.zero_padding = 0;
other.data_.v_int64 = 0;
}
TVM_FFI_INLINE AnyView& operator=(AnyView&& other) {
@@ -198,13 +201,11 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny*
data,
if (data->type_index == TypeIndex::kTVMFFIRawStr) {
// convert raw string to owned string object
String temp(data->v_c_str);
- data->type_index = TypeIndex::kTVMFFIStr;
- data->v_obj =
details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp));
+ TypeTraits<String>::MoveToAny(std::move(temp), data);
} else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) {
// convert byte array to owned bytes object
Bytes temp(*static_cast<TVMFFIByteArray*>(data->v_ptr));
- data->type_index = TypeIndex::kTVMFFIBytes;
- data->v_obj =
details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp));
+ TypeTraits<Bytes>::MoveToAny(std::move(temp), data);
} else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) {
// convert rvalue ref to owned object
Object** obj_addr = static_cast<Object**>(data->v_ptr);
@@ -212,8 +213,7 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny*
data,
ObjectRef
temp(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(obj_addr[0]));
// set the rvalue ref to nullptr to avoid double move
obj_addr[0] = nullptr;
- data->type_index = temp->type_index();
- data->v_obj =
details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp));
+ TypeTraits<ObjectRef>::MoveToAny(std::move(temp), data);
}
}
}
@@ -239,6 +239,7 @@ class Any {
details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj);
}
data_.type_index = TVMFFITypeIndex::kTVMFFINone;
+ data_.zero_padding = 0;
data_.v_int64 = 0;
}
/*!
@@ -251,6 +252,7 @@ class Any {
// default constructors
Any() {
data_.type_index = TypeIndex::kTVMFFINone;
+ data_.zero_padding = 0;
data_.v_int64 = 0;
}
~Any() { this->reset(); }
@@ -262,6 +264,7 @@ class Any {
}
Any(Any&& other) : data_(other.data_) {
other.data_.type_index = TypeIndex::kTVMFFINone;
+ other.data_.zero_padding = 0;
other.data_.v_int64 = 0;
}
TVM_FFI_INLINE Any& operator=(const Any& other) {
@@ -408,7 +411,8 @@ class Any {
* \return True if the two Any are same type and value, false otherwise.
*/
TVM_FFI_INLINE bool same_as(const Any& other) const noexcept {
- return data_.type_index == other.data_.type_index && data_.v_int64 ==
other.data_.v_int64;
+ return data_.type_index == other.data_.type_index &&
+ data_.zero_padding == other.data_.zero_padding && data_.v_int64 ==
other.data_.v_int64;
}
/*
@@ -485,6 +489,7 @@ struct AnyUnsafe : public ObjectUnsafe {
TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any&& any) {
TVMFFIAny result = any.data_;
any.data_.type_index = TypeIndex::kTVMFFINone;
+ any.data_.zero_padding = 0;
any.data_.v_int64 = 0;
return result;
}
@@ -493,6 +498,7 @@ struct AnyUnsafe : public ObjectUnsafe {
Any any;
any.data_ = data;
data.type_index = TypeIndex::kTVMFFINone;
+ data.zero_padding = 0;
data.v_int64 = 0;
return any;
}
@@ -543,17 +549,24 @@ struct AnyHash {
* \return Hash code of a, string hash for strings and pointer address
otherwise.
*/
uint64_t operator()(const Any& src) const {
- uint64_t val_hash = [&]() -> uint64_t {
- if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
- src.data_.type_index == TypeIndex::kTVMFFIBytes) {
- const details::BytesObjBase* src_str =
- details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(src);
- return details::StableHashBytes(src_str->data, src_str->size);
- } else {
- return src.data_.v_uint64;
- }
- }();
- return details::StableHashCombine(src.data_.type_index, val_hash);
+ if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) {
+ // for small string, we use the same type key hash as normal string
+ // so heap allocated string and on stack string will have the same hash
+ return details::StableHashCombine(TypeIndex::kTVMFFIStr,
+
details::StableHashSmallStrBytes(&src.data_));
+ } else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) {
+ // use byte the same type key as bytes
+ return details::StableHashCombine(TypeIndex::kTVMFFIBytes,
+
details::StableHashSmallStrBytes(&src.data_));
+ } else if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
+ src.data_.type_index == TypeIndex::kTVMFFIBytes) {
+ const details::BytesObjBase* src_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(src);
+ return details::StableHashCombine(src.data_.type_index,
+
details::StableHashBytes(src_str->data, src_str->size));
+ } else {
+ return details::StableHashCombine(src.data_.type_index,
src.data_.v_uint64);
+ }
}
};
@@ -566,19 +579,60 @@ struct AnyEqual {
* \return String equality if both are strings, pointer address equality
otherwise.
*/
bool operator()(const Any& lhs, const Any& rhs) const {
- if (lhs.data_.type_index != rhs.data_.type_index) return false;
- // byte equivalence
- if (lhs.data_.v_int64 == rhs.data_.v_int64) return true;
- // specialy handle string hash
- if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
- lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
- const details::BytesObjBase* lhs_str =
- details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
- const details::BytesObjBase* rhs_str =
- details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
- return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size,
rhs_str->size);
+ // header with type index
+ const int64_t* lhs_as_int64 = reinterpret_cast<const int64_t*>(&lhs.data_);
+ const int64_t* rhs_as_int64 = reinterpret_cast<const int64_t*>(&rhs.data_);
+ static_assert(sizeof(TVMFFIAny) == 16);
+ // fast path, check byte equality
+ if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] ==
rhs_as_int64[1]) {
+ return true;
+ }
+ // common false case type index match, in this case we only need to pay
attention to string
+ // equality
+ if (lhs.data_.type_index == rhs.data_.type_index) {
+ // specialy handle string hash
+ if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
+ lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
+ const details::BytesObjBase* lhs_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
+ const details::BytesObjBase* rhs_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
+ return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size,
rhs_str->size);
+ }
+ return false;
+ } else {
+ // type_index mismatch, if index is not string, return false
+ if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index !=
kTVMFFISmallStr &&
+ lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index !=
kTVMFFIBytes) {
+ return false;
+ }
+ // small string and normal string comparison
+ if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index ==
kTVMFFISmallStr) {
+ const details::BytesObjBase* lhs_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
+ return Bytes::memequal(lhs_str->data, rhs.data_.v_bytes, lhs_str->size,
+ rhs.data_.small_str_len);
+ }
+ if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index ==
kTVMFFIStr) {
+ const details::BytesObjBase* rhs_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
+ return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data,
lhs.data_.small_str_len,
+ rhs_str->size);
+ }
+ if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index ==
kTVMFFISmallBytes) {
+ const details::BytesObjBase* lhs_bytes =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
+ return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes,
lhs_bytes->size,
+ rhs.data_.small_str_len);
+ }
+ if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index ==
kTVMFFIBytes) {
+ const details::BytesObjBase* rhs_bytes =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
+ return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data,
lhs.data_.small_str_len,
+ rhs_bytes->size);
+ }
+ return false;
}
- return false;
}
};
diff --git a/ffi/include/tvm/ffi/base_details.h
b/ffi/include/tvm/ffi/base_details.h
index cfdadff6ea..7c96b091d7 100644
--- a/ffi/include/tvm/ffi/base_details.h
+++ b/ffi/include/tvm/ffi/base_details.h
@@ -170,7 +170,8 @@ TVM_FFI_INLINE uint64_t StableHashCombine(uint64_t key,
const T& value) {
* \param size The size of the bytes.
* \return the hash value.
*/
-TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
+TVM_FFI_INLINE uint64_t StableHashBytes(const void* data_ptr, size_t size) {
+ const char* data = reinterpret_cast<const char*>(data_ptr);
const constexpr uint64_t kMultiplier = 1099511628211ULL;
const constexpr uint64_t kMod = 2147483647ULL;
union Union {
@@ -250,6 +251,20 @@ TVM_FFI_INLINE uint64_t StableHashBytes(const char* data,
size_t size) {
return result;
}
+/*!
+ * \brief Same as StableHashBytes, but for small string data.
+ * \param data The data pointer
+ * \return the hash value.
+ */
+TVM_FFI_INLINE uint64_t StableHashSmallStrBytes(const TVMFFIAny* data) {
+ if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) {
+ // fast path, no endian swap, simply hash as uint64_t
+ const constexpr uint64_t kMod = 2147483647ULL;
+ return data->v_uint64 % kMod;
+ }
+ return StableHashBytes(reinterpret_cast<const void*>(data),
sizeof(data->v_uint64));
+}
+
} // namespace details
} // namespace ffi
} // namespace tvm
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index d99832af01..11080a21f0 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -65,13 +65,7 @@ enum TVMFFITypeIndex : int32_t {
#else
typedef enum {
#endif
- // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin)
- // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array,
- // which is not owned by TVMFFIAny. It is required that the following
- // invariant holds:
- // - `Any::type_index` is never `kTVMFFIRawStr`
- // - `AnyView::type_index` can be `kTVMFFIRawStr`
- //
+
/*
* \brief The root type of all FFI objects.
*
@@ -80,6 +74,13 @@ typedef enum {
* However, it may appear in field annotations during reflection.
*/
kTVMFFIAny = -1,
+ // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin)
+ // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array,
+ // which is not owned by TVMFFIAny. It is required that the following
+ // invariant holds:
+ // - `Any::type_index` is never `kTVMFFIRawStr`
+ // - `AnyView::type_index` can be `kTVMFFIRawStr`
+ //
/*! \brief None/nullptr value */
kTVMFFINone = 0,
/*! \brief POD int value */
@@ -96,12 +97,16 @@ typedef enum {
kTVMFFIDevice = 6,
/*! \brief DLTensor* */
kTVMFFIDLTensorPtr = 7,
- /*! \brief const char**/
+ /*! \brief const char* */
kTVMFFIRawStr = 8,
/*! \brief TVMFFIByteArray* */
kTVMFFIByteArrayPtr = 9,
/*! \brief R-value reference to ObjectRef */
kTVMFFIObjectRValueRef = 10,
+ /*! \brief Small string on stack */
+ kTVMFFISmallStr = 11,
+ /*! \brief Small bytes on stack */
+ kTVMFFISmallBytes = 12,
/*! \brief Start of statically defined objects. */
kTVMFFIStaticObjectBegin = 64,
/*!
@@ -183,11 +188,17 @@ typedef struct TVMFFIAny {
* \note The type index of Object and Any are shared in FFI.
*/
int32_t type_index;
- /*!
- * \brief length for on-stack Any object, such as small-string
- * \note This field is reserved for future compact.
- */
- int32_t small_len;
+ union { // 4 bytes
+ /*! \brief padding, must set to zero for values other than small string. */
+ uint32_t zero_padding;
+ /*!
+ * \brief Length of small string, with a max value of 7.
+ *
+ * We keep small str to start at next 4 bytes to ensure alignment
+ * when accessing the small str content.
+ */
+ uint32_t small_str_len;
+ };
union { // 8 bytes
int64_t v_int64; // integers
double v_float64; // floating-point numbers
@@ -823,7 +834,7 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const
TVMFFIByteArray* str, DLDataType*
* \note The input dtype is a pointer to the DLDataType to avoid ABI
compatibility issues.
*/
-TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype,
TVMFFIObjectHandle* out);
+TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny*
out);
//------------------------------------------------------------
// Section: Backend noexcept functions for internal use
@@ -903,6 +914,15 @@ inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle
obj) {
return static_cast<TVMFFIObject*>(obj)->type_index;
}
+/*!
+ * \brief Get the content of a small string in bytearray format.
+ * \param obj The object handle.
+ * \return The content of the small string in bytearray format.
+ */
+inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny*
value) {
+ return TVMFFIByteArray{value->v_bytes,
static_cast<size_t>(value->small_str_len)};
+}
+
/*!
* \brief Get the data pointer of a bytearray from a string or bytes object.
* \param obj The object handle.
diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h
index 9cac1f99a8..997c0bb178 100644
--- a/ffi/include/tvm/ffi/cast.h
+++ b/ffi/include/tvm/ffi/cast.h
@@ -27,6 +27,7 @@
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/object.h>
+#include <tvm/ffi/optional.h>
#include <utility>
diff --git a/ffi/include/tvm/ffi/container/variant.h
b/ffi/include/tvm/ffi/container/variant.h
index a16ff5d425..ee1f8316d8 100644
--- a/ffi/include/tvm/ffi/container/variant.h
+++ b/ffi/include/tvm/ffi/container/variant.h
@@ -80,10 +80,12 @@ class VariantBase<true> : public ObjectRef {
TVMFFIAny any_data;
if (data_ == nullptr) {
any_data.type_index = TypeIndex::kTVMFFINone;
+ any_data.zero_padding = 0;
any_data.v_int64 = 0;
} else {
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
any_data.type_index = data_->type_index();
+ any_data.zero_padding = 0;
any_data.v_obj =
details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_);
}
return AnyView::CopyFromTVMFFIAny(any_data);
diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h
index 2eafccd2db..c153d71cb7 100644
--- a/ffi/include/tvm/ffi/dtype.h
+++ b/ffi/include/tvm/ffi/dtype.h
@@ -115,14 +115,15 @@ inline const char* DLDataTypeCodeAsCStr(DLDataTypeCode
type_code) { // NOLINT(*
inline DLDataType StringToDLDataType(const String& str) {
DLDataType out;
- TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(str.get(), &out));
+ TVMFFIByteArray data{str.data(), str.size()};
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(&data, &out));
return out;
}
inline String DLDataTypeToString(DLDataType dtype) {
- TVMFFIObjectHandle out;
+ TVMFFIAny out;
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out));
- return
String(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(out)));
+ return TypeTraits<String>::MoveFromAnyAfterCheck(&out);
}
// DLDataType
@@ -134,6 +135,7 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase {
// clear padding part to ensure the equality check can always check the
v_uint64 part
result->v_uint64 = 0;
result->type_index = TypeIndex::kTVMFFIDataType;
+ result->zero_padding = 0;
result->v_dtype = src;
}
@@ -141,6 +143,7 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase {
// clear padding part to ensure the equality check can always check the
v_uint64 part
result->v_uint64 = 0;
result->type_index = TypeIndex::kTVMFFIDataType;
+ result->zero_padding = 0;
result->v_dtype = src;
}
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index a49a9f1700..4b7b56209a 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -60,6 +60,8 @@ struct StaticTypeKey {
static constexpr const char* kTVMFFIFunction = "ffi.Function";
static constexpr const char* kTVMFFIArray = "ffi.Array";
static constexpr const char* kTVMFFIMap = "ffi.Map";
+ static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr";
+ static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes";
};
/*!
diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h
index 003038b9fd..a52f64e483 100644
--- a/ffi/include/tvm/ffi/optional.h
+++ b/ffi/include/tvm/ffi/optional.h
@@ -27,6 +27,7 @@
#include <tvm/ffi/error.h>
#include <tvm/ffi/object.h>
+#include <tvm/ffi/string.h>
#include <optional>
#include <string>
@@ -53,7 +54,8 @@ inline constexpr bool use_ptr_based_optional_v =
// Specialization for non-ObjectRef types.
// simply fallback to std::optional
template <typename T>
-class Optional<T, std::enable_if_t<!use_ptr_based_optional_v<T>>> {
+class Optional<T, std::enable_if_t<!use_ptr_based_optional_v<T> &&
!std::is_same_v<T, String> &&
+ !std::is_same_v<T, Bytes>>> {
public:
// default constructors.
Optional() = default;
@@ -138,6 +140,118 @@ class Optional<T,
std::enable_if_t<!use_ptr_based_optional_v<T>>> {
std::optional<T> data_;
};
+// Specialization for String type, use nullptr to indicate nullopt
+template <typename T>
+class Optional<T, std::enable_if_t<std::is_same_v<T, String> ||
std::is_same_v<T, Bytes>>> {
+ public:
+ // default constructors.
+ Optional() = default;
+ Optional(const Optional<T>& other) : data_(other.data_) {}
+ Optional(Optional<T>&& other) : data_(std::move(other.data_)) {}
+ Optional(std::nullopt_t) {} // NOLINT(*)
+ // normal value handling.
+ Optional(T other) // NOLINT(*)
+ : data_(std::move(other)) {}
+
+ TVM_FFI_INLINE Optional<T>& operator=(const Optional<T>& other) {
+ data_ = other.data_;
+ return *this;
+ }
+
+ TVM_FFI_INLINE Optional<T>& operator=(Optional<T>&& other) {
+ data_ = std::move(other.data_);
+ return *this;
+ }
+
+ TVM_FFI_INLINE Optional<T>& operator=(T other) {
+ data_ = std::move(other);
+ return *this;
+ }
+
+ TVM_FFI_INLINE Optional<T>& operator=(std::nullopt_t) {
+ T(details::BytesBaseCell(std::nullopt)).swap(data_);
+ return *this;
+ }
+
+ TVM_FFI_INLINE const T& value() const& {
+ if (data_.data_ == std::nullopt) {
+ TVM_FFI_THROW(RuntimeError) << "Back optional access";
+ }
+ return data_;
+ }
+
+ TVM_FFI_INLINE String&& value() && {
+ if (data_.data_ == std::nullopt) {
+ TVM_FFI_THROW(RuntimeError) << "Back optional access";
+ }
+ return std::move(data_);
+ }
+
+ template <typename U = T>
+ TVM_FFI_INLINE T value_or(U&& default_value) const {
+ if (data_.data_ == std::nullopt) {
+ return std::forward<U>(default_value);
+ }
+ return data_;
+ }
+
+ TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.data_
!= std::nullopt; }
+
+ TVM_FFI_INLINE bool has_value() const noexcept { return data_.data_ !=
std::nullopt; }
+
+ TVM_FFI_INLINE bool operator==(const Optional<T>& other) const {
+ if (data_.data_ == std::nullopt) {
+ return other.data_.data_ == std::nullopt;
+ }
+ if (other.data_.data_ == std::nullopt) {
+ return false;
+ }
+ return data_ == other.data_;
+ }
+
+ TVM_FFI_INLINE bool operator!=(const Optional<T>& other) const { return
!(*this == other); }
+
+ template <typename U>
+ TVM_FFI_INLINE bool operator==(const U& other) const {
+ if constexpr (std::is_same_v<U, std::nullopt_t>) {
+ return data_.data_ == std::nullopt;
+ } else {
+ if (data_.data_ == std::nullopt) {
+ return false;
+ }
+ return data_ == other;
+ }
+ }
+ template <typename U>
+ TVM_FFI_INLINE bool operator!=(const U& other) const {
+ if constexpr (std::is_same_v<U, std::nullopt_t>) {
+ return data_.data_ != std::nullopt;
+ } else {
+ if (data_.data_ == std::nullopt) {
+ return true;
+ }
+ return data_ != other;
+ }
+ }
+
+ /*!
+ * \brief Direct access to the value.
+ * \return the xvalue reference to the stored value.
+ * \note only use this function after checking has_value()
+ */
+ TVM_FFI_INLINE T&& operator*() && noexcept { return std::move(data_); }
+ /*!
+ * \brief Direct access to the value.
+ * \return the const reference to the stored value.
+ * \note only use this function after checking has_value()
+ */
+ TVM_FFI_INLINE const T& operator*() const& noexcept { return data_; }
+
+ private:
+ // this is a private initializer
+ T data_{details::BytesBaseCell(std::nullopt)};
+};
+
// Specialization for ObjectRef types.
// nullptr is treated as std::nullopt.
template <typename T>
diff --git a/ffi/include/tvm/ffi/reflection/accessor.h
b/ffi/include/tvm/ffi/reflection/accessor.h
index 40adfa3499..5215444052 100644
--- a/ffi/include/tvm/ffi/reflection/accessor.h
+++ b/ffi/include/tvm/ffi/reflection/accessor.h
@@ -48,7 +48,7 @@ inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view
type_key, const char
return &(info->fields[i]);
}
}
- TVM_FFI_THROW(RuntimeError) << "Cannot find field " << field_name << " in "
<< type_key;
+ TVM_FFI_THROW(RuntimeError) << "Cannot find field `" << field_name << "` in
" << type_key;
TVM_FFI_UNREACHABLE();
}
diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h
index b185e8d941..7c89038cc2 100644
--- a/ffi/include/tvm/ffi/rvalue_ref.h
+++ b/ffi/include/tvm/ffi/rvalue_ref.h
@@ -94,6 +94,7 @@ struct TypeTraits<RValueRef<TObjRef>> : public TypeTraitsBase
{
TVM_FFI_INLINE static void CopyToAnyView(const RValueRef<TObjRef>& src,
TVMFFIAny* result) {
result->type_index = TypeIndex::kTVMFFIObjectRValueRef;
+ result->zero_padding = 0;
// store the address of the ObjectPtr, which allows us to move the value
// and set the original ObjectPtr to nullptr
result->v_ptr = &(src.data_);
@@ -106,7 +107,7 @@ struct TypeTraits<RValueRef<TObjRef>> : public
TypeTraitsBase {
// in this case we do not move the original rvalue ref since conversion
creates a copy
TVMFFIAny tmp_any;
tmp_any.type_index = rvalue_ref->get()->type_index();
-
+ tmp_any.zero_padding = 0;
tmp_any.v_obj = reinterpret_cast<TVMFFIObject*>(rvalue_ref->get());
return "RValueRef<" + TypeTraits<TObjRef>::GetMismatchTypeInfo(&tmp_any)
+ ">";
} else {
@@ -120,6 +121,7 @@ struct TypeTraits<RValueRef<TObjRef>> : public
TypeTraitsBase {
ObjectPtr<Object>* rvalue_ref =
reinterpret_cast<ObjectPtr<Object>*>(src->v_ptr);
TVMFFIAny tmp_any;
tmp_any.type_index = rvalue_ref->get()->type_index();
+ tmp_any.zero_padding = 0;
tmp_any.v_obj = reinterpret_cast<TVMFFIObject*>(rvalue_ref->get());
// fast path, storage type matches, direct move the rvalue ref
if (TypeTraits<TObjRef>::CheckAnyStrict(&tmp_any)) {
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index 481b704436..fe84b61547 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -47,7 +47,9 @@
namespace tvm {
namespace ffi {
namespace details {
-/*! \brief Base class for bytes and string. */
+/*!
+ * \brief Base class for bytes and string objects.
+ */
class BytesObjBase : public Object, public TVMFFIByteArray {};
/*!
@@ -87,47 +89,201 @@ class BytesObjStdImpl : public Base {
std::string data_;
};
-// inplace string allocation
-template <typename Base>
-TVM_FFI_INLINE ObjectPtr<Base> MakeInplaceBytes(const char* data, size_t
length) {
- ObjectPtr<Base> p = make_inplace_array_object<Base, char>(length + 1);
- static_assert(alignof(Base) % alignof(char) == 0);
- static_assert(sizeof(Base) % alignof(char) == 0);
- char* dest_data = reinterpret_cast<char*>(p.get()) + sizeof(Base);
- p->data = dest_data;
- p->size = length;
- std::memcpy(dest_data, data, length);
- dest_data[length] = '\0';
- return p;
-}
+/*!
+ * \brief Helper cell class that can be used to back small string
+ * \note Do not use directly, use String or Bytes instead
+ */
+class BytesBaseCell {
+ public:
+ BytesBaseCell() {
+ // initialize to none
+ data_.type_index = TypeIndex::kTVMFFINone;
+ data_.zero_padding = 0;
+ data_.v_int64 = 0;
+ }
+
+ explicit BytesBaseCell(std::nullopt_t) {
+ data_.type_index = TypeIndex::kTVMFFINone;
+ data_.zero_padding = 0;
+ data_.v_int64 = 0;
+ }
+
+ BytesBaseCell(const BytesBaseCell& other) : data_(other.data_) { //
NOLINT(*)
+ if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj);
+ }
+ }
+
+ BytesBaseCell(BytesBaseCell&& other) : data_(other.data_) { // NOLINT(*)
+ other.data_.type_index = TypeIndex::kTVMFFINone;
+ }
+
+ BytesBaseCell& operator=(const BytesBaseCell& other) {
+ BytesBaseCell(other).swap(*this); // NOLINT(*)
+ return *this;
+ }
+
+ BytesBaseCell& operator=(BytesBaseCell&& other) {
+ BytesBaseCell(std::move(other)).swap(*this); // NOLINT(*)
+ return *this;
+ }
+
+ ~BytesBaseCell() {
+ if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj);
+ }
+ }
+
+ /*!
+ * \brief Check if the cell is null
+ * \return true if the cell is null, false otherwise
+ */
+ bool operator==(std::nullopt_t) const { return data_.type_index ==
TypeIndex::kTVMFFINone; }
+
+ /*!
+ * \brief Check if the cell is not null
+ * \return true if the cell is not null, false otherwise
+ */
+ bool operator!=(std::nullopt_t) const { return data_.type_index !=
TypeIndex::kTVMFFINone; }
+
+ /*!
+ * \brief Swap this String with another string
+ * \param other The other string
+ */
+ void swap(BytesBaseCell& other) { // NOLINT(*)
+ std::swap(data_, other.data_);
+ }
+
+ const char* data() const noexcept {
+ if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
+ return data_.v_bytes;
+ } else {
+ return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->data;
+ }
+ }
+
+ size_t size() const noexcept {
+ if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
+ return data_.small_str_len;
+ } else {
+ return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->size;
+ }
+ }
+
+ template <typename LargeObj>
+ void InitFromStd(std::string&& other, int32_t large_type_index) {
+ // needs to be reset to none first for exception safety
+ data_.type_index = TypeIndex::kTVMFFINone;
+ data_.zero_padding = 0;
+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_);
+ ObjectPtr<LargeObj> ptr =
make_object<BytesObjStdImpl<LargeObj>>(std::move(other));
+ data_.v_obj =
details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr));
+ data_.type_index = large_type_index;
+ }
+
+ /*!
+ * \brief Create a new empty space for a string
+ * \param size The size of the string
+ * \param small_type_index The type index for the small string
+ * \param large_type_index The type index for the large string
+ * \note always reserve one byte for \0 compactibility
+ * \return A pointer to the empty space
+ */
+ template <typename LargeObj>
+ char* InitSpaceForSize(size_t size, int32_t small_type_index, int32_t
large_type_index) {
+ size_t kMaxSmallBytesLen = sizeof(int64_t) - 1;
+ // first zero the content, this is important for exception safety
+ data_.type_index = small_type_index;
+ data_.zero_padding = 0;
+ if (size <= kMaxSmallBytesLen) {
+ // set up the size accordingly
+ data_.small_str_len = static_cast<uint32_t>(size);
+ return data_.v_bytes;
+ } else {
+ // allocate from heap
+ ObjectPtr<LargeObj> ptr = make_inplace_array_object<LargeObj, char>(size
+ 1);
+ char* dest_data = reinterpret_cast<char*>(ptr.get()) + sizeof(LargeObj);
+ ptr->data = dest_data;
+ ptr->size = size;
+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_);
+ data_.v_obj =
details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr));
+ // now reset the type index to str
+ data_.type_index = large_type_index;
+ return dest_data;
+ }
+ }
+
+ void InitTypeIndex(int32_t type_index) { data_.type_index = type_index; }
+
+ void MoveToAny(TVMFFIAny* result) {
+ *result = data_;
+ data_.type_index = TypeIndex::kTVMFFINone;
+ data_.zero_padding = 0;
+ data_.v_int64 = 0;
+ }
+
+ TVMFFIAny CopyToTVMFFIAny() const { return data_; }
+
+ static BytesBaseCell CopyFromAnyView(const TVMFFIAny* src) {
+ BytesBaseCell result(*src);
+ if (result.data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ details::ObjectUnsafe::IncRefObjectHandle(result.data_.v_obj);
+ }
+ return result;
+ }
+
+ static BytesBaseCell MoveFromAny(TVMFFIAny* src) {
+ BytesBaseCell result(*src);
+ src->type_index = TypeIndex::kTVMFFINone;
+ src->zero_padding = 0;
+ src->v_int64 = 0;
+ return result;
+ }
+
+ private:
+ explicit BytesBaseCell(TVMFFIAny data) : data_(data) {}
+ /*! \brief internal backing data */
+ TVMFFIAny data_;
+};
} // namespace details
/*!
* \brief Managed reference of byte array.
*/
-class Bytes : public ObjectRef {
+class Bytes {
public:
+ /*! \brief default constructor */
+ Bytes() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallBytes); }
/*!
- * \brief constructor from char [N]
+ * \brief constructor from size
+ *
+ * \param other a char array.
+ */
+ Bytes(const char* data, size_t size) { this->InitData(data, size); }
+ /*!
+ * \brief constructor from TVMFFIByteArray
*
* \param other a char array.
*/
- Bytes(const char* data, size_t size) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(data, size)) {}
+ Bytes(TVMFFIByteArray bytes) { // NOLINT(*)
+ this->InitData(bytes.data, bytes.size);
+ }
/*!
- * \brief constructor from char [N]
+ * \brief constructor from std::string
*
* \param other a char array.
*/
- Bytes(TVMFFIByteArray bytes) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(bytes.data,
bytes.size)) {}
+ Bytes(const std::string& other) { // NOLINT(*)
+ this->InitData(other.data(), other.size());
+ }
/*!
- * \brief constructor from char [N]
+ * \brief constructor from std::string
*
* \param other a char array.
*/
- Bytes(std::string other) // NOLINT(*)
- :
ObjectRef(make_object<details::BytesObjStdImpl<details::BytesObj>>(std::move(other)))
{}
+ Bytes(std::string&& other) { // NOLINT(*)
+ data_.InitFromStd<details::BytesObj>(std::move(other),
TypeIndex::kTVMFFIBytes);
+ }
/*!
* \brief Swap this String with another string
* \param other The other string
@@ -147,21 +303,19 @@ class Bytes : public ObjectRef {
*
* \return size_t string length
*/
- size_t size() const { return get()->size; }
+ size_t size() const { return data_.size(); }
/*!
* \brief Return the data pointer
*
* \return const char* data pointer
*/
- const char* data() const { return get()->data; }
+ const char* data() const { return data_.data(); }
/*!
* \brief Convert String to an std::string object
*
* \return std::string
*/
- operator std::string() const { return std::string{get()->data, size()}; }
-
- TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef,
details::BytesObj);
+ operator std::string() const { return std::string{data(), size()}; }
/*!
* \brief Compare two char sequence
@@ -198,110 +352,134 @@ class Bytes : public ObjectRef {
*
* \return true if the two char sequences are equal, false otherwise.
*/
- static bool memequal(const char* lhs, const char* rhs, size_t lhs_count,
size_t rhs_count) {
+ static bool memequal(const void* lhs, const void* rhs, size_t lhs_count,
size_t rhs_count) {
return lhs_count == rhs_count && (lhs == rhs || std::memcmp(lhs, rhs,
lhs_count) == 0);
}
private:
- friend class String;
+ template <typename, typename>
+ friend struct TypeTraits;
+ template <typename, typename>
+ friend class Optional;
+ // internal backing cell
+ details::BytesBaseCell data_;
+ // create a new String from TVMFFIAny, must keep private
+ explicit Bytes(details::BytesBaseCell data) : data_(data) {}
+ char* InitSpaceForSize(size_t size) {
+ return data_.InitSpaceForSize<details::BytesObj>(size,
TypeIndex::kTVMFFISmallBytes,
+ TypeIndex::kTVMFFIBytes);
+ }
+ void InitData(const char* data, size_t size) {
+ char* dest_data = InitSpaceForSize(size);
+ std::memcpy(dest_data, data, size);
+ // mainly to be compat with string
+ dest_data[size] = '\0';
+ }
};
/*!
- * \brief Reference to string objects.
- *
- * \code
- *
- * // Example to create runtime String reference object from std::string
- * std::string s = "hello world";
- *
- * // You can create the reference from existing std::string
- * String ref{std::move(s)};
- *
- * // You can rebind the reference to another string.
- * ref = std::string{"hello world2"};
- *
- * // You can use the reference as hash map key
- * std::unordered_map<String, int32_t> m;
- * m[ref] = 1;
- *
- * // You can compare the reference object with other string objects
- * assert(ref == "hello world", true);
- *
- * // You can convert the reference to std::string again
- * string s2 = (string)ref;
- *
- * \endcode
+ * \brief String container class.
*/
-class String : public ObjectRef {
+class String {
public:
+ /*!
+ * \brief avoid misuse of nullptr
+ */
String(std::nullptr_t) = delete; // NOLINT(*)
-
/*!
- * \brief constructor from char [N]
- *
- * \param other a char array.
+ * \brief constructor
*/
- template <size_t N>
- String(const char other[N]) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, N)) {}
+ String() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallStr); }
+ // constructors from Any
+ String(const String& other) = default; // NOLINT(*)
+ String(String&& other) = default; // NOLINT(*)
+ String& operator=(const String& other) = default; // NOLINT(*)
+ String& operator=(String&& other) = default; // NOLINT(*)
/*!
- * \brief constructor
+ * \brief Swap this String with another string
+ * \param other The other string
*/
- String() : String("") {}
+ void swap(String& other) noexcept { // NOLINT(*)
+ std::swap(data_, other.data_);
+ }
+
+ String& operator=(const std::string& other) {
+ String(other).swap(*this); // NOLINT(*)
+ return *this;
+ }
+ String& operator=(std::string&& other) {
+ String(std::move(other)).swap(*this); // NOLINT(*)
+ return *this;
+ }
+
+ String& operator=(const char* other) {
+ String(other).swap(*this); // NOLINT(*)
+ return *this;
+ }
/*!
* \brief constructor from raw string
*
* \param other a char array.
*/
- String(const char* other) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other,
std::strlen(other))) {}
+ String(const char* other, size_t size) { this->InitData(other, size); }
/*!
* \brief constructor from raw string
*
* \param other a char array.
+ * \note This constructor is marked as explicit to avoid implicit conversion
+ * of nullptr value here to string, which then was used in comparison
*/
- String(const char* other, size_t size) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, size))
{}
-
+ String(const char* other) { // NOLINT(*)
+ this->InitData(other, std::char_traits<char>::length(other));
+ }
/*!
* \brief Construct a new string object
* \param other The std::string object to be copied
*/
- String(const std::string& other) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data(),
other.size())) {}
+ String(const std::string& other) { // NOLINT(*)
+ this->InitData(other.data(), other.size());
+ }
/*!
* \brief Construct a new string object
* \param other The std::string object to be moved
*/
- String(std::string&& other) // NOLINT(*)
- :
ObjectRef(make_object<details::BytesObjStdImpl<details::StringObj>>(std::move(other)))
{}
+ String(std::string&& other) { // NOLINT(*)
+ // exception safety, first set to none so if exception is thrown
+ // destructor works correctly
+ data_.InitFromStd<details::StringObj>(std::move(other),
TypeIndex::kTVMFFIStr);
+ }
/*!
* \brief constructor from TVMFFIByteArray
*
* \param other a TVMFFIByteArray.
*/
- explicit String(TVMFFIByteArray other)
- : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data,
other.size)) {}
+ explicit String(TVMFFIByteArray other) { this->InitData(other.data,
other.size); }
/*!
- * \brief Swap this String with another string
- * \param other The other string
+ * \brief Return the data pointer
+ *
+ * \return const char* data pointer
*/
- void swap(String& other) { // NOLINT(*)
- std::swap(data_, other.data_);
- }
+ const char* data() const noexcept { return data_.data(); }
- template <typename T>
- String& operator=(T&& other) {
- // copy-and-swap idiom
- String(std::forward<T>(other)).swap(*this); // NOLINT(*)
- return *this;
- }
+ /*!
+ * \brief Returns a pointer to the char array in the string.
+ *
+ * \return const char*
+ */
+ const char* c_str() const noexcept { return data(); }
+
+ /*!
+ * \brief Return the length of the string
+ *
+ * \return size_t string length
+ */
+ size_t size() const noexcept { return data_.size(); }
/*!
* \brief Compares this String object to other
@@ -362,23 +540,6 @@ class String : public ObjectRef {
return Bytes::memncmp(data(), other.data, size(), other.size);
}
- /*!
- * \brief Returns a pointer to the char array in the string.
- *
- * \return const char*
- */
- const char* c_str() const { return get()->data; }
-
- /*!
- * \brief Return the length of the string
- *
- * \return size_t string length
- */
- size_t size() const {
- const auto* ptr = get();
- return ptr->size;
- }
-
/*!
* \brief Return the length of the string
*
@@ -407,23 +568,36 @@ class String : public ObjectRef {
}
}
- /*!
- * \brief Return the data pointer
- *
- * \return const char* data pointer
- */
- const char* data() const { return get()->data; }
-
/*!
* \brief Convert String to an std::string object
*
* \return std::string
*/
- operator std::string() const { return std::string{get()->data, size()}; }
-
- TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef,
details::StringObj);
+ operator std::string() const { return std::string{data(), size()}; }
private:
+ template <typename, typename>
+ friend struct TypeTraits;
+ template <typename, typename>
+ friend class Optional;
+ // internal backing cell
+ details::BytesBaseCell data_;
+ // create a new String from TVMFFIAny, must keep private
+ explicit String(details::BytesBaseCell data) : data_(data) {}
+ /*!
+ * \brief Create a new empty space for a string
+ * \param size The size of the string
+ * \return A pointer to the empty space
+ */
+ char* InitSpaceForSize(size_t size) {
+ return data_.InitSpaceForSize<details::StringObj>(size,
TypeIndex::kTVMFFISmallStr,
+ TypeIndex::kTVMFFIStr);
+ }
+ void InitData(const char* data, size_t size) {
+ char* dest_data = InitSpaceForSize(size);
+ std::memcpy(dest_data, data, size);
+ dest_data[size] = '\0';
+ }
/*!
* \brief Concatenate two char sequences
*
@@ -435,11 +609,25 @@ class String : public ObjectRef {
* \return The concatenated char sequence
*/
static String Concat(const char* lhs, size_t lhs_size, const char* rhs,
size_t rhs_size) {
- std::string ret(lhs, lhs_size);
- ret.append(rhs, rhs_size);
- return String(ret);
+ String ret;
+ // disable stringop-overflow and restrict warnings
+ // gcc may produce false positive when we enable dest_data returned from
small string path
+ // Because compiler is not able to detect the condition that the path is
only triggered via
+ // size < kMaxSmallStrLen and can report it as a overflow case.
+#if (__GNUC__) && !(__clang__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wstringop-overflow"
+#pragma GCC diagnostic ignored "-Wrestrict"
+#endif
+ char* dest_data = ret.InitSpaceForSize(lhs_size + rhs_size);
+ std::memcpy(dest_data, lhs, lhs_size);
+ std::memcpy(dest_data + lhs_size, rhs, rhs_size);
+ dest_data[lhs_size + rhs_size] = '\0';
+#if (__GNUC__) && !(__clang__)
+#pragma GCC diagnostic pop
+#endif
+ return ret;
}
-
// Overload + operator
friend String operator+(const String& lhs, const String& rhs);
friend String operator+(const String& lhs, const std::string& rhs);
@@ -453,6 +641,93 @@ TVM_FFI_INLINE std::string_view
ToStringView(TVMFFIByteArray str) {
return std::string_view(str.data, str.size);
}
+template <>
+inline constexpr bool use_default_type_traits_v<Bytes> = false;
+
+// specialize to enable implicit conversion from TVMFFIByteArray*
+template <>
+struct TypeTraits<Bytes> : public TypeTraitsBase {
+ // bytes can be union type of small bytes and object, so keep it as any
+ static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny;
+
+ TVM_FFI_INLINE static void CopyToAnyView(const Bytes& src, TVMFFIAny*
result) {
+ *result = src.data_.CopyToTVMFFIAny();
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(Bytes src, TVMFFIAny* result) {
+ src.data_.MoveToAny(result);
+ }
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ return src->type_index == TypeIndex::kTVMFFISmallBytes ||
+ src->type_index == TypeIndex::kTVMFFIBytes;
+ }
+
+ TVM_FFI_INLINE static Bytes CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+ return Bytes(details::BytesBaseCell::CopyFromAnyView(src));
+ }
+
+ TVM_FFI_INLINE static Bytes MoveFromAnyAfterCheck(TVMFFIAny* src) {
+ return Bytes(details::BytesBaseCell::MoveFromAny(src));
+ }
+
+ TVM_FFI_INLINE static std::optional<Bytes> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) {
+ return Bytes(*static_cast<TVMFFIByteArray*>(src->v_ptr));
+ }
+ if (src->type_index == TypeIndex::kTVMFFISmallBytes ||
+ src->type_index == TypeIndex::kTVMFFIBytes) {
+ return Bytes(details::BytesBaseCell::CopyFromAnyView(src));
+ }
+ return std::nullopt;
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() { return "bytes"; }
+};
+
+template <>
+inline constexpr bool use_default_type_traits_v<String> = false;
+
+// specialize to enable implicit conversion from const char*
+template <>
+struct TypeTraits<String> : public TypeTraitsBase {
+ // string can be union type of small string and object, so keep it as any
+ static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny;
+
+ TVM_FFI_INLINE static void CopyToAnyView(const String& src, TVMFFIAny*
result) {
+ *result = src.data_.CopyToTVMFFIAny();
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(String src, TVMFFIAny* result) {
+ src.data_.MoveToAny(result);
+ }
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ return src->type_index == TypeIndex::kTVMFFISmallStr ||
+ src->type_index == TypeIndex::kTVMFFIStr;
+ }
+
+ TVM_FFI_INLINE static String CopyFromAnyViewAfterCheck(const TVMFFIAny* src)
{
+ return String(details::BytesBaseCell::CopyFromAnyView(src));
+ }
+
+ TVM_FFI_INLINE static String MoveFromAnyAfterCheck(TVMFFIAny* src) {
+ return String(details::BytesBaseCell::MoveFromAny(src));
+ }
+
+ TVM_FFI_INLINE static std::optional<String> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFIRawStr) {
+ return String(src->v_c_str);
+ }
+ if (src->type_index == TypeIndex::kTVMFFISmallStr || src->type_index ==
TypeIndex::kTVMFFIStr) {
+ return String(details::BytesBaseCell::CopyFromAnyView(src));
+ }
+ return std::nullopt;
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() { return "str"; }
+};
+
// const char*, requirement: not nullable, do not retain ownership
template <int N>
struct TypeTraits<char[N]> : public TypeTraitsBase {
@@ -461,12 +736,13 @@ struct TypeTraits<char[N]> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(const char src[N], TVMFFIAny*
result) {
result->type_index = TypeIndex::kTVMFFIRawStr;
+ result->zero_padding = 0;
result->v_c_str = src;
}
TVM_FFI_INLINE static void MoveToAny(const char src[N], TVMFFIAny* result) {
// when we need to move to any, convert to owned object first
- ObjectRefTypeTraitsBase<String>::MoveToAny(String(src), result);
+ TypeTraits<String>::MoveToAny(String(src), result);
}
};
@@ -477,12 +753,13 @@ struct TypeTraits<const char*> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(const char* src, TVMFFIAny* result)
{
TVM_FFI_ICHECK_NOTNULL(src);
result->type_index = TypeIndex::kTVMFFIRawStr;
+ result->zero_padding = 0;
result->v_c_str = src;
}
TVM_FFI_INLINE static void MoveToAny(const char* src, TVMFFIAny* result) {
// when we need to move to any, convert to owned object first
- ObjectRefTypeTraitsBase<String>::MoveToAny(String(src), result);
+ TypeTraits<String>::MoveToAny(String(src), result);
}
// Do not allow const char* in a container, so we do not need CheckAnyStrict
TVM_FFI_INLINE static std::optional<const char*> TryCastFromAnyView(const
TVMFFIAny* src) {
@@ -504,12 +781,13 @@ struct TypeTraits<TVMFFIByteArray*> : public
TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(TVMFFIByteArray* src, TVMFFIAny*
result) {
TVM_FFI_ICHECK_NOTNULL(src);
result->type_index = TypeIndex::kTVMFFIByteArrayPtr;
+ result->zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
result->v_ptr = src;
}
TVM_FFI_INLINE static void MoveToAny(TVMFFIByteArray* src, TVMFFIAny*
result) {
- ObjectRefTypeTraitsBase<Bytes>::MoveToAny(Bytes(*src), result);
+ TypeTraits<Bytes>::MoveToAny(Bytes(*src), result);
}
TVM_FFI_INLINE static std::optional<TVMFFIByteArray*>
TryCastFromAnyView(const TVMFFIAny* src) {
@@ -522,26 +800,6 @@ struct TypeTraits<TVMFFIByteArray*> : public
TypeTraitsBase {
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIByteArrayPtr; }
};
-template <>
-inline constexpr bool use_default_type_traits_v<Bytes> = false;
-
-// specialize to enable implicit conversion from TVMFFIByteArray*
-template <>
-struct TypeTraits<Bytes> : public ObjectRefWithFallbackTraitsBase<Bytes,
TVMFFIByteArray*> {
- static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBytes;
- TVM_FFI_INLINE static Bytes ConvertFallbackValue(TVMFFIByteArray* src) {
return Bytes(*src); }
-};
-
-template <>
-inline constexpr bool use_default_type_traits_v<String> = false;
-
-// specialize to enable implicit conversion from const char*
-template <>
-struct TypeTraits<String> : public ObjectRefWithFallbackTraitsBase<String,
const char*> {
- static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIStr;
- TVM_FFI_INLINE static String ConvertFallbackValue(const char* src) { return
String(src); }
-};
-
template <>
inline constexpr bool use_default_type_traits_v<std::string> = false;
@@ -550,12 +808,13 @@ struct TypeTraits<std::string>
: public FallbackOnlyTraitsBase<std::string, const char*,
TVMFFIByteArray*, Bytes, String> {
TVM_FFI_INLINE static void CopyToAnyView(const std::string& src, TVMFFIAny*
result) {
result->type_index = TypeIndex::kTVMFFIRawStr;
+ result->zero_padding = 0;
result->v_c_str = src.c_str();
}
TVM_FFI_INLINE static void MoveToAny(std::string src, TVMFFIAny* result) {
// when we need to move to any, convert to owned object first
- ObjectRefTypeTraitsBase<String>::MoveToAny(String(std::move(src)), result);
+ TypeTraits<String>::MoveToAny(String(std::move(src)), result);
}
TVM_FFI_INLINE static std::string TypeStr() { return "std::string"; }
@@ -608,6 +867,9 @@ inline String operator+(const String& lhs, const char* rhs)
{
}
// Overload < operator
+inline bool operator<(std::nullptr_t, const String& rhs) = delete;
+inline bool operator<(const String& lhs, std::nullptr_t) = delete;
+
inline bool operator<(const String& lhs, const std::string& rhs) { return
lhs.compare(rhs) < 0; }
inline bool operator<(const std::string& lhs, const String& rhs) { return
rhs.compare(lhs) > 0; }
@@ -619,6 +881,9 @@ inline bool operator<(const String& lhs, const char* rhs) {
return lhs.compare(r
inline bool operator<(const char* lhs, const String& rhs) { return
rhs.compare(lhs) > 0; }
// Overload > operator
+inline bool operator>(std::nullptr_t, const String& rhs) = delete;
+inline bool operator>(const String& lhs, std::nullptr_t) = delete;
+
inline bool operator>(const String& lhs, const std::string& rhs) { return
lhs.compare(rhs) > 0; }
inline bool operator>(const std::string& lhs, const String& rhs) { return
rhs.compare(lhs) < 0; }
@@ -630,6 +895,9 @@ inline bool operator>(const String& lhs, const char* rhs) {
return lhs.compare(r
inline bool operator>(const char* lhs, const String& rhs) { return
rhs.compare(lhs) < 0; }
// Overload <= operator
+inline bool operator<=(std::nullptr_t, const String& rhs) = delete;
+inline bool operator<=(const String& lhs, std::nullptr_t) = delete;
+
inline bool operator<=(const String& lhs, const std::string& rhs) { return
lhs.compare(rhs) <= 0; }
inline bool operator<=(const std::string& lhs, const String& rhs) { return
rhs.compare(lhs) >= 0; }
@@ -641,6 +909,9 @@ inline bool operator<=(const String& lhs, const char* rhs)
{ return lhs.compare(
inline bool operator<=(const char* lhs, const String& rhs) { return
rhs.compare(lhs) >= 0; }
// Overload >= operator
+inline bool operator>=(std::nullptr_t, const String& rhs) = delete;
+inline bool operator>=(const String& lhs, std::nullptr_t) = delete;
+
inline bool operator>=(const String& lhs, const std::string& rhs) { return
lhs.compare(rhs) >= 0; }
inline bool operator>=(const std::string& lhs, const String& rhs) { return
rhs.compare(lhs) <= 0; }
@@ -651,7 +922,10 @@ inline bool operator>=(const String& lhs, const char* rhs)
{ return lhs.compare(
inline bool operator>=(const char* lhs, const String& rhs) { return
rhs.compare(lhs) <= 0; }
-// Overload == operator
+// delete Overload == operator for nullptr
+inline bool operator==(const String& lhs, std::nullptr_t) = delete;
+inline bool operator==(std::nullptr_t, const String& rhs) = delete;
+
inline bool operator==(const String& lhs, const std::string& rhs) {
return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size());
}
@@ -669,6 +943,9 @@ inline bool operator==(const String& lhs, const char* rhs)
{ return lhs.compare(
inline bool operator==(const char* lhs, const String& rhs) { return
rhs.compare(lhs) == 0; }
// Overload != operator
+inline bool operator!=(const String& lhs, std::nullptr_t) = delete;
+inline bool operator!=(std::nullptr_t, const String& rhs) = delete;
+
inline bool operator!=(const String& lhs, const std::string& rhs) { return
lhs.compare(rhs) != 0; }
inline bool operator!=(const std::string& lhs, const String& rhs) { return
rhs.compare(lhs) != 0; }
@@ -696,14 +973,14 @@ namespace std {
template <>
struct hash<::tvm::ffi::Bytes> {
std::size_t operator()(const ::tvm::ffi::Bytes& bytes) const {
- return ::tvm::ffi::details::StableHashBytes(bytes.data(), bytes.size());
+ return std::hash<std::string_view>()(std::string_view(bytes.data(),
bytes.size()));
}
};
template <>
struct hash<::tvm::ffi::String> {
std::size_t operator()(const ::tvm::ffi::String& str) const {
- return ::tvm::ffi::details::StableHashBytes(str.data(), str.size());
+ return std::hash<std::string_view>()(std::string_view(str.data(),
str.size()));
}
};
} // namespace std
diff --git a/ffi/include/tvm/ffi/type_traits.h
b/ffi/include/tvm/ffi/type_traits.h
index 2c0dba90e7..b019935a6c 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -27,7 +27,6 @@
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/object.h>
-#include <tvm/ffi/optional.h>
#include <string>
#include <type_traits>
@@ -121,6 +120,7 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(const std::nullptr_t&, TVMFFIAny*
result) {
result->type_index = TypeIndex::kTVMFFINone;
+ result->zero_padding = 0;
// invariant: the pointer field also equals nullptr
// this will simplify same_as comparisons and hash
result->v_int64 = 0;
@@ -128,6 +128,7 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase {
TVM_FFI_INLINE static void MoveToAny(std::nullptr_t, TVMFFIAny* result) {
result->type_index = TypeIndex::kTVMFFINone;
+ result->zero_padding = 0;
// invariant: the pointer field also equals nullptr
// this will simplify same_as comparisons and hash
result->v_int64 = 0;
@@ -173,6 +174,7 @@ struct TypeTraits<StrictBool> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(const StrictBool& src, TVMFFIAny*
result) {
result->type_index = TypeIndex::kTVMFFIBool;
+ result->zero_padding = 0;
result->v_int64 = static_cast<bool>(src);
}
@@ -210,6 +212,7 @@ struct TypeTraits<bool> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(const bool& src, TVMFFIAny* result)
{
result->type_index = TypeIndex::kTVMFFIBool;
+ result->zero_padding = 0;
result->v_int64 = static_cast<int64_t>(src);
}
@@ -245,6 +248,7 @@ struct TypeTraits<Int,
std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
TVM_FFI_INLINE static void CopyToAnyView(const Int& src, TVMFFIAny* result) {
result->type_index = TypeIndex::kTVMFFIInt;
+ result->zero_padding = 0;
result->v_int64 = static_cast<int64_t>(src);
}
@@ -283,6 +287,7 @@ struct TypeTraits<IntEnum,
std::enable_if_t<std::is_enum_v<IntEnum> &&
TVM_FFI_INLINE static void CopyToAnyView(const IntEnum& src, TVMFFIAny*
result) {
result->type_index = TypeIndex::kTVMFFIInt;
+ result->zero_padding = 0;
result->v_int64 = static_cast<int64_t>(src);
}
@@ -322,6 +327,7 @@ struct TypeTraits<Float,
std::enable_if_t<std::is_floating_point_v<Float>>>
TVM_FFI_INLINE static void CopyToAnyView(const Float& src, TVMFFIAny*
result) {
result->type_index = TypeIndex::kTVMFFIFloat;
+ result->zero_padding = 0;
result->v_float64 = static_cast<double>(src);
}
@@ -361,6 +367,7 @@ struct TypeTraits<void*> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(void* src, TVMFFIAny* result) {
result->type_index = TypeIndex::kTVMFFIOpaquePtr;
+ result->zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
result->v_ptr = src;
}
@@ -399,11 +406,13 @@ struct TypeTraits<DLDevice> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(const DLDevice& src, TVMFFIAny*
result) {
result->type_index = TypeIndex::kTVMFFIDevice;
+ result->zero_padding = 0;
result->v_device = src;
}
TVM_FFI_INLINE static void MoveToAny(DLDevice src, TVMFFIAny* result) {
result->type_index = TypeIndex::kTVMFFIDevice;
+ result->zero_padding = 0;
result->v_device = src;
}
@@ -439,6 +448,7 @@ struct TypeTraits<DLTensor*> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(DLTensor* src, TVMFFIAny* result) {
TVM_FFI_ICHECK_NOTNULL(src);
result->type_index = TypeIndex::kTVMFFIDLTensorPtr;
+ result->zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
result->v_ptr = src;
}
@@ -488,6 +498,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase {
}
TVMFFIObject* obj_ptr =
details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(src);
result->type_index = obj_ptr->type_index;
+ result->zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
result->v_obj = obj_ptr;
}
@@ -501,6 +512,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase {
}
TVMFFIObject* obj_ptr =
details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(src));
result->type_index = obj_ptr->type_index;
+ result->zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
result->v_obj = obj_ptr;
}
@@ -636,6 +648,7 @@ struct TypeTraits<TObject*,
std::enable_if_t<std::is_base_of_v<Object, TObject>>
TVM_FFI_INLINE static void CopyToAnyView(TObject* src, TVMFFIAny* result) {
TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src);
result->type_index = obj_ptr->type_index;
+ result->zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
result->v_obj = obj_ptr;
}
@@ -643,6 +656,7 @@ struct TypeTraits<TObject*,
std::enable_if_t<std::is_base_of_v<Object, TObject>>
TVM_FFI_INLINE static void MoveToAny(TObject* src, TVMFFIAny* result) {
TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src);
result->type_index = obj_ptr->type_index;
+ result->zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
result->v_obj = obj_ptr;
// needs to increase ref because original weak ptr do not own the code
diff --git a/ffi/src/ffi/dtype.cc b/ffi/src/ffi/dtype.cc
index cb0bd49597..e119f77330 100644
--- a/ffi/src/ffi/dtype.cc
+++ b/ffi/src/ffi/dtype.cc
@@ -320,9 +320,9 @@ int TVMFFIDataTypeFromString(const TVMFFIByteArray* str,
DLDataType* out) {
TVM_FFI_SAFE_CALL_END();
}
-int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) {
+int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(*dtype));
- *out =
tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out_str));
+ tvm::ffi::TypeTraits<tvm::ffi::String>::MoveToAny(std::move(out_str), out);
TVM_FFI_SAFE_CALL_END();
}
diff --git a/ffi/src/ffi/extra/structural_equal.cc
b/ffi/src/ffi/extra/structural_equal.cc
index 3d70e525d9..97ebbf4072 100644
--- a/ffi/src/ffi/extra/structural_equal.cc
+++ b/ffi/src/ffi/extra/structural_equal.cc
@@ -47,6 +47,36 @@ class StructEqualHandler {
const TVMFFIAny* lhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(lhs);
const TVMFFIAny* rhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(rhs);
if (lhs_data->type_index != rhs_data->type_index) {
+ // type_index mismatch, if index is not string, return false
+ if (lhs_data->type_index != kTVMFFIStr && lhs_data->type_index !=
kTVMFFISmallStr &&
+ lhs_data->type_index != kTVMFFISmallBytes && lhs_data->type_index !=
kTVMFFIBytes) {
+ return false;
+ }
+ // small string and normal string comparison
+ if (lhs_data->type_index == kTVMFFIStr && rhs_data->type_index ==
kTVMFFISmallStr) {
+ const details::BytesObjBase* lhs_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
+ return Bytes::memequal(lhs_str->data, rhs_data->v_bytes, lhs_str->size,
+ rhs_data->small_str_len);
+ }
+ if (lhs_data->type_index == kTVMFFISmallStr && rhs_data->type_index ==
kTVMFFIStr) {
+ const details::BytesObjBase* rhs_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
+ return Bytes::memequal(lhs_data->v_bytes, rhs_str->data,
lhs_data->small_str_len,
+ rhs_str->size);
+ }
+ if (lhs_data->type_index == kTVMFFIBytes && rhs_data->type_index ==
kTVMFFISmallBytes) {
+ const details::BytesObjBase* lhs_bytes =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
+ return Bytes::memequal(lhs_bytes->data, rhs_data->v_bytes,
lhs_bytes->size,
+ rhs_data->small_str_len);
+ }
+ if (lhs_data->type_index == kTVMFFISmallBytes && rhs_data->type_index ==
kTVMFFIBytes) {
+ const details::BytesObjBase* rhs_bytes =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
+ return Bytes::memequal(lhs_data->v_bytes, rhs_bytes->data,
lhs_data->small_str_len,
+ rhs_bytes->size);
+ }
return false;
}
@@ -56,7 +86,8 @@ class StructEqualHandler {
return std::isnan(rhs_data->v_float64);
}
// this is POD data, we can just compare the value
- return lhs_data->v_int64 == rhs_data->v_int64;
+ return lhs_data->zero_padding == rhs_data->zero_padding &&
+ lhs_data->v_int64 == rhs_data->v_int64;
}
switch (lhs_data->type_index) {
case TypeIndex::kTVMFFIStr:
@@ -66,7 +97,7 @@ class StructEqualHandler {
AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
const details::BytesObjBase* rhs_str =
AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
- return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size,
rhs_str->size) == 0;
+ return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size,
rhs_str->size);
}
case TypeIndex::kTVMFFIArray: {
return
CompareArray(AnyUnsafe::MoveFromAnyAfterCheck<Array<Any>>(std::move(lhs)),
diff --git a/ffi/src/ffi/extra/structural_hash.cc
b/ffi/src/ffi/extra/structural_hash.cc
index 1d90c5a62d..9f245c1d17 100644
--- a/ffi/src/ffi/extra/structural_hash.cc
+++ b/ffi/src/ffi/extra/structural_hash.cc
@@ -56,6 +56,12 @@ class StructuralHashHandler {
temp.v_float64 = std::numeric_limits<double>::quiet_NaN();
return details::StableHashCombine(temp.type_index, temp.v_uint64);
}
+ if (src_data->type_index == TypeIndex::kTVMFFISmallStr) {
+ // for small string, we use the same type key hash as normal string
+ // so heap allocated string and on stack string will have the same hash
+ return details::StableHashCombine(TypeIndex::kTVMFFIStr,
+
details::StableHashSmallStrBytes(src_data));
+ }
// this is POD data, we can just hash the value
return details::StableHashCombine(src_data->type_index,
src_data->v_uint64);
}
@@ -191,6 +197,13 @@ class StructuralHashHandler {
const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src);
if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
+ if (src_data->type_index == TypeIndex::kTVMFFISmallStr) {
+ // for small string, we use the same type key hash as normal string
+ // so heap allocated string and on stack string will have the same hash
+ return details::StableHashCombine(
+ TypeIndex::kTVMFFIStr,
+ details::StableHashBytes(src_data->v_bytes,
src_data->small_str_len));
+ }
// this is POD data, we can just hash the value
return details::StableHashCombine(src_data->type_index,
src_data->v_uint64);
} else {
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 4abe933d4d..374c0c7c4e 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -317,7 +317,7 @@ class TypeTable {
type_table_.emplace_back(nullptr);
}
// initialize the entry for object
- this->GetOrAllocTypeIndex(Object::_type_key, Object::_type_index,
Object::_type_depth,
+ this->GetOrAllocTypeIndex(String(Object::_type_key), Object::_type_index,
Object::_type_depth,
Object::_type_child_slots,
Object::_type_child_slots_can_overflow,
-1);
TVMFFITypeMetadata info;
@@ -337,20 +337,36 @@ class TypeTable {
ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIByteArrayPtr,
TypeIndex::kTVMFFIByteArrayPtr);
ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef,
TypeIndex::kTVMFFIObjectRValueRef);
+ ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallStr,
TypeIndex::kTVMFFISmallStr);
+ ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallBytes,
TypeIndex::kTVMFFISmallBytes);
// no need to reserve for object types as they will be registered
}
void ReserveBuiltinTypeIndex(const char* type_key, int32_t
static_type_index) {
- this->GetOrAllocTypeIndex(type_key, static_type_index, 0, 0, false, -1);
+ this->GetOrAllocTypeIndex(String(type_key), static_type_index, 0, 0,
false, -1);
+ }
+
+ static ObjectPtr<details::StringObj> MakeInplaceString(const char* data,
size_t length) {
+ ObjectPtr<details::StringObj> p =
+ make_inplace_array_object<details::StringObj, char>(length + 1);
+ static_assert(alignof(details::StringObj) % alignof(char) == 0);
+ static_assert(sizeof(details::StringObj) % alignof(char) == 0);
+ char* dest_data = reinterpret_cast<char*>(p.get()) +
sizeof(details::StringObj);
+ p->data = dest_data;
+ p->size = length;
+ std::memcpy(dest_data, data, length);
+ dest_data[length] = '\0';
+ return p;
}
TVMFFIByteArray CopyString(TVMFFIByteArray str) {
if (str.size == 0) {
return TVMFFIByteArray{nullptr, 0};
}
- String val = String(str.data, str.size);
- TVMFFIByteArray c_val{val.data(), val.length()};
- any_pool_.emplace_back(std::move(val));
+ // use explicit object creation to ensure the space pointer to not move
+ auto str_obj = MakeInplaceString(str.data, str.size);
+ TVMFFIByteArray c_val{str_obj->data, str_obj->size};
+ any_pool_.emplace_back(ObjectRef(std::move(str_obj)));
return c_val;
}
diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc
index a1a2b4514a..d1f56e1a93 100644
--- a/ffi/tests/cpp/test_any.cc
+++ b/ffi/tests/cpp/test_any.cc
@@ -394,4 +394,22 @@ TEST(Any, ObjectMove) {
EXPECT_TRUE(any1 == nullptr);
}
+TEST(Any, AnyEqualHash) {
+ // small string
+ Any a = "a1";
+ // on heap allocated string
+ Any b = String(std::string("a1"));
+ EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFISmallStr);
+ EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr);
+ EXPECT_TRUE(AnyEqual()(a, b));
+ EXPECT_EQ(AnyHash()(a), AnyHash()(b));
+
+ Any c = Bytes("a1", 2);
+ Any d = Bytes(std::string("a1"));
+ EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFISmallBytes);
+ EXPECT_EQ(d.type_index(), TypeIndex::kTVMFFIBytes);
+ EXPECT_TRUE(AnyEqual()(c, d));
+ EXPECT_EQ(AnyHash()(c), AnyHash()(d));
+}
+
} // namespace
diff --git a/ffi/tests/cpp/test_dtype.cc b/ffi/tests/cpp/test_dtype.cc
index 620f729a66..79fc9d7c2d 100644
--- a/ffi/tests/cpp/test_dtype.cc
+++ b/ffi/tests/cpp/test_dtype.cc
@@ -20,6 +20,7 @@
#include <tvm/ffi/any.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/memory.h>
+#include <tvm/ffi/optional.h>
namespace {
diff --git a/ffi/tests/cpp/test_optional.cc b/ffi/tests/cpp/test_optional.cc
index 256a7da8b4..eb114df8a3 100644
--- a/ffi/tests/cpp/test_optional.cc
+++ b/ffi/tests/cpp/test_optional.cc
@@ -170,4 +170,33 @@ TEST(Optional, OptionalInArray) {
auto opt_arr = any.cast<Array<Optional<Array<TInt>>>>();
EXPECT_EQ(opt_arr[0].value()[0]->value, 0);
}
+
+TEST(Optional, String) {
+ Optional<String> opt_str;
+ EXPECT_TRUE(!opt_str.has_value());
+ EXPECT_EQ(opt_str.value_or("default"), "default");
+ EXPECT_TRUE(opt_str != "default");
+ EXPECT_TRUE(opt_str != String("default"));
+ EXPECT_TRUE(opt_str == std::nullopt);
+
+ opt_str = "hello";
+ EXPECT_TRUE(opt_str.has_value());
+ EXPECT_EQ(opt_str.value(), "hello");
+ EXPECT_TRUE(opt_str == "hello");
+ EXPECT_TRUE(opt_str == String("hello"));
+ EXPECT_TRUE(opt_str != std::nullopt);
+ static_assert(sizeof(Optional<String>) == sizeof(String));
+}
+
+TEST(Optional, Bytes) {
+ Optional<Bytes> opt_bytes;
+ EXPECT_TRUE(!opt_bytes.has_value());
+ EXPECT_EQ(opt_bytes.value_or(std::string("default")), "default");
+
+ opt_bytes = std::string("hello");
+ EXPECT_TRUE(opt_bytes.has_value());
+ EXPECT_EQ(opt_bytes.value().operator std::string(), "hello");
+ EXPECT_TRUE(opt_bytes != std::nullopt);
+ static_assert(sizeof(Optional<Bytes>) == sizeof(Bytes));
+}
} // namespace
diff --git a/ffi/tests/cpp/test_reflection_accessor.cc
b/ffi/tests/cpp/test_reflection_accessor.cc
index aa3dfc5e92..cb5145db07 100644
--- a/ffi/tests/cpp/test_reflection_accessor.cc
+++ b/ffi/tests/cpp/test_reflection_accessor.cc
@@ -99,7 +99,6 @@ TEST(Reflection, FieldInfo) {
const TVMFFIFieldInfo* info_prim_expr_dtype =
reflection::GetFieldInfo("test.PrimExpr", "dtype");
AnyView default_value =
AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value);
EXPECT_EQ(default_value.cast<String>(), "float");
- EXPECT_EQ(default_value.as<String>().value().use_count(), 2);
EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskHasDefault);
EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable);
EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype
field");
diff --git a/ffi/tests/cpp/test_rvalue_ref.cc b/ffi/tests/cpp/test_rvalue_ref.cc
index 7cbd5c627b..dd211a34dc 100644
--- a/ffi/tests/cpp/test_rvalue_ref.cc
+++ b/ffi/tests/cpp/test_rvalue_ref.cc
@@ -90,8 +90,8 @@ TEST(RValueRef, ParamChecking) {
TPrimExpr expr = *std::move(a);
return expr->dtype;
});
- EXPECT_EQ(func3(RValueRef(String("int32"))).cast<String>(), "int32");
+ // EXPECT_EQ(func3(RValueRef(String("int32"))).cast<String>(), "int32");
// triggered a lvalue based conversion
- EXPECT_EQ(func3(String("int32")).cast<String>(), "int32");
+ // EXPECT_EQ(func3(String("int32")).cast<String>(), "int32");
}
} // namespace
diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc
index d53ac105ab..364f2f6540 100644
--- a/ffi/tests/cpp/test_string.cc
+++ b/ffi/tests/cpp/test_string.cc
@@ -54,9 +54,9 @@ TEST(String, Assignment) {
s = std::move(s2);
EXPECT_EQ(s == "world2", true);
- ObjectRef r;
+ Any r;
r = String("hello");
- EXPECT_EQ(r.defined(), true);
+ EXPECT_EQ(r != nullptr, true);
}
TEST(String, empty) {
@@ -265,7 +265,7 @@ TEST(String, Cast) {
using namespace std;
string source = "this is a string";
String s{source};
- ObjectRef r = s;
+ Any r = s;
String s2 = Downcast<String>(r);
}
@@ -284,14 +284,19 @@ TEST(String, Concat) {
EXPECT_EQ(res3.compare("worldhello"), 0);
EXPECT_EQ(res4.compare("helloworld"), 0);
EXPECT_EQ(res5.compare("worldhello"), 0);
+
+ String storage_scope;
+ String res = "The input storage scope \"" + storage_scope + "\" is invalid.";
+ EXPECT_EQ(res.compare("The input storage scope \"\" is invalid."), 0);
}
TEST(String, Any) {
// test anyview promotion to any
AnyView view = "hello";
+ EXPECT_EQ(view.type_index(), TypeIndex::kTVMFFIRawStr);
Any b = view;
- EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr);
+ EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFISmallStr);
EXPECT_EQ(b.as<String>().value(), "hello");
EXPECT_TRUE(b.as<String>().has_value());
EXPECT_EQ(b.try_cast<std::string>().value(), "hello");
@@ -302,17 +307,21 @@ TEST(String, Any) {
String s{"hello"};
Any a = s;
- EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFIStr);
+ EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFISmallStr);
EXPECT_EQ(a.as<String>().value(), "hello");
EXPECT_EQ(a.try_cast<std::string>().value(), "hello");
- Any c = "helloworld";
+ Any c = "long string very long";
EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFIStr);
- EXPECT_EQ(c.as<String>().value(), "helloworld");
- EXPECT_EQ(c.try_cast<std::string>().value(), "helloworld");
+ EXPECT_EQ(c.as<String>().value(), "long string very long");
+ EXPECT_EQ(c.try_cast<std::string>().value(), "long string very long");
}
TEST(String, Bytes) {
+ Bytes b0;
+ EXPECT_EQ(b0.size(), 0);
+ EXPECT_EQ(b0.operator std::string(), "");
+
// explicitly test zero element
std::string s = {'\0', 'a', 'b', 'c'};
Bytes b = s;
@@ -334,10 +343,17 @@ TEST(String, BytesAny) {
EXPECT_EQ(view.try_cast<Bytes>().value().operator std::string(), s);
Any b = view;
- EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIBytes);
+ EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFISmallBytes);
EXPECT_EQ(b.try_cast<Bytes>().value().operator std::string(), s);
EXPECT_EQ(b.cast<std::string>(), s);
+
+ std::string s2 = "hello long long long string";
+ s2[0] = '\0';
+ Any b2 = Bytes(s2);
+ EXPECT_EQ(b2.type_index(), TypeIndex::kTVMFFIBytes);
+ EXPECT_EQ(b2.try_cast<std::string>().value(), s2);
+ EXPECT_EQ(b2.cast<std::string>(), s2);
}
TEST(String, StdString) {
@@ -382,10 +398,9 @@ TEST(String, StdString) {
TEST(String, CAPIAccessor) {
using namespace std;
String s{"hello"};
- TVMFFIObjectHandle obj = details::ObjectUnsafe::RawObjectPtrFromObjectRef(s);
- TVMFFIByteArray* arr = TVMFFIBytesGetByteArrayPtr(obj);
- EXPECT_EQ(arr->size, 5);
- EXPECT_EQ(std::string(arr->data, arr->size), "hello");
+ TVMFFIByteArray arr{s.data(), s.size()};
+ EXPECT_EQ(arr.size, 5);
+ EXPECT_EQ(std::string(arr.data, arr.size), "hello");
}
TEST(String, BytesHash) {
@@ -403,4 +418,14 @@ TEST(String, BytesHash) {
EXPECT_EQ(hash1, hash2);
}
+TEST(String, StdHash) {
+ String s1 = "a";
+ String s2(std::string("a"));
+ EXPECT_EQ(std::hash<String>()(s1), std::hash<String>()(s2));
+
+ Bytes s3("a", 1);
+ Bytes s4(std::string("a"));
+ EXPECT_EQ(std::hash<Bytes>()(s3), std::hash<Bytes>()(s4));
+}
+
} // namespace
diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc
index b140e7db6e..639e6ee671 100644
--- a/ffi/tests/cpp/test_variant.cc
+++ b/ffi/tests/cpp/test_variant.cc
@@ -154,11 +154,11 @@ TEST(Variant, PODSameAs) {
Variant<String, int> v0 = 1;
Variant<String, int> v1 = 1;
EXPECT_TRUE(v0.same_as(v1));
- String s = String("hello");
+ String s = String("hello long str");
v0 = s;
v1 = s;
EXPECT_TRUE(v0.same_as(v1));
- v1 = String("hello");
+ v1 = String("hello long str");
EXPECT_TRUE(!v0.same_as(v1));
}
} // namespace
diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h
index fa6a837c8b..cd9a71eb9e 100644
--- a/include/tvm/relax/exec_builder.h
+++ b/include/tvm/relax/exec_builder.h
@@ -170,7 +170,7 @@ class ExecBuilderNode : public Object {
/*! \brief The mutable internal executable. */
ObjectPtr<vm::VMExecutable> exec_; // mutable
/*! \brief internal dedup map when creating index for a new constant */
- std::unordered_map<ObjectRef, vm::Index, StructuralHash, StructuralEqual>
const_dedup_map_;
+ std::unordered_map<ffi::Any, vm::Index, StructuralHash, StructuralEqual>
const_dedup_map_;
};
class ExecBuilder : public ObjectRef {
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 4068f7c682..1567294a4b 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -196,7 +196,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
*
* \return The Pass.
*/
-TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params);
+TVM_DLL Pass BindParams(String func_name, Map<Any, ObjectRef> params);
/*!
* \brief Bind symbolic vars to constant shape values.
@@ -213,7 +213,7 @@ TVM_DLL Pass BindParams(String func_name, Map<ObjectRef,
ObjectRef> params);
*
* \return The Pass.
*/
-TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map,
+TVM_DLL Pass BindSymbolicVars(Map<Variant<tir::Var, String>, PrimExpr>
binding_map,
Optional<String> func_name = std::nullopt);
/*!
diff --git a/include/tvm/script/ir_builder/tir/frame.h
b/include/tvm/script/ir_builder/tir/frame.h
index e9087588ff..1e205edc43 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -555,7 +555,7 @@ class AllocateConstFrame : public TIRFrame {
class AttrFrameNode : public TIRFrameNode {
public:
/*! \brief The node to annotate the attribute. */
- ObjectRef node;
+ Any node;
/*! \brief Attribute type key. */
String attr_key;
/*! \brief The value of the attribute. */
diff --git a/include/tvm/script/printer/ir_docsifier.h
b/include/tvm/script/printer/ir_docsifier.h
index 9d189dda09..8a181cf853 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -319,6 +319,7 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const
ObjectPath& path) con
return Downcast<TDoc>(LiteralDoc::Int(value.as<int64_t>().value(),
path));
case ffi::TypeIndex::kTVMFFIFloat:
return Downcast<TDoc>(LiteralDoc::Float(value.as<double>().value(),
path));
+ case ffi::TypeIndex::kTVMFFISmallStr:
case ffi::TypeIndex::kTVMFFIStr: {
std::string string_value = value.cast<std::string>();
bool has_multiple_lines = string_value.find_first_of('\n') !=
std::string::npos;
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 6b31324fa5..b4ed44fbff 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -984,6 +984,7 @@ enum TVMStructFieldKind : int {
// TVMValue field
kTVMValueContent,
kTVMFFIAnyTypeIndex,
+ kTVMFFIAnyZeroPadding,
kTVMFFIAnyUnionValue,
kTVMValueKindBound_
};
diff --git a/jvm/native/src/main/native/jni_helper_func.h
b/jvm/native/src/main/native/jni_helper_func.h
index 76520d43f7..5db3e279cf 100644
--- a/jvm/native/src/main/native/jni_helper_func.h
+++ b/jvm/native/src/main/native/jni_helper_func.h
@@ -223,10 +223,16 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) {
case TypeIndex::kTVMFFINDArray: {
return newNDArray(env, reinterpret_cast<jlong>(value.v_obj), false);
}
+ case TypeIndex::kTVMFFISmallStr: {
+ TVMFFIByteArray arr = TVMFFISmallBytesGetContentByteArray(&value);
+ return newTVMValueString(env, &arr);
+ }
case TypeIndex::kTVMFFIStr: {
- jobject ret = newTVMValueString(env,
TVMFFIBytesGetByteArrayPtr(value.v_obj));
- TVMFFIObjectFree(value.v_obj);
- return ret;
+ return newTVMValueString(env, TVMFFIBytesGetByteArrayPtr(value.v_obj));
+ }
+ case TypeIndex::kTVMFFISmallBytes: {
+ TVMFFIByteArray arr = TVMFFISmallBytesGetContentByteArray(&value);
+ return newTVMValueBytes(env, &arr);
}
case TypeIndex::kTVMFFIBytes: {
jobject ret = newTVMValueBytes(env,
TVMFFIBytesGetByteArrayPtr(value.v_obj));
diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
index a5481dd9ac..3ebe7fddfa 100644
--- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
+++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
@@ -110,6 +110,7 @@ JNIEXPORT void JNICALL
Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgHandle(J
TVMFFIAny temp;
temp.v_int64 = static_cast<int64_t>(arg);
temp.type_index = static_cast<int>(argTypeIndex);
+ temp.zero_padding = 0;
stack->packed_args.emplace_back(tvm::ffi::AnyView::CopyFromTVMFFIAny(temp));
}
@@ -175,6 +176,7 @@ JNIEXPORT jint JNICALL
Java_org_apache_tvm_LibInfo_tvmFFIFunctionCall(JNIEnv* en
TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal();
TVMFFIAny ret_val;
ret_val.type_index = tvm::ffi::TypeIndex::kTVMFFINone;
+ ret_val.zero_padding = 0;
ret_val.v_int64 = 0;
int ret = TVMFFIFunctionCall(reinterpret_cast<TVMFFIObjectHandle>(jhandle),
reinterpret_cast<TVMFFIAny*>(stack->packed_args.data()),
diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi
index 8d31205d2e..00b76e68f7 100644
--- a/python/tvm/ffi/cython/base.pxi
+++ b/python/tvm/ffi/cython/base.pxi
@@ -40,6 +40,8 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIRawStr = 8
kTVMFFIByteArrayPtr = 9
kTVMFFIObjectRValueRef = 10
+ kTVMFFISmallStr = 11
+ kTVMFFISmallBytes = 12
kTVMFFIStaticObjectBegin = 64
kTVMFFIObject = 64
kTVMFFIStr = 65
@@ -95,7 +97,7 @@ cdef extern from "tvm/ffi/c_api.h":
ctypedef struct TVMFFIAny:
int32_t type_index
- int32_t padding
+ int32_t zero_padding
int64_t v_int64
double v_float64
void* v_ptr
@@ -184,7 +186,7 @@ cdef extern from "tvm/ffi/c_api.h":
int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil
int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex)
nogil
int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil
- int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle*
out) nogil
+ int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil
const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno,
const char* func) nogil;
int TVMFFINDArrayFromDLPack(DLManagedTensor* src, int32_t
require_alignment,
int32_t require_contiguous,
TVMFFIObjectHandle* out) nogil
@@ -196,6 +198,7 @@ cdef extern from "tvm/ffi/c_api.h":
int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle src,
DLManagedTensorVersioned** out) nogil
const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil
+ TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny*
value) nogil
TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil
TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil
TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil
diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi
index 80ec5d9364..279b17f8c8 100644
--- a/python/tvm/ffi/cython/dtype.pxi
+++ b/python/tvm/ffi/cython/dtype.pxi
@@ -92,12 +92,19 @@ cdef class DataType:
return (self.cdtype.bits * self.cdtype.lanes + 7) // 8
def __str__(self):
- cdef TVMFFIObjectHandle dtype_str
- cdef TVMFFIByteArray* bytes
- CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &dtype_str))
- bytes = TVMFFIBytesGetByteArrayPtr(dtype_str)
- res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size))
- CHECK_CALL(TVMFFIObjectFree(dtype_str))
+ cdef TVMFFIAny temp_any
+ cdef TVMFFIByteArray* bytes_ptr
+ cdef TVMFFIByteArray bytes
+
+ CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &temp_any))
+ if temp_any.type_index == kTVMFFISmallStr:
+ bytes = TVMFFISmallBytesGetContentByteArray(&temp_any)
+ res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size))
+ return res
+
+ bytes_ptr = TVMFFIBytesGetByteArrayPtr(temp_any.v_obj)
+ res = py_str(PyBytes_FromStringAndSize(bytes_ptr.data, bytes_ptr.size))
+ CHECK_CALL(TVMFFIObjectFree(temp_any.v_obj))
return res
diff --git a/python/tvm/ffi/cython/function.pxi
b/python/tvm/ffi/cython/function.pxi
index d86d004d10..cbff3fecf1 100644
--- a/python/tvm/ffi/cython/function.pxi
+++ b/python/tvm/ffi/cython/function.pxi
@@ -23,6 +23,20 @@ except ImportError:
torch = None
+cdef inline object make_ret_small_str(TVMFFIAny result):
+ """convert small string to return value."""
+ cdef TVMFFIByteArray bytes
+ bytes = TVMFFISmallBytesGetContentByteArray(&result)
+ return py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size))
+
+
+cdef inline object make_ret_small_bytes(TVMFFIAny result):
+ """convert small bytes to return value."""
+ cdef TVMFFIByteArray bytes
+ bytes = TVMFFISmallBytesGetContentByteArray(&result)
+ return PyBytes_FromStringAndSize(bytes.data, bytes.size)
+
+
cdef inline object make_ret(TVMFFIAny result):
"""convert result to return value."""
# TODO: Implement
@@ -41,6 +55,10 @@ cdef inline object make_ret(TVMFFIAny result):
return result.v_int64
elif type_index == kTVMFFIFloat:
return result.v_float64
+ elif type_index == kTVMFFISmallStr:
+ return make_ret_small_str(result)
+ elif type_index == kTVMFFISmallBytes:
+ return make_ret_small_bytes(result)
elif type_index == kTVMFFIOpaquePtr:
return ctypes_handle(result.v_ptr)
elif type_index == kTVMFFIDataType:
@@ -65,6 +83,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list
temp_args) except
# clear the value to ensure zero padding on 32bit platforms
if sizeof(void*) != 8:
out[i].v_int64 = 0
+ out[i].zero_padding = 0
if isinstance(arg, NDArray):
if (<Object>arg).chandle != NULL:
diff --git a/src/contrib/msc/core/ir/graph_builder.h
b/src/contrib/msc/core/ir/graph_builder.h
index cc1905c0fa..401c452d95 100644
--- a/src/contrib/msc/core/ir/graph_builder.h
+++ b/src/contrib/msc/core/ir/graph_builder.h
@@ -154,6 +154,7 @@ class AttrGetter {
attrs_->Set(key,
runtime::DLDataTypeToString(value.cast<DLDataType>()));
break;
}
+ case kTVMFFISmallStr:
case kTVMFFIStr: {
attrs_->Set(key, value.cast<String>());
break;
diff --git a/src/contrib/msc/core/printer/cpp_printer.cc
b/src/contrib/msc/core/printer/cpp_printer.cc
index 6ae71860b6..1f0fdb1177 100644
--- a/src/contrib/msc/core/printer/cpp_printer.cc
+++ b/src/contrib/msc/core/printer/cpp_printer.cc
@@ -167,7 +167,7 @@ void CppPrinter::PrintTypedDoc(const ScopeDoc& doc) {
void CppPrinter::PrintTypedDoc(const FunctionDoc& doc) {
MaybePrintComment(doc, true);
for (const AssignDoc& arg_doc : doc->args) {
- ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment
attached to them.";
+ ICHECK(!arg_doc->comment.has_value()) << "Function arg cannot have comment
attached to them.";
}
if (doc->return_type.defined()) {
if (!IsEmptyDoc(doc->return_type.value())) {
@@ -273,7 +273,8 @@ void CppPrinter::PrintTypedDoc(const StructDoc& doc) {
void CppPrinter::PrintTypedDoc(const ConstructorDoc& doc) {
MaybePrintComment(doc, true);
for (const AssignDoc& arg_doc : doc->args) {
- ICHECK(arg_doc->comment == nullptr) << "Constructor arg cannot have
comment attached to them.";
+ ICHECK(!arg_doc->comment.has_value())
+ << "Constructor arg cannot have comment attached to them.";
}
PrintDoc(doc->name, false);
output_ << "(";
@@ -293,7 +294,7 @@ void CppPrinter::PrintTypedDoc(const ConstructorDoc& doc) {
void CppPrinter::PrintTypedDoc(const LambdaDoc& doc) {
MaybePrintComment(doc, true);
for (const AssignDoc& arg_doc : doc->args) {
- ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment
attached to them.";
+ ICHECK(!arg_doc->comment.has_value()) << "Function arg cannot have comment
attached to them.";
}
output_ << "auto ";
PrintDoc(doc->name, false);
diff --git a/src/contrib/msc/core/printer/python_printer.cc
b/src/contrib/msc/core/printer/python_printer.cc
index 184d7ce870..df75887ce1 100644
--- a/src/contrib/msc/core/printer/python_printer.cc
+++ b/src/contrib/msc/core/printer/python_printer.cc
@@ -157,7 +157,7 @@ void PythonPrinter::PrintTypedDoc(const ScopeDoc& doc) {
void PythonPrinter::PrintTypedDoc(const FunctionDoc& doc) {
for (const AssignDoc& arg_doc : doc->args) {
- ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment
attached to them.";
+ ICHECK(!arg_doc->comment.has_value()) << "Function arg cannot have comment
attached to them.";
}
PrintDecorators(doc->decorators);
diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc
b/src/meta_schedule/mutator/mutate_tile_size.cc
index af5fb3ebab..36a38cac75 100644
--- a/src/meta_schedule/mutator/mutate_tile_size.cc
+++ b/src/meta_schedule/mutator/mutate_tile_size.cc
@@ -94,9 +94,8 @@ void FindSamplePerfectTile(const Trace& trace,
std::vector<Instruction>* inst,
decisions.reserve(trace->decisions.size());
for (const auto& kv : trace->decisions) {
const Instruction& inst = kv.first;
- const ObjectRef& decision = kv.second.cast<ObjectRef>();
if (inst->kind.same_as(inst_sample_perfect_tile)) {
- std::vector<int64_t> tiles = DowncastTilingDecision(decision);
+ std::vector<int64_t> tiles =
DowncastTilingDecision(kv.second.cast<ObjectRef>());
if (tiles.size() >= 2 && Product(tiles) >= 2) {
instructions.push_back(inst);
decisions.push_back(tiles);
@@ -130,7 +129,6 @@ void FindSampleVectorize(const Trace& trace,
std::vector<Instruction>* inst,
// Find sampling instruction that generates the annotation
for (const auto& kv : trace->decisions) {
const Instruction& inst = kv.first;
- const ObjectRef& decision = kv.second.cast<ObjectRef>();
if (inst->kind.same_as(inst_sample_categorical)) {
ICHECK_EQ(inst->outputs.size(), 1);
if (annotated.count(inst->outputs[0].as<Object>())) {
@@ -141,6 +139,7 @@ void FindSampleVectorize(const Trace& trace,
std::vector<Instruction>* inst,
// Skip mutating the sampling instructions who have only single
candidate.
continue;
}
+ const ObjectRef& decision = kv.second.cast<ObjectRef>();
const auto* d = TVM_TYPE_AS(decision, IntImmNode);
instructions.push_back(inst);
decisions.push_back(d->value);
diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc
index 240b4f1758..d3b62b5e87 100644
--- a/src/node/repr_printer.cc
+++ b/src/node/repr_printer.cc
@@ -78,11 +78,13 @@ void ReprPrinter::Print(const ffi::Any& node) {
Print(node.cast<ObjectRef>());
break;
}
+ case ffi::TypeIndex::kTVMFFISmallStr:
case ffi::TypeIndex::kTVMFFIStr: {
ffi::String str = node.cast<ffi::String>();
stream << '"' << support::StrEscape(str.data(), str.size()) << '"';
break;
}
+ case ffi::TypeIndex::kTVMFFISmallBytes:
case ffi::TypeIndex::kTVMFFIBytes: {
ffi::Bytes bytes = node.cast<ffi::Bytes>();
stream << "b\"" << support::StrEscape(bytes.data(), bytes.size()) << '"';
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index 65b9728317..0c3ca959a3 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -108,7 +108,9 @@ class NodeIndexer {
}
}
} else if (node.type_index() == ffi::TypeIndex::kTVMFFIStr ||
- node.type_index() == ffi::TypeIndex::kTVMFFIBytes) {
+ node.type_index() == ffi::TypeIndex::kTVMFFISmallStr ||
+ node.type_index() == ffi::TypeIndex::kTVMFFIBytes ||
+ node.type_index() == ffi::TypeIndex::kTVMFFISmallBytes) {
// skip content index for string and bytes
} else if (auto opt_object = node.as<const Object*>()) {
Object* n = const_cast<Object*>(opt_object.value());
@@ -126,8 +128,8 @@ class NodeIndexer {
<< "` misses reflection registration and do not support serialization";
ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo*
field_info) {
Any field_value = ffi::reflection::FieldGetter(field_info)(obj);
- // only make index for ObjectRef
- if (field_value.as<Object>()) {
+ // only make index for ObjectRef and String(which may not be object for
small str)
+ if (field_value.as<Object>() || field_value.as<String>()) {
this->MakeIndex(field_value);
}
});
@@ -234,9 +236,9 @@ class JSONAttrGetter {
}
}
- void Visit(const char* key, ObjectRef* value) {
- if (value->defined()) {
- node_->attrs[key] = std::to_string(node_index_->at(Any(*value)));
+ void Visit(const char* key, Any* value) {
+ if (value != nullptr) {
+ node_->attrs[key] = std::to_string(node_index_->at(*value));
} else {
node_->attrs[key] = "null";
}
@@ -249,6 +251,13 @@ class JSONAttrGetter {
return;
}
node_->type_key = node.GetTypeKey();
+ // canonicalize type key for str
+ if (node_->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) {
+ node_->type_key = ffi::StaticTypeKey::kTVMFFIStr;
+ }
+ if (node_->type_key == ffi::StaticTypeKey::kTVMFFISmallBytes) {
+ node_->type_key = ffi::StaticTypeKey::kTVMFFIBytes;
+ }
// populates the fields.
node_->attrs.clear();
node_->data.clear();
@@ -344,19 +353,9 @@ class JSONAttrGetter {
this->Visit(field_info->name.data, &value);
break;
}
- case ffi::TypeIndex::kTVMFFINDArray: {
- runtime::NDArray value = field_value.cast<runtime::NDArray>();
- this->Visit(field_info->name.data, &value);
- break;
- }
default: {
- if (field_value.type_index() >=
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
- ObjectRef obj = field_value.cast<ObjectRef>();
- this->Visit(field_info->name.data, &obj);
- break;
- } else {
- LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey();
- }
+ this->Visit(field_info->name.data, &field_value);
+ break;
}
}
});
@@ -401,14 +400,16 @@ class FieldDependencyFinder {
if (node == nullptr) {
return;
}
- if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
- return;
- }
if (node.type_index() == ffi::TypeIndex::kTVMFFIStr ||
- node.type_index() == ffi::TypeIndex::kTVMFFIBytes) {
+ node.type_index() == ffi::TypeIndex::kTVMFFISmallStr ||
+ node.type_index() == ffi::TypeIndex::kTVMFFIBytes ||
+ node.type_index() == ffi::TypeIndex::kTVMFFISmallBytes) {
// skip indexing content of string and bytes
return;
}
+ if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ return;
+ }
// Skip the objects that have their own string repr
if (jnode->repr_bytes.length() > 0 ||
reflection_->GetReprBytes(node.cast<const Object*>(), nullptr)) {
@@ -562,9 +563,11 @@ class JSONAttrSetter {
setter.ParseValue("v_device_type", &device_type);
setter.ParseValue("v_device_id", &device_id);
return Any(DLDevice{static_cast<DLDeviceType>(device_type), device_id});
- } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr) {
+ } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr ||
+ jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) {
return Any(String(jnode->repr_bytes));
- } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) {
+ } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes ||
+ jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallBytes) {
return Any(Bytes(jnode->repr_bytes));
} else {
return ObjectRef(reflection->CreateInitObject(jnode->type_key,
jnode->repr_bytes));
@@ -596,7 +599,9 @@ class JSONAttrSetter {
}
*node = result;
} else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr ||
- jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) {
+ jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr ||
+ jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes ||
+ jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallBytes) {
// skip set attrs for string and bytes
} else if (auto opt_object = node->as<const Object*>()) {
Object* n = const_cast<Object*>(opt_object.value());
@@ -652,7 +657,7 @@ class JSONAttrSetter {
ParseOptionalValue(field_info->name.data, &index,
[this](const char* key, int64_t* value) {
ParseValue(key, value); });
if (index.has_value()) {
- Any value = node_list_->at(*index).cast<ObjectRef>();
+ Any value = node_list_->at(*index);
setter(obj, value);
} else {
setter(obj, Any());
diff --git a/src/relax/backend/contrib/clml/codegen.cc
b/src/relax/backend/contrib/clml/codegen.cc
index ec7063f2e9..84ef050938 100644
--- a/src/relax/backend/contrib/clml/codegen.cc
+++ b/src/relax/backend/contrib/clml/codegen.cc
@@ -139,7 +139,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer {
const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
auto opt_composite = fn->GetAttr<String>(attr::kComposite);
- ICHECK(opt_composite.defined());
+ ICHECK(opt_composite.has_value());
std::string name = opt_composite.value();
std::shared_ptr<JSONGraphNode> node;
@@ -194,7 +194,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer {
ICHECK(fn_var);
const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
auto opt_composite = fn->GetAttr<String>(attr::kComposite);
- ICHECK(opt_composite.defined());
+ ICHECK(opt_composite.has_value());
nodes.pad = backend::TryGetOpInFunction(fn, "relax.nn.pad");
nodes.conv = backend::TryGetOpInFunction(fn, "relax.nn.conv2d");
@@ -223,7 +223,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer {
ICHECK(fn_var);
const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
auto opt_composite = fn->GetAttr<String>(attr::kComposite);
- ICHECK(opt_composite.defined());
+ ICHECK(opt_composite.has_value());
std::string name = opt_composite.value();
std::vector<JSONGraphNodeEntry> inputs;
diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h
b/src/relax/backend/contrib/codegen_json/codegen_json.h
index b2c3e47c73..ecf34ecd9f 100644
--- a/src/relax/backend/contrib/codegen_json/codegen_json.h
+++ b/src/relax/backend/contrib/codegen_json/codegen_json.h
@@ -180,11 +180,8 @@ class OpAttrExtractor {
break;
}
default: {
- if (field_value.type_index() >=
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
- this->Visit(field_info->name.data, &field_value);
- break;
- }
- LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey();
+ this->Visit(field_info->name.data, &field_value);
+ break;
}
}
});
diff --git a/src/relax/backend/contrib/cublas/codegen.cc
b/src/relax/backend/contrib/cublas/codegen.cc
index 41a4cb766a..3f132b024a 100644
--- a/src/relax/backend/contrib/cublas/codegen.cc
+++ b/src/relax/backend/contrib/cublas/codegen.cc
@@ -53,7 +53,7 @@ class CublasJSONSerializer : public JSONSerializer {
ICHECK(fn.defined()) << "Expects the callee to be a function.";
auto composite_opt = fn->GetAttr<String>(attr::kComposite);
- ICHECK(composite_opt.defined()) << "Only composite functions are
supported.";
+ ICHECK(composite_opt.has_value()) << "Only composite functions are
supported.";
std::string composite_name = composite_opt.value();
diff --git a/src/relax/backend/contrib/cudnn/codegen.cc
b/src/relax/backend/contrib/cudnn/codegen.cc
index 358f2d6604..b529c6f796 100644
--- a/src/relax/backend/contrib/cudnn/codegen.cc
+++ b/src/relax/backend/contrib/cudnn/codegen.cc
@@ -52,7 +52,7 @@ class cuDNNJSONSerializer : public JSONSerializer {
ICHECK(fn.defined()) << "Expects the callee to be a function.";
auto composite_opt = fn->GetAttr<String>(attr::kComposite);
- ICHECK(composite_opt.defined()) << "Only composite functions are
supported.";
+ ICHECK(composite_opt.has_value()) << "Only composite functions are
supported.";
std::string composite_name = composite_opt.value();
diff --git a/src/relax/backend/contrib/cutlass/codegen.cc
b/src/relax/backend/contrib/cutlass/codegen.cc
index 874dced500..932fdadddf 100644
--- a/src/relax/backend/contrib/cutlass/codegen.cc
+++ b/src/relax/backend/contrib/cutlass/codegen.cc
@@ -221,7 +221,7 @@ class CodegenCutlass : public
relax::MemoizedExprTranslator<OutputType>,
}
OutputType VisitExpr_(const FunctionNode* fn) final {
- ICHECK(fn->GetAttr<String>(attr::kComposite).defined())
+ ICHECK(fn->GetAttr<String>(attr::kComposite).has_value())
<< "JSON runtime only supports composite functions";
// FunctionNode should be handled by the caller.
return {};
diff --git a/src/relax/backend/contrib/dnnl/codegen.cc
b/src/relax/backend/contrib/dnnl/codegen.cc
index 349dbd4ef1..83cbdd8e2b 100644
--- a/src/relax/backend/contrib/dnnl/codegen.cc
+++ b/src/relax/backend/contrib/dnnl/codegen.cc
@@ -52,7 +52,7 @@ class DNNLJSONSerializer : public JSONSerializer {
ICHECK(fn.defined()) << "Expects the callee to be a function.";
auto composite_opt = fn->GetAttr<String>(attr::kComposite);
- ICHECK(composite_opt.defined()) << "Only composite functions are
supported.";
+ ICHECK(composite_opt.has_value()) << "Only composite functions are
supported.";
std::string composite_name = composite_opt.value();
diff --git a/src/relax/backend/contrib/hipblas/codegen.cc
b/src/relax/backend/contrib/hipblas/codegen.cc
index d14d7aed57..761221c88b 100644
--- a/src/relax/backend/contrib/hipblas/codegen.cc
+++ b/src/relax/backend/contrib/hipblas/codegen.cc
@@ -52,7 +52,7 @@ class HipblasJSONSerializer : public JSONSerializer {
ICHECK(fn.defined()) << "Expects the callee to be a function.";
auto composite_opt = fn->GetAttr<String>(attr::kComposite);
- ICHECK(composite_opt.defined()) << "Only composite functions are
supported.";
+ ICHECK(composite_opt.has_value()) << "Only composite functions are
supported.";
std::string composite_name = composite_opt.value();
diff --git a/src/relax/backend/contrib/nnapi/codegen.cc
b/src/relax/backend/contrib/nnapi/codegen.cc
index ded7340b6f..c62523f539 100644
--- a/src/relax/backend/contrib/nnapi/codegen.cc
+++ b/src/relax/backend/contrib/nnapi/codegen.cc
@@ -201,7 +201,7 @@ class NNAPIJSONSerializer : public JSONSerializer {
ICHECK(fn.defined()) << "Expects the callee to be a function.";
auto composite_opt = fn->GetAttr<String>(attr::kComposite);
- ICHECK(composite_opt.defined()) << "Only composite functions are
supported.";
+ ICHECK(composite_opt.has_value()) << "Only composite functions are
supported.";
std::string composite_name = composite_opt.value();
diff --git a/src/relax/backend/vm/exec_builder.cc
b/src/relax/backend/vm/exec_builder.cc
index 0a768e89fe..15f292261e 100644
--- a/src/relax/backend/vm/exec_builder.cc
+++ b/src/relax/backend/vm/exec_builder.cc
@@ -56,24 +56,15 @@ vm::Instruction::Arg ExecBuilderNode::ConvertConstant_(Any
cvalue) {
return vm::Instruction::Arg::Immediate(val);
}
}
-
// run dedup for object with structural equality
- if (auto opt_obj = cvalue.as<ObjectRef>()) {
- ObjectRef obj = opt_obj.value();
- auto it = const_dedup_map_.find(obj);
- if (it != const_dedup_map_.end()) {
- return vm::Instruction::Arg::ConstIdx(it->second);
- }
- vm::Index idx = exec_->constants.size();
- exec_->constants.push_back(cvalue);
- const_dedup_map_[obj] = idx;
- return vm::Instruction::Arg::ConstIdx(idx);
- } else {
- // emit normal constant
- vm::Index idx = exec_->constants.size();
- exec_->constants.push_back(cvalue);
- return vm::Instruction::Arg::ConstIdx(idx);
+ auto it = const_dedup_map_.find(cvalue);
+ if (it != const_dedup_map_.end()) {
+ return vm::Instruction::Arg::ConstIdx(it->second);
}
+ vm::Index idx = exec_->constants.size();
+ exec_->constants.push_back(cvalue);
+ const_dedup_map_[cvalue] = idx;
+ return vm::Instruction::Arg::ConstIdx(idx);
}
void ExecBuilderNode::DeclareFunction(const std::string& func_name,
VMFuncInfo::FuncKind kind) {
diff --git a/src/relax/transform/bind_params.cc
b/src/relax/transform/bind_params.cc
index 49fe469e89..13b138ecce 100644
--- a/src/relax/transform/bind_params.cc
+++ b/src/relax/transform/bind_params.cc
@@ -83,7 +83,7 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant,
}
std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings(
- const Function& func, const Map<ObjectRef, ObjectRef>& untyped_params) {
+ const Function& func, const Map<Any, ObjectRef>& untyped_params) {
ICHECK(func.defined());
ICHECK(untyped_params.defined());
@@ -158,7 +158,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>>
NormalizeBindings(
* \param params params dict
* \return Function
*/
-Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>&
untyped_params) {
+Function FunctionBindParams(Function func, const Map<Any, ObjectRef>&
untyped_params) {
auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params);
Expr bound_expr = Bind(func, bind_dict, symbolic_var_map);
@@ -172,7 +172,7 @@ Function FunctionBindParams(Function func, const
Map<ObjectRef, ObjectRef>& unty
* \param param The param dict
* \return The module after binding params.
*/
-IRModule BindParam(IRModule m, String func_name, Map<ObjectRef, ObjectRef>
bind_params) {
+IRModule BindParam(IRModule m, String func_name, Map<Any, ObjectRef>
bind_params) {
IRModuleNode* new_module = m.CopyOnWrite();
Map<GlobalVar, BaseFunc> functions = m->functions;
for (const auto& func_pr : functions) {
@@ -203,7 +203,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
namespace transform {
-Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params) {
+Pass BindParams(String func_name, Map<Any, ObjectRef> params) {
auto pass_func = [=](IRModule mod, PassContext pc) {
return BindParam(std::move(mod), func_name, params);
};
diff --git a/src/relax/transform/bind_symbolic_vars.cc
b/src/relax/transform/bind_symbolic_vars.cc
index 22c557874c..5ba25b7e16 100644
--- a/src/relax/transform/bind_symbolic_vars.cc
+++ b/src/relax/transform/bind_symbolic_vars.cc
@@ -31,7 +31,8 @@
namespace tvm {
namespace relax {
-Function FunctionBindSymbolicVars(Function func, Map<ffi::Any, PrimExpr>
obj_remap) {
+Function FunctionBindSymbolicVars(Function func,
+ Map<Variant<tir::Var, String>, PrimExpr>
obj_remap) {
// Early bail-out if no updates need to be made.
if (obj_remap.empty()) {
return func;
@@ -90,7 +91,8 @@ Function FunctionBindSymbolicVars(Function func,
Map<ffi::Any, PrimExpr> obj_rem
}
namespace {
-IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, PrimExpr>
binding_map) {
+IRModule ModuleBindSymbolicVars(IRModule mod,
+ Map<Variant<tir::Var, String>, PrimExpr>
binding_map) {
std::unordered_set<ffi::Any, ffi::AnyHash, ffi::AnyEqual> used;
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
@@ -98,7 +100,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any,
PrimExpr> binding_ma
auto func = opt.value();
// Collect bindings that are used by this function.
- auto func_binding_map = [&]() -> Map<ffi::Any, PrimExpr> {
+ auto func_binding_map = [&]() -> Map<Variant<tir::Var, String>,
PrimExpr> {
std::unordered_set<std::string> var_names;
std::unordered_set<const tir::VarNode*> vars;
for (const auto& var : DefinedSymbolicVars(func)) {
@@ -106,7 +108,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any,
PrimExpr> binding_ma
vars.insert(var.get());
}
- Map<ffi::Any, PrimExpr> out;
+ Map<Variant<tir::Var, String>, PrimExpr> out;
for (const auto& [key, replacement] : binding_map) {
bool used_by_function = false;
if (auto opt = key.as<String>()) {
@@ -156,7 +158,8 @@ TVM_FFI_STATIC_INIT_BLOCK({
namespace transform {
-Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map, Optional<String>
func_name) {
+Pass BindSymbolicVars(Map<Variant<tir::Var, String>, PrimExpr> binding_map,
+ Optional<String> func_name) {
auto pass_func = [=](IRModule mod, PassContext context) -> IRModule {
if (func_name) {
auto gvar = mod->GetGlobalVar(func_name.value());
diff --git a/src/runtime/minrpc/rpc_reference.h
b/src/runtime/minrpc/rpc_reference.h
index 42be97b53f..b5f1e6995f 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -408,6 +408,9 @@ struct RPCReference {
int32_t type_index;
channel->Read(&type_index);
packed_args[i].type_index = type_index;
+ packed_args[i].zero_padding = 0;
+ // clear to ensure compact for 32 bit platform
+ packed_args[i].v_int64 = 0;
switch (type_index) {
case ffi::TypeIndex::kTVMFFINone: {
break;
diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc
index ddd5462c68..e9652618e4 100644
--- a/src/runtime/profiling.cc
+++ b/src/runtime/profiling.cc
@@ -613,7 +613,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool
compute_col_sums) con
// fill empty data with empty strings
cols[i].push_back("");
} else {
- cols[i].push_back(print_metric((*it).second.cast<ObjectRef>()));
+ cols[i].push_back(print_metric((*it).second));
}
}
}
@@ -653,7 +653,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool
compute_col_sums) con
// Add configuration information. It will not be aligned with the columns.
s << std::endl << "Configuration" << std::endl << "-------------" <<
std::endl;
for (auto kv : configuration) {
- s << kv.first << ": " << print_metric(kv.second.cast<ObjectRef>()) <<
std::endl;
+ s << kv.first << ": " << print_metric(kv.second) << std::endl;
}
return s.str();
}
diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc
index d1fb7bab90..a693c671f3 100644
--- a/src/runtime/rpc/rpc_module.cc
+++ b/src/runtime/rpc/rpc_module.cc
@@ -88,10 +88,17 @@ class RPCWrappedFunc : public Object {
// scan and check whether we need rewrite these arguments
// to their remote variant.
for (int i = 0; i < args.size(); ++i) {
+ // handle both str and small str
if (args[i].type_index() == ffi::TypeIndex::kTVMFFIStr) {
// pass string as c_str
packed_args[i] = args[i].cast<ffi::String>().data();
continue;
+ } else if (args[i].type_index() == ffi::TypeIndex::kTVMFFISmallStr) {
+ // we cannot cast here, since we need to make sure the space is alive
+ const TVMFFIAny* any_view_ptr = reinterpret_cast<const
TVMFFIAny*>(&args.data()[i]);
+ TVMFFIByteArray bytes =
TVMFFISmallBytesGetContentByteArray(any_view_ptr);
+ packed_args[i] = bytes.data;
+ continue;
}
packed_args[i] = args[i];
// run a remote translation to translate RPC related objects to
@@ -314,7 +321,9 @@ void
RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv)
AddRPCSessionMask(tensor->device,
sess_->table_index()),
nd_handle);
} else if (type_index == ffi::TypeIndex::kTVMFFIBytes ||
- type_index == ffi::TypeIndex::kTVMFFIStr) {
+ type_index == ffi::TypeIndex::kTVMFFIStr ||
+ type_index == ffi::TypeIndex::kTVMFFISmallStr ||
+ type_index == ffi::TypeIndex::kTVMFFISmallBytes) {
ICHECK_EQ(args.size(), 2);
*rv = args[1];
} else if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 9d5d9dade5..33a687f54b 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -245,8 +245,7 @@ Map<String, Any> MergeAnnotations(const Map<String, Any>&
new_attrs,
// Case 2.2: the values are not both dicts, check if the keys are the same
if (!ffi::AnyEqual()(old_value.value(), value)) {
LOG(FATAL) << "ValueError: Try to merge two annotations with different
values for key `"
- << key << "`, previous one is " <<
old_value->cast<ObjectRef>() << ", new one is "
- << value.cast<ObjectRef>();
+ << key << "`, previous one is " << old_value.value() << ",
new one is " << value;
}
}
return result;
@@ -521,11 +520,11 @@ AllocateConstFrame AllocateConst(tvm::runtime::NDArray
data, DataType dtype,
AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) {
// convert POD value to PrimExpr
- if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ if (node.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
node = node.cast<PrimExpr>();
}
ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
- n->node = node.cast<ObjectRef>();
+ n->node = std::move(node);
n->attr_key = attr_key;
n->value = value;
return AttrFrame(n);
diff --git a/src/script/printer/doc_printer/python_doc_printer.cc
b/src/script/printer/doc_printer/python_doc_printer.cc
index f8d773334f..21f5e33015 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -663,7 +663,7 @@ void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) {
void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
for (const AssignDoc& arg_doc : doc->args) {
- ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment
attached to them.";
+ ICHECK(!arg_doc->comment.has_value()) << "Function arg cannot have comment
attached to them.";
}
PrintDecorators(doc->decorators);
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index 737f27c7e9..a1b1272cde 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -212,7 +212,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
}
return arr;
})
- .def("testing.AcceptsMapOfPrimExpr", [](Map<ObjectRef, PrimExpr> map) ->
ObjectRef {
+ .def("testing.AcceptsMapOfPrimExpr", [](Map<Any, PrimExpr> map) ->
ObjectRef {
for (const auto& kv : map) {
ObjectRef value = kv.second;
CHECK(value->IsInstance<PrimExprNode>())
diff --git a/src/support/utils.h b/src/support/utils.h
index eb0d4b9a88..8af2747831 100644
--- a/src/support/utils.h
+++ b/src/support/utils.h
@@ -139,13 +139,14 @@ inline std::vector<std::string> Split(const std::string&
str, char delim) {
* \return Whether the prefix matched.
*/
inline bool StartsWith(const ffi::String& str, const char* prefix) {
- size_t n = str.length();
- for (size_t i = 0; i < n; i++) {
- if (prefix[i] == '\0') return true;
- if (str.data()[i] != prefix[i]) return false;
+ const char* data = str.data();
+ const char* data_end = data + str.size();
+ for (; data != data_end; ++data, ++prefix) {
+ if (*prefix == '\0') return true;
+ if (*data != *prefix) return false;
}
// return true if the str is equal to the prefix
- return prefix[n] == '\0';
+ return *prefix == '\0';
}
/*!
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index b85b51e3d2..4dd24026c0 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -339,6 +339,11 @@ CodeGenLLVM::TypedPointer
CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value
buf = builder_->CreateInBoundsGEP(t_tvm_ffi_any_, buf, {index,
ConstInt32(0)});
return TypedPointer(t_int32_, buf);
}
+ case builtin::kTVMFFIAnyZeroPadding: {
+ buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_ffi_any_,
0));
+ buf = builder_->CreateInBoundsGEP(t_tvm_ffi_any_, buf, {index,
ConstInt32(1)});
+ return TypedPointer(t_int32_, buf);
+ }
case builtin::kTVMFFIAnyUnionValue: {
ICHECK_EQ(t.lanes(), 1);
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_ffi_any_,
0));
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 11f0eaf1ba..acc05cf96c 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -335,6 +335,12 @@ std::string CodeGenC::GetStructRef(DataType t, const
PrimExpr& buffer, const Pri
this->PrintExpr(buffer, os);
os << ")[" << index << "].type_index)";
return os.str();
+ } else if (kind == builtin::kTVMFFIAnyZeroPadding) {
+ std::ostringstream os;
+ os << "(((TVMFFIAny*)";
+ this->PrintExpr(buffer, os);
+ os << ")[" << index << "].zero_padding)";
+ return os.str();
} else if (kind == builtin::kTVMFFIAnyUnionValue) {
std::ostringstream os;
os << "(((TVMFFIAny*)";
diff --git a/src/target/source/codegen_c_host.cc
b/src/target/source/codegen_c_host.cc
index 2e808738ef..6cd12a9319 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -246,6 +246,8 @@ void CodeGenCHost::PrintCallPacked(const CallNode* op) {
// must make sure type_index is set to none
this->stream << result << ".type_index = kTVMFFINone;\n";
this->PrintIndent();
+ this->stream << result << ".zero_padding = 0;\n";
+ this->PrintIndent();
this->stream << result << ".v_int64 = 0;\n";
this->PrintIndent();
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 6803e01f50..56fab07605 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -100,10 +100,10 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](Any node, String attr_key, PrimExpr value, Stmt
body, Span span) {
// when node is a POD data type like int or bool,
first convert to
// primexpr.
- if (node.type_index() <
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ if (node.type_index() <
ffi::TypeIndex::kTVMFFISmallStr) {
return AttrStmt(node.cast<PrimExpr>(), attr_key,
value, body, span);
}
- return AttrStmt(node.cast<ObjectRef>(), attr_key,
value, body, span);
+ return AttrStmt(node, attr_key, value, body, span);
});
});
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index c00c946852..6f7e682d6c 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -916,8 +916,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const
ffi::Any& ann_val) {
if (auto opt_str = ann_val.try_cast<ffi::String>()) {
return *std::move(opt_str);
}
-
- if (ann_val.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ if (ann_val.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
return ann_val;
}
// prefer to return int/float literals for annotations
diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc
index 3ee43c698a..2f327354c9 100644
--- a/src/tir/schedule/instruction.cc
+++ b/src/tir/schedule/instruction.cc
@@ -74,7 +74,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
inputs.push_back(String('"' + (*opt_str).operator std::string() +
'"'));
} else if (obj.as<BlockRVNode>() || obj.as<LoopRVNode>()) {
inputs.push_back(String("_"));
- } else if (obj.type_index() <
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
inputs.push_back(obj);
} else if (obj.as<IntImmNode>() || obj.as<FloatImmNode>()) {
inputs.push_back(obj);
diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc
index 43c2ce0a7b..61f24f980f 100644
--- a/src/tir/schedule/trace.cc
+++ b/src/tir/schedule/trace.cc
@@ -71,7 +71,7 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs,
};
for (const Any& input : inputs) {
- if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
// directly put back POD type
result.push_back(input);
} else if (auto expr = input.as<ffi::String>()) {
@@ -110,8 +110,11 @@ Array<Any> TranslateInputRVs(
results.push_back(String("None"));
continue;
}
- if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
- // directly put back POD type
+ // string => "content"
+ if (auto opt_str = input.as<ffi::String>()) {
+ results.push_back(String('"' + (*opt_str).operator std::string() + '"'));
+ } else if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
+ // directly put back POD type and not string
results.push_back(input);
} else if (input.as<BlockRVNode>() || // RV: block
input.as<LoopRVNode>() || // RV: loop
@@ -124,9 +127,6 @@ Array<Any> TranslateInputRVs(
LOG(FATAL) << "IndexError: Random variable is not defined " << input;
throw;
}
- } else if (auto opt_str = input.as<ffi::String>()) {
- // Case 2. string => "content"
- results.push_back(String('"' + (*opt_str).operator std::string() + '"'));
} else if (input.as<IntImmNode>() || input.as<FloatImmNode>()) {
// Case 3. integer or floating-point number
results.push_back(input);
@@ -159,7 +159,7 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs,
Array<Any> results;
results.reserve(inputs.size());
for (const Any& input : inputs) {
- if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
// directly put back POD type
results.push_back(input);
continue;
diff --git a/src/tir/transforms/lower_tvm_builtin.cc
b/src/tir/transforms/lower_tvm_builtin.cc
index 0db4398711..e74f5c7c90 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -521,6 +521,9 @@ class BuiltinLower : public StmtExprMutator {
prep_seq->emplace_back(TVMStructSet(args_stack, stack_offset,
builtin::kTVMFFIAnyTypeIndex,
ConstInt32(arg_type_index)));
}
+ // set zero padding to ensure compatibility with FFI convention
+ prep_seq->emplace_back(
+ TVMStructSet(args_stack, stack_offset,
builtin::kTVMFFIAnyZeroPadding, ConstInt32(0)));
// handle arg value
// NOTE: the intrinsic codegen will handle padding value clear for 32bit
// types or types that are smaller than 64 bits.
@@ -578,6 +581,8 @@ class BuiltinLower : public StmtExprMutator {
// explicitly set return value to None to avoid bad state interpretation
prep_seq.emplace_back(TVMStructSet(scope.stack_ffi_any, num_args,
builtin::kTVMFFIAnyTypeIndex,
ConstInt32(ffi::TypeIndex::kTVMFFINone)));
+ prep_seq.emplace_back(
+ TVMStructSet(scope.stack_ffi_any, num_args,
builtin::kTVMFFIAnyZeroPadding, ConstInt32(0)));
prep_seq.emplace_back(TVMStructSet(scope.stack_ffi_any, num_args,
builtin::kTVMFFIAnyUnionValue,
make_zero(DataType::Int(64))));
// Verify stack size matches earlier value.
diff --git a/src/tir/transforms/make_packed_api.cc
b/src/tir/transforms/make_packed_api.cc
index d95a02a0ba..7477fe8636 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -105,12 +105,17 @@ class ReturnRewriter : public StmtMutator {
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32),
tir::builtin::kTVMFFIAnyTypeIndex),
IntImm(DataType::Int(32), info.type_index)}));
+ Stmt store_zero_padding =
+ tir::Evaluate(tir::Call(DataType::Int(32),
tir::builtin::tvm_struct_set(),
+ {ret_var_, IntImm(DataType::Int(32), 0),
+ IntImm(DataType::Int(32),
tir::builtin::kTVMFFIAnyZeroPadding),
+ IntImm(DataType::Int(32), 0)}));
Stmt store_val = tir::Evaluate(
tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(),
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32),
tir::builtin::kTVMFFIAnyUnionValue), info.expr}));
Stmt ret_zero = Evaluate(tvm::ret(0));
- return SeqStmt({store_tindex, store_val, ret_zero});
+ return SeqStmt({store_tindex, store_zero_padding, store_val, ret_zero});
}
Var ret_var_;
diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
index 299c193146..08f377829f 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
@@ -97,13 +97,17 @@ def test_lower_call_packed():
T.tvm_struct_set(stack_array, 2, 9, 0)
T.tvm_struct_set(stack_array, 2, 10, 1)
T.tvm_struct_set(stack_ffi_any, 0, 13, 7)
- T.tvm_struct_set(stack_ffi_any, 0, 14,
T.tvm_struct_get(stack_array, 0, 0, "handle"))
+ T.tvm_struct_set(stack_ffi_any, 0, 14, 0)
+ T.tvm_struct_set(stack_ffi_any, 0, 15,
T.tvm_struct_get(stack_array, 0, 0, "handle"))
T.tvm_struct_set(stack_ffi_any, 1, 13, 7)
- T.tvm_struct_set(stack_ffi_any, 1, 14,
T.tvm_struct_get(stack_array, 1, 0, "handle"))
+ T.tvm_struct_set(stack_ffi_any, 1, 14, 0)
+ T.tvm_struct_set(stack_ffi_any, 1, 15,
T.tvm_struct_get(stack_array, 1, 0, "handle"))
T.tvm_struct_set(stack_ffi_any, 2, 13, 7)
- T.tvm_struct_set(stack_ffi_any, 2, 14,
T.tvm_struct_get(stack_array, 2, 0, "handle"))
+ T.tvm_struct_set(stack_ffi_any, 2, 14, 0)
+ T.tvm_struct_set(stack_ffi_any, 2, 15,
T.tvm_struct_get(stack_array, 2, 0, "handle"))
T.tvm_struct_set(stack_ffi_any, 3, 13, 0)
- T.tvm_struct_set(stack_ffi_any, 3, 14, T.int64(0))
+ T.tvm_struct_set(stack_ffi_any, 3, 14, 0)
+ T.tvm_struct_set(stack_ffi_any, 3, 15, T.int64(0))
T.call_packed_lowered("tvm.test_matmul", stack_ffi_any, 0, 3)
After = tvm.tir.transform.LowerTVMBuiltin()(Before)
diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py
b/tests/python/tir-transform/test_tir_transform_make_packed_api.py
index 49bfa75b72..dd7bd3bf54 100644
--- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py
+++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py
@@ -266,7 +266,8 @@ def test_zero_arg_function():
assert num_args == 0, "func_without_arg: num_args should be 0"
with T.attr(0, "compute_scope", "func_without_arg_compute_"):
T.tvm_struct_set(result, 0, 13, 1)
- T.tvm_struct_set(result, 0, 14, T.Cast("int64", T.int64(42)))
+ T.tvm_struct_set(result, 0, 14, 0)
+ T.tvm_struct_set(result, 0, 15, T.Cast("int64", T.int64(42)))
return 0
return 0
@@ -320,15 +321,17 @@ def test_int_parameter():
assert not T.isnullptr(args), "main: args pointer is NULL"
arg_type_index: T.int32 = T.tvm_struct_get(args, 0, 13, "int32")
assert arg_type_index == 1 or arg_type_index == 2, "main: Expect
arg[0] to be int"
- arg: T.int32 = T.Cast("int32", T.tvm_struct_get(args, 0, 14,
"int64"))
+ arg: T.int32 = T.Cast("int32", T.tvm_struct_get(args, 0, 15,
"int64"))
with T.attr(0, "compute_scope", "main_compute_"):
if arg > 0:
T.tvm_struct_set(result, 0, 13, 1)
- T.tvm_struct_set(result, 0, 14, T.Cast("int64", 10))
+ T.tvm_struct_set(result, 0, 14, 0)
+ T.tvm_struct_set(result, 0, 15, T.Cast("int64", 10))
return 0
else:
T.tvm_struct_set(result, 0, 13, 1)
- T.tvm_struct_set(result, 0, 14, T.Cast("int64", 20))
+ T.tvm_struct_set(result, 0, 14, 0)
+ T.tvm_struct_set(result, 0, 15, T.Cast("int64", 20))
return 0
return 0
@@ -375,15 +378,17 @@ def test_bool_parameter():
assert not T.isnullptr(args), "main: args pointer is NULL"
arg_type_index: T.int32 = T.tvm_struct_get(args, 0, 13, "int32")
assert arg_type_index == 2 or arg_type_index == 1, "main: Expect
arg[0] to be boolean"
- arg: T.bool = T.Cast("bool", T.tvm_struct_get(args, 0, 14,
"int64"))
+ arg: T.bool = T.Cast("bool", T.tvm_struct_get(args, 0, 15,
"int64"))
with T.attr(0, "compute_scope", "main_compute_"):
if arg:
T.tvm_struct_set(result, 0, 13, 1)
- T.tvm_struct_set(result, 0, 14, T.Cast("int64", 10))
+ T.tvm_struct_set(result, 0, 14, 0)
+ T.tvm_struct_set(result, 0, 15, T.Cast("int64", 10))
return 0
else:
T.tvm_struct_set(result, 0, 13, 1)
- T.tvm_struct_set(result, 0, 14, T.Cast("int64", 20))
+ T.tvm_struct_set(result, 0, 14, 0)
+ T.tvm_struct_set(result, 0, 15, T.Cast("int64", 20))
return 0
return 0
diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts
index feee56b81f..41d848a228 100644
--- a/web/src/ctypes.ts
+++ b/web/src/ctypes.ts
@@ -72,6 +72,10 @@ export const enum TypeIndex {
kTVMFFIByteArrayPtr = 9,
/*! \brief R-value reference to ObjectRef */
kTVMFFIObjectRValueRef = 10,
+ /*! \brief Small string on stack */
+ kTVMFFISmallStr = 11,
+ /*! \brief Small bytes on stack */
+ kTVMFFISmallBytes = 12,
/*! \brief Start of statically defined objects. */
kTVMFFIStaticObjectBegin = 64,
/*!
diff --git a/web/src/memory.ts b/web/src/memory.ts
index 850f3bd371..94ecb4e15a 100644
--- a/web/src/memory.ts
+++ b/web/src/memory.ts
@@ -186,11 +186,44 @@ export class Memory {
const typeKeyPtr = typeInfoPtr + 2 * SizeOf.I32;
return this.loadByteArrayAsString(typeKeyPtr);
}
+ /**
+ * Load small string from value pointer.
+ * @param ffiAnyPtr The pointer to the value.
+ * @returns The small string.
+ */
+ loadSmallStr(ffiAnyPtr: Pointer): string {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ const sizePtr = ffiAnyPtr + SizeOf.I32;
+ const length = this.loadU32(sizePtr);
+ const dataPtr = ffiAnyPtr + SizeOf.I32 + SizeOf.I32;
+ const ret = [];
+ for (let i = 0; i < length; i++) {
+ ret.push(String.fromCharCode(this.viewU8[dataPtr + i]));
+ }
+ return ret.join("");
+ }
+ /**
+ * Load small bytes from value pointer.
+ * @param ffiAnyPtr
+ */
+ loadSmallBytes(ffiAnyPtr: Pointer): Uint8Array {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ const sizePtr = ffiAnyPtr + SizeOf.I32;
+ const length = this.loadU32(sizePtr);
+ const dataPtr = ffiAnyPtr + SizeOf.I32 + SizeOf.I32;
+ const result = new Uint8Array(length);
+ result.set(this.viewU8.slice(dataPtr, dataPtr + length));
+ return result;
+ }
/**
* Load bytearray as string from ptr.
* @param byteArrayPtr The head address of the bytearray.
*/
- loadByteArrayAsString(byteArrayPtr: Pointer): string {
+ loadByteArrayAsString(byteArrayPtr: Pointer): string {
if (this.buffer != this.memory.buffer) {
this.updateViews();
}
@@ -207,16 +240,16 @@ export class Memory {
* Load bytearray as bytes from ptr.
* @param byteArrayPtr The head address of the bytearray.
*/
- loadByteArrayAsBytes(byteArrayPtr: Pointer): Uint8Array {
- if (this.buffer != this.memory.buffer) {
- this.updateViews();
+ loadByteArrayAsBytes(byteArrayPtr: Pointer): Uint8Array {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ const ptr = this.loadPointer(byteArrayPtr);
+ const length = this.loadUSize(byteArrayPtr + this.sizeofPtr());
+ const result = new Uint8Array(length);
+ result.set(this.viewU8.slice(ptr, ptr + length));
+ return result;
}
- const ptr = this.loadPointer(byteArrayPtr);
- const length = this.loadUSize(byteArrayPtr + this.sizeofPtr());
- const result = new Uint8Array(length);
- result.set(this.viewU8.slice(ptr, ptr + length));
- return result;
-}
// private functions
/**
* Update memory view after the memory growth.
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 162052d41b..75f4de8555 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -2019,6 +2019,7 @@ export class Instance implements Disposable {
const tp = typeof val;
const argOffset = packedArgs + i * SizeOf.TVMFFIAny;
const argTypeIndexOffset = argOffset;
+ const argZeroPaddingOffset = argOffset + SizeOf.I32;
const argValueOffset = argOffset + SizeOf.I32 * 2;
// Convert string[] to a TVMArray of, hence treated as a TVMObject
@@ -2028,8 +2029,9 @@ export class Instance implements Disposable {
val = this.makeTVMArray(tvmStringArray);
}
- // clear off the extra padding valuesbefore ptr storage
- stack.storeI32(argTypeIndexOffset + SizeOf.I32, 0);
+ // clear off the extra zero padding before ptr storage
+ stack.storeI32(argZeroPaddingOffset, 0);
+ // clear off the extra zero padding after ptr storage
stack.storeI32(argValueOffset + SizeOf.I32, 0);
if (val instanceof NDArray) {
if (!val.isView) {
@@ -2177,6 +2179,8 @@ export class Instance implements Disposable {
const retOffset = stack.allocRawBytes(SizeOf.TVMFFIAny);
// pre-store the result to be null
stack.storeI32(retOffset, TypeIndex.kTVMFFINone);
+ // clear off the extra zero padding before ptr storage
+ stack.storeI32(retOffset + SizeOf.I32, 0);
stack.commitToWasmMemory();
this.lib.checkCall(
(this.exports.TVMFFIFunctionCall as ctypes.FTVMFFIFunctionCall)(
@@ -2253,6 +2257,9 @@ export class Instance implements Disposable {
);
return result;
}
+ case TypeIndex.kTVMFFISmallStr: {
+ return this.memory.loadSmallStr(resultAnyPtr);
+ }
case TypeIndex.kTVMFFIStr: {
const strObjPtr = this.memory.loadPointer(valuePtr);
const result = this.memory.loadByteArrayAsString(strObjPtr +
SizeOf.ObjectHeader);
@@ -2261,6 +2268,9 @@ export class Instance implements Disposable {
);
return result;
}
+ case TypeIndex.kTVMFFISmallBytes: {
+ return this.memory.loadSmallBytes(resultAnyPtr);
+ }
case TypeIndex.kTVMFFIBytes: {
const bytesObjPtr = this.memory.loadPointer(valuePtr);
const result = this.memory.loadByteArrayAsBytes(bytesObjPtr +
SizeOf.ObjectHeader);
diff --git a/web/tests/node/test_packed_func.js
b/web/tests/node/test_packed_func.js
index e2b6c7b7c9..3c6980cc1f 100644
--- a/web/tests/node/test_packed_func.js
+++ b/web/tests/node/test_packed_func.js
@@ -46,7 +46,9 @@ test("GetGlobal", () => {
// check function argument with different types.
assert(fecho(1123) == 1123);
assert(fecho("xyz") == "xyz");
-
+ // test long string as the abi can be different from small str
+ const long_str = "1234567890123456789abcdefghijklmnopqrstuvwxyz";
+ assert(fecho(long_str) == long_str);
let bytes = new Uint8Array([1, 2, 3]);
let rbytes = fecho(bytes);
assert(rbytes.length == bytes.length);
@@ -55,6 +57,16 @@ test("GetGlobal", () => {
assert(rbytes[i] == bytes[i]);
}
+ const long_bytes = new Uint8Array(1024);
+ for (let i = 0; i < long_bytes.length; ++i) {
+ long_bytes[i] = i;
+ }
+ let rlong_bytes = fecho(long_bytes);
+ assert(rlong_bytes.length == long_bytes.length);
+ for (let i = 0; i < long_bytes.length; ++i) {
+ assert(rlong_bytes[i] == long_bytes[i]);
+ }
+
assert(fecho(undefined) == undefined);
tvm.beginScope();