This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git


The following commit(s) were added to refs/heads/main by this push:
     new 1305a6e9 refactor(python): Unify DataClassSerializer and 
ComplexObjectSerializer (#2389)
1305a6e9 is described below

commit 1305a6e9bd6331a78faea66b3c7bf7a53e84697d
Author: Emre Şafak <3928300+esa...@users.noreply.github.com>
AuthorDate: Sun Jul 6 04:58:22 2025 -0400

    refactor(python): Unify DataClassSerializer and ComplexObjectSerializer 
(#2389)
    
    ## What does this PR do?
    
    - Extended DataClassSerializer to support both Python native and xlang
    serialization modes.
    - ComplexObjectSerializer is now a deprecated alias for
    DataClassSerializer (in xlang mode).
    - Added tests for DataClassSerializer in xlang mode.
    
    ## Related issues
    
    Completes #2157
    
    ## Does this PR introduce any user-facing change?
    
    Not yet: we deprecate `ComplexObjectSerializer` in order give users time
    to migrate.
    
    - [ ] Does this PR introduce any public API change?
    - [ ] Does this PR introduce any binary protocol compatibility change?
    
    ---------
    
    Co-authored-by: google-labs-jules[bot] 
<161369871+google-labs-jules[bot]@users.noreply.github.com>
    Co-authored-by: Emre Şafak <esa...@users.noreply.github.com>
---
 python/pyfory/_struct.py           | 92 +++++++++++---------------------------
 python/pyfory/serializer.py        | 86 ++++++++++++++++++++++++++++++-----
 python/pyfory/tests/test_struct.py | 73 ++++++++++++++++++++++++++++++
 3 files changed, 173 insertions(+), 78 deletions(-)

diff --git a/python/pyfory/_struct.py b/python/pyfory/_struct.py
index 92e33b53..0746b66f 100644
--- a/python/pyfory/_struct.py
+++ b/python/pyfory/_struct.py
@@ -18,16 +18,8 @@
 import datetime
 import enum
 import logging
-import typing
-
-from pyfory.buffer import Buffer
-from pyfory.error import TypeNotCompatibleError
-from pyfory.serializer import (
-    ListSerializer,
-    MapSerializer,
-    PickleSerializer,
-    Serializer,
-)
+
+from pyfory.serializer import Serializer
 from pyfory.type import (
     TypeVisitor,
     infer_field,
@@ -81,11 +73,15 @@ class ComplexTypeVisitor(TypeVisitor):
         self.fory = fory
 
     def visit_list(self, field_name, elem_type, types_path=None):
+        from pyfory.serializer import ListSerializer  # Local import
+
         # Infer type recursively for type such as List[Dict[str, str]]
         elem_serializer = infer_field("item", elem_type, self, 
types_path=types_path)
         return ListSerializer(self.fory, list, elem_serializer)
 
     def visit_dict(self, field_name, key_type, value_type, types_path=None):
+        from pyfory.serializer import MapSerializer  # Local import
+
         # Infer type recursively for type such as Dict[str, Dict[str, str]]
         key_serializer = infer_field("key", key_type, self, 
types_path=types_path)
         value_serializer = infer_field("value", value_type, self, 
types_path=types_path)
@@ -95,6 +91,8 @@ class ComplexTypeVisitor(TypeVisitor):
         return None
 
     def visit_other(self, field_name, type_, types_path=None):
+        from pyfory.serializer import PickleSerializer  # Local import
+
         if is_subclass(type_, enum.Enum):
             return self.fory.type_resolver.get_serializer(type_)
         if type_ not in basic_types and not is_py_array_type(type_):
@@ -174,65 +172,23 @@ def _sort_fields(type_resolver, field_names, serializers):
     return [t[1] for t in all_types], [t[2] for t in all_types]
 
 
-class ComplexObjectSerializer(Serializer):
-    def __init__(self, fory, clz):
-        super().__init__(fory, clz)
-        self._type_hints = typing.get_type_hints(clz)
-        self._field_names = sorted(self._type_hints.keys())
-        self._serializers = [None] * len(self._field_names)
-        visitor = ComplexTypeVisitor(fory)
-        for index, key in enumerate(self._field_names):
-            serializer = infer_field(key, self._type_hints[key], visitor, 
types_path=[])
-            self._serializers[index] = serializer
-        self._serializers, self._field_names = _sort_fields(
-            fory.type_resolver, self._field_names, self._serializers
-        )
+import warnings
 
-        from pyfory import Language
+# Removed DataClassSerializer from here to break the cycle for the alias 
target.
+# Other serializers like ListSerializer, MapSerializer, Serializer are still 
imported at the top.
 
-        if self.fory.language == Language.PYTHON:
-            logger.warning(
-                "Type of class %s shouldn't be serialized using cross-language 
"
-                "serializer",
-                clz,
-            )
-        self._hash = 0
-
-    def write(self, buffer, value):
-        return self.xwrite(buffer, value)
-
-    def read(self, buffer):
-        return self.xread(buffer)
-
-    def xwrite(self, buffer: Buffer, value):
-        if self._hash == 0:
-            self._hash = _get_hash(self.fory, self._field_names, 
self._type_hints)
-        buffer.write_int32(self._hash)
-        for index, field_name in enumerate(self._field_names):
-            field_value = getattr(value, field_name)
-            serializer = self._serializers[index]
-            self.fory.xserialize_ref(buffer, field_value, 
serializer=serializer)
-
-    def xread(self, buffer):
-        if self._hash == 0:
-            self._hash = _get_hash(self.fory, self._field_names, 
self._type_hints)
-        hash_ = buffer.read_int32()
-        if hash_ != self._hash:
-            raise TypeNotCompatibleError(
-                f"Hash {hash_} is not consistent with {self._hash} "
-                f"for type {self.type_}",
-            )
-        obj = self.type_.__new__(self.type_)
-        self.fory.ref_resolver.reference(obj)
-        for index, field_name in enumerate(self._field_names):
-            serializer = self._serializers[index]
-            field_value = self.fory.xdeserialize_ref(buffer, 
serializer=serializer)
-            setattr(
-                obj,
-                field_name,
-                field_value,
-            )
-        return obj
+
+class ComplexObjectSerializer(Serializer):
+    def __new__(cls, fory, clz):
+        from pyfory.serializer import DataClassSerializer  # Local import
+
+        warnings.warn(
+            "`ComplexObjectSerializer` is deprecated and will be removed in a 
future version. "
+            "Use `DataClassSerializer(fory, clz, xlang=True)` instead.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        return DataClassSerializer(fory, clz, xlang=True)
 
 
 class StructHashVisitor(TypeVisitor):
@@ -265,6 +221,8 @@ class StructHashVisitor(TypeVisitor):
         self._hash = self._compute_field_hash(self._hash, hash_value)
 
     def visit_other(self, field_name, type_, types_path=None):
+        from pyfory.serializer import PickleSerializer  # Local import
+
         typeinfo = self.fory.type_resolver.get_typeinfo(type_, create=False)
         if typeinfo is None:
             id_ = 0
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index 6437c62e..fe466ced 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -32,6 +32,7 @@ from pyfory.codegen import (
 from pyfory.error import TypeNotCompatibleError
 from pyfory.lib.collection import WeakIdentityKeyDictionary
 from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
+from pyfory import Language
 
 try:
     import numpy as np
@@ -109,6 +110,7 @@ from pyfory.type import (
     Float32NDArrayType,
     Float64NDArrayType,
     TypeId,
+    infer_field,  # Added infer_field
 )
 
 
@@ -283,22 +285,52 @@ _ENABLE_FORY_PYTHON_JIT = 
os.environ.get("ENABLE_FORY_PYTHON_JIT", "True").lower
     "1",
 )
 
+# Moved from L32 to here, after all Serializer base classes and specific 
serializers
+# like ListSerializer, MapSerializer, PickleSerializer are defined or imported
+# and before DataClassSerializer which uses ComplexTypeVisitor from _struct.
+from pyfory._struct import _get_hash, _sort_fields, ComplexTypeVisitor
+
 
 class DataClassSerializer(Serializer):
-    def __init__(self, fory, clz: type):
+    def __init__(self, fory, clz: type, xlang: bool = False):
         super().__init__(fory, clz)
+        self._xlang = xlang
         # This will get superclass type hints too.
         self._type_hints = typing.get_type_hints(clz)
         self._field_names = sorted(self._type_hints.keys())
         self._has_slots = hasattr(clz, "__slots__")
-        # TODO compute hash
-        self._hash = len(self._field_names)
-        self._generated_write_method = self._gen_write_method()
-        self._generated_read_method = self._gen_read_method()
-        if _ENABLE_FORY_PYTHON_JIT:
-            # don't use `__slots__`, which will make instance method readonly
-            self.write = self._gen_write_method()
-            self.read = self._gen_read_method()
+
+        if self._xlang:
+            self._serializers = [None] * len(self._field_names)
+            visitor = ComplexTypeVisitor(fory)
+            for index, key in enumerate(self._field_names):
+                # Changed from self.fory.infer_field to infer_field
+                serializer = infer_field(
+                    key, self._type_hints[key], visitor, types_path=[]
+                )
+                self._serializers[index] = serializer
+            self._serializers, self._field_names = _sort_fields(
+                fory.type_resolver, self._field_names, self._serializers
+            )
+            self._hash = 0  # Will be computed on first xwrite/xread
+            if self.fory.language == Language.PYTHON:
+                import logging  # Import here to avoid circular dependency
+
+                logger = logging.getLogger(__name__)
+                logger.warning(
+                    "Type of class %s shouldn't be serialized using 
cross-language "
+                    "serializer",
+                    clz,
+                )
+        else:
+            # TODO compute hash for non-xlang mode more robustly
+            self._hash = len(self._field_names)
+            self._generated_write_method = self._gen_write_method()
+            self._generated_read_method = self._gen_read_method()
+            if _ENABLE_FORY_PYTHON_JIT:
+                # don't use `__slots__`, which will make instance method 
readonly
+                self.write = self._generated_write_method
+                self.read = self._generated_read_method
 
     def _gen_write_method(self):
         context = {}
@@ -415,10 +447,42 @@ class DataClassSerializer(Serializer):
         return obj
 
     def xwrite(self, buffer: Buffer, value):
-        raise NotImplementedError
+        if not self._xlang:
+            raise TypeError(
+                "xwrite can only be called when DataClassSerializer is in 
xlang mode"
+            )
+        if self._hash == 0:
+            self._hash = _get_hash(self.fory, self._field_names, 
self._type_hints)
+        buffer.write_int32(self._hash)
+        for index, field_name in enumerate(self._field_names):
+            field_value = getattr(value, field_name)
+            serializer = self._serializers[index]
+            self.fory.xserialize_ref(buffer, field_value, 
serializer=serializer)
 
     def xread(self, buffer):
-        raise NotImplementedError
+        if not self._xlang:
+            raise TypeError(
+                "xread can only be called when DataClassSerializer is in xlang 
mode"
+            )
+        if self._hash == 0:
+            self._hash = _get_hash(self.fory, self._field_names, 
self._type_hints)
+        hash_ = buffer.read_int32()
+        if hash_ != self._hash:
+            raise TypeNotCompatibleError(
+                f"Hash {hash_} is not consistent with {self._hash} "
+                f"for type {self.type_}",
+            )
+        obj = self.type_.__new__(self.type_)
+        self.fory.ref_resolver.reference(obj)
+        for index, field_name in enumerate(self._field_names):
+            serializer = self._serializers[index]
+            field_value = self.fory.xdeserialize_ref(buffer, 
serializer=serializer)
+            setattr(
+                obj,
+                field_name,
+                field_value,
+            )
+        return obj
 
 
 # Use numpy array or python array module.
diff --git a/python/pyfory/tests/test_struct.py 
b/python/pyfory/tests/test_struct.py
index 7e65acdb..5be25cb4 100644
--- a/python/pyfory/tests/test_struct.py
+++ b/python/pyfory/tests/test_struct.py
@@ -113,3 +113,76 @@ def test_inheritance():
         type(fory.type_resolver.get_serializer(ChildClass1))
         == pyfory.DataClassSerializer
     )
+
+
+@dataclass
+class TestDataClassObject:
+    f_int: int
+    f_float: float
+    f_str: str
+    f_bool: bool
+    f_list: List[int]
+    f_dict: Dict[str, float]
+    f_any: Any
+    f_complex: ComplexObject = None
+
+
+def test_data_class_serializer_xlang():
+    fory = Fory(language=Language.XLANG, ref_tracking=True)
+    fory.register_type(ComplexObject, typename="example.ComplexObject")
+    fory.register_type(TestDataClassObject, 
typename="example.TestDataClassObject")
+
+    complex_data = ComplexObject(
+        f1="nested_str",
+        f5=100,
+        f8=3.14,
+        f10={10: 1.0, 20: 2.0},
+    )
+    obj_original = TestDataClassObject(
+        f_int=123,
+        f_float=45.67,
+        f_str="hello xlang",
+        f_bool=True,
+        f_list=[1, 2, 3, 4, 5],
+        f_dict={"a": 1.1, "b": 2.2},
+        f_any="any_value",
+        f_complex=complex_data,
+    )
+
+    obj_deserialized = ser_de(fory, obj_original)
+
+    assert obj_deserialized == obj_original
+    assert obj_deserialized.f_int == obj_original.f_int
+    assert obj_deserialized.f_float == obj_original.f_float
+    assert obj_deserialized.f_str == obj_original.f_str
+    assert obj_deserialized.f_bool == obj_original.f_bool
+    assert obj_deserialized.f_list == obj_original.f_list
+    assert obj_deserialized.f_dict == obj_original.f_dict
+    assert obj_deserialized.f_any == obj_original.f_any
+    assert obj_deserialized.f_complex == obj_original.f_complex
+    assert (
+        type(fory.type_resolver.get_serializer(TestDataClassObject))
+        == pyfory.DataClassSerializer
+    )
+    # Ensure it's using xlang mode (indirectly, by checking no JIT methods if 
possible,
+    # or by ensuring it was registered with _register_xtype which now uses 
DataClassSerializer(xlang=True)
+    # For now, the registration path check is implicit via Language.XLANG 
usage.
+    # We can also check if the hash is non-zero if it was computed,
+    # or if _serializers attribute exists.
+    serializer_instance = 
fory.type_resolver.get_serializer(TestDataClassObject)
+    assert hasattr(serializer_instance, "_serializers")  # xlang mode creates 
this
+    assert serializer_instance._xlang is True
+
+    # Test with None for complex field
+    obj_with_none_complex = TestDataClassObject(
+        f_int=789,
+        f_float=12.34,
+        f_str="another string",
+        f_bool=False,
+        f_list=[10, 20],
+        f_dict={"x": 7.7, "y": 8.8},
+        f_any=None,
+        f_complex=None,
+    )
+    obj_deserialized_none = ser_de(fory, obj_with_none_complex)
+    assert obj_deserialized_none == obj_with_none_complex


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@fory.apache.org
For additional commands, e-mail: commits-h...@fory.apache.org

Reply via email to