This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s3 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit b7020daba9f367de5b67754fea9765fa55235388 Author: tqchen <[email protected]> AuthorDate: Sun May 4 18:34:40 2025 -0400 Fix cython pickle --- python/tvm/ffi/cython/object.pxi | 4 ++++ python/tvm/ffi/cython/string.pxi | 2 ++ src/node/structural_hash.cc | 32 ++++++++++++++++++++++++++++++++ tests/python/ffi/test_string.py | 9 +++++++++ 4 files changed, 47 insertions(+) diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi index 1ac32c3bc6..d2eade8af9 100644 --- a/python/tvm/ffi/cython/object.pxi +++ b/python/tvm/ffi/cython/object.pxi @@ -100,6 +100,10 @@ cdef class Object: cls = type(self) return (_new_object, (cls,), self.__getstate__()) + def __reduce_cython__(self): + cls = type(self) + return (_new_object, (cls,), self.__getstate__()) + def __getstate__(self): if not self.__chandle__() == 0: # need to explicit convert to str in case String diff --git a/python/tvm/ffi/cython/string.pxi b/python/tvm/ffi/cython/string.pxi index 00ec92b7ec..733ea90301 100644 --- a/python/tvm/ffi/cython/string.pxi +++ b/python/tvm/ffi/cython/string.pxi @@ -65,6 +65,7 @@ class String(str, PyNativeObject): val.__tvm_ffi_object__ = obj return val + _register_object_by_index(kTVMFFIStr, String) @@ -89,6 +90,7 @@ class Bytes(bytes, PyNativeObject): val.__tvm_ffi_object__ = obj return val + _register_object_by_index(kTVMFFIBytes, Bytes) # We special handle str/bytes constructor in cython to avoid extra cyclic deps diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index bba8a59647..cba27055ba 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -328,6 +328,22 @@ struct StringObjTrait { } }; +struct BytesObjTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const ffi::BytesObj* key, SHashReducer hash_reduce) { + hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, key->size)); + } + + static bool SEqualReduce(const ffi::BytesObj* lhs, const ffi::BytesObj* rhs, + SEqualReducer equal) { + if (lhs == rhs) return true; + if (lhs->size != rhs->size) return false; + if (lhs->data == rhs->data) return true; + return std::memcmp(lhs->data, rhs->data, lhs->size) == 0; + } +}; + struct RefToObjectPtr : public ObjectRef { static ObjectPtr<Object> Get(const ObjectRef& ref) { return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(ref); @@ -350,6 +366,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; }); +TVM_REGISTER_REFLECTION_VTABLE(ffi::BytesObj, BytesObjTrait) + .set_creator([](const std::string& bytes) { + return RefToObjectPtr::Get(runtime::String(bytes)); + }) + .set_repr_bytes([](const Object* n) -> std::string { + return GetRef<ffi::Bytes>(static_cast<const ffi::BytesObj*>(n)) + . + operator std::string(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<ffi::BytesObj>([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast<const ffi::BytesObj*>(node.get()); + p->stream << "b\"" << support::StrEscape(op->data, op->size) << '"'; + }); + struct ModuleNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; static constexpr const std::nullptr_t SHashReduce = nullptr; diff --git a/tests/python/ffi/test_string.py b/tests/python/ffi/test_string.py index 67c82d52f7..040934bce0 100644 --- a/tests/python/ffi/test_string.py +++ b/tests/python/ffi/test_string.py @@ -1,3 +1,4 @@ +import pickle from tvm import ffi as tvm_ffi @@ -12,6 +13,10 @@ def test_string(): assert isinstance(s3, tvm_ffi.String) assert isinstance(s3, str) + s4 = pickle.loads(pickle.dumps(s)) + assert s4 == "hello" + assert isinstance(s4, tvm_ffi.String) + def test_bytes(): fecho = tvm_ffi.get_global_func("testing.echo") @@ -27,3 +32,7 @@ def test_bytes(): b4 = tvm_ffi.convert(bytearray(b"hello")) assert isinstance(b4, tvm_ffi.Bytes) assert isinstance(b4, bytes) + + b5 = pickle.loads(pickle.dumps(b)) + assert b5 == b"hello" + assert isinstance(b5, tvm_ffi.Bytes) \ No newline at end of file
