This is an automated email from the ASF dual-hosted git repository. chaokunyang pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push: new 1305a6e9 refactor(python): Unify DataClassSerializer and ComplexObjectSerializer (#2389) 1305a6e9 is described below commit 1305a6e9bd6331a78faea66b3c7bf7a53e84697d Author: Emre Şafak <3928300+esa...@users.noreply.github.com> AuthorDate: Sun Jul 6 04:58:22 2025 -0400 refactor(python): Unify DataClassSerializer and ComplexObjectSerializer (#2389) ## What does this PR do? - Extended DataClassSerializer to support both Python native and xlang serialization modes. - ComplexObjectSerializer is now a deprecated alias for DataClassSerializer (in xlang mode). - Added tests for DataClassSerializer in xlang mode. ## Related issues Completes #2157 ## Does this PR introduce any user-facing change? Not yet: we deprecate `ComplexObjectSerializer` in order give users time to migrate. - [ ] Does this PR introduce any public API change? - [ ] Does this PR introduce any binary protocol compatibility change? --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: Emre Şafak <esa...@users.noreply.github.com> --- python/pyfory/_struct.py | 92 +++++++++++--------------------------- python/pyfory/serializer.py | 86 ++++++++++++++++++++++++++++++----- python/pyfory/tests/test_struct.py | 73 ++++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+), 78 deletions(-) diff --git a/python/pyfory/_struct.py b/python/pyfory/_struct.py index 92e33b53..0746b66f 100644 --- a/python/pyfory/_struct.py +++ b/python/pyfory/_struct.py @@ -18,16 +18,8 @@ import datetime import enum import logging -import typing - -from pyfory.buffer import Buffer -from pyfory.error import TypeNotCompatibleError -from pyfory.serializer import ( - ListSerializer, - MapSerializer, - PickleSerializer, - Serializer, -) + +from pyfory.serializer import Serializer from pyfory.type import ( TypeVisitor, infer_field, @@ -81,11 +73,15 @@ class ComplexTypeVisitor(TypeVisitor): self.fory = fory def visit_list(self, field_name, elem_type, types_path=None): + from pyfory.serializer import ListSerializer # Local import + # Infer type recursively for type such as List[Dict[str, str]] elem_serializer = infer_field("item", elem_type, self, types_path=types_path) return ListSerializer(self.fory, list, elem_serializer) def visit_dict(self, field_name, key_type, value_type, types_path=None): + from pyfory.serializer import MapSerializer # Local import + # Infer type recursively for type such as Dict[str, Dict[str, str]] key_serializer = infer_field("key", key_type, self, types_path=types_path) value_serializer = infer_field("value", value_type, self, types_path=types_path) @@ -95,6 +91,8 @@ class ComplexTypeVisitor(TypeVisitor): return None def visit_other(self, field_name, type_, types_path=None): + from pyfory.serializer import PickleSerializer # Local import + if is_subclass(type_, enum.Enum): return self.fory.type_resolver.get_serializer(type_) if type_ not in basic_types and not is_py_array_type(type_): @@ -174,65 +172,23 @@ def _sort_fields(type_resolver, field_names, serializers): return [t[1] for t in all_types], [t[2] for t in all_types] -class ComplexObjectSerializer(Serializer): - def __init__(self, fory, clz): - super().__init__(fory, clz) - self._type_hints = typing.get_type_hints(clz) - self._field_names = sorted(self._type_hints.keys()) - self._serializers = [None] * len(self._field_names) - visitor = ComplexTypeVisitor(fory) - for index, key in enumerate(self._field_names): - serializer = infer_field(key, self._type_hints[key], visitor, types_path=[]) - self._serializers[index] = serializer - self._serializers, self._field_names = _sort_fields( - fory.type_resolver, self._field_names, self._serializers - ) +import warnings - from pyfory import Language +# Removed DataClassSerializer from here to break the cycle for the alias target. +# Other serializers like ListSerializer, MapSerializer, Serializer are still imported at the top. - if self.fory.language == Language.PYTHON: - logger.warning( - "Type of class %s shouldn't be serialized using cross-language " - "serializer", - clz, - ) - self._hash = 0 - - def write(self, buffer, value): - return self.xwrite(buffer, value) - - def read(self, buffer): - return self.xread(buffer) - - def xwrite(self, buffer: Buffer, value): - if self._hash == 0: - self._hash = _get_hash(self.fory, self._field_names, self._type_hints) - buffer.write_int32(self._hash) - for index, field_name in enumerate(self._field_names): - field_value = getattr(value, field_name) - serializer = self._serializers[index] - self.fory.xserialize_ref(buffer, field_value, serializer=serializer) - - def xread(self, buffer): - if self._hash == 0: - self._hash = _get_hash(self.fory, self._field_names, self._type_hints) - hash_ = buffer.read_int32() - if hash_ != self._hash: - raise TypeNotCompatibleError( - f"Hash {hash_} is not consistent with {self._hash} " - f"for type {self.type_}", - ) - obj = self.type_.__new__(self.type_) - self.fory.ref_resolver.reference(obj) - for index, field_name in enumerate(self._field_names): - serializer = self._serializers[index] - field_value = self.fory.xdeserialize_ref(buffer, serializer=serializer) - setattr( - obj, - field_name, - field_value, - ) - return obj + +class ComplexObjectSerializer(Serializer): + def __new__(cls, fory, clz): + from pyfory.serializer import DataClassSerializer # Local import + + warnings.warn( + "`ComplexObjectSerializer` is deprecated and will be removed in a future version. " + "Use `DataClassSerializer(fory, clz, xlang=True)` instead.", + DeprecationWarning, + stacklevel=2, + ) + return DataClassSerializer(fory, clz, xlang=True) class StructHashVisitor(TypeVisitor): @@ -265,6 +221,8 @@ class StructHashVisitor(TypeVisitor): self._hash = self._compute_field_hash(self._hash, hash_value) def visit_other(self, field_name, type_, types_path=None): + from pyfory.serializer import PickleSerializer # Local import + typeinfo = self.fory.type_resolver.get_typeinfo(type_, create=False) if typeinfo is None: id_ = 0 diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 6437c62e..fe466ced 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -32,6 +32,7 @@ from pyfory.codegen import ( from pyfory.error import TypeNotCompatibleError from pyfory.lib.collection import WeakIdentityKeyDictionary from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG +from pyfory import Language try: import numpy as np @@ -109,6 +110,7 @@ from pyfory.type import ( Float32NDArrayType, Float64NDArrayType, TypeId, + infer_field, # Added infer_field ) @@ -283,22 +285,52 @@ _ENABLE_FORY_PYTHON_JIT = os.environ.get("ENABLE_FORY_PYTHON_JIT", "True").lower "1", ) +# Moved from L32 to here, after all Serializer base classes and specific serializers +# like ListSerializer, MapSerializer, PickleSerializer are defined or imported +# and before DataClassSerializer which uses ComplexTypeVisitor from _struct. +from pyfory._struct import _get_hash, _sort_fields, ComplexTypeVisitor + class DataClassSerializer(Serializer): - def __init__(self, fory, clz: type): + def __init__(self, fory, clz: type, xlang: bool = False): super().__init__(fory, clz) + self._xlang = xlang # This will get superclass type hints too. self._type_hints = typing.get_type_hints(clz) self._field_names = sorted(self._type_hints.keys()) self._has_slots = hasattr(clz, "__slots__") - # TODO compute hash - self._hash = len(self._field_names) - self._generated_write_method = self._gen_write_method() - self._generated_read_method = self._gen_read_method() - if _ENABLE_FORY_PYTHON_JIT: - # don't use `__slots__`, which will make instance method readonly - self.write = self._gen_write_method() - self.read = self._gen_read_method() + + if self._xlang: + self._serializers = [None] * len(self._field_names) + visitor = ComplexTypeVisitor(fory) + for index, key in enumerate(self._field_names): + # Changed from self.fory.infer_field to infer_field + serializer = infer_field( + key, self._type_hints[key], visitor, types_path=[] + ) + self._serializers[index] = serializer + self._serializers, self._field_names = _sort_fields( + fory.type_resolver, self._field_names, self._serializers + ) + self._hash = 0 # Will be computed on first xwrite/xread + if self.fory.language == Language.PYTHON: + import logging # Import here to avoid circular dependency + + logger = logging.getLogger(__name__) + logger.warning( + "Type of class %s shouldn't be serialized using cross-language " + "serializer", + clz, + ) + else: + # TODO compute hash for non-xlang mode more robustly + self._hash = len(self._field_names) + self._generated_write_method = self._gen_write_method() + self._generated_read_method = self._gen_read_method() + if _ENABLE_FORY_PYTHON_JIT: + # don't use `__slots__`, which will make instance method readonly + self.write = self._generated_write_method + self.read = self._generated_read_method def _gen_write_method(self): context = {} @@ -415,10 +447,42 @@ class DataClassSerializer(Serializer): return obj def xwrite(self, buffer: Buffer, value): - raise NotImplementedError + if not self._xlang: + raise TypeError( + "xwrite can only be called when DataClassSerializer is in xlang mode" + ) + if self._hash == 0: + self._hash = _get_hash(self.fory, self._field_names, self._type_hints) + buffer.write_int32(self._hash) + for index, field_name in enumerate(self._field_names): + field_value = getattr(value, field_name) + serializer = self._serializers[index] + self.fory.xserialize_ref(buffer, field_value, serializer=serializer) def xread(self, buffer): - raise NotImplementedError + if not self._xlang: + raise TypeError( + "xread can only be called when DataClassSerializer is in xlang mode" + ) + if self._hash == 0: + self._hash = _get_hash(self.fory, self._field_names, self._type_hints) + hash_ = buffer.read_int32() + if hash_ != self._hash: + raise TypeNotCompatibleError( + f"Hash {hash_} is not consistent with {self._hash} " + f"for type {self.type_}", + ) + obj = self.type_.__new__(self.type_) + self.fory.ref_resolver.reference(obj) + for index, field_name in enumerate(self._field_names): + serializer = self._serializers[index] + field_value = self.fory.xdeserialize_ref(buffer, serializer=serializer) + setattr( + obj, + field_name, + field_value, + ) + return obj # Use numpy array or python array module. diff --git a/python/pyfory/tests/test_struct.py b/python/pyfory/tests/test_struct.py index 7e65acdb..5be25cb4 100644 --- a/python/pyfory/tests/test_struct.py +++ b/python/pyfory/tests/test_struct.py @@ -113,3 +113,76 @@ def test_inheritance(): type(fory.type_resolver.get_serializer(ChildClass1)) == pyfory.DataClassSerializer ) + + +@dataclass +class TestDataClassObject: + f_int: int + f_float: float + f_str: str + f_bool: bool + f_list: List[int] + f_dict: Dict[str, float] + f_any: Any + f_complex: ComplexObject = None + + +def test_data_class_serializer_xlang(): + fory = Fory(language=Language.XLANG, ref_tracking=True) + fory.register_type(ComplexObject, typename="example.ComplexObject") + fory.register_type(TestDataClassObject, typename="example.TestDataClassObject") + + complex_data = ComplexObject( + f1="nested_str", + f5=100, + f8=3.14, + f10={10: 1.0, 20: 2.0}, + ) + obj_original = TestDataClassObject( + f_int=123, + f_float=45.67, + f_str="hello xlang", + f_bool=True, + f_list=[1, 2, 3, 4, 5], + f_dict={"a": 1.1, "b": 2.2}, + f_any="any_value", + f_complex=complex_data, + ) + + obj_deserialized = ser_de(fory, obj_original) + + assert obj_deserialized == obj_original + assert obj_deserialized.f_int == obj_original.f_int + assert obj_deserialized.f_float == obj_original.f_float + assert obj_deserialized.f_str == obj_original.f_str + assert obj_deserialized.f_bool == obj_original.f_bool + assert obj_deserialized.f_list == obj_original.f_list + assert obj_deserialized.f_dict == obj_original.f_dict + assert obj_deserialized.f_any == obj_original.f_any + assert obj_deserialized.f_complex == obj_original.f_complex + assert ( + type(fory.type_resolver.get_serializer(TestDataClassObject)) + == pyfory.DataClassSerializer + ) + # Ensure it's using xlang mode (indirectly, by checking no JIT methods if possible, + # or by ensuring it was registered with _register_xtype which now uses DataClassSerializer(xlang=True) + # For now, the registration path check is implicit via Language.XLANG usage. + # We can also check if the hash is non-zero if it was computed, + # or if _serializers attribute exists. + serializer_instance = fory.type_resolver.get_serializer(TestDataClassObject) + assert hasattr(serializer_instance, "_serializers") # xlang mode creates this + assert serializer_instance._xlang is True + + # Test with None for complex field + obj_with_none_complex = TestDataClassObject( + f_int=789, + f_float=12.34, + f_str="another string", + f_bool=False, + f_list=[10, 20], + f_dict={"x": 7.7, "y": 8.8}, + f_any=None, + f_complex=None, + ) + obj_deserialized_none = ser_de(fory, obj_with_none_complex) + assert obj_deserialized_none == obj_with_none_complex --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@fory.apache.org For additional commands, e-mail: commits-h...@fory.apache.org