This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git


The following commit(s) were added to refs/heads/main by this push:
     new b6963cbee feat(python): add Union type support for xlang serialization 
(#3059)
b6963cbee is described below

commit b6963cbee5851314e8c97de1a17b1ff754455886
Author: zhan7236 <[email protected]>
AuthorDate: Thu Dec 18 10:56:52 2025 +0800

    feat(python): add Union type support for xlang serialization (#3059)
    
    ## Why?
    
    Python Fory lacks support for `typing.Union` types, which prevents
    cross-language serialization compatibility with other languages like
    C++, Java, and Rust that already support union/variant types. This
    limitation was reported in issue #3029.
    
    Without Union type support, Python users cannot:
    - Serialize union types to communicate with other languages
    - Use type annotations like `Union[int, str]` in dataclass fields
    - Leverage the full xlang serialization capabilities
    
    ## What does this PR do?
    
    This PR implements complete Union type support for Python Fory by:
    
    1. **Added TypeId constants** to `python/pyfory/type.py`:
       - `TypeId.UNION = 38` - for tagged union types
       - `TypeId.NONE = 39` - for empty/unit values
    
    2. **Implemented UnionSerializer** in `python/pyfory/_serializer.py`:
       - Serializes Union types by writing variant index + value
    - Supports both Python mode (no type info) and xlang mode (with type
    info)
       - Properly dispatches to alternative type serializers
       - Handles Optional types (`Union[T, None]`) by filtering out NoneType
    
    3. **Updated type detection** in `python/pyfory/_registry.py`:
    - Automatically detects `typing.Union` types using `typing.get_origin()`
       - Creates appropriate UnionSerializer instances
       - Handles edge cases (empty unions, single-type unions)
    
    4. **Exported UnionSerializer** from `python/pyfory/serializer.py`:
       - Available in both Cython and pure Python modes
    
    5. **Added comprehensive tests** in `python/pyfory/tests/test_union.py`:
       - Basic types (int, str, float)
       - Collection types (list, dict)
       - Optional types
       - Nested unions in dataclasses
       - Cross-language serialization
    
    The implementation follows the same binary protocol as C++ variant
    serialization:
    1. Write variant index (varuint32)
    2. In xlang mode, write type info for the active alternative
    3. Write the value using the alternative's serializer
    
    ## Related issues
    
    - Fixes #3029
    
    ## Does this PR introduce any user-facing change?
    
    - [x] Does this PR introduce any public API change?
    - **Yes**: Adds support for `typing.Union` types in serialization. Users
    can now use Union types in their code and they will be automatically
    serialized/deserialized.
      - This is backward compatible - existing code continues to work.
    
    - [x] Does this PR introduce any binary protocol compatibility change?
    - **No**: The implementation uses the existing TypeId.UNION (38) and
    TypeId.NONE (39) defined in the specification.
    - The binary protocol follows the same format as C++/Java/Rust
    implementations.
      - Existing serialized data is not affected.
    
    ## Benchmark
    
    This PR does not have a performance impact on existing functionality:
    - Union type handling only activates when Union types are actually used
    - The implementation uses efficient type checking via `isinstance()`
    - Variant index is encoded using varuint32 (compact encoding)
    - No changes to existing serializers or hot paths
    
    For Union types specifically, the overhead is minimal:
    - One varuint32 write/read for the variant index
    - One type dispatch (same as normal polymorphic serialization)
    - No additional allocations or copies
---
 python/pyfory/_registry.py        |  19 ++++
 python/pyfory/_serializer.py      |  98 ++++++++++++++++++
 python/pyfory/serializer.py       |   2 +
 python/pyfory/tests/test_union.py | 202 ++++++++++++++++++++++++++++++++++++++
 python/pyfory/type.py             |   7 ++
 5 files changed, 328 insertions(+)

diff --git a/python/pyfory/_registry.py b/python/pyfory/_registry.py
index f37c17830..5efa3a58b 100644
--- a/python/pyfory/_registry.py
+++ b/python/pyfory/_registry.py
@@ -23,6 +23,7 @@ import functools
 import logging
 import pickle
 import types
+import typing
 from typing import TypeVar, Union
 from enum import Enum
 
@@ -67,6 +68,7 @@ from pyfory.serializer import (
     UnsupportedSerializer,
     NativeFuncMethodSerializer,
     PickleBufferSerializer,
+    UnionSerializer,
 )
 from pyfory.meta.metastring import MetaStringEncoder, MetaStringDecoder
 from pyfory.meta.meta_compressor import DeflaterMetaCompressor
@@ -540,6 +542,23 @@ class TypeResolver:
         return typeinfo
 
     def _create_serializer(self, cls):
+        # Check if it's a Union type first
+        origin = typing.get_origin(cls) if hasattr(typing, "get_origin") else 
getattr(cls, "__origin__", None)
+        if origin is typing.Union:
+            # Extract alternative types from Union
+            args = typing.get_args(cls) if hasattr(typing, "get_args") else 
getattr(cls, "__args__", ())
+            # Filter out NoneType as it's handled separately via ref tracking
+            alternative_types = [arg for arg in args if arg is not type(None)]
+            if len(alternative_types) == 0:
+                # Union with only None is equivalent to NoneType
+                return NoneSerializer(self.fory)
+            elif len(alternative_types) == 1:
+                # Optional[T] should use the serializer for T
+                return self.get_serializer(alternative_types[0])
+            else:
+                # Real union with multiple alternatives
+                return UnionSerializer(self.fory, cls, alternative_types)
+
         for clz in cls.__mro__:
             type_info = self._types_info.get(clz)
             if type_info and type_info.serializer and 
type_info.serializer.support_subclass():
diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py
index 17c9d9ed2..2ce347595 100644
--- a/python/pyfory/_serializer.py
+++ b/python/pyfory/_serializer.py
@@ -749,3 +749,101 @@ class SliceSerializer(Serializer):
 
     def xread(self, buffer):
         raise NotImplementedError
+
+
+class UnionSerializer(Serializer):
+    """
+    Serializer for typing.Union types.
+
+    Serializes a Union by storing:
+    1. The index of the active alternative (as varuint32)
+    2. The value of the active alternative
+
+    This allows the deserializer to determine which alternative to use
+    and forward to the appropriate serializer.
+    """
+
+    __slots__ = ("alternative_types", "alternative_serializers", 
"type_resolver")
+
+    def __init__(self, fory, type_, alternative_types):
+        super().__init__(fory, type_)
+        self.alternative_types = alternative_types
+        self.type_resolver = fory.type_resolver
+        self.alternative_serializers = []
+        for alt_type in alternative_types:
+            serializer = fory.type_resolver.get_serializer(alt_type)
+            self.alternative_serializers.append((alt_type, serializer))
+
+    def write(self, buffer, value):
+        # Find which alternative type matches the value
+        active_index = None
+        active_serializer = None
+
+        for i, (alt_type, serializer) in 
enumerate(self.alternative_serializers):
+            if isinstance(value, alt_type):
+                active_index = i
+                active_serializer = serializer
+                break
+
+        if active_index is None:
+            raise TypeError(f"Value {value} of type {type(value)} doesn't 
match any alternative in Union{self.alternative_types}")
+
+        # Write the active variant index
+        buffer.write_varuint32(active_index)
+
+        # Write the alternative's value (no type info in Python mode)
+        active_serializer.write(buffer, value)
+
+    def read(self, buffer):
+        # Read the stored variant index
+        stored_index = buffer.read_varuint32()
+
+        # Validate index is within bounds
+        if stored_index >= len(self.alternative_serializers):
+            raise ValueError(f"Union index out of bounds: {stored_index} (max: 
{len(self.alternative_serializers) - 1})")
+
+        # Dispatch to the appropriate alternative's serializer
+        _, serializer = self.alternative_serializers[stored_index]
+        return serializer.read(buffer)
+
+    def xwrite(self, buffer, value):
+        # Find which alternative type matches the value
+        active_index = None
+        active_serializer = None
+        active_type = None
+
+        for i, (alt_type, serializer) in 
enumerate(self.alternative_serializers):
+            if isinstance(value, alt_type):
+                active_index = i
+                active_serializer = serializer
+                active_type = alt_type
+                break
+
+        if active_index is None:
+            raise TypeError(f"Value {value} of type {type(value)} doesn't 
match any alternative in Union{self.alternative_types}")
+
+        # Write the active variant index
+        buffer.write_varuint32(active_index)
+
+        # In xlang mode, write type info for the alternative
+        # Get the typeinfo for the alternative type and write it
+        typeinfo = self.type_resolver.get_typeinfo(active_type)
+        self.type_resolver.write_typeinfo(buffer, typeinfo)
+
+        # Write the alternative's value data
+        active_serializer.xwrite(buffer, value)
+
+    def xread(self, buffer):
+        # Read the stored variant index
+        stored_index = buffer.read_varuint32()
+
+        # Validate index is within bounds
+        if stored_index >= len(self.alternative_serializers):
+            raise ValueError(f"Union index out of bounds: {stored_index} (max: 
{len(self.alternative_serializers) - 1})")
+
+        # In xlang mode, read type info for the alternative
+        typeinfo = self.type_resolver.read_typeinfo(buffer)
+
+        # Dispatch to the appropriate alternative's serializer
+        # Use typeinfo's serializer which may be more specific than what we 
registered
+        return typeinfo.serializer.xread(buffer)
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index 5c94ba502..1cbf9fc26 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -78,6 +78,7 @@ if ENABLE_FORY_CYTHON_SERIALIZATION:
         EnumSerializer,
         SliceSerializer,
     )
+    from pyfory._serializer import UnionSerializer  # noqa: F401
 else:
     from pyfory._serializer import (  # noqa: F401 # pylint: 
disable=unused-import
         Serializer,
@@ -100,6 +101,7 @@ else:
         MapSerializer,
         EnumSerializer,
         SliceSerializer,
+        UnionSerializer,
     )
 
 from pyfory.type import (
diff --git a/python/pyfory/tests/test_union.py 
b/python/pyfory/tests/test_union.py
new file mode 100644
index 000000000..1bfd3745a
--- /dev/null
+++ b/python/pyfory/tests/test_union.py
@@ -0,0 +1,202 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import dataclasses
+from typing import Union
+
+from pyfory import Fory
+
+
+def test_union_basic_types():
+    """Test Union with basic types like int and str"""
+    fory = Fory()
+
+    # Test with int value
+    value_int: Union[int, str] = 42
+    serialized = fory.serialize(value_int)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == 42
+    assert type(deserialized) is int
+
+    # Test with str value
+    value_str: Union[int, str] = "hello"
+    serialized = fory.serialize(value_str)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == "hello"
+    assert type(deserialized) is str
+
+
+def test_union_multiple_types():
+    """Test Union with more than two types"""
+    fory = Fory()
+
+    # Test with int
+    value1: Union[int, str, float] = 123
+    serialized = fory.serialize(value1)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == 123
+    assert type(deserialized) is int
+
+    # Test with str
+    value2: Union[int, str, float] = "test"
+    serialized = fory.serialize(value2)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == "test"
+    assert type(deserialized) is str
+
+    # Test with float
+    value3: Union[int, str, float] = 3.14
+    serialized = fory.serialize(value3)
+    deserialized = fory.deserialize(serialized)
+    assert abs(deserialized - 3.14) < 0.0001
+    assert type(deserialized) is float
+
+
+def test_union_with_collections():
+    """Test Union with collection types"""
+    fory = Fory()
+
+    # Test with list
+    value_list: Union[list, dict] = [1, 2, 3]
+    serialized = fory.serialize(value_list)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == [1, 2, 3]
+    assert type(deserialized) is list
+
+    # Test with dict
+    value_dict: Union[list, dict] = {"a": 1, "b": 2}
+    serialized = fory.serialize(value_dict)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == {"a": 1, "b": 2}
+    assert type(deserialized) is dict
+
+
+def test_union_with_optional():
+    """Test Union with Optional (Union[T, None])"""
+    fory = Fory(ref_tracking=True)
+
+    # Test with non-None value
+    value: Union[int, None] = 42
+    serialized = fory.serialize(value)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == 42
+
+    # Test with None value
+    value_none: Union[int, None] = None
+    serialized = fory.serialize(value_none)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized is None
+
+
+def test_union_with_dataclass():
+    """Test Union with dataclass types"""
+
+    @dataclasses.dataclass
+    class Person:
+        name: str
+        age: int
+
+    @dataclasses.dataclass
+    class Company:
+        name: str
+        employees: int
+
+    fory = Fory()
+    fory.register(Person)
+    fory.register(Company)
+
+    # Test with Person
+    person = Person("Alice", 30)
+    value: Union[Person, Company] = person
+    serialized = fory.serialize(value)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == person
+    assert type(deserialized) is Person
+
+    # Test with Company
+    company = Company("TechCorp", 100)
+    value2: Union[Person, Company] = company
+    serialized = fory.serialize(value2)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == company
+    assert type(deserialized) is Company
+
+
+def test_union_nested_in_dataclass():
+    """Test Union type as a field in a dataclass"""
+
+    @dataclasses.dataclass
+    class Container:
+        value: Union[int, str]
+        name: str
+
+    fory = Fory()
+    fory.register(Container)
+
+    # Test with int value
+    obj1 = Container(value=42, name="test1")
+    serialized = fory.serialize(obj1)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized.value == 42
+    assert deserialized.name == "test1"
+    assert type(deserialized.value) is int
+
+    # Test with str value
+    obj2 = Container(value="hello", name="test2")
+    serialized = fory.serialize(obj2)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized.value == "hello"
+    assert deserialized.name == "test2"
+    assert type(deserialized.value) is str
+
+
+def test_union_with_bytes():
+    """Test Union with bytes type"""
+    fory = Fory()
+
+    # Test with bytes
+    value_bytes: Union[bytes, str] = b"hello"
+    serialized = fory.serialize(value_bytes)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == b"hello"
+    assert type(deserialized) is bytes
+
+    # Test with str
+    value_str: Union[bytes, str] = "world"
+    serialized = fory.serialize(value_str)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == "world"
+    assert type(deserialized) is str
+
+
+def test_union_cross_language():
+    """Test Union with cross-language serialization"""
+    fory = Fory(language="xlang")
+
+    # Test with int value
+    value_int: Union[int, str] = 42
+    serialized = fory.serialize(value_int)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == 42
+    assert type(deserialized) is int
+
+    # Test with str value
+    value_str: Union[int, str] = "test"
+    serialized = fory.serialize(value_str)
+    deserialized = fory.deserialize(serialized)
+    assert deserialized == "test"
+    assert type(deserialized) is str
diff --git a/python/pyfory/type.py b/python/pyfory/type.py
index 6df38111b..6faf62c15 100644
--- a/python/pyfory/type.py
+++ b/python/pyfory/type.py
@@ -223,6 +223,10 @@ class TypeId:
     FLOAT32_ARRAY = 36
     # one dimensional float64 array.
     FLOAT64_ARRAY = 37
+    # a tagged union type that can hold one of several alternative types.
+    UNION = 38
+    # represents an empty/unit value with no data (e.g., for empty union 
alternatives).
+    NONE = 39
 
     # Bound value for range checks (types with id >= BOUND are not internal 
types).
     BOUND = 64
@@ -467,6 +471,9 @@ def infer_field(field_name, type_, visitor: TypeVisitor, 
types_path=None):
         elif origin is dict or origin == typing.Dict:
             key_type, value_type = args
             return visitor.visit_dict(field_name, key_type, value_type, 
types_path=types_path)
+        elif origin is typing.Union:
+            # Union types are treated as "other" types and handled by 
UnionSerializer
+            return visitor.visit_other(field_name, type_, 
types_path=types_path)
         else:
             raise TypeError(f"Collection types should be {list, dict} instead 
of {type_}")
     else:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to