This is an automated email from the ASF dual-hosted git repository. chaokunyang pushed a commit to branch releases-0.12 in repository https://gitbox.apache.org/repos/asf/fory.git
commit 1c0df657a83f64acb80995f617a1822d820baa86 Author: Shawn Yang <[email protected]> AuthorDate: Fri Sep 19 18:57:48 2025 +0800 feat(python): drop-in replacement for pickle serialization (#2629) <!-- **Thanks for contributing to Apache Fory™.** **If this is your first time opening a PR on fory, you can refer to [CONTRIBUTING.md](https://github.com/apache/fory/blob/main/CONTRIBUTING.md).** Contribution Checklist - The **Apache Fory™** community has requirements on the naming of pr titles. You can also find instructions in [CONTRIBUTING.md](https://github.com/apache/fory/blob/main/CONTRIBUTING.md). - Apache Fory™ has a strong focus on performance. If the PR you submit will have an impact on performance, please benchmark it first and provide the benchmark result here. --> Implement serialization for any pickleable objects, so that pyfory can be used to replace pickle for smaller size and faster speed. <!-- Describe the details of this PR. --> Closes #2417 <!-- If any user-facing interface changes, please [open an issue](https://github.com/apache/fory/issues/new/choose) describing the need to do so and update the document if necessary. Delete section if not applicable. --> - [ ] Does this PR introduce any public API change? - [ ] Does this PR introduce any binary protocol compatibility change? <!-- When the PR has an impact on performance (if you don't know whether the PR will have an impact on performance, you can submit the PR first, and if it will have impact on performance, the code reviewer will explain it), be sure to attach a benchmark data here. Delete section if not applicable. --> --- python/pyfory/_fory.py | 63 +----- python/pyfory/_registry.py | 101 +++++----- python/pyfory/_serialization.pyx | 74 +------ python/pyfory/_serializer.py | 8 +- python/pyfory/_struct.py | 5 - python/pyfory/serializer.py | 340 ++++++++++++++++++++------------- python/pyfory/tests/test_serializer.py | 49 +---- python/pyproject.toml | 1 - 8 files changed, 276 insertions(+), 365 deletions(-) diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 60562f9b4..2894fe7a9 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -18,7 +18,6 @@ import enum import logging import os -import warnings from abc import ABC, abstractmethod from typing import Union, Iterable, TypeVar @@ -37,9 +36,6 @@ try: except ImportError: np = None -from cloudpickle import Pickler - -from pickle import Unpickler logger = logging.getLogger(__name__) @@ -101,11 +97,8 @@ class Fory: "ref_tracking", "ref_resolver", "type_resolver", - "serialization_context", "require_type_registration", "buffer", - "pickler", - "unpickler", "_buffer_callback", "_buffers", "metastring_resolver", @@ -115,7 +108,6 @@ class Fory: "max_depth", "depth", ) - serialization_context: "SerializationContext" def __init__( self, @@ -152,19 +144,7 @@ class Fory: self.metastring_resolver = MetaStringResolver() self.type_resolver = TypeResolver(self) self.type_resolver.initialize() - self.serialization_context = SerializationContext() self.buffer = Buffer.allocate(32) - if not require_type_registration: - warnings.warn( - "Type registration is disabled, unknown types can be deserialized which may be insecure.", - RuntimeWarning, - stacklevel=2, - ) - self.pickler = Pickler(self.buffer) - self.unpickler = None - else: - self.pickler = _PicklerStub() - self.unpickler = _UnpicklerStub() self._buffer_callback = None self._buffers = None self._unsupported_callback = None @@ -231,9 +211,7 @@ class Fory: ) -> Union[Buffer, bytes]: self._buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback - if buffer is not None: - self.pickler = Pickler(buffer) - else: + if buffer is None: self.buffer.writer_index = 0 buffer = self.buffer if self.language == Language.XLANG: @@ -463,21 +441,11 @@ class Fory: def handle_unsupported_write(self, buffer, obj): if self._unsupported_callback is None or self._unsupported_callback(obj): - buffer.write_bool(True) - self.pickler.dump(obj) - else: - buffer.write_bool(False) + raise NotImplementedError(f"{type(obj)} is not supported for write") def handle_unsupported_read(self, buffer): - in_band = buffer.read_bool() - if in_band: - unpickler = self.unpickler - if unpickler is None: - self.unpickler = unpickler = Unpickler(buffer) - return unpickler.load() - else: - assert self._unsupported_objects is not None - return next(self._unsupported_objects) + assert self._unsupported_objects is not None + return next(self._unsupported_objects) def write_ref_pyobject(self, buffer, value, typeinfo=None): if self.ref_resolver.write_ref_or_null(buffer, value): @@ -488,7 +456,7 @@ class Fory: typeinfo.serializer.write(buffer, value) def read_ref_pyobject(self, buffer): - return self.deserialize_ref(buffer) + return self.deserialize_ref(buffer) def inc_depth(self): self.depth += 1 @@ -507,9 +475,7 @@ class Fory: def reset_write(self): self.ref_resolver.reset_write() self.type_resolver.reset_write() - self.serialization_context.reset() self.metastring_resolver.reset_write() - self.pickler.clear_memo() self._buffer_callback = None self._unsupported_callback = None @@ -517,9 +483,7 @@ class Fory: self.depth = 0 self.ref_resolver.reset_read() self.type_resolver.reset_read() - self.serialization_context.reset() self.metastring_resolver.reset_write() - self.unpickler = None self._buffers = None self._unsupported_objects = None @@ -562,20 +526,3 @@ _ENABLE_TYPE_REGISTRATION_FORCIBLY = os.getenv("ENABLE_TYPE_REGISTRATION_FORCIBL "1", "true", } - - -class _PicklerStub: - def dump(self, o): - raise ValueError( - f"Type {type(o)} is not registered, " - f"pickle is not allowed when type registration enabled, " - f"Please register the type or pass unsupported_callback" - ) - - def clear_memo(self): - pass - - -class _UnpicklerStub: - def load(self): - raise ValueError("pickle is not allowed when type registration enabled, Please register the type or pass unsupported_callback") diff --git a/python/pyfory/_registry.py b/python/pyfory/_registry.py index f72baac63..e97f9ced7 100644 --- a/python/pyfory/_registry.py +++ b/python/pyfory/_registry.py @@ -25,7 +25,7 @@ import types from typing import TypeVar, Union from enum import Enum -from pyfory._serialization import ENABLE_FORY_CYTHON_SERIALIZATION +from pyfory import ENABLE_FORY_CYTHON_SERIALIZATION from pyfory import Language from pyfory.error import TypeUnregisteredError @@ -35,9 +35,6 @@ from pyfory.serializer import ( NDArraySerializer, PyArraySerializer, DynamicPyArraySerializer, - _PickleStub, - PickleStrongCacheStub, - PickleCacheStub, NoneSerializer, BooleanSerializer, ByteSerializer, @@ -56,14 +53,15 @@ from pyfory.serializer import ( SetSerializer, EnumSerializer, SliceSerializer, - PickleCacheSerializer, - PickleStrongCacheSerializer, - PickleSerializer, DataClassSerializer, StatefulSerializer, ReduceSerializer, FunctionSerializer, ObjectSerializer, + TypeSerializer, + MethodSerializer, + UnsupportedSerializer, + NativeFuncMethodSerializer, ) from pyfory.meta.metastring import MetaStringEncoder, MetaStringDecoder from pyfory.type import ( @@ -75,6 +73,7 @@ from pyfory.type import ( Float32Type, Float64Type, load_class, + record_class_factory, ) from pyfory._fory import ( DYNAMIC_TYPE_ID, @@ -158,9 +157,10 @@ class TypeResolver: "metastring_resolver", "language", "_type_id_to_typeinfo", + "_internal_py_serializer_map", ) - def __init__(self, fory): + def __init__(self, fory, meta_share=False): self.fory = fory self.metastring_resolver = fory.metastring_resolver self.language = fory.language @@ -182,34 +182,41 @@ class TypeResolver: self.namespace_decoder = MetaStringDecoder(".", "_") self.typename_encoder = MetaStringEncoder("$", "_") self.typename_decoder = MetaStringDecoder("$", "_") + self._internal_py_serializer_map = {} def initialize(self): - self._initialize_xlang() + self._initialize_common() if self.fory.language == Language.PYTHON: self._initialize_py() + else: + self._initialize_xlang() def _initialize_py(self): register = functools.partial(self._register_type, internal=True) - register( - _PickleStub, - type_id=PickleSerializer.PICKLE_TYPE_ID, - serializer=PickleSerializer, - ) - register( - PickleStrongCacheStub, - type_id=97, - serializer=PickleStrongCacheSerializer(self.fory), - ) - register( - PickleCacheStub, - type_id=98, - serializer=PickleCacheSerializer(self.fory), - ) register(type(None), serializer=NoneSerializer) register(tuple, serializer=TupleSerializer) register(slice, serializer=SliceSerializer) + register(np.ndarray, serializer=NDArraySerializer) + register(array.array, serializer=DynamicPyArraySerializer) + self._internal_py_serializer_map = { + ReduceSerializer: (self._stub_cls("__Reduce__"), self._next_type_id()), + TypeSerializer: (self._stub_cls("__Type__"), self._next_type_id()), + MethodSerializer: (self._stub_cls("__Method__"), self._next_type_id()), + NativeFuncMethodSerializer: (self._stub_cls("__NativeFunction__"), self._next_type_id()), + } + for serializer, (stub_cls, type_id) in self._internal_py_serializer_map.items(): + register(stub_cls, serializer=serializer, type_id=type_id) + + @staticmethod + def _stub_cls(name: str): + return record_class_factory(name, []) def _initialize_xlang(self): + register = functools.partial(self._register_type, internal=True) + register(array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer) + register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer) + + def _initialize_common(self): register = functools.partial(self._register_type, internal=True) register(None, type_id=TypeId.NA, serializer=NoneSerializer) register(bool, type_id=TypeId.BOOL, serializer=BooleanSerializer) @@ -240,7 +247,6 @@ class TypeResolver: type_id=typeid, serializer=PyArraySerializer(self.fory, ftype, typeid), ) - register(array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer) if np: # overwrite pyarray with same type id. # if pyarray are needed, one must annotate that value with XXXArrayType @@ -256,7 +262,6 @@ class TypeResolver: type_id=typeid, serializer=Numpy1DArraySerializer(self.fory, ftype, dtype), ) - register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer) register(list, type_id=TypeId.LIST, serializer=ListSerializer) register(set, type_id=TypeId.SET, serializer=SetSerializer) register(dict, type_id=TypeId.MAP, serializer=MapSerializer) @@ -416,7 +421,7 @@ class TypeResolver: self._named_type_to_typeinfo[(namespace, typename)] = typeinfo self._ns_type_to_typeinfo[(ns_meta_bytes, type_meta_bytes)] = typeinfo self._types_info[cls] = typeinfo - if type_id > 0 and (self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id)): + if type_id is not None and type_id != 0 and (self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id)): if type_id not in self._type_id_to_typeinfo or not internal: self._type_id_to_typeinfo[type_id] = typeinfo self._types_info[cls] = typeinfo @@ -469,12 +474,12 @@ class TypeResolver: if self.language == Language.PYTHON: if isinstance(serializer, EnumSerializer): type_id = TypeId.NAMED_ENUM - elif type(serializer) is PickleSerializer: - type_id = PickleSerializer.PICKLE_TYPE_ID elif isinstance(serializer, FunctionSerializer): type_id = TypeId.NAMED_EXT - elif isinstance(serializer, (ObjectSerializer, StatefulSerializer, ReduceSerializer)): + elif isinstance(serializer, (ObjectSerializer, StatefulSerializer)): type_id = TypeId.NAMED_EXT + elif self._internal_py_serializer_map.get(type(serializer)) is not None: + type_id = self._internal_py_serializer_map.get(type(serializer))[1] if not self.require_registration: if isinstance(serializer, DataClassSerializer): type_id = TypeId.NAMED_STRUCT @@ -502,35 +507,33 @@ class TypeResolver: serializer = DataClassSerializer(self.fory, cls) elif issubclass(cls, enum.Enum): serializer = EnumSerializer(self.fory, cls) + elif ("builtin_function_or_method" in str(cls) or "cython_function_or_method" in str(cls)) and "<locals>" not in str(cls): + serializer = NativeFuncMethodSerializer(self.fory, cls) + elif cls is type(self.initialize): + # Handle bound method objects + serializer = MethodSerializer(self.fory, cls) + elif issubclass(cls, type): + # Handle Python type objects and metaclass such as numpy._DTypeMeta(i.e. np.dtype) + serializer = TypeSerializer(self.fory, cls) + elif cls is array.array: + # Handle array.array objects with DynamicPyArraySerializer + # Note: This will use DynamicPyArraySerializer for all array.array objects + serializer = DynamicPyArraySerializer(self.fory, cls) elif (hasattr(cls, "__reduce__") and cls.__reduce__ is not object.__reduce__) or ( hasattr(cls, "__reduce_ex__") and cls.__reduce_ex__ is not object.__reduce_ex__ ): # Use ReduceSerializer for objects that have custom __reduce__ or __reduce_ex__ methods # This has higher precedence than StatefulSerializer and ObjectSerializer # Only use it for objects with custom reduce methods, not default ones from the object - module_name = getattr(cls, "__module__", "") - if module_name.startswith("pandas.") or module_name == "builtins" or cls.__name__ in ("type", "function", "method"): - # Exclude pandas, built-ins, and certain system types - serializer = PickleSerializer(self.fory, cls) - else: - serializer = ReduceSerializer(self.fory, cls) + serializer = ReduceSerializer(self.fory, cls) elif hasattr(cls, "__getstate__") and hasattr(cls, "__setstate__"): # Use StatefulSerializer for objects that support __getstate__ and __setstate__ - # But exclude certain types that have incompatible state methods - module_name = getattr(cls, "__module__", "") - if module_name.startswith("pandas."): - # Pandas objects have __getstate__/__setstate__ but use incompatible pickle formats - serializer = PickleSerializer(self.fory, cls) - else: - serializer = StatefulSerializer(self.fory, cls) - elif ( - cls is not type - and (hasattr(cls, "__dict__") or hasattr(cls, "__slots__")) - and not (np and (issubclass(cls, np.dtype) or cls is type(np.dtype))) - ): + serializer = StatefulSerializer(self.fory, cls) + elif hasattr(cls, "__dict__") or hasattr(cls, "__slots__"): serializer = ObjectSerializer(self.fory, cls) else: - serializer = PickleSerializer(self.fory, cls) + # c-extension types will go to here + serializer = UnsupportedSerializer(self.fory, cls) return serializer def _load_metabytes_to_typeinfo(self, ns_metabytes, type_metabytes): diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx index 833fbed24..3360610dd 100644 --- a/python/pyfory/_serialization.pyx +++ b/python/pyfory/_serialization.pyx @@ -30,7 +30,6 @@ from typing import TypeVar, Union, Iterable from pyfory._util import get_bit, set_bit, clear_bit from pyfory import _fory as fmod from pyfory._fory import Language -from pyfory._fory import _PicklerStub, _UnpicklerStub, Pickler, Unpickler from pyfory._fory import _ENABLE_TYPE_REGISTRATION_FORCIBLY from pyfory.lib import mmh3 from pyfory.meta.metastring import Encoding @@ -590,10 +589,7 @@ cdef class Fory: cdef readonly MapRefResolver ref_resolver cdef readonly TypeResolver type_resolver cdef readonly MetaStringResolver metastring_resolver - cdef readonly SerializationContext serialization_context cdef Buffer buffer - cdef public object pickler # pickle.Pickler - cdef public object unpickler # Optional[pickle.Unpickler] cdef object _buffer_callback cdef object _buffers # iterator cdef object _unsupported_callback @@ -634,20 +630,7 @@ cdef class Fory: self.metastring_resolver = MetaStringResolver() self.type_resolver = TypeResolver(self) self.type_resolver.initialize() - self.serialization_context = SerializationContext() self.buffer = Buffer.allocate(32) - if not require_type_registration: - warnings.warn( - "Type registration is disabled, unknown types can be deserialized " - "which may be insecure.", - RuntimeWarning, - stacklevel=2, - ) - self.pickler = Pickler(self.buffer) - else: - self.pickler = _PicklerStub() - self.unpickler = _UnpicklerStub() - self.unpickler = None self._buffer_callback = None self._buffers = None self._unsupported_callback = None @@ -702,9 +685,7 @@ cdef class Fory: self, obj, Buffer buffer, buffer_callback=None, unsupported_callback=None): self._buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback - if buffer is not None: - self.pickler = Pickler(self.buffer) - else: + if buffer is None: self.buffer.writer_index = 0 buffer = self.buffer if self.language == Language.XLANG: @@ -832,8 +813,6 @@ cdef class Fory: cpdef inline _deserialize( self, Buffer buffer, buffers=None, unsupported_objects=None): - if not self.require_type_registration: - self.unpickler = Unpickler(buffer) if unsupported_objects is not None: self._unsupported_objects = iter(unsupported_objects) if self.language == Language.XLANG: @@ -939,7 +918,6 @@ cdef class Fory: self, Buffer buffer, Serializer serializer=None): if serializer is None: serializer = self.type_resolver.read_typeinfo(buffer).serializer - self.depth += 1 self.inc_depth() o = serializer.xread(buffer) self.depth -= 1 @@ -983,22 +961,13 @@ cdef class Fory: buffer.reader_index += size return buf - cpdef inline handle_unsupported_write(self, Buffer buffer, obj): + cpdef handle_unsupported_write(self, buffer, obj): if self._unsupported_callback is None or self._unsupported_callback(obj): - buffer.write_bool(True) - self.pickler.dump(obj) - else: - buffer.write_bool(False) + raise NotImplementedError(f"{type(obj)} is not supported for write") - cpdef inline handle_unsupported_read(self, Buffer buffer): - cdef c_bool in_band = buffer.read_bool() - if in_band: - if self.unpickler is None: - self.unpickler.buffer = Unpickler(buffer) - return self.unpickler.load() - else: - assert self._unsupported_objects is not None - return next(self._unsupported_objects) + cpdef handle_unsupported_read(self, buffer): + assert self._unsupported_objects is not None + return next(self._unsupported_objects) cpdef inline write_ref_pyobject( self, Buffer buffer, value, TypeInfo typeinfo=None): @@ -1016,7 +985,9 @@ cdef class Fory: return ref_resolver.get_read_object() # indicates that the object is first read. cdef TypeInfo typeinfo = self.type_resolver.read_typeinfo(buffer) + self.inc_depth() o = typeinfo.serializer.read(buffer) + self.depth -= 1 ref_resolver.set_read_object(ref_id, o) return o @@ -1024,8 +995,6 @@ cdef class Fory: self.ref_resolver.reset_write() self.type_resolver.reset_write() self.metastring_resolver.reset_write() - self.serialization_context.reset() - self.pickler.clear_memo() self._unsupported_callback = None cpdef inline reset_read(self): @@ -1033,9 +1002,7 @@ cdef class Fory: self.ref_resolver.reset_read() self.type_resolver.reset_read() self.metastring_resolver.reset_read() - self.serialization_context.reset() self._buffers = None - self.unpickler = None self._unsupported_objects = None cpdef inline reset(self): @@ -1095,29 +1062,6 @@ cpdef inline read_nullable_pystr(Buffer buffer): return None [email protected] -cdef class SerializationContext: - cdef dict objects - - def __init__(self): - self.objects = dict() - - def add(self, key, obj): - self.objects[id(key)] = obj - - def __contains__(self, key): - return id(key) in self.objects - - def __getitem__(self, key): - return self.objects[id(key)] - - def get(self, key): - return self.objects.get(id(key)) - - def reset(self): - if len(self.objects) > 0: - self.objects.clear() - cdef class Serializer: cdef readonly Fory fory cdef readonly object type_ @@ -1650,10 +1594,12 @@ cdef class TupleSerializer(CollectionSerializer): else: self._read_same_type_ref(buffer, len_, tuple_, typeinfo) else: + self.fory.inc_depth() for i in range(len_): elem = get_next_element(buffer, ref_resolver, type_resolver, is_py) Py_INCREF(elem) PyTuple_SET_ITEM(tuple_, i, elem) + self.fory.dec_depth() return tuple_ cpdef inline _add_element(self, object collection_, int64_t index, object element): diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py index 77691c6a8..1a0666ed8 100644 --- a/python/pyfory/_serializer.py +++ b/python/pyfory/_serializer.py @@ -19,7 +19,7 @@ import datetime import logging import platform import time -from abc import ABC, abstractmethod +from abc import ABC from typing import Dict from pyfory._fory import NOT_NULL_INT64_FLAG @@ -74,13 +74,11 @@ class Serializer(ABC): def read(self, buffer): raise NotImplementedError - @abstractmethod def xwrite(self, buffer, value): - pass + raise NotImplementedError - @abstractmethod def xread(self, buffer): - pass + raise NotImplementedError @classmethod def support_subclass(cls) -> bool: diff --git a/python/pyfory/_struct.py b/python/pyfory/_struct.py index 6c455424e..dd7b91add 100644 --- a/python/pyfory/_struct.py +++ b/python/pyfory/_struct.py @@ -90,14 +90,12 @@ class ComplexTypeVisitor(TypeVisitor): return None def visit_other(self, field_name, type_, types_path=None): - from pyfory.serializer import PickleSerializer # Local import if is_subclass(type_, enum.Enum): return self.fory.type_resolver.get_serializer(type_) if type_ not in basic_types and not is_py_array_type(type_): return None serializer = self.fory.type_resolver.get_serializer(type_) - assert not isinstance(serializer, (PickleSerializer,)) return serializer @@ -199,14 +197,11 @@ class StructHashVisitor(TypeVisitor): self._hash = self._compute_field_hash(self._hash, hash_value) def visit_other(self, field_name, type_, types_path=None): - from pyfory.serializer import PickleSerializer # Local import - typeinfo = self.fory.type_resolver.get_typeinfo(type_, create=False) if typeinfo is None: id_ = 0 else: serializer = typeinfo.serializer - assert not isinstance(serializer, (PickleSerializer,)) id_ = typeinfo.type_id assert id_ is not None, serializer id_ = abs(id_) diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index aba3809c4..6e211fb89 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -17,17 +17,17 @@ import array import builtins +import dataclasses +import importlib +import inspect import itertools import marshal import logging import os -import pickle import types import typing import warnings -from weakref import WeakValueDictionary -import pyfory.lib.mmh3 from pyfory.buffer import Buffer from pyfory.codegen import ( gen_write_nullable_basic_stmts, @@ -35,9 +35,9 @@ from pyfory.codegen import ( compile_function, ) from pyfory.error import TypeNotCompatibleError -from pyfory.lib.collection import WeakIdentityKeyDictionary from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG from pyfory import Language +from typing import List try: import numpy as np @@ -137,82 +137,31 @@ class NoneSerializer(Serializer): return None -class _PickleStub: - pass +class TypeSerializer(Serializer): + """Serializer for Python type objects (classes).""" - -class PickleStrongCacheStub: - pass - - -class PickleCacheStub: - pass - - -class PickleStrongCacheSerializer(Serializer): - """If we can't create weak ref to object, use this cache serializer instead. - clear cache by threshold to avoid memory leak.""" - - __slots__ = "_cached", "_clear_threshold", "_counter" - - def __init__(self, fory, clear_threshold: int = 1000): - super().__init__(fory, PickleStrongCacheStub) - self._cached = {} - self._clear_threshold = clear_threshold + def __init__(self, fory, cls): + super().__init__(fory, cls) + self.cls = cls def write(self, buffer, value): - serialized = self._cached.get(value) - if serialized is None: - serialized = pickle.dumps(value) - self._cached[value] = serialized - buffer.write_bytes_and_size(serialized) - if len(self._cached) == self._clear_threshold: - self._cached.clear() + # Serialize the type by its module and name + module_name = getattr(value, "__module__", "") + type_name = getattr(value, "__name__", "") + buffer.write_string(module_name) + buffer.write_string(type_name) def read(self, buffer): - return pickle.loads(buffer.read_bytes_and_size()) + module_name = buffer.read_string() + type_name = buffer.read_string() - def xwrite(self, buffer, value): - raise NotImplementedError - - def xread(self, buffer): - raise NotImplementedError - - -class PickleCacheSerializer(Serializer): - __slots__ = "_cached", "_reverse_cached" - - def __init__(self, fory): - super().__init__(fory, PickleCacheStub) - self._cached = WeakIdentityKeyDictionary() - self._reverse_cached = WeakValueDictionary() - - def write(self, buffer, value): - cache = self._cached.get(value) - if cache is None: - serialized = pickle.dumps(value) - value_hash = pyfory.lib.mmh3.hash_buffer(serialized)[0] - cache = value_hash, serialized - self._cached[value] = cache - buffer.write_int64(cache[0]) - buffer.write_bytes_and_size(cache[1]) - - def read(self, buffer): - value_hash = buffer.read_int64() - value = self._reverse_cached.get(value_hash) - if value is None: - value = pickle.loads(buffer.read_bytes_and_size()) - self._reverse_cached[value_hash] = value + # Import the module and get the type + if module_name and module_name != "builtins": + module = __import__(module_name, fromlist=[type_name]) + return getattr(module, type_name) else: - size = buffer.read_int32() - buffer.skip(size) - return value - - def xwrite(self, buffer, value): - raise NotImplementedError - - def xread(self, buffer): - raise NotImplementedError + # Handle built-in types + return getattr(builtins, type_name, type) class PandasRangeIndexSerializer(Serializer): @@ -290,27 +239,26 @@ _ENABLE_FORY_PYTHON_JIT = os.environ.get("ENABLE_FORY_PYTHON_JIT", "True").lower "1", ) -# Moved from L32 to here, after all Serializer base classes and specific serializers -# like ListSerializer, MapSerializer, PickleSerializer are defined or imported -# and before DataClassSerializer which uses ComplexTypeVisitor from _struct. + from pyfory._struct import _get_hash, _sort_fields, ComplexTypeVisitor class DataClassSerializer(Serializer): - def __init__(self, fory, clz: type, xlang: bool = False): + def __init__(self, fory, clz: type, xlang: bool = False, field_names: List[str] = None, serializers: List[Serializer] = None): super().__init__(fory, clz) self._xlang = xlang # This will get superclass type hints too. self._type_hints = typing.get_type_hints(clz) - self._field_names = self._get_field_names(clz) + self._field_names = field_names or self._get_field_names(clz) self._has_slots = hasattr(clz, "__slots__") if self._xlang: - self._serializers = [None] * len(self._field_names) - visitor = ComplexTypeVisitor(fory) - for index, key in enumerate(self._field_names): - serializer = infer_field(key, self._type_hints[key], visitor, types_path=[]) - self._serializers[index] = serializer + self._serializers = serializers or [None] * len(self._field_names) + if serializers is None: + visitor = ComplexTypeVisitor(fory) + for index, key in enumerate(self._field_names): + serializer = infer_field(key, self._type_hints[key], visitor, types_path=[]) + self._serializers[index] = serializer self._serializers, self._field_names = _sort_fields(fory.type_resolver, self._field_names, self._serializers) self._hash = 0 # Will be computed on first xwrite/xread self._generated_xwrite_method = self._gen_xwrite_method() @@ -443,13 +391,13 @@ class DataClassSerializer(Serializer): context["_field_names"] = self._field_names context["_type_hints"] = self._type_hints context["_serializers"] = self._serializers - # Compute hash at generation time since we're in xlang mode - if self._hash == 0: - self._hash = _get_hash(self.fory, self._field_names, self._type_hints) stmts = [ f'"""xwrite method for {self.type_}"""', - f"{buffer}.write_int32({self._hash})", ] + # Compute hash at generation time since we're in xlang mode + if self._hash == 0: + self._hash = _get_hash(self.fory, self._field_names, self._type_hints) + stmts.append(f"{buffer}.write_int32({self._hash})") if not self._has_slots: stmts.append(f"{value_dict} = {value}.__dict__") for index, field_name in enumerate(self._field_names): @@ -487,18 +435,28 @@ class DataClassSerializer(Serializer): context["_field_names"] = self._field_names context["_type_hints"] = self._type_hints context["_serializers"] = self._serializers + current_class_field_names = set(self._get_field_names(self.type_)) + stmts = [ + f'"""xread method for {self.type_}"""', + ] # Compute hash at generation time since we're in xlang mode if self._hash == 0: self._hash = _get_hash(self.fory, self._field_names, self._type_hints) - stmts = [ - f'"""xread method for {self.type_}"""', - f"read_hash = {buffer}.read_int32()", - f"if read_hash != {self._hash}:", - f""" raise TypeNotCompatibleError( + stmts.extend( + [ + f"read_hash = {buffer}.read_int32()", + f"if read_hash != {self._hash}:", + f""" raise TypeNotCompatibleError( f"Hash {{read_hash}} is not consistent with {self._hash} for type {self.type_}")""", - f"{obj} = {obj_class}.__new__({obj_class})", - f"{ref_resolver}.reference({obj})", - ] + ] + ) + stmts.extend( + [ + f"{obj} = {obj_class}.__new__({obj_class})", + f"{ref_resolver}.reference({obj})", + ] + ) + if not self._has_slots: stmts.append(f"{obj_dict} = {obj}.__dict__") @@ -666,12 +624,18 @@ class PyArraySerializer(CrossLanguageCompatibleSerializer): def read(self, buffer): typecode = buffer.read_string() data = buffer.read_bytes_and_size() - arr = array.array(typecode, []) + arr = array.array(typecode[0], []) # Take first character arr.frombytes(data) return arr class DynamicPyArraySerializer(Serializer): + """Serializer for dynamic Python arrays that handles any typecode.""" + + def __init__(self, fory, cls): + super().__init__(fory, cls) + self._serializer = ReduceSerializer(fory, cls) + def xwrite(self, buffer, value): itemsize, ftype, type_id = typecode_dict[value.typecode] view = memoryview(value) @@ -692,11 +656,10 @@ class DynamicPyArraySerializer(Serializer): return arr def write(self, buffer, value): - buffer.write_varuint32(PickleSerializer.PICKLE_TYPE_ID) - self.fory.handle_unsupported_write(buffer, value) + self._serializer.write(buffer, value) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + return self._serializer.read(buffer) if np: @@ -731,6 +694,7 @@ class Numpy1DArraySerializer(Serializer): super().__init__(fory, ftype) self.dtype = dtype self.itemsize, self.format, self.typecode, self.type_id = _np_dtypes_dict[self.dtype] + self._serializer = ReduceSerializer(fory, np.ndarray) def xwrite(self, buffer, value): assert value.itemsize == self.itemsize @@ -752,11 +716,10 @@ class Numpy1DArraySerializer(Serializer): return np.frombuffer(data, dtype=self.dtype) def write(self, buffer, value): - buffer.write_int8(PickleSerializer.PICKLE_TYPE_ID) - self.fory.handle_unsupported_write(buffer, value) + self._serializer.write(buffer, value) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + return self._serializer.read(buffer) class NDArraySerializer(Serializer): @@ -775,11 +738,32 @@ class NDArraySerializer(Serializer): raise NotImplementedError("Multi-dimensional array not supported currently") def write(self, buffer, value): - buffer.write_int8(PickleSerializer.PICKLE_TYPE_ID) - self.fory.handle_unsupported_write(buffer, value) + # Serialize numpy ND array using native format + dtype = value.dtype + fory = self.fory + fory.serialize_ref(buffer, dtype) + buffer.write_varuint32(len(value.shape)) + for dim in value.shape: + buffer.write_varuint32(dim) + if dtype.kind == "O": + buffer.write_varint32(len(value)) + for item in value: + fory.serialize_ref(buffer, item) + else: + data = value.tobytes() + buffer.write_bytes_and_size(data) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + fory = self.fory + dtype = fory.deserialize_ref(buffer) + ndim = buffer.read_varuint32() + shape = tuple(buffer.read_varuint32() for _ in range(ndim)) + if dtype.kind == "O": + length = buffer.read_varint32() + items = [fory.deserialize_ref(buffer) for _ in range(length)] + return np.array(items, dtype=object) + data = buffer.read_bytes_and_size() + return np.frombuffer(data, dtype=dtype).reshape(shape) class BytesSerializer(CrossLanguageCompatibleSerializer): @@ -886,39 +870,58 @@ class ReduceSerializer(CrossLanguageCompatibleSerializer): # Handle different __reduce__ return formats if isinstance(reduce_result, str): # Case 1: Just a global name (simple case) - self.fory.serialize_ref(buffer, ("global", reduce_result, None, None, None)) + reduce_data = ("global", reduce_result) elif isinstance(reduce_result, tuple): if len(reduce_result) == 2: # Case 2: (callable, args) callable_obj, args = reduce_result - self.fory.serialize_ref(buffer, ("callable", callable_obj, args, None, None)) + reduce_data = ("callable", callable_obj, args) elif len(reduce_result) == 3: # Case 3: (callable, args, state) callable_obj, args, state = reduce_result - self.fory.serialize_ref(buffer, ("callable", callable_obj, args, state, None)) + reduce_data = ("callable", callable_obj, args, state) elif len(reduce_result) == 4: # Case 4: (callable, args, state, listitems) callable_obj, args, state, listitems = reduce_result - self.fory.serialize_ref(buffer, ("callable", callable_obj, args, state, listitems)) + reduce_data = ("callable", callable_obj, args, state, listitems) elif len(reduce_result) == 5: # Case 5: (callable, args, state, listitems, dictitems) callable_obj, args, state, listitems, dictitems = reduce_result - self.fory.serialize_ref(buffer, ("callable", callable_obj, args, state, listitems, dictitems)) + reduce_data = ("callable", callable_obj, args, state, listitems, dictitems) else: raise ValueError(f"Invalid __reduce__ result length: {len(reduce_result)}") else: raise ValueError(f"Invalid __reduce__ result type: {type(reduce_result)}") + buffer.write_varuint32(len(reduce_data)) + fory = self.fory + for item in reduce_data: + fory.serialize_ref(buffer, item) def read(self, buffer): - reduce_data = self.fory.deserialize_ref(buffer) + reduce_data_num_items = buffer.read_varuint32() + assert reduce_data_num_items <= 6, buffer + reduce_data = [None] * 6 + fory = self.fory + for i in range(reduce_data_num_items): + reduce_data[i] = fory.deserialize_ref(buffer) if reduce_data[0] == "global": # Case 1: Global name global_name = reduce_data[1] # Import and return the global object - module_name, obj_name = global_name.rsplit(".", 1) - module = __import__(module_name, fromlist=[obj_name]) - return getattr(module, obj_name) + if "." in global_name: + module_name, obj_name = global_name.rsplit(".", 1) + module = __import__(module_name, fromlist=[obj_name]) + return getattr(module, obj_name) + else: + # Handle case where global_name doesn't contain a dot + # This might be a built-in type or a simple name + try: + import builtins + + return getattr(builtins, global_name) + except AttributeError: + raise ValueError(f"Cannot resolve global name: {global_name}") elif reduce_data[0] == "callable": # Case 2-5: Callable with args and optional state/items callable_obj = reduce_data[1] @@ -977,33 +980,41 @@ class FunctionSerializer(CrossLanguageCompatibleSerializer): def _serialize_function(self, buffer, func): """Serialize a function by capturing all its components.""" # Get function metadata - is_method = hasattr(func, "__self__") - if is_method: + instance = getattr(func, "__self__", None) + if instance is not None and not inspect.ismodule(instance): # Handle bound methods - self_obj = func.__self__ + self_obj = instance func_name = func.__name__ # Serialize as a tuple (is_method, self_obj, method_name) - buffer.write_bool(True) # is a method + buffer.write_int8(0) # is a method # For the 'self' object, we need to use fory's serialization self.fory.serialize_ref(buffer, self_obj) buffer.write_string(func_name) return + import types # Regular function or lambda code = func.__code__ name = func.__name__ - defaults = func.__defaults__ - closure = func.__closure__ - globals_dict = func.__globals__ module = func.__module__ qualname = func.__qualname__ + if "<locals>" not in qualname and module != "__main__": + buffer.write_int8(1) # Not a method + buffer.write_string(name) + buffer.write_string(module) + return + # Serialize function metadata - buffer.write_bool(False) # Not a method + buffer.write_int8(2) # Not a method buffer.write_string(name) buffer.write_string(module) buffer.write_string(qualname) + defaults = func.__defaults__ + closure = func.__closure__ + globals_dict = func.__globals__ + # Instead of trying to serialize the code object in parts, use marshal # which is specifically designed for code objects marshalled_code = marshal.dumps(code) @@ -1071,16 +1082,21 @@ class FunctionSerializer(CrossLanguageCompatibleSerializer): def _deserialize_function(self, buffer): """Deserialize a function from its components.""" - import sys # Check if it's a method - is_method = buffer.read_bool() - if is_method: + func_type_id = buffer.read_int8() + if func_type_id == 0: # Handle bound methods self_obj = self.fory.deserialize_ref(buffer) method_name = buffer.read_string() return getattr(self_obj, method_name) + if func_type_id == 1: + name = buffer.read_string() + module = buffer.read_string() + mod = importlib.import_module(module) + return getattr(mod, name) + # Regular function or lambda name = buffer.read_string() module = buffer.read_string() @@ -1128,7 +1144,7 @@ class FunctionSerializer(CrossLanguageCompatibleSerializer): # Create a globals dictionary with module's globals as the base func_globals = {} try: - mod = sys.modules.get(module) + mod = importlib.import_module(module) if mod: func_globals.update(mod.__dict__) except (KeyError, AttributeError): @@ -1156,12 +1172,10 @@ class FunctionSerializer(CrossLanguageCompatibleSerializer): return func def xwrite(self, buffer, value): - """Serialize a function for cross-language compatibility.""" - self._serialize_function(buffer, value) + raise NotImplementedError() def xread(self, buffer): - """Deserialize a function for cross-language compatibility.""" - return self._deserialize_function(buffer) + raise NotImplementedError() def write(self, buffer, value): """Serialize a function for Python-only mode.""" @@ -1172,20 +1186,56 @@ class FunctionSerializer(CrossLanguageCompatibleSerializer): return self._deserialize_function(buffer) -class PickleSerializer(Serializer): - PICKLE_TYPE_ID = 96 +class NativeFuncMethodSerializer(Serializer): + def write(self, buffer, func): + name = func.__name__ + buffer.write_string(name) + obj = getattr(func, "__self__", None) + if obj is None or inspect.ismodule(obj): + buffer.write_bool(True) + module = func.__module__ + buffer.write_string(module) + else: + buffer.write_bool(False) + self.fory.serialize_ref(buffer, obj) - def xwrite(self, buffer, value): - raise NotImplementedError + def read(self, buffer): + name = buffer.read_string() + if buffer.read_bool(): + module = buffer.read_string() + mod = importlib.import_module(module) + return getattr(mod, name) + else: + obj = self.fory.deserialize_ref(buffer) + return getattr(obj, name) - def xread(self, buffer): - raise NotImplementedError + +class MethodSerializer(Serializer): + """Serializer for bound method objects.""" + + def __init__(self, fory, cls): + super().__init__(fory, cls) + self.cls = cls def write(self, buffer, value): - self.fory.handle_unsupported_write(buffer, value) + # Serialize bound method as (instance, method_name) + instance = value.__self__ + method_name = value.__func__.__name__ + + self.fory.serialize_ref(buffer, instance) + buffer.write_string(method_name) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + instance = self.fory.deserialize_ref(buffer) + method_name = buffer.read_string() + + return getattr(instance, method_name) + + def xwrite(self, buffer, value): + return self.write(buffer, value) + + def xread(self, buffer): + return self.read(buffer) class ObjectSerializer(Serializer): @@ -1245,3 +1295,17 @@ class ComplexObjectSerializer(DataClassSerializer): stacklevel=2, ) return DataClassSerializer(fory, clz, xlang=True) + + +class UnsupportedSerializer(Serializer): + def write(self, buffer, value): + self.fory.handle_unsupported_write(value) + + def read(self, buffer): + return self.fory.handle_unsupported_read(buffer) + + def xwrite(self, buffer, value): + raise NotImplementedError(f"{self.type_} is not supported for xwrite") + + def xread(self, buffer): + raise NotImplementedError(f"{self.type_} is not supported for xread") diff --git a/python/pyfory/tests/test_serializer.py b/python/pyfory/tests/test_serializer.py index 6df8875b4..6ba56c865 100644 --- a/python/pyfory/tests/test_serializer.py +++ b/python/pyfory/tests/test_serializer.py @@ -308,37 +308,6 @@ def ser_de(fory, obj): return fory.deserialize(binary) -def test_pickle(): - buf = Buffer.allocate(32) - pickler = pickle.Pickler(buf) - pickler.dump(b"abc") - buf.write_int32(-1) - pickler.dump("abcd") - assert buf.writer_index - 4 == len(pickle.dumps(b"abc")) + len(pickle.dumps("abcd")) - print(f"writer_index {buf.writer_index}") - - bytes_io_ = io.BytesIO(buf) - unpickler = pickle.Unpickler(bytes_io_) - assert unpickler.load() == b"abc" - bytes_io_.seek(bytes_io_.tell() + 4) - assert unpickler.load() == "abcd" - print(f"reader_index {buf.reader_index} {bytes_io_.tell()}") - - if pa: - pa_buf = pa.BufferReader(buf) - unpickler = pickle.Unpickler(pa_buf) - assert unpickler.load() == b"abc" - pa_buf.seek(pa_buf.tell() + 4) - assert unpickler.load() == "abcd" - print(f"reader_index {buf.reader_index} {pa_buf.tell()} {buf.reader_index}") - - unpickler = pickle.Unpickler(buf) - assert unpickler.load() == b"abc" - buf.reader_index = buf.reader_index + 4 - assert unpickler.load() == "abcd" - print(f"reader_index {buf.reader_index}") - - @require_pyarrow def test_serialize_arrow(): record_batch = create_record_batch(10000) @@ -454,13 +423,16 @@ def test_register_type(): assert isinstance(fory.deserialize(fory.serialize(A.B.C())), A.B.C) -def test_pickle_fallback(): +def test_np_types(): fory = Fory(language=Language.PYTHON, ref_tracking=True, require_type_registration=False) o1 = [1, True, np.dtype(np.int32)] data1 = fory.serialize(o1) new_o1 = fory.deserialize(data1) assert o1 == new_o1 + +def test_pandas_dataframe(): + fory = Fory(language=Language.PYTHON, ref_tracking=True, require_type_registration=False) df = pd.DataFrame({"a": list(range(10))}) df2 = fory.deserialize(fory.serialize(df)) assert df2.equals(df) @@ -545,19 +517,6 @@ def test_duplicate_serialize(): assert ser_de(fory, EnumClass.E4) == EnumClass.E4 -@dataclass(unsafe_hash=True) -class CacheClass1: - f1: int - - -def test_cache_serializer(): - fory = Fory(language=Language.PYTHON, ref_tracking=True) - fory.register_type(CacheClass1, serializer=pyfory.PickleStrongCacheSerializer(fory)) - assert ser_de(fory, CacheClass1(1)) == CacheClass1(1) - fory.register_type(CacheClass1, serializer=pyfory.PickleCacheSerializer(fory)) - assert ser_de(fory, CacheClass1(1)) == CacheClass1(1) - - def test_pandas_range_index(): fory = Fory(language=Language.PYTHON, ref_tracking=True, require_type_registration=False) fory.register_type(pd.RangeIndex, serializer=pyfory.PandasRangeIndexSerializer(fory)) diff --git a/python/pyproject.toml b/python/pyproject.toml index 58b81674e..81f9253db 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -50,7 +50,6 @@ classifiers = [ ] keywords = ["fory", "serialization", "multi-language", "fast", "row-format", "jit", "codegen", "polymorphic", "zero-copy"] dependencies = [ - "cloudpickle", ] [project.optional-dependencies] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
