This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3c189f015c [FFI][REFACTOR] Hide StringObj/BytesObj into details
(#18184)
3c189f015c is described below
commit 3c189f015c2b5835db589c8f8686da17cab02bba
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Aug 1 21:42:55 2025 -0400
[FFI][REFACTOR] Hide StringObj/BytesObj into details (#18184)
This PR hides StringObj/BytesObj into details and bring
implementations to directly focus on the String/Bytes.
This change will prepare us for future changes such as SmallStr support.
Also moves more ObjectRef into Any in RPC.
---
ffi/include/tvm/ffi/any.h | 12 ++--
ffi/include/tvm/ffi/string.h | 26 ++++----
ffi/src/ffi/extra/structural_equal.cc | 8 +--
ffi/src/ffi/extra/structural_hash.cc | 8 +--
include/tvm/ir/transform.h | 46 +++++++++-----
include/tvm/runtime/profiling.h | 2 +-
include/tvm/target/target_kind.h | 22 +++++--
python/tvm/exec/disco_worker.py | 4 +-
python/tvm/ffi/cython/function.pxi | 2 +-
python/tvm/ffi/cython/string.pxi | 4 +-
python/tvm/meta_schedule/utils.py | 4 +-
python/tvm/relax/utils.py | 3 +-
python/tvm/tir/schedule/trace.py | 3 +-
src/contrib/msc/core/printer/msc_base_printer.cc | 4 +-
src/contrib/msc/core/printer/prototxt_printer.cc | 4 +-
src/ir/transform.cc | 14 ++---
src/meta_schedule/database/database_utils.cc | 6 +-
src/node/repr_printer.cc | 10 +++
src/node/serialization.cc | 10 ++-
src/node/structural_hash.cc | 12 ----
src/runtime/disco/protocol.h | 71 +++++++++++-----------
src/runtime/minrpc/rpc_reference.h | 10 +--
src/runtime/profiling.cc | 14 ++---
src/runtime/rpc/rpc_endpoint.cc | 19 +++---
src/runtime/rpc/rpc_local_session.cc | 11 ++--
src/runtime/rpc/rpc_module.cc | 5 +-
.../printer/doc_printer/python_doc_printer.cc | 4 +-
src/script/printer/ir/misc.cc | 2 +-
src/support/ffi_testing.cc | 2 +-
src/target/target.cc | 8 +--
src/tir/schedule/instruction.cc | 4 +-
src/tir/schedule/instruction_traits.h | 4 +-
src/tir/schedule/trace.cc | 22 +++----
tests/python/contrib/test_popen_pool.py | 1 -
tests/python/disco/test_session.py | 2 +-
tests/python/ffi/test_function.py | 2 -
tests/python/ffi/test_string.py | 11 ++--
37 files changed, 207 insertions(+), 189 deletions(-)
diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index c570a61e4b..d94185c064 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -546,8 +546,8 @@ struct AnyHash {
uint64_t val_hash = [&]() -> uint64_t {
if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
- const BytesObjBase* src_str =
- details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
BytesObjBase*>(src);
+ const details::BytesObjBase* src_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(src);
return details::StableHashBytes(src_str->data, src_str->size);
} else {
return src.data_.v_uint64;
@@ -572,10 +572,10 @@ struct AnyEqual {
// specialy handle string hash
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
- const BytesObjBase* lhs_str =
- details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
BytesObjBase*>(lhs);
- const BytesObjBase* rhs_str =
- details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
BytesObjBase*>(rhs);
+ const details::BytesObjBase* lhs_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
+ const details::BytesObjBase* rhs_str =
+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size,
rhs_str->size);
}
return false;
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index e77b27b268..481b704436 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -46,7 +46,7 @@
namespace tvm {
namespace ffi {
-
+namespace details {
/*! \brief Base class for bytes and string. */
class BytesObjBase : public Object, public TVMFFIByteArray {};
@@ -73,8 +73,6 @@ class StringObj : public BytesObjBase {
TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object);
};
-namespace details {
-
// String moved from std::string
// without having to trigger a copy
template <typename Base>
@@ -115,21 +113,21 @@ class Bytes : public ObjectRef {
* \param other a char array.
*/
Bytes(const char* data, size_t size) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<BytesObj>(data, size)) {}
+ : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(data, size)) {}
/*!
* \brief constructor from char [N]
*
* \param other a char array.
*/
Bytes(TVMFFIByteArray bytes) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<BytesObj>(bytes.data, bytes.size))
{}
+ : ObjectRef(details::MakeInplaceBytes<details::BytesObj>(bytes.data,
bytes.size)) {}
/*!
* \brief constructor from char [N]
*
* \param other a char array.
*/
Bytes(std::string other) // NOLINT(*)
- :
ObjectRef(make_object<details::BytesObjStdImpl<BytesObj>>(std::move(other))) {}
+ :
ObjectRef(make_object<details::BytesObjStdImpl<details::BytesObj>>(std::move(other)))
{}
/*!
* \brief Swap this String with another string
* \param other The other string
@@ -163,7 +161,7 @@ class Bytes : public ObjectRef {
*/
operator std::string() const { return std::string{get()->data, size()}; }
- TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef, BytesObj);
+ TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef,
details::BytesObj);
/*!
* \brief Compare two char sequence
@@ -245,7 +243,7 @@ class String : public ObjectRef {
*/
template <size_t N>
String(const char other[N]) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<StringObj>(other, N)) {}
+ : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, N)) {}
/*!
* \brief constructor
@@ -258,7 +256,7 @@ class String : public ObjectRef {
* \param other a char array.
*/
String(const char* other) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<StringObj>(other,
std::strlen(other))) {}
+ : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other,
std::strlen(other))) {}
/*!
* \brief constructor from raw string
@@ -266,21 +264,21 @@ class String : public ObjectRef {
* \param other a char array.
*/
String(const char* other, size_t size) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<StringObj>(other, size)) {}
+ : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, size))
{}
/*!
* \brief Construct a new string object
* \param other The std::string object to be copied
*/
String(const std::string& other) // NOLINT(*)
- : ObjectRef(details::MakeInplaceBytes<StringObj>(other.data(),
other.size())) {}
+ : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data(),
other.size())) {}
/*!
* \brief Construct a new string object
* \param other The std::string object to be moved
*/
String(std::string&& other) // NOLINT(*)
- :
ObjectRef(make_object<details::BytesObjStdImpl<StringObj>>(std::move(other))) {}
+ :
ObjectRef(make_object<details::BytesObjStdImpl<details::StringObj>>(std::move(other)))
{}
/*!
* \brief constructor from TVMFFIByteArray
@@ -288,7 +286,7 @@ class String : public ObjectRef {
* \param other a TVMFFIByteArray.
*/
explicit String(TVMFFIByteArray other)
- : ObjectRef(details::MakeInplaceBytes<StringObj>(other.data,
other.size)) {}
+ : ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data,
other.size)) {}
/*!
* \brief Swap this String with another string
@@ -423,7 +421,7 @@ class String : public ObjectRef {
*/
operator std::string() const { return std::string{get()->data, size()}; }
- TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
+ TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef,
details::StringObj);
private:
/*!
diff --git a/ffi/src/ffi/extra/structural_equal.cc
b/ffi/src/ffi/extra/structural_equal.cc
index a73c07713f..3d70e525d9 100644
--- a/ffi/src/ffi/extra/structural_equal.cc
+++ b/ffi/src/ffi/extra/structural_equal.cc
@@ -62,10 +62,10 @@ class StructEqualHandler {
case TypeIndex::kTVMFFIStr:
case TypeIndex::kTVMFFIBytes: {
// compare bytes
- const BytesObjBase* lhs_str =
- AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(lhs);
- const BytesObjBase* rhs_str =
- AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(rhs);
+ const details::BytesObjBase* lhs_str =
+ AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(lhs);
+ const details::BytesObjBase* rhs_str =
+ AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size,
rhs_str->size) == 0;
}
case TypeIndex::kTVMFFIArray: {
diff --git a/ffi/src/ffi/extra/structural_hash.cc
b/ffi/src/ffi/extra/structural_hash.cc
index e47fbbacc8..1d90c5a62d 100644
--- a/ffi/src/ffi/extra/structural_hash.cc
+++ b/ffi/src/ffi/extra/structural_hash.cc
@@ -64,8 +64,8 @@ class StructuralHashHandler {
case TypeIndex::kTVMFFIStr:
case TypeIndex::kTVMFFIBytes: {
// return same hash as AnyHash
- const BytesObjBase* src_str =
- AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(src);
+ const details::BytesObjBase* src_str =
+ AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(src);
return details::StableHashCombine(src_data->type_index,
details::StableHashBytes(src_str->data, src_str->size));
}
@@ -196,8 +196,8 @@ class StructuralHashHandler {
} else {
if (src_data->type_index == TypeIndex::kTVMFFIStr ||
src_data->type_index == TypeIndex::kTVMFFIBytes) {
- const BytesObjBase* src_str =
- AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(src);
+ const details::BytesObjBase* src_str =
+ AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(src);
// return same hash as AnyHash
return details::StableHashCombine(src_data->type_index,
details::StableHashBytes(src_str->data, src_str->size));
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 4f9004fba5..425cec1425 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -242,26 +242,40 @@ class PassContext : public ObjectRef {
template <typename ValueType>
static int32_t RegisterConfigOption(const char* key) {
// NOTE: we could further update the function later.
- int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
- auto* reflection = ReflectionVTable::Global();
- auto type_key = ffi::TypeIndexToTypeKey(tindex);
-
- auto legalization = [=](ffi::Any value) -> ffi::Any {
- if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
- return reflection->CreateObject(type_key, opt_map.value());
- } else {
+ if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
+ int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
+ auto* reflection = ReflectionVTable::Global();
+ auto type_key = ffi::TypeIndexToTypeKey(tindex);
+ auto legalization = [=](ffi::Any value) -> ffi::Any {
+ if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
+ return reflection->CreateObject(type_key, opt_map.value());
+ } else {
+ auto opt_val = value.try_cast<ValueType>();
+ if (!opt_val.has_value()) {
+ TVM_FFI_THROW(AttributeError)
+ << "Expect config " << key << " to have type " << type_key <<
", but instead get "
+ <<
ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
+ }
+ return *opt_val;
+ }
+ };
+ RegisterConfigOption(key, type_key, legalization);
+ } else {
+ // non-object type, do not support implicit conversion from map
+ std::string type_str = ffi::TypeTraits<ValueType>::TypeStr();
+ auto legalization = [=](ffi::Any value) -> ffi::Any {
auto opt_val = value.try_cast<ValueType>();
if (!opt_val.has_value()) {
TVM_FFI_THROW(AttributeError)
- << "Expect config " << key << " to have type " << type_key << ",
but instead get "
+ << "Expect config " << key << " to have type " << type_str << ",
but instead get "
<<
ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
+ } else {
+ return *opt_val;
}
- return value;
- }
- };
-
- RegisterConfigOption(key, tindex, legalization);
- return tindex;
+ };
+ RegisterConfigOption(key, type_str, legalization);
+ }
+ return 0;
}
// accessor.
@@ -274,7 +288,7 @@ class PassContext : public ObjectRef {
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Register configuration key value type.
- TVM_DLL static void RegisterConfigOption(const char* key, uint32_t
value_type_index,
+ TVM_DLL static void RegisterConfigOption(const char* key, String
value_type_str,
std::function<ffi::Any(ffi::Any)>
legalization);
// Classes to get the Python `with` like syntax.
diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h
index 66c0b64f18..78f79475b8 100644
--- a/include/tvm/runtime/profiling.h
+++ b/include/tvm/runtime/profiling.h
@@ -315,7 +315,7 @@ class MetricCollectorNode : public Object {
/*! \brief Stop collecting metrics.
* \param obj The object created by the corresponding `Start` call.
* \returns A set of metric names and the associated values. Values must be
- * one of DurationNode, PercentNode, CountNode, or StringObj.
+ * one of DurationNode, PercentNode, CountNode, or String.
*/
virtual Map<String, ffi::Any> Stop(ffi::ObjectRef obj) = 0;
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index 15b5f62cd5..d89148964b 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -287,13 +287,27 @@ struct ValueTypeInfoMaker<ValueType, std::false_type,
std::false_type> {
using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
ValueTypeInfo operator()() const {
- int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
ValueTypeInfo info;
- info.type_index = tindex;
- info.type_key = runtime::Object::TypeIndex2Key(tindex);
info.key = nullptr;
info.val = nullptr;
- return info;
+ if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
+ int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
+ info.type_index = tindex;
+ info.type_key = runtime::Object::TypeIndex2Key(tindex);
+ return info;
+ } else if constexpr (std::is_same_v<ValueType, String>) {
+ // special handle string since it can be backed by multiple types.
+ info.type_index = ffi::TypeIndex::kTVMFFIStr;
+ info.type_key = ffi::TypeTraits<ValueType>::TypeStr();
+ return info;
+ } else {
+ // TODO(tqchen) consider upgrade to leverage any system to support union
type
+ constexpr int32_t tindex =
ffi::TypeToFieldStaticTypeIndex<ValueType>::value;
+ static_assert(tindex != ffi::TypeIndex::kTVMFFIAny, "Do not support
union type for now");
+ info.type_index = tindex;
+ info.type_key = runtime::Object::TypeIndex2Key(tindex);
+ return info;
+ }
}
};
diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py
index ecfeaa1ebb..6d1d4b7f33 100644
--- a/python/tvm/exec/disco_worker.py
+++ b/python/tvm/exec/disco_worker.py
@@ -48,8 +48,8 @@ def _str_func(x: str):
@register_func("tests.disco.str_obj", override=True)
-def _str_obj_func(x: String):
- assert isinstance(x, String)
+def _str_obj_func(x: str):
+ assert isinstance(x, str)
return String(x + "_suffix")
diff --git a/python/tvm/ffi/cython/function.pxi
b/python/tvm/ffi/cython/function.pxi
index 8591918db6..d86d004d10 100644
--- a/python/tvm/ffi/cython/function.pxi
+++ b/python/tvm/ffi/cython/function.pxi
@@ -87,7 +87,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list
temp_args) except
out[i].type_index = kTVMFFINDArray
out[i].v_ptr = (<NDArray>arg).chandle
temp_args.append(arg)
- elif isinstance(arg, PyNativeObject):
+ elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not
None:
arg = arg.__tvm_ffi_object__
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
out[i].v_ptr = (<Object>arg).chandle
diff --git a/python/tvm/ffi/cython/string.pxi b/python/tvm/ffi/cython/string.pxi
index 512aa7bace..4ab5c48ce0 100644
--- a/python/tvm/ffi/cython/string.pxi
+++ b/python/tvm/ffi/cython/string.pxi
@@ -40,7 +40,7 @@ class String(str, PyNativeObject):
"""
def __new__(cls, value):
val = str.__new__(cls, value)
- val.__init_tvm_ffi_object_by_constructor__(_STR_CONSTRUCTOR, value)
+ val.__tvm_ffi_object__ = None
return val
# pylint: disable=no-self-argument
@@ -65,7 +65,7 @@ class Bytes(bytes, PyNativeObject):
"""
def __new__(cls, value):
val = bytes.__new__(cls, value)
- val.__init_tvm_ffi_object_by_constructor__(_BYTES_CONSTRUCTOR, value)
+ val.__tvm_ffi_object__ = None
return val
# pylint: disable=no-self-argument
diff --git a/python/tvm/meta_schedule/utils.py
b/python/tvm/meta_schedule/utils.py
index 61b32e1e32..2f18f54a81 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -26,7 +26,7 @@ from tvm.ffi import get_global_func, register_func
from tvm.error import TVMError
from tvm.ir import Array, IRModule, Map
from tvm.rpc import RPCSession
-from tvm.runtime import PackedFunc, String
+from tvm.runtime import PackedFunc
from tvm.tir import FloatImm, IntImm
@@ -352,7 +352,7 @@ def _json_de_tvm(obj: Any) -> Any:
return obj
if isinstance(obj, (IntImm, FloatImm)):
return obj.value
- if isinstance(obj, (str, String)):
+ if isinstance(obj, (str,)):
return str(obj)
if isinstance(obj, Array):
return [_json_de_tvm(i) for i in obj]
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index 9795631fbe..192235d595 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -28,7 +28,6 @@ from typing import Any, Callable, List, Dict, Optional
import tvm
from .. import tir
from ..tir import PrimExpr
-from ..runtime import String
from . import _ffi_api
from .expr import Tuple as rx_Tuple
from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
@@ -114,7 +113,7 @@ def convert_to_expr(value: Any) -> Expr:
if isinstance(tvm_value, PrimExpr):
return PrimValue(value)
# Case 3
- if isinstance(tvm_value, (str, String)):
+ if isinstance(tvm_value, (str,)):
return StringImm(value)
# Case 4
if isinstance(value, (tuple, list)):
diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py
index 15bb201ae6..da3508a42e 100644
--- a/python/tvm/tir/schedule/trace.py
+++ b/python/tvm/tir/schedule/trace.py
@@ -22,7 +22,6 @@ from tvm.ffi import register_object as _register_object
from tvm.runtime import Object
from ...ir import Array, Map, save_json
-from ...runtime import String
from ..expr import FloatImm, IntImm
from ..function import IndexMap
from . import _ffi_api
@@ -45,7 +44,7 @@ def _json_from_tvm(obj):
return [_json_from_tvm(i) for i in obj]
elif isinstance(obj, Map):
return {_json_from_tvm(k): _json_from_tvm(v) for k, v in obj.items()}
- elif isinstance(obj, String):
+ elif isinstance(obj, str):
return str(obj)
elif isinstance(obj, (IntImm, FloatImm)):
return obj
diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc
b/src/contrib/msc/core/printer/msc_base_printer.cc
index 31869f29bb..644692aa6b 100644
--- a/src/contrib/msc/core/printer/msc_base_printer.cc
+++ b/src/contrib/msc/core/printer/msc_base_printer.cc
@@ -113,8 +113,8 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) {
} else {
output_ << float_imm->value;
}
- } else if (const auto* string_obj = value.as<ffi::StringObj>()) {
- output_ << "\"" << tvm::support::StrEscape(string_obj->data,
string_obj->size) << "\"";
+ } else if (auto opt_str = value.as<ffi::String>()) {
+ output_ << "\"" << tvm::support::StrEscape((*opt_str).data(),
(*opt_str).size()) << "\"";
} else {
LOG(FATAL) << "TypeError: Unsupported literal value type: " <<
value.GetTypeKey();
}
diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc
b/src/contrib/msc/core/printer/prototxt_printer.cc
index 82d15dc718..d62e5ac2a8 100644
--- a/src/contrib/msc/core/printer/prototxt_printer.cc
+++ b/src/contrib/msc/core/printer/prototxt_printer.cc
@@ -31,8 +31,8 @@ namespace contrib {
namespace msc {
LiteralDoc PrototxtPrinter::ToLiteralDoc(const ffi::Any& obj) {
- if (obj.as<ffi::StringObj>()) {
- return LiteralDoc::Str(Downcast<String>(obj), std::nullopt);
+ if (auto opt_str = obj.as<ffi::String>()) {
+ return LiteralDoc::Str(*opt_str, std::nullopt);
} else if (obj.as<IntImmNode>()) {
return LiteralDoc::Int(Downcast<IntImm>(obj)->value, std::nullopt);
} else if (obj.as<FloatImmNode>()) {
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index d069bd7fee..6fb809e25a 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -107,12 +107,11 @@ bool PassContext::PassEnabled(const PassInfo& info) const
{
class PassConfigManager {
public:
- void Register(std::string key, uint32_t value_type_index,
+ void Register(std::string key, String value_type_str,
std::function<ffi::Any(ffi::Any)> legalization) {
ICHECK_EQ(key2vtype_.count(key), 0U);
ValueTypeInfo info;
- info.type_index = value_type_index;
- info.type_key = runtime::Object::TypeIndex2Key(value_type_index);
+ info.type_str = value_type_str;
info.legalization = legalization;
key2vtype_[key] = info;
}
@@ -154,7 +153,7 @@ class PassConfigManager {
Map<String, Map<String, String>> configs;
for (const auto& kv : key2vtype_) {
Map<String, String> metadata;
- metadata.Set("type", kv.second.type_key);
+ metadata.Set("type", kv.second.type_str);
configs.Set(kv.first, metadata);
}
return configs;
@@ -167,17 +166,16 @@ class PassConfigManager {
private:
struct ValueTypeInfo {
- std::string type_key;
- uint32_t type_index;
+ std::string type_str;
std::function<ffi::Any(ffi::Any)> legalization;
};
std::unordered_map<std::string, ValueTypeInfo> key2vtype_;
};
-void PassContext::RegisterConfigOption(const char* key, uint32_t
value_type_index,
+void PassContext::RegisterConfigOption(const char* key, String value_type_str,
std::function<ffi::Any(ffi::Any)>
legalization) {
- PassConfigManager::Global()->Register(key, value_type_index, legalization);
+ PassConfigManager::Global()->Register(key, value_type_str, legalization);
}
Map<String, Map<String, String>> PassContext::ListConfigs() {
diff --git a/src/meta_schedule/database/database_utils.cc
b/src/meta_schedule/database/database_utils.cc
index 230e4d3509..fd24072aae 100644
--- a/src/meta_schedule/database/database_utils.cc
+++ b/src/meta_schedule/database/database_utils.cc
@@ -43,8 +43,8 @@ void JSONDumps(Any json_obj, std::ostringstream& os) {
} else if (auto opt_float_imm = json_obj.try_cast<FloatImm>()) {
FloatImm float_imm = *std::move(opt_float_imm);
os << std::setprecision(20) << float_imm->value;
- } else if (const auto* str = json_obj.as<ffi::StringObj>()) {
- os << '"' << support::StrEscape(str->data, str->size) << '"';
+ } else if (auto opt_str = json_obj.as<ffi::String>()) {
+ os << '"' << support::StrEscape((*opt_str).data(), (*opt_str).size()) <<
'"';
} else if (const auto* array = json_obj.as<ffi::ArrayObj>()) {
os << "[";
int n = array->size();
@@ -371,7 +371,7 @@ class JSONParser {
}
// Case 3
Any key = ParseObject(std::move(token));
- ICHECK(key.as<ffi::StringObj>()) << "ValueError: key must be a string,
but gets: " << key;
+ ICHECK(key.as<ffi::String>()) << "ValueError: key must be a string,
but gets: " << key;
token = tokenizer_.Next();
CHECK(token.type == TokenType::kColon)
<< "ValueError: Unexpected token before: " << tokenizer_.cur_;
diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc
index 34b08994d0..240b4f1758 100644
--- a/src/node/repr_printer.cc
+++ b/src/node/repr_printer.cc
@@ -78,6 +78,16 @@ void ReprPrinter::Print(const ffi::Any& node) {
Print(node.cast<ObjectRef>());
break;
}
+ case ffi::TypeIndex::kTVMFFIStr: {
+ ffi::String str = node.cast<ffi::String>();
+ stream << '"' << support::StrEscape(str.data(), str.size()) << '"';
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIBytes: {
+ ffi::Bytes bytes = node.cast<ffi::Bytes>();
+ stream << "b\"" << support::StrEscape(bytes.data(), bytes.size()) << '"';
+ break;
+ }
default: {
if (auto opt_obj = node.as<ObjectRef>()) {
Print(opt_obj.value());
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index c3060fc91f..65b9728317 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -95,9 +95,8 @@ class NodeIndexer {
}
} else if (auto opt_map = node.as<const ffi::MapObj*>()) {
const ffi::MapObj* n = opt_map.value();
- bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) {
- return v.first.template as<const ffi::StringObj*>();
- });
+ bool is_str_map = std::all_of(
+ n->begin(), n->end(), [](const auto& v) { return v.first.template
as<ffi::String>(); });
if (is_str_map) {
for (const auto& kv : *n) {
MakeIndex(kv.second);
@@ -261,9 +260,8 @@ class JSONAttrGetter {
}
} else if (auto opt_map = node.as<const ffi::MapObj*>()) {
const ffi::MapObj* n = opt_map.value();
- bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) {
- return v.first.template as<const ffi::StringObj*>();
- });
+ bool is_str_map = std::all_of(
+ n->begin(), n->end(), [](const auto& v) { return v.first.template
as<ffi::String>(); });
if (is_str_map) {
for (const auto& kv : *n) {
node_->keys.push_back(kv.first.cast<String>());
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index bf9d7b23d5..4be01004cb 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -58,18 +58,6 @@ struct RefToObjectPtr : public ObjectRef {
}
};
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ffi::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ffi::StringObj*>(node.get());
- p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ffi::BytesObj>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ffi::BytesObj*>(node.get());
- p->stream << "b\"" << support::StrEscape(op->data, op->size) << '"';
- });
-
TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode)
.set_creator([](const std::string& blob) {
runtime::Module rtmod = codegen::DeserializeModuleFromBytes(blob);
diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h
index a6af311e6e..ee6d5bf32c 100644
--- a/src/runtime/disco/protocol.h
+++ b/src/runtime/disco/protocol.h
@@ -54,13 +54,13 @@ struct DiscoProtocol {
}
/*! \brief Get the length of the object being serialized. Used by
RPCReference. */
- inline uint64_t GetObjectBytes(Object* obj);
+ inline uint64_t GetFFIAnyProtocolBytes(const TVMFFIAny* obj);
/*! \brief Write the object to stream. Used by RPCReference. */
- inline void WriteObject(Object* obj);
+ inline void WriteFFIAny(const TVMFFIAny* obj);
/*! \brief Read the object from stream. Used by RPCReference. */
- inline void ReadObject(TVMFFIAny* out);
+ inline void ReadFFIAny(TVMFFIAny* out);
/*! \brief Callback method used when starting a new message. Used by
RPCReference. */
void MessageStart(uint64_t packet_nbytes) {}
@@ -113,67 +113,70 @@ struct DiscoDebugObject : public Object {
/*! \brief Deserialize the debug object from string */
static inline ObjectPtr<DiscoDebugObject> LoadFromStr(std::string json_str);
/*! \brief Get the size of the debug object in bytes */
- inline uint64_t GetObjectBytes() const { return sizeof(uint64_t) +
this->SaveToStr().size(); }
+ inline uint64_t GetFFIAnyProtocolBytes() const {
+ return sizeof(uint64_t) + this->SaveToStr().size();
+ }
static constexpr const char* _type_key = "runtime.disco.DiscoDebugObject";
TVM_DECLARE_FINAL_OBJECT_INFO(DiscoDebugObject, SessionObj);
};
template <class SubClassType>
-inline uint64_t DiscoProtocol<SubClassType>::GetObjectBytes(Object* obj) {
- if (obj->IsInstance<DRefObj>()) {
+inline uint64_t DiscoProtocol<SubClassType>::GetFFIAnyProtocolBytes(const
TVMFFIAny* value) {
+ const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(value);
+ if (any_view_ptr->as<DRefObj>()) {
return sizeof(uint32_t) + sizeof(int64_t);
- } else if (obj->IsInstance<ffi::StringObj>()) {
- uint64_t size = static_cast<ffi::StringObj*>(obj)->size;
+ } else if (const auto opt_str = any_view_ptr->as<ffi::String>()) {
+ uint64_t size = (*opt_str).size();
return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char);
- } else if (obj->IsInstance<ffi::BytesObj>()) {
- uint64_t size = static_cast<ffi::BytesObj*>(obj)->size;
+ } else if (const auto opt_bytes = any_view_ptr->as<ffi::Bytes>()) {
+ uint64_t size = (*opt_bytes).size();
return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char);
- } else if (obj->IsInstance<ffi::ShapeObj>()) {
- uint64_t ndim = static_cast<ffi::ShapeObj*>(obj)->size;
+ } else if (const auto opt_shape = any_view_ptr->as<ffi::Shape>()) {
+ uint64_t ndim = (*opt_shape).size();
return sizeof(uint32_t) + sizeof(uint64_t) + ndim *
sizeof(ffi::ShapeObj::index_type);
- } else if (obj->IsInstance<DiscoDebugObject>()) {
- return sizeof(uint32_t) +
static_cast<DiscoDebugObject*>(obj)->GetObjectBytes();
+ } else if (const auto opt_debug_obj = any_view_ptr->as<DiscoDebugObject>()) {
+ return sizeof(uint32_t) + (*opt_debug_obj).GetFFIAnyProtocolBytes();
} else {
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
- << obj->GetTypeKey() << " (type_index = " << obj->type_index()
<< ")";
+ << any_view_ptr->GetTypeKey() << " (type_index = " <<
any_view_ptr->type_index()
+ << ")";
}
}
template <class SubClassType>
-inline void DiscoProtocol<SubClassType>::WriteObject(Object* obj) {
+inline void DiscoProtocol<SubClassType>::WriteFFIAny(const TVMFFIAny* value) {
SubClassType* self = static_cast<SubClassType*>(this);
- if (obj->IsInstance<DRefObj>()) {
- int64_t reg_id = static_cast<DRefObj*>(obj)->reg_id;
+ const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(value);
+ if (const auto* ref = any_view_ptr->as<DRefObj>()) {
+ int64_t reg_id = ref->reg_id;
self->template Write<uint32_t>(TypeIndex::kRuntimeDiscoDRef);
self->template Write<int64_t>(reg_id);
- } else if (obj->IsInstance<ffi::StringObj>()) {
- ffi::StringObj* str = static_cast<ffi::StringObj*>(obj);
+ } else if (const auto opt_str = any_view_ptr->as<ffi::String>()) {
self->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIStr);
- self->template Write<uint64_t>(str->size);
- self->template WriteArray<char>(str->data, str->size);
- } else if (obj->IsInstance<ffi::BytesObj>()) {
- ffi::BytesObj* bytes = static_cast<ffi::BytesObj*>(obj);
+ self->template Write<uint64_t>((*opt_str).size());
+ self->template WriteArray<char>((*opt_str).data(), (*opt_str).size());
+ } else if (const auto opt_bytes = any_view_ptr->as<ffi::Bytes>()) {
self->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIBytes);
- self->template Write<uint64_t>(bytes->size);
- self->template WriteArray<char>(bytes->data, bytes->size);
- } else if (obj->IsInstance<ffi::ShapeObj>()) {
- ffi::ShapeObj* shape = static_cast<ffi::ShapeObj*>(obj);
+ self->template Write<uint64_t>((*opt_bytes).size());
+ self->template WriteArray<char>((*opt_bytes).data(), (*opt_bytes).size());
+ } else if (const auto opt_shape = any_view_ptr->as<ffi::Shape>()) {
self->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIShape);
- self->template Write<uint64_t>(shape->size);
- self->template WriteArray<ffi::ShapeObj::index_type>(shape->data,
shape->size);
- } else if (obj->IsInstance<DiscoDebugObject>()) {
+ self->template Write<uint64_t>((*opt_shape).size());
+ self->template WriteArray<ffi::ShapeObj::index_type>((*opt_shape).data(),
(*opt_shape).size());
+ } else if (const auto opt_debug_obj = any_view_ptr->as<DiscoDebugObject>()) {
self->template Write<uint32_t>(0);
- std::string str = static_cast<DiscoDebugObject*>(obj)->SaveToStr();
+ std::string str = (*opt_debug_obj).SaveToStr();
self->template Write<uint64_t>(str.size());
self->template WriteArray<char>(str.data(), str.size());
} else {
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
- << obj->GetTypeKey() << " (type_index = " << obj->type_index()
<< ")";
+ << any_view_ptr->GetTypeKey() << " (type_index = " <<
any_view_ptr->type_index()
+ << ")";
}
}
template <class SubClassType>
-inline void DiscoProtocol<SubClassType>::ReadObject(TVMFFIAny* out) {
+inline void DiscoProtocol<SubClassType>::ReadFFIAny(TVMFFIAny* out) {
SubClassType* self = static_cast<SubClassType*>(this);
ffi::Any result{nullptr};
uint32_t type_index;
diff --git a/src/runtime/minrpc/rpc_reference.h
b/src/runtime/minrpc/rpc_reference.h
index 41bb40b3f2..42be97b53f 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -200,7 +200,7 @@ struct RPCReference {
num_bytes_ += sizeof(T) * num;
}
- void WriteObject(ffi::Object* obj) { num_bytes_ +=
channel_->GetObjectBytes(obj); }
+ void WriteFFIAny(const TVMFFIAny* obj) { num_bytes_ +=
channel_->GetFFIAnyProtocolBytes(obj); }
void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); }
@@ -373,11 +373,7 @@ struct RPCReference {
break;
}
default: {
- if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
-
channel->WriteObject(reinterpret_cast<ffi::Object*>(packed_args[i].v_obj));
- } else {
- channel->ThrowError(RPCServerStatus::kUnknownTypeIndex);
- }
+ channel->WriteFFIAny(&(packed_args[i]));
break;
}
}
@@ -472,7 +468,7 @@ struct RPCReference {
}
default: {
if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
- channel->ReadObject(&(packed_args[i]));
+ channel->ReadFFIAny(&(packed_args[i]));
} else {
channel->ThrowError(RPCServerStatus::kUnknownTypeIndex);
}
diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc
index 4e29dcc392..ddd5462c68 100644
--- a/src/runtime/profiling.cc
+++ b/src/runtime/profiling.cc
@@ -289,8 +289,8 @@ String ReportNode::AsCSV() const {
s << (*it).second.as<PercentNode>()->percent;
} else if ((*it).second.as<RatioNode>()) {
s << (*it).second.as<RatioNode>()->ratio;
- } else if ((*it).second.as<ffi::StringObj>()) {
- s << "\"" << Downcast<String>((*it).second) << "\"";
+ } else if (auto opt_str = (*it).second.as<ffi::String>()) {
+ s << "\"" << *opt_str << "\"";
}
}
if (i < headers.size() - 1) {
@@ -418,9 +418,9 @@ Any AggregateMetric(const std::vector<ffi::Any>& metrics) {
sum += metric.as<RatioNode>()->ratio;
}
return ObjectRef(make_object<RatioNode>(sum / metrics.size()));
- } else if (metrics[0].as<ffi::StringObj>()) {
+ } else if (auto opt_str = metrics[0].as<ffi::String>()) {
for (auto& m : metrics) {
- if (Downcast<String>(metrics[0]) != Downcast<String>(m)) {
+ if (*opt_str != m.as<ffi::String>()) {
return String("");
}
}
@@ -428,7 +428,7 @@ Any AggregateMetric(const std::vector<ffi::Any>& metrics) {
return metrics[0];
} else {
LOG(FATAL) << "Can only aggregate metrics with types DurationNode,
CountNode, "
- "PercentNode, RatioNode, and StringObj, but got "
+ "PercentNode, RatioNode, and String, but got "
<< metrics[0].GetTypeKey();
return ffi::Any(); // To silence warnings
}
@@ -467,8 +467,8 @@ static String print_metric(ffi::Any metric) {
set_locale_for_separators(s);
s << std::setprecision(2) << metric.as<RatioNode>()->ratio;
val = s.str();
- } else if (metric.as<ffi::StringObj>()) {
- val = Downcast<String>(metric);
+ } else if (auto opt_str = metric.as<ffi::String>()) {
+ val = *opt_str;
} else {
LOG(FATAL) << "Cannot print metric of type " << metric.GetTypeKey();
}
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index 9b9816a4d9..3dea9dc822 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -218,33 +218,36 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
this->Write(cdata);
}
- void WriteObject(Object* obj) {
+ void WriteFFIAny(const TVMFFIAny* in) {
// NOTE: for now all remote object are encoded as RPCObjectRef
// follow the same disco protocol in case we would like to upgrade later
//
// Rationale note: Only handle remote object allows the same mechanism to
work for minRPC
// which is needed for wasm and other env that goes through C API
- if (obj->IsInstance<RPCObjectRefObj>()) {
- auto* ref = static_cast<RPCObjectRefObj*>(obj);
+ const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(in);
+ if (const auto* ref = any_view_ptr->as<RPCObjectRefObj>()) {
this->template Write<uint32_t>(runtime::TypeIndex::kRuntimeRPCObjectRef);
uint64_t handle = reinterpret_cast<uint64_t>(ref->object_handle());
this->template Write<int64_t>(handle);
} else {
LOG(FATAL) << "ValueError: Object type is not supported in RPC calling
convention: "
- << obj->GetTypeKey() << " (type_index = " <<
obj->type_index() << ")";
+ << any_view_ptr->GetTypeKey() << " (type_index = " <<
any_view_ptr->type_index()
+ << ")";
}
}
- uint64_t GetObjectBytes(Object* obj) {
- if (obj->IsInstance<RPCObjectRefObj>()) {
+ uint64_t GetFFIAnyProtocolBytes(const TVMFFIAny* in) {
+ const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(in);
+ if (any_view_ptr->as<RPCObjectRefObj>()) {
return sizeof(uint32_t) + sizeof(int64_t);
} else {
LOG(FATAL) << "ValueError: Object type is not supported in RPC calling
convention: "
- << obj->GetTypeKey() << " (type_index = " <<
obj->type_index() << ")";
+ << any_view_ptr->GetTypeKey() << " (type_index = " <<
any_view_ptr->type_index()
+ << ")";
TVM_FFI_UNREACHABLE();
}
}
- void ReadObject(TVMFFIAny* out) {
+ void ReadFFIAny(TVMFFIAny* out) {
// NOTE: for now all remote object are encoded as RPCObjectRef
// follow the same disco protocol in case we would like to upgrade later
//
diff --git a/src/runtime/rpc/rpc_local_session.cc
b/src/runtime/rpc/rpc_local_session.cc
index 3fcd38dfec..3d4928f8b4 100644
--- a/src/runtime/rpc/rpc_local_session.cc
+++ b/src/runtime/rpc/rpc_local_session.cc
@@ -63,15 +63,16 @@ void LocalSession::EncodeReturn(ffi::Any rv, const
FEncodeReturn& encode_return)
packed_args[1] = TVMFFINDArrayGetDLTensorPtr(opaque_handle);
packed_args[2] = opaque_handle;
encode_return(ffi::PackedArgs(packed_args, 3));
- } else if (const auto* bytes = rv.as<ffi::BytesObj>()) {
+ } else if (const auto opt_bytes = rv.as<ffi::Bytes>()) {
// always pass bytes as byte array
TVMFFIByteArray byte_arr;
- byte_arr.data = bytes->data;
- byte_arr.size = bytes->size;
+ byte_arr.data = (*opt_bytes).data();
+ byte_arr.size = (*opt_bytes).size();
packed_args[1] = &byte_arr;
encode_return(ffi::PackedArgs(packed_args, 2));
- } else if (const auto* str = rv.as<ffi::StringObj>()) {
- packed_args[1] = str->data;
+ } else if (auto opt_str = rv.as<ffi::String>()) {
+ // encode string as c_str
+ packed_args[1] = (*opt_str).data();
encode_return(ffi::PackedArgs(packed_args, 2));
} else if (rv.as<ffi::ObjectRef>()) {
TVMFFIAny ret_any =
ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(rv));
diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc
index 42643de2d1..d1fb7bab90 100644
--- a/src/runtime/rpc/rpc_module.cc
+++ b/src/runtime/rpc/rpc_module.cc
@@ -88,8 +88,9 @@ class RPCWrappedFunc : public Object {
// scan and check whether we need rewrite these arguments
// to their remote variant.
for (int i = 0; i < args.size(); ++i) {
- if (const auto* str = args[i].as<ffi::StringObj>()) {
- packed_args[i] = str->data;
+ if (args[i].type_index() == ffi::TypeIndex::kTVMFFIStr) {
+ // pass string as c_str
+ packed_args[i] = args[i].cast<ffi::String>().data();
continue;
}
packed_args[i] = args[i];
diff --git a/src/script/printer/doc_printer/python_doc_printer.cc
b/src/script/printer/doc_printer/python_doc_printer.cc
index 8c352298c1..f8d773334f 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -351,8 +351,8 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc)
{
output_ << float_imm->value;
}
- } else if (const auto* string_obj = value.as<ffi::StringObj>()) {
- output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size)
<< "\"";
+ } else if (const auto opt_str = value.as<ffi::String>()) {
+ output_ << "\"" << support::StrEscape((*opt_str).data(),
(*opt_str).size()) << "\"";
} else {
LOG(FATAL) << "TypeError: Unsupported literal value type: " <<
value.GetTypeKey();
}
diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc
index 8288016d3e..63e703be55 100644
--- a/src/script/printer/ir/misc.cc
+++ b/src/script/printer/ir/misc.cc
@@ -41,7 +41,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
std::vector<POO> items{dict.begin(), dict.end()};
bool is_str_map = true;
for (const auto& kv : items) {
- if (!kv.first.as<ffi::StringObj>()) {
+ if (!kv.first.as<ffi::String>()) {
is_str_map = false;
break;
}
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index 0becec1f3f..737f27c7e9 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -184,7 +184,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("testing.AcceptsVariant",
[](Variant<String, Integer> arg) -> String {
if (auto opt_str = arg.as<String>()) {
- return ffi::StringObj::_type_key;
+ return ffi::StaticTypeKey::kTVMFFIStr;
} else {
return arg.get<Integer>().GetTypeKey();
}
diff --git a/src/target/target.cc b/src/target/target.cc
index ac82c51b1b..a337d4a161 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -431,7 +431,7 @@ Any TargetInternal::ParseType(const Any& obj, const
TargetKindNode::ValueTypeInf
return Target(TargetInternal::FromString(str.value()));
} else if (const auto* ptr = obj.as<ffi::MapObj>()) {
for (const auto& kv : *ptr) {
- if (!kv.first.as<ffi::StringObj>()) {
+ if (!kv.first.as<ffi::String>()) {
TVM_FFI_THROW(TypeError)
<< "Target object requires key of dict to be str, but get: " <<
kv.first.GetTypeKey();
}
@@ -444,16 +444,16 @@ Any TargetInternal::ParseType(const Any& obj, const
TargetKindNode::ValueTypeInf
} else if (info.type_index == ffi::ArrayObj::RuntimeTypeIndex()) {
// Parsing array
const auto* array = ObjTypeCheck<const ffi::ArrayObj*>(obj, "Array");
- std::vector<ObjectRef> result;
+ std::vector<Any> result;
for (const Any& e : *array) {
try {
- result.push_back(TargetInternal::ParseType(e,
*info.key).cast<ObjectRef>());
+ result.push_back(TargetInternal::ParseType(e, *info.key));
} catch (const Error& e) {
std::string index = '[' + std::to_string(result.size()) + ']';
throw Error(e.kind(), index + e.message(), e.traceback());
}
}
- return Array<ObjectRef>(result);
+ return Array<Any>(result);
} else if (info.type_index == ffi::MapObj::RuntimeTypeIndex()) {
// Parsing map
const auto* map = ObjTypeCheck<const ffi::MapObj*>(obj, "Map");
diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc
index 68e7bbbf83..3ee43c698a 100644
--- a/src/tir/schedule/instruction.cc
+++ b/src/tir/schedule/instruction.cc
@@ -70,10 +70,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
for (const Any& obj : self->inputs) {
if (obj == nullptr) {
inputs.push_back(String("None"));
+ } else if (auto opt_str = obj.as<ffi::String>()) {
+ inputs.push_back(String('"' + (*opt_str).operator std::string() +
'"'));
} else if (obj.as<BlockRVNode>() || obj.as<LoopRVNode>()) {
inputs.push_back(String("_"));
- } else if (const auto* str_obj = obj.as<ffi::StringObj>()) {
- inputs.push_back(String('"' + std::string(str_obj->data) + '"'));
} else if (obj.type_index() <
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
inputs.push_back(obj);
} else if (obj.as<IntImmNode>() || obj.as<FloatImmNode>()) {
diff --git a/src/tir/schedule/instruction_traits.h
b/src/tir/schedule/instruction_traits.h
index 5507c02bfe..bff619ca49 100644
--- a/src/tir/schedule/instruction_traits.h
+++ b/src/tir/schedule/instruction_traits.h
@@ -418,8 +418,8 @@ TVM_ALWAYS_INLINE Array<Any>
UnpackedInstTraits<TTraits>::_ConvertOutputs(const
inline void PythonAPICall::AsPythonString(const Any& obj, std::ostream& os) {
if (obj == nullptr) {
os << "None";
- } else if (const auto* str = obj.as<ffi::StringObj>()) {
- os << str->data;
+ } else if (auto opt_str = obj.as<ffi::String>()) {
+ os << *opt_str;
} else if (const auto opt_int_imm = obj.try_cast<IntImm>()) {
os << (*opt_int_imm)->value;
} else if (const auto opt_float_imm = obj.try_cast<FloatImm>()) {
diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc
index 6efb17de25..43c2ce0a7b 100644
--- a/src/tir/schedule/trace.cc
+++ b/src/tir/schedule/trace.cc
@@ -124,9 +124,9 @@ Array<Any> TranslateInputRVs(
LOG(FATAL) << "IndexError: Random variable is not defined " << input;
throw;
}
- } else if (const auto* str_obj = input.as<ffi::StringObj>()) {
+ } else if (auto opt_str = input.as<ffi::String>()) {
// Case 2. string => "content"
- results.push_back(String('"' + std::string(str_obj->data) + '"'));
+ results.push_back(String('"' + (*opt_str).operator std::string() + '"'));
} else if (input.as<IntImmNode>() || input.as<FloatImmNode>()) {
// Case 3. integer or floating-point number
results.push_back(input);
@@ -179,11 +179,11 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs,
results.push_back(input);
continue;
}
- const auto* str = input.as<ffi::StringObj>();
- CHECK(str) << "TypeError: Expect String, but gets: " << input.GetTypeKey();
- CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in
input names";
- const char* name = str->data;
- int64_t size = str->size;
+ auto opt_str = input.as<ffi::String>();
+ CHECK(opt_str) << "TypeError: Expect String, but gets: " <<
input.GetTypeKey();
+ CHECK_GT((*opt_str).size(), 0) << "ValueError: Empty string is not allowed
in input names";
+ const char* name = (*opt_str).data();
+ int64_t size = (*opt_str).size();
if (name[0] == '{' && name[size - 1] == '}') {
Any obj = LoadJSON(name);
// Case 6. IndexMap
@@ -363,8 +363,8 @@ Array<String> TraceNode::AsPython(bool remove_postproc)
const {
Array<Any> attrs;
attrs.reserve(inst->attrs.size());
for (const Any& obj : inst->attrs) {
- if (const auto* str = obj.as<ffi::StringObj>()) {
- attrs.push_back(String('"' + std::string(str->data) + '"'));
+ if (auto opt_str = obj.as<ffi::String>()) {
+ attrs.push_back(String('"' + (*opt_str).operator std::string() + '"'));
} else {
attrs.push_back(obj);
}
@@ -428,8 +428,8 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule
sch) {
try {
const auto* arr = inst_entry.as<ffi::ArrayObj>();
ICHECK(arr && arr->size() == 4);
- const auto* arr0 = arr->at(0).as<ffi::StringObj>();
- kind = InstructionKind::Get(arr0->data);
+ ffi::String arr0 = arr->at(0).cast<ffi::String>();
+ kind = InstructionKind::Get(arr0);
inputs = arr->at(1).cast<Array<Any>>();
attrs = arr->at(2).cast<Array<Any>>();
outputs = arr->at(3).cast<Array<String>>();
diff --git a/tests/python/contrib/test_popen_pool.py
b/tests/python/contrib/test_popen_pool.py
index 7ac3c42dcb..cc1fa95734 100644
--- a/tests/python/contrib/test_popen_pool.py
+++ b/tests/python/contrib/test_popen_pool.py
@@ -96,7 +96,6 @@ def test_popen_pool_executor():
assert value3.result() == 3
value = value4.result()
- assert isinstance(value, tvm.runtime.String)
assert value == "xyz"
pool = PopenPoolExecutor(max_workers=4, timeout=None)
diff --git a/tests/python/disco/test_session.py
b/tests/python/disco/test_session.py
index 83002e971c..db357c5439 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -174,7 +174,7 @@ def test_string_obj(session_kind):
for i in range(num_workers):
value = result.debug_get_from_remote(i)
- assert isinstance(value, String)
+ assert isinstance(value, str)
assert value == "hello_suffix"
diff --git a/tests/python/ffi/test_function.py
b/tests/python/ffi/test_function.py
index 4dbc54cd9d..5a8b4acb1f 100644
--- a/tests/python/ffi/test_function.py
+++ b/tests/python/ffi/test_function.py
@@ -43,13 +43,11 @@ def test_echo():
str_result = fecho("hello")
assert isinstance(str_result, str)
assert str_result == "hello"
- assert isinstance(str_result, tvm_ffi.String)
# test bytes
bytes_result = fecho(b"abc")
assert isinstance(bytes_result, bytes)
assert bytes_result == b"abc"
- assert isinstance(bytes_result, tvm_ffi.Bytes)
# test dtype
dtype_result = fecho(tvm_ffi.dtype("float32"))
diff --git a/tests/python/ffi/test_string.py b/tests/python/ffi/test_string.py
index cac948b53d..85fed5670c 100644
--- a/tests/python/ffi/test_string.py
+++ b/tests/python/ffi/test_string.py
@@ -22,17 +22,16 @@ from tvm import ffi as tvm_ffi
def test_string():
fecho = tvm_ffi.get_global_func("testing.echo")
s = tvm_ffi.String("hello")
- assert isinstance(s, tvm_ffi.String)
s2 = fecho(s)
- assert s2.__tvm_ffi_object__.same_as(s.__tvm_ffi_object__)
-
+ assert s2 == "hello"
s3 = tvm_ffi.convert("hello")
- assert isinstance(s3, tvm_ffi.String)
assert isinstance(s3, str)
+ x = "hello long string"
+ assert fecho(x) == x
+
s4 = pickle.loads(pickle.dumps(s))
assert s4 == "hello"
- assert isinstance(s4, tvm_ffi.String)
def test_bytes():
@@ -40,7 +39,7 @@ def test_bytes():
b = tvm_ffi.Bytes(b"hello")
assert isinstance(b, tvm_ffi.Bytes)
b2 = fecho(b)
- assert b2.__tvm_ffi_object__.same_as(b.__tvm_ffi_object__)
+ assert b2 == b"hello"
b3 = tvm_ffi.convert(b"hello")
assert isinstance(b3, tvm_ffi.Bytes)