This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch small-str-v1
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/small-str-v1 by this push:
new 5fb9e5c368 [FFI] Bring up small bytes along with smallstr
5fb9e5c368 is described below
commit 5fb9e5c3681d2fdd70ffbedf8c80bd87f4974d80
Author: tqchen <[email protected]>
AuthorDate: Sat Aug 2 12:35:43 2025 -0400
[FFI] Bring up small bytes along with smallstr
---
ffi/include/tvm/ffi/any.h | 19 +++-
ffi/include/tvm/ffi/c_api.h | 4 +-
ffi/include/tvm/ffi/object.h | 1 +
ffi/include/tvm/ffi/string.h | 149 +++++++++++++++++----------
ffi/src/ffi/extra/structural_equal.cc | 15 ++-
ffi/src/ffi/object.cc | 1 +
ffi/tests/cpp/test_any.cc | 7 ++
ffi/tests/cpp/test_string.cc | 13 ++-
jvm/native/src/main/native/jni_helper_func.h | 6 +-
python/tvm/ffi/cython/base.pxi | 3 +-
python/tvm/ffi/cython/dtype.pxi | 2 +-
python/tvm/ffi/cython/function.pxi | 11 +-
src/node/repr_printer.cc | 1 +
src/node/serialization.cc | 12 ++-
src/runtime/rpc/rpc_module.cc | 5 +-
src/script/ir_builder/tir/ir.cc | 3 +-
src/tir/ir/stmt.cc | 3 +-
src/tir/schedule/concrete_schedule.cc | 3 +-
src/tir/schedule/instruction.cc | 3 +-
src/tir/schedule/trace.cc | 9 +-
web/src/ctypes.ts | 2 +
21 files changed, 187 insertions(+), 85 deletions(-)
diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 51d68303c2..55eff8802a 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -554,6 +554,10 @@ struct AnyHash {
// so heap allocated string and on stack string will have the same hash
return details::StableHashCombine(TypeIndex::kTVMFFIStr,
details::StableHashSmallStrBytes(&src.data_));
+ } else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) {
+ // use byte the same type key as bytes
+ return details::StableHashCombine(TypeIndex::kTVMFFIBytes,
+
details::StableHashSmallStrBytes(&src.data_));
} else if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
const details::BytesObjBase* src_str =
@@ -598,7 +602,8 @@ struct AnyEqual {
return false;
} else {
// type_index mismatch, if index is not string, return false
- if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index !=
kTVMFFISmallStr) {
+ if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index !=
kTVMFFISmallStr &&
+ lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index !=
kTVMFFIBytes) {
return false;
}
// small string and normal string comparison
@@ -614,6 +619,18 @@ struct AnyEqual {
return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data,
lhs.data_.small_str_len,
rhs_str->size);
}
+ if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index ==
kTVMFFISmallBytes) {
+ const details::BytesObjBase* lhs_bytes =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
+ return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes,
lhs_bytes->size,
+ rhs.data_.small_str_len);
+ }
+ if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index ==
kTVMFFIBytes) {
+ const details::BytesObjBase* rhs_bytes =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
+ return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data,
lhs.data_.small_str_len,
+ rhs_bytes->size);
+ }
return false;
}
}
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index bb25dfed0d..11080a21f0 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -105,6 +105,8 @@ typedef enum {
kTVMFFIObjectRValueRef = 10,
/*! \brief Small string on stack */
kTVMFFISmallStr = 11,
+ /*! \brief Small bytes on stack */
+ kTVMFFISmallBytes = 12,
/*! \brief Start of statically defined objects. */
kTVMFFIStaticObjectBegin = 64,
/*!
@@ -917,7 +919,7 @@ inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle
obj) {
* \param obj The object handle.
* \return The content of the small string in bytearray format.
*/
-inline TVMFFIByteArray TVMFFISmallStrGetContentByteArray(const TVMFFIAny*
value) {
+inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny*
value) {
return TVMFFIByteArray{value->v_bytes,
static_cast<size_t>(value->small_str_len)};
}
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index 74977e0216..4b7b56209a 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -61,6 +61,7 @@ struct StaticTypeKey {
static constexpr const char* kTVMFFIArray = "ffi.Array";
static constexpr const char* kTVMFFIMap = "ffi.Map";
static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr";
+ static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes";
};
/*!
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index a679e68d75..26672b12cb 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -272,29 +272,40 @@ class BytesBaseCell {
/*!
* \brief Managed reference of byte array.
*/
-class Bytes : public ObjectRef {
+class Bytes {
public:
+ /*! \brief default constructor */
+ Bytes() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallBytes); }
/*!
* \brief constructor from size
*
* \param other a char array.
*/
- Bytes(const char* data, size_t size) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(data, size)) {}
+ Bytes(const char* data, size_t size) { this->InitData(data, size); }
/*!
* \brief constructor from TVMFFIByteArray
*
* \param other a char array.
*/
- Bytes(TVMFFIByteArray bytes) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(bytes.data,
bytes.size)) {}
+ Bytes(TVMFFIByteArray bytes) { // NOLINT(*)
+ this->InitData(bytes.data, bytes.size);
+ }
/*!
* \brief constructor from std::string
*
* \param other a char array.
*/
- Bytes(std::string other) // NOLINT(*)
- :
ObjectRef(make_object<details::BytesObjStdImpl<details::BytesObj>>(std::move(other)))
{}
+ Bytes(const std::string& other) { // NOLINT(*)
+ this->InitData(other.data(), other.size());
+ }
+ /*!
+ * \brief constructor from std::string
+ *
+ * \param other a char array.
+ */
+ Bytes(std::string&& other) { // NOLINT(*)
+ data_.InitFromStd<details::BytesObj>(std::move(other),
TypeIndex::kTVMFFIBytes);
+ }
/*!
* \brief Swap this String with another string
* \param other The other string
@@ -314,21 +325,19 @@ class Bytes : public ObjectRef {
*
* \return size_t string length
*/
- size_t size() const { return get()->size; }
+ size_t size() const { return data_.size(); }
/*!
* \brief Return the data pointer
*
* \return const char* data pointer
*/
- const char* data() const { return get()->data; }
+ const char* data() const { return data_.data(); }
/*!
* \brief Convert String to an std::string object
*
* \return std::string
*/
- operator std::string() const { return std::string{get()->data, size()}; }
-
- TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef,
details::BytesObj);
+ operator std::string() const { return std::string{data(), size()}; }
/*!
* \brief Compare two char sequence
@@ -371,40 +380,32 @@ class Bytes : public ObjectRef {
private:
friend class String;
+ template <typename, typename>
+ friend struct TypeTraits;
+ template <typename, typename>
+ friend struct Optional;
template <typename>
friend struct std::hash;
-
- static uint64_t AnyHash(const Bytes& bytes) {
- return details::StableHashCombine(TypeIndex::kTVMFFIBytes,
- details::StableHashBytes(bytes.data(),
bytes.size()));
+ // internal backing cell
+ details::BytesBaseCell data_;
+ // create a new String from TVMFFIAny, must keep private
+ explicit Bytes(details::BytesBaseCell data) : data_(data) {}
+ char* InitSpaceForSize(size_t size) {
+ return data_.InitSpaceForSize<details::BytesObj>(size,
TypeIndex::kTVMFFISmallBytes,
+ TypeIndex::kTVMFFIBytes);
}
+ void InitData(const char* data, size_t size) {
+ char* dest_data = InitSpaceForSize(size);
+ std::memcpy(dest_data, data, size);
+ // mainly to be compat with string
+ dest_data[size] = '\0';
+ }
+
+ static uint64_t AnyHash(const Bytes& bytes) { return bytes.data_.AnyHash(); }
};
/*!
- * \brief Reference to string objects.
- *
- * \code
- *
- * // Example to create runtime String reference object from std::string
- * std::string s = "hello world";
- *
- * // You can create the reference from existing std::string
- * String ref{std::move(s)};
- *
- * // You can rebind the reference to another string.
- * ref = std::string{"hello world2"};
- *
- * // You can use the reference as hash map key
- * std::unordered_map<String, int32_t> m;
- * m[ref] = 1;
- *
- * // You can compare the reference object with other string objects
- * assert(ref == "hello world", true);
- *
- * // You can convert the reference to std::string again
- * string s2 = (string)ref;
- *
- * \endcode
+ * \brief String container class.
*/
class String {
public:
@@ -417,10 +418,10 @@ class String {
*/
String() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallStr); }
// constructors from Any
- String(const String& other) = default;
- String(String&& other) = default;
- String& operator=(const String& other) = default;
- String& operator=(String&& other) = default;
+ String(const String& other) = default; // NOLINT(*)
+ String(String&& other) = default; // NOLINT(*)
+ String& operator=(const String& other) = default; // NOLINT(*)
+ String& operator=(String&& other) = default; // NOLINT(*)
/*!
* \brief Swap this String with another string
@@ -602,7 +603,8 @@ class String {
operator std::string() const { return std::string{data(), size()}; }
private:
- friend struct TypeTraits<String>;
+ template <typename, typename>
+ friend struct TypeTraits;
template <typename, typename>
friend class Optional;
template <typename>
@@ -620,7 +622,6 @@ class String {
return data_.InitSpaceForSize<details::StringObj>(size,
TypeIndex::kTVMFFISmallStr,
TypeIndex::kTVMFFIStr);
}
- // create a new TVMFFIAny from the data and size
void InitData(const char* data, size_t size) {
char* dest_data = InitSpaceForSize(size);
std::memcpy(dest_data, data, size);
@@ -675,6 +676,50 @@ TVM_FFI_INLINE std::string_view
ToStringView(TVMFFIByteArray str) {
return std::string_view(str.data, str.size);
}
+template <>
+inline constexpr bool use_default_type_traits_v<Bytes> = false;
+
+// specialize to enable implicit conversion from TVMFFIByteArray*
+template <>
+struct TypeTraits<Bytes> : public TypeTraitsBase {
+ // bytes can be union type of small bytes and object, so keep it as any
+ static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny;
+
+ TVM_FFI_INLINE static void CopyToAnyView(const Bytes& src, TVMFFIAny*
result) {
+ *result = src.data_.CopyToTVMFFIAny();
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(Bytes src, TVMFFIAny* result) {
+ src.data_.MoveToAny(result);
+ }
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ return src->type_index == TypeIndex::kTVMFFISmallBytes ||
+ src->type_index == TypeIndex::kTVMFFIBytes;
+ }
+
+ TVM_FFI_INLINE static Bytes CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+ return Bytes(details::BytesBaseCell::CopyFromAnyView(src));
+ }
+
+ TVM_FFI_INLINE static Bytes MoveFromAnyAfterCheck(TVMFFIAny* src) {
+ return Bytes(details::BytesBaseCell::MoveFromAny(src));
+ }
+
+ TVM_FFI_INLINE static std::optional<Bytes> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) {
+ return Bytes(*static_cast<TVMFFIByteArray*>(src->v_ptr));
+ }
+ if (src->type_index == TypeIndex::kTVMFFISmallBytes ||
+ src->type_index == TypeIndex::kTVMFFIBytes) {
+ return Bytes(details::BytesBaseCell::CopyFromAnyView(src));
+ }
+ return std::nullopt;
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() { return "str"; }
+};
+
template <>
inline constexpr bool use_default_type_traits_v<String> = false;
@@ -777,7 +822,7 @@ struct TypeTraits<TVMFFIByteArray*> : public TypeTraitsBase
{
}
TVM_FFI_INLINE static void MoveToAny(TVMFFIByteArray* src, TVMFFIAny*
result) {
- ObjectRefTypeTraitsBase<Bytes>::MoveToAny(Bytes(*src), result);
+ TypeTraits<Bytes>::MoveToAny(Bytes(*src), result);
}
TVM_FFI_INLINE static std::optional<TVMFFIByteArray*>
TryCastFromAnyView(const TVMFFIAny* src) {
@@ -790,16 +835,6 @@ struct TypeTraits<TVMFFIByteArray*> : public
TypeTraitsBase {
TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIByteArrayPtr; }
};
-template <>
-inline constexpr bool use_default_type_traits_v<Bytes> = false;
-
-// specialize to enable implicit conversion from TVMFFIByteArray*
-template <>
-struct TypeTraits<Bytes> : public ObjectRefWithFallbackTraitsBase<Bytes,
TVMFFIByteArray*> {
- static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBytes;
- TVM_FFI_INLINE static Bytes ConvertFallbackValue(TVMFFIByteArray* src) {
return Bytes(*src); }
-};
-
template <>
inline constexpr bool use_default_type_traits_v<std::string> = false;
diff --git a/ffi/src/ffi/extra/structural_equal.cc
b/ffi/src/ffi/extra/structural_equal.cc
index 90cc50ac5d..97ebbf4072 100644
--- a/ffi/src/ffi/extra/structural_equal.cc
+++ b/ffi/src/ffi/extra/structural_equal.cc
@@ -48,7 +48,8 @@ class StructEqualHandler {
const TVMFFIAny* rhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(rhs);
if (lhs_data->type_index != rhs_data->type_index) {
// type_index mismatch, if index is not string, return false
- if (lhs_data->type_index != kTVMFFIStr && lhs_data->type_index !=
kTVMFFISmallStr) {
+ if (lhs_data->type_index != kTVMFFIStr && lhs_data->type_index !=
kTVMFFISmallStr &&
+ lhs_data->type_index != kTVMFFISmallBytes && lhs_data->type_index !=
kTVMFFIBytes) {
return false;
}
// small string and normal string comparison
@@ -64,6 +65,18 @@ class StructEqualHandler {
return Bytes::memequal(lhs_data->v_bytes, rhs_str->data,
lhs_data->small_str_len,
rhs_str->size);
}
+ if (lhs_data->type_index == kTVMFFIBytes && rhs_data->type_index ==
kTVMFFISmallBytes) {
+ const details::BytesObjBase* lhs_bytes =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
+ return Bytes::memequal(lhs_bytes->data, rhs_data->v_bytes,
lhs_bytes->size,
+ rhs_data->small_str_len);
+ }
+ if (lhs_data->type_index == kTVMFFISmallBytes && rhs_data->type_index ==
kTVMFFIBytes) {
+ const details::BytesObjBase* rhs_bytes =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
+ return Bytes::memequal(lhs_data->v_bytes, rhs_bytes->data,
lhs_data->small_str_len,
+ rhs_bytes->size);
+ }
return false;
}
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 20ad356f60..9948ceda6b 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -338,6 +338,7 @@ class TypeTable {
ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef,
TypeIndex::kTVMFFIObjectRValueRef);
ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallStr,
TypeIndex::kTVMFFISmallStr);
+ ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallBytes,
TypeIndex::kTVMFFISmallBytes);
// no need to reserve for object types as they will be registered
}
diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc
index 1f393c42ab..d1f56e1a93 100644
--- a/ffi/tests/cpp/test_any.cc
+++ b/ffi/tests/cpp/test_any.cc
@@ -403,6 +403,13 @@ TEST(Any, AnyEqualHash) {
EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr);
EXPECT_TRUE(AnyEqual()(a, b));
EXPECT_EQ(AnyHash()(a), AnyHash()(b));
+
+ Any c = Bytes("a1", 2);
+ Any d = Bytes(std::string("a1"));
+ EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFISmallBytes);
+ EXPECT_EQ(d.type_index(), TypeIndex::kTVMFFIBytes);
+ EXPECT_TRUE(AnyEqual()(c, d));
+ EXPECT_EQ(AnyHash()(c), AnyHash()(d));
}
} // namespace
diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc
index 54105d29ea..66d409b453 100644
--- a/ffi/tests/cpp/test_string.cc
+++ b/ffi/tests/cpp/test_string.cc
@@ -318,6 +318,10 @@ TEST(String, Any) {
}
TEST(String, Bytes) {
+ Bytes b0;
+ EXPECT_EQ(b0.size(), 0);
+ EXPECT_EQ(b0.operator std::string(), "");
+
// explicitly test zero element
std::string s = {'\0', 'a', 'b', 'c'};
Bytes b = s;
@@ -339,10 +343,17 @@ TEST(String, BytesAny) {
EXPECT_EQ(view.try_cast<Bytes>().value().operator std::string(), s);
Any b = view;
- EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIBytes);
+ EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFISmallBytes);
EXPECT_EQ(b.try_cast<Bytes>().value().operator std::string(), s);
EXPECT_EQ(b.cast<std::string>(), s);
+
+ std::string s2 = "hello long long long string";
+ s2[0] = '\0';
+ Any b2 = Bytes(s2);
+ EXPECT_EQ(b2.type_index(), TypeIndex::kTVMFFIBytes);
+ EXPECT_EQ(b2.try_cast<std::string>().value(), s2);
+ EXPECT_EQ(b2.cast<std::string>(), s2);
}
TEST(String, StdString) {
diff --git a/jvm/native/src/main/native/jni_helper_func.h
b/jvm/native/src/main/native/jni_helper_func.h
index ab043028d3..5db3e279cf 100644
--- a/jvm/native/src/main/native/jni_helper_func.h
+++ b/jvm/native/src/main/native/jni_helper_func.h
@@ -224,12 +224,16 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) {
return newNDArray(env, reinterpret_cast<jlong>(value.v_obj), false);
}
case TypeIndex::kTVMFFISmallStr: {
- TVMFFIByteArray arr = TVMFFISmallStrGetContentByteArray(&value);
+ TVMFFIByteArray arr = TVMFFISmallBytesGetContentByteArray(&value);
return newTVMValueString(env, &arr);
}
case TypeIndex::kTVMFFIStr: {
return newTVMValueString(env, TVMFFIBytesGetByteArrayPtr(value.v_obj));
}
+ case TypeIndex::kTVMFFISmallBytes: {
+ TVMFFIByteArray arr = TVMFFISmallBytesGetContentByteArray(&value);
+ return newTVMValueBytes(env, &arr);
+ }
case TypeIndex::kTVMFFIBytes: {
jobject ret = newTVMValueBytes(env,
TVMFFIBytesGetByteArrayPtr(value.v_obj));
TVMFFIObjectFree(value.v_obj);
diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi
index 70db76207d..00b76e68f7 100644
--- a/python/tvm/ffi/cython/base.pxi
+++ b/python/tvm/ffi/cython/base.pxi
@@ -41,6 +41,7 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIByteArrayPtr = 9
kTVMFFIObjectRValueRef = 10
kTVMFFISmallStr = 11
+ kTVMFFISmallBytes = 12
kTVMFFIStaticObjectBegin = 64
kTVMFFIObject = 64
kTVMFFIStr = 65
@@ -197,7 +198,7 @@ cdef extern from "tvm/ffi/c_api.h":
int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle src,
DLManagedTensorVersioned** out) nogil
const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil
- TVMFFIByteArray TVMFFISmallStrGetContentByteArray(const TVMFFIAny* value)
nogil
+ TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny*
value) nogil
TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil
TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil
TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil
diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi
index b98c3e1107..279b17f8c8 100644
--- a/python/tvm/ffi/cython/dtype.pxi
+++ b/python/tvm/ffi/cython/dtype.pxi
@@ -98,7 +98,7 @@ cdef class DataType:
CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &temp_any))
if temp_any.type_index == kTVMFFISmallStr:
- bytes = TVMFFISmallStrGetContentByteArray(&temp_any)
+ bytes = TVMFFISmallBytesGetContentByteArray(&temp_any)
res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size))
return res
diff --git a/python/tvm/ffi/cython/function.pxi
b/python/tvm/ffi/cython/function.pxi
index e8e6987dd9..cbff3fecf1 100644
--- a/python/tvm/ffi/cython/function.pxi
+++ b/python/tvm/ffi/cython/function.pxi
@@ -26,10 +26,17 @@ except ImportError:
cdef inline object make_ret_small_str(TVMFFIAny result):
"""convert small string to return value."""
cdef TVMFFIByteArray bytes
- bytes = TVMFFISmallStrGetContentByteArray(&result)
+ bytes = TVMFFISmallBytesGetContentByteArray(&result)
return py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size))
+cdef inline object make_ret_small_bytes(TVMFFIAny result):
+ """convert small bytes to return value."""
+ cdef TVMFFIByteArray bytes
+ bytes = TVMFFISmallBytesGetContentByteArray(&result)
+ return PyBytes_FromStringAndSize(bytes.data, bytes.size)
+
+
cdef inline object make_ret(TVMFFIAny result):
"""convert result to return value."""
# TODO: Implement
@@ -50,6 +57,8 @@ cdef inline object make_ret(TVMFFIAny result):
return result.v_float64
elif type_index == kTVMFFISmallStr:
return make_ret_small_str(result)
+ elif type_index == kTVMFFISmallBytes:
+ return make_ret_small_bytes(result)
elif type_index == kTVMFFIOpaquePtr:
return ctypes_handle(result.v_ptr)
elif type_index == kTVMFFIDataType:
diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc
index ec62a3c95a..d3b62b5e87 100644
--- a/src/node/repr_printer.cc
+++ b/src/node/repr_printer.cc
@@ -84,6 +84,7 @@ void ReprPrinter::Print(const ffi::Any& node) {
stream << '"' << support::StrEscape(str.data(), str.size()) << '"';
break;
}
+ case ffi::TypeIndex::kTVMFFISmallBytes:
case ffi::TypeIndex::kTVMFFIBytes: {
ffi::Bytes bytes = node.cast<ffi::Bytes>();
stream << "b\"" << support::StrEscape(bytes.data(), bytes.size()) << '"';
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index 3d0175bcfa..1ded51e59c 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -109,7 +109,8 @@ class NodeIndexer {
}
} else if (node.type_index() == ffi::TypeIndex::kTVMFFIStr ||
node.type_index() == ffi::TypeIndex::kTVMFFISmallStr ||
- node.type_index() == ffi::TypeIndex::kTVMFFIBytes) {
+ node.type_index() == ffi::TypeIndex::kTVMFFIBytes ||
+ node.type_index() == ffi::TypeIndex::kTVMFFISmallBytes) {
// skip content index for string and bytes
} else if (auto opt_object = node.as<const Object*>()) {
Object* n = const_cast<Object*>(opt_object.value());
@@ -398,7 +399,8 @@ class FieldDependencyFinder {
}
if (node.type_index() == ffi::TypeIndex::kTVMFFIStr ||
node.type_index() == ffi::TypeIndex::kTVMFFISmallStr ||
- node.type_index() == ffi::TypeIndex::kTVMFFIBytes) {
+ node.type_index() == ffi::TypeIndex::kTVMFFIBytes ||
+ node.type_index() == ffi::TypeIndex::kTVMFFISmallBytes) {
// skip indexing content of string and bytes
return;
}
@@ -561,7 +563,8 @@ class JSONAttrSetter {
} else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr ||
jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) {
return Any(String(jnode->repr_bytes));
- } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) {
+ } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes ||
+ jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallBytes) {
return Any(Bytes(jnode->repr_bytes));
} else {
return ObjectRef(reflection->CreateInitObject(jnode->type_key,
jnode->repr_bytes));
@@ -594,7 +597,8 @@ class JSONAttrSetter {
*node = result;
} else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr ||
jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr ||
- jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) {
+ jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes ||
+ jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallBytes) {
// skip set attrs for string and bytes
} else if (auto opt_object = node->as<const Object*>()) {
Object* n = const_cast<Object*>(opt_object.value());
diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc
index 42450e3322..a693c671f3 100644
--- a/src/runtime/rpc/rpc_module.cc
+++ b/src/runtime/rpc/rpc_module.cc
@@ -96,7 +96,7 @@ class RPCWrappedFunc : public Object {
} else if (args[i].type_index() == ffi::TypeIndex::kTVMFFISmallStr) {
// we cannot cast here, since we need to make sure the space is alive
const TVMFFIAny* any_view_ptr = reinterpret_cast<const
TVMFFIAny*>(&args.data()[i]);
- TVMFFIByteArray bytes =
TVMFFISmallStrGetContentByteArray(any_view_ptr);
+ TVMFFIByteArray bytes =
TVMFFISmallBytesGetContentByteArray(any_view_ptr);
packed_args[i] = bytes.data;
continue;
}
@@ -322,7 +322,8 @@ void
RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv)
nd_handle);
} else if (type_index == ffi::TypeIndex::kTVMFFIBytes ||
type_index == ffi::TypeIndex::kTVMFFIStr ||
- type_index == ffi::TypeIndex::kTVMFFISmallStr) {
+ type_index == ffi::TypeIndex::kTVMFFISmallStr ||
+ type_index == ffi::TypeIndex::kTVMFFISmallBytes) {
ICHECK_EQ(args.size(), 2);
*rv = args[1];
} else if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 78bccb829c..33a687f54b 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -520,8 +520,7 @@ AllocateConstFrame AllocateConst(tvm::runtime::NDArray
data, DataType dtype,
AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) {
// convert POD value to PrimExpr
- if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin &&
- node.type_index() != ffi::TypeIndex::kTVMFFISmallStr) {
+ if (node.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
node = node.cast<PrimExpr>();
}
ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 5a2b95844b..56fab07605 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -100,8 +100,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](Any node, String attr_key, PrimExpr value, Stmt
body, Span span) {
// when node is a POD data type like int or bool,
first convert to
// primexpr.
- if (node.type_index() <
ffi::TypeIndex::kTVMFFIStaticObjectBegin &&
- node.type_index() !=
ffi::TypeIndex::kTVMFFISmallStr) {
+ if (node.type_index() <
ffi::TypeIndex::kTVMFFISmallStr) {
return AttrStmt(node.cast<PrimExpr>(), attr_key,
value, body, span);
}
return AttrStmt(node, attr_key, value, body, span);
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index db175c77f2..6f7e682d6c 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -916,8 +916,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const
ffi::Any& ann_val) {
if (auto opt_str = ann_val.try_cast<ffi::String>()) {
return *std::move(opt_str);
}
- if (ann_val.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin &&
- ann_val.type_index() != ffi::TypeIndex::kTVMFFISmallStr) {
+ if (ann_val.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
return ann_val;
}
// prefer to return int/float literals for annotations
diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc
index fdc0dd41c4..2f327354c9 100644
--- a/src/tir/schedule/instruction.cc
+++ b/src/tir/schedule/instruction.cc
@@ -74,8 +74,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
inputs.push_back(String('"' + (*opt_str).operator std::string() +
'"'));
} else if (obj.as<BlockRVNode>() || obj.as<LoopRVNode>()) {
inputs.push_back(String("_"));
- } else if (obj.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin
&&
- obj.type_index() != ffi::TypeIndex::kTVMFFISmallStr) {
+ } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
inputs.push_back(obj);
} else if (obj.as<IntImmNode>() || obj.as<FloatImmNode>()) {
inputs.push_back(obj);
diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc
index b1fb7881a6..61f24f980f 100644
--- a/src/tir/schedule/trace.cc
+++ b/src/tir/schedule/trace.cc
@@ -71,8 +71,7 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs,
};
for (const Any& input : inputs) {
- if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin &&
- input.type_index() != ffi::TypeIndex::kTVMFFISmallStr) {
+ if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
// directly put back POD type
result.push_back(input);
} else if (auto expr = input.as<ffi::String>()) {
@@ -114,8 +113,7 @@ Array<Any> TranslateInputRVs(
// string => "content"
if (auto opt_str = input.as<ffi::String>()) {
results.push_back(String('"' + (*opt_str).operator std::string() + '"'));
- } else if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin &&
- input.type_index() != ffi::TypeIndex::kTVMFFISmallStr) {
+ } else if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
// directly put back POD type and not string
results.push_back(input);
} else if (input.as<BlockRVNode>() || // RV: block
@@ -161,8 +159,7 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs,
Array<Any> results;
results.reserve(inputs.size());
for (const Any& input : inputs) {
- if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin &&
- input.type_index() != ffi::TypeIndex::kTVMFFISmallStr) {
+ if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
// directly put back POD type
results.push_back(input);
continue;
diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts
index 50e82e445c..41d848a228 100644
--- a/web/src/ctypes.ts
+++ b/web/src/ctypes.ts
@@ -74,6 +74,8 @@ export const enum TypeIndex {
kTVMFFIObjectRValueRef = 10,
/*! \brief Small string on stack */
kTVMFFISmallStr = 11,
+ /*! \brief Small bytes on stack */
+ kTVMFFISmallBytes = 12,
/*! \brief Start of statically defined objects. */
kTVMFFIStaticObjectBegin = 64,
/*!