This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s3
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit b7020daba9f367de5b67754fea9765fa55235388
Author: tqchen <[email protected]>
AuthorDate: Sun May 4 18:34:40 2025 -0400

    Fix cython pickle
---
 python/tvm/ffi/cython/object.pxi |  4 ++++
 python/tvm/ffi/cython/string.pxi |  2 ++
 src/node/structural_hash.cc      | 32 ++++++++++++++++++++++++++++++++
 tests/python/ffi/test_string.py  |  9 +++++++++
 4 files changed, 47 insertions(+)

diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi
index 1ac32c3bc6..d2eade8af9 100644
--- a/python/tvm/ffi/cython/object.pxi
+++ b/python/tvm/ffi/cython/object.pxi
@@ -100,6 +100,10 @@ cdef class Object:
         cls = type(self)
         return (_new_object, (cls,), self.__getstate__())
 
+    def __reduce_cython__(self):
+        cls = type(self)
+        return (_new_object, (cls,), self.__getstate__())
+
     def __getstate__(self):
         if not self.__chandle__() == 0:
             # need to explicit convert to str in case String
diff --git a/python/tvm/ffi/cython/string.pxi b/python/tvm/ffi/cython/string.pxi
index 00ec92b7ec..733ea90301 100644
--- a/python/tvm/ffi/cython/string.pxi
+++ b/python/tvm/ffi/cython/string.pxi
@@ -65,6 +65,7 @@ class String(str, PyNativeObject):
         val.__tvm_ffi_object__ = obj
         return val
 
+
 _register_object_by_index(kTVMFFIStr, String)
 
 
@@ -89,6 +90,7 @@ class Bytes(bytes, PyNativeObject):
         val.__tvm_ffi_object__ = obj
         return val
 
+
 _register_object_by_index(kTVMFFIBytes, Bytes)
 
 # We special handle str/bytes constructor in cython to avoid extra cyclic deps
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index bba8a59647..cba27055ba 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -328,6 +328,22 @@ struct StringObjTrait {
   }
 };
 
+struct BytesObjTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+  static void SHashReduce(const ffi::BytesObj* key, SHashReducer hash_reduce) {
+    
hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, 
key->size));
+  }
+
+  static bool SEqualReduce(const ffi::BytesObj* lhs, const ffi::BytesObj* rhs,
+                           SEqualReducer equal) {
+    if (lhs == rhs) return true;
+    if (lhs->size != rhs->size) return false;
+    if (lhs->data == rhs->data) return true;
+    return std::memcmp(lhs->data, rhs->data, lhs->size) == 0;
+  }
+};
+
 struct RefToObjectPtr : public ObjectRef {
   static ObjectPtr<Object> Get(const ObjectRef& ref) {
     return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(ref);
@@ -350,6 +366,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
     });
 
+TVM_REGISTER_REFLECTION_VTABLE(ffi::BytesObj, BytesObjTrait)
+    .set_creator([](const std::string& bytes) {
+      return RefToObjectPtr::Get(runtime::String(bytes));
+    })
+    .set_repr_bytes([](const Object* n) -> std::string {
+      return GetRef<ffi::Bytes>(static_cast<const ffi::BytesObj*>(n))
+          .
+          operator std::string();
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<ffi::BytesObj>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const ffi::BytesObj*>(node.get());
+      p->stream << "b\"" << support::StrEscape(op->data, op->size) << '"';
+    });
+
 struct ModuleNodeTrait {
   static constexpr const std::nullptr_t VisitAttrs = nullptr;
   static constexpr const std::nullptr_t SHashReduce = nullptr;
diff --git a/tests/python/ffi/test_string.py b/tests/python/ffi/test_string.py
index 67c82d52f7..040934bce0 100644
--- a/tests/python/ffi/test_string.py
+++ b/tests/python/ffi/test_string.py
@@ -1,3 +1,4 @@
+import pickle
 from tvm import ffi as tvm_ffi
 
 
@@ -12,6 +13,10 @@ def test_string():
     assert isinstance(s3, tvm_ffi.String)
     assert isinstance(s3, str)
 
+    s4 = pickle.loads(pickle.dumps(s))
+    assert s4 == "hello"
+    assert isinstance(s4, tvm_ffi.String)
+
 
 def test_bytes():
     fecho = tvm_ffi.get_global_func("testing.echo")
@@ -27,3 +32,7 @@ def test_bytes():
     b4 = tvm_ffi.convert(bytearray(b"hello"))
     assert isinstance(b4, tvm_ffi.Bytes)
     assert isinstance(b4, bytes)
+
+    b5 = pickle.loads(pickle.dumps(b))
+    assert b5 == b"hello"
+    assert isinstance(b5, tvm_ffi.Bytes)
\ No newline at end of file

Reply via email to