This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 861a42969 feat(java/python): support enum xlang serialization (#2603)
861a42969 is described below
commit 861a429695d846a35d8823133d9bf59425a5fa59
Author: Shawn Yang <[email protected]>
AuthorDate: Thu Sep 11 16:11:05 2025 +0800
feat(java/python): support enum xlang serialization (#2603)
## Why?
support enum xlang serialization between java and python
## What does this PR do?
<!-- Describe the details of this PR. -->
## Related issues
#2602
#2585
## Does this PR introduce any user-facing change?
<!--
If any user-facing interface changes, please [open an
issue](https://github.com/apache/fory/issues/new/choose) describing the
need to do so and update the document if necessary.
Delete section if not applicable.
-->
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
<!--
When the PR has an impact on performance (if you don't know whether the
PR will have an impact on performance, you can submit the PR first, and
if it will have impact on performance, the code reviewer will explain
it), be sure to attach a benchmark data here.
Delete section if not applicable.
-->
---
.../apache/fory/serializer/ObjectSerializer.java | 17 +++---
.../src/main/java/org/apache/fory/type/Types.java | 13 +++++
.../java/org/apache/fory/CrossLanguageTest.java | 16 ++++--
python/pyfory/_registry.py | 10 ++--
python/pyfory/_serialization.pyx | 56 +++++++++++---------
python/pyfory/_struct.py | 33 ++++++++++--
python/pyfory/meta/typedef.py | 44 +++++++++++-----
python/pyfory/serializer.py | 60 ++++++++++++++++++++--
python/pyfory/tests/test_cross_language.py | 5 +-
python/pyfory/tests/test_struct.py | 19 +++++++
python/pyfory/type.py | 8 +++
11 files changed, 215 insertions(+), 66 deletions(-)
diff --git
a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java
b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java
index 13a7b5afe..0640edd8f 100644
---
a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java
+++
b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java
@@ -43,6 +43,7 @@ import org.apache.fory.type.DescriptorGrouper;
import org.apache.fory.type.Generics;
import org.apache.fory.type.TypeUtils;
import org.apache.fory.type.Types;
+import org.apache.fory.util.ExceptionUtils;
import org.apache.fory.util.Preconditions;
import org.apache.fory.util.record.RecordInfo;
import org.apache.fory.util.record.RecordUtils;
@@ -353,7 +354,7 @@ public final class ObjectSerializer<T> extends
AbstractObjectSerializer<T> {
}
private static int computeFieldHash(int hash, Fory fory, TypeRef<?> typeRef)
{
- int id;
+ int id = 0;
if (typeRef.isSubtypeOf(List.class)) {
// TODO(chaokunyang) add list element type into schema hash
id = Types.LIST;
@@ -365,21 +366,17 @@ public final class ObjectSerializer<T> extends
AbstractObjectSerializer<T> {
TypeResolver resolver =
fory.isCrossLanguage() ? fory.getXtypeResolver() :
fory.getClassResolver();
Class<?> cls = typeRef.getRawType();
- if (ReflectionUtils.isAbstract(cls) || cls.isInterface()) {
- id = 0;
- } else {
- ClassInfo classInfo = resolver.getClassInfo(typeRef.getRawType());
- int xtypeId = classInfo.getXtypeId();
- if (Types.isStructType(xtypeId & 0xff)) {
+ if (!ReflectionUtils.isAbstract(cls) && !cls.isInterface()) {
+ ClassInfo classInfo = resolver.getClassInfo(cls);
+ int xtypeId = id = classInfo.getXtypeId();
+ if (Types.isNamedType(xtypeId & 0xff)) {
id =
TypeUtils.computeStringHash(
classInfo.decodeNamespace() + classInfo.decodeTypeName());
- } else {
- id = Math.abs(xtypeId);
}
}
} catch (Exception e) {
- id = 0;
+ ExceptionUtils.ignore(e);
}
}
long newHash = ((long) hash) * 31 + id;
diff --git a/java/fory-core/src/main/java/org/apache/fory/type/Types.java
b/java/fory-core/src/main/java/org/apache/fory/type/Types.java
index b7e937ba4..c6300b2be 100644
--- a/java/fory-core/src/main/java/org/apache/fory/type/Types.java
+++ b/java/fory-core/src/main/java/org/apache/fory/type/Types.java
@@ -165,6 +165,19 @@ public class Types {
public static final int UNKNOWN = 63;
// Helper methods
+ public static boolean isNamedType(int value) {
+ assert value < 0xff;
+ switch (value) {
+ case NAMED_STRUCT:
+ case NAMED_COMPATIBLE_STRUCT:
+ case NAMED_ENUM:
+ case NAMED_EXT:
+ return true;
+ default:
+ return false;
+ }
+ }
+
public static boolean isStructType(int value) {
assert value < 0xff;
return value == STRUCT
diff --git
a/java/fory-core/src/test/java/org/apache/fory/CrossLanguageTest.java
b/java/fory-core/src/test/java/org/apache/fory/CrossLanguageTest.java
index bc74fef00..b47b1de89 100644
--- a/java/fory-core/src/test/java/org/apache/fory/CrossLanguageTest.java
+++ b/java/fory-core/src/test/java/org/apache/fory/CrossLanguageTest.java
@@ -571,7 +571,7 @@ public class CrossLanguageTest extends ForyTestBase {
System.out.println(dataFile.toAbsolutePath());
Files.deleteIfExists(dataFile);
Files.write(dataFile, serialized);
- dataFile.toFile().deleteOnExit();
+ // dataFile.toFile().deleteOnExit();
ImmutableList<String> command =
ImmutableList.of(
PYTHON_EXECUTABLE, "-m", PYTHON_MODULE, testName,
dataFile.toAbsolutePath().toString());
@@ -823,9 +823,15 @@ public class CrossLanguageTest extends ForyTestBase {
String f3;
}
- @Test
- public void testEnumField() throws java.io.IOException {
- Fory fory =
Fory.builder().withLanguage(Language.XLANG).requireClassRegistration(true).build();
+ @Test(dataProvider = "compatible")
+ public void testEnumField(boolean compatible) throws java.io.IOException {
+ Fory fory =
+ Fory.builder()
+ .withLanguage(Language.XLANG)
+ .withCompatibleMode(
+ compatible ? CompatibleMode.COMPATIBLE :
CompatibleMode.SCHEMA_CONSISTENT)
+ .requireClassRegistration(true)
+ .build();
fory.register(EnumTestClass.class, "test.EnumTestClass");
fory.register(EnumFieldStruct.class, "test.EnumFieldStruct");
@@ -834,6 +840,6 @@ public class CrossLanguageTest extends ForyTestBase {
a.f2 = EnumTestClass.BAR;
a.f3 = "abc";
Assert.assertEquals(xserDe(fory, a), a);
- structRoundBack(fory, a, "test_enum_field");
+ structRoundBack(fory, a, "test_enum_field" + (compatible ? "_compatible" :
""));
}
}
diff --git a/python/pyfory/_registry.py b/python/pyfory/_registry.py
index 6f4ae76e9..da42f43c2 100644
--- a/python/pyfory/_registry.py
+++ b/python/pyfory/_registry.py
@@ -60,6 +60,7 @@ from pyfory.serializer import (
PickleStrongCacheSerializer,
PickleSerializer,
DataClassSerializer,
+ DataClassStubSerializer,
StatefulSerializer,
ReduceSerializer,
FunctionSerializer,
@@ -127,7 +128,7 @@ else:
self.namespace_bytes = namespace_bytes
self.typename_bytes = typename_bytes
self.dynamic_type = dynamic_type
- self.type_def = None
+ self.type_def = type_def
def __repr__(self):
return f"TypeInfo(cls={self.cls}, type_id={self.type_id},
serializer={self.serializer})"
@@ -533,11 +534,8 @@ class TypeResolver:
# Use FunctionSerializer for function types (including lambdas)
serializer = FunctionSerializer(self.fory, cls)
elif dataclasses.is_dataclass(cls):
- if not self.meta_share:
- serializer = DataClassSerializer(self.fory, cls, xlang=not
self.fory.is_py)
- else:
- # lazy create serializer to handle nested struct fields.
- serializer = None
+ # lazy create serializer to handle nested struct fields.
+ serializer = DataClassStubSerializer(self.fory, cls, xlang=not
self.fory.is_py)
elif issubclass(cls, enum.Enum):
serializer = EnumSerializer(self.fory, cls)
elif (hasattr(cls, "__reduce__") and cls.__reduce__ is not
object.__reduce__) or (
diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx
index ad9ec5033..1320a9f4a 100644
--- a/python/pyfory/_serialization.pyx
+++ b/python/pyfory/_serialization.pyx
@@ -534,10 +534,10 @@ cdef class TypeResolver:
self._c_types_info[<uintptr_t> <PyObject *> cls] = <PyObject *>
type_info
self._populate_typeinfo(type_info)
return type_info
-
+
def is_registered_by_name(self, cls):
return self._resolver.is_registered_by_name(cls)
-
+
def is_registered_by_id(self, cls):
return self._resolver.is_registered_by_id(cls)
@@ -565,11 +565,11 @@ cdef class TypeResolver:
cdef:
int32_t type_id = typeinfo.type_id
int32_t internal_type_id = type_id & 0xFF
-
+
if self.meta_share:
self.write_shared_type_meta(buffer, typeinfo)
return
-
+
buffer.write_varuint32(type_id)
if IsNamespacedType(internal_type_id):
self.metastring_resolver.write_meta_string_bytes(buffer,
typeinfo.namespace_bytes)
@@ -578,7 +578,7 @@ cdef class TypeResolver:
cpdef inline TypeInfo read_typeinfo(self, Buffer buffer):
if self.meta_share:
return self.read_shared_type_meta(buffer)
-
+
cdef:
int32_t type_id = buffer.read_varuint32()
if type_id < 0:
@@ -597,7 +597,7 @@ cdef class TypeResolver:
raise ValueError(f"Unexpected type_id {type_id}")
typeinfo = <TypeInfo> typeinfo_ptr
return typeinfo
-
+
cpdef inline TypeInfo get_typeinfo_by_id(self, int32_t type_id):
if type_id >= self._c_registered_id_to_type_info.size() or type_id < 0
or IsNamespacedType(type_id & 0xFF):
raise ValueError(f"Unexpected type_id {type_id}")
@@ -605,11 +605,14 @@ cdef class TypeResolver:
if typeinfo_ptr == NULL:
raise ValueError(f"Unexpected type_id {type_id}")
typeinfo = <TypeInfo> typeinfo_ptr
- return typeinfo
+ return typeinfo
def get_typeinfo_by_name(self, namespace, typename):
return self._resolver.get_typeinfo_by_name(namespace=namespace,
typename=typename)
+ cpdef _set_typeinfo(self, typeinfo):
+ self._resolver._set_typeinfo(typeinfo)
+
def get_meta_compressor(self):
return self._resolver.get_meta_compressor()
@@ -647,26 +650,26 @@ cdef class MetaContext:
"""
Context for sharing type meta across multiple serialization. Type name,
field name and field
type will be shared between different serialization.
-
+
Note that this context is not thread-safe, you should use it with one Fory
instance.
"""
cdef:
# Types which have sent definitions to peer
# Maps type objects to their assigned IDs
- flat_hash_map[uint64_t, int32_t] _c_type_map
-
+ flat_hash_map[uint64_t, int32_t] _c_type_map
+
# Counter for assigning new IDs
list _writing_type_defs
list _read_type_infos
object fory
object type_resolver
-
+
def __cinit__(self, object fory):
self.fory = fory
self.type_resolver = fory.type_resolver
self._writing_type_defs = []
self._read_type_infos = []
-
+
cpdef inline void write_shared_typeinfo(self, Buffer buffer, typeinfo):
"""Add a type definition to the writing queue."""
type_cls = typeinfo.cls
@@ -680,45 +683,48 @@ cdef class MetaContext:
cdef flat_hash_map[uint64_t, int32_t].iterator it =
self._c_type_map.find(type_addr)
if it != self._c_type_map.end():
buffer.write_varuint32(deref(it).second)
-
+
cdef index = self._c_type_map.size()
buffer.write_varuint32(index)
self._c_type_map[type_addr] = index
type_def = typeinfo.type_def
+ if type_def is None:
+ self.type_resolver._set_typeinfo(typeinfo)
+ type_def = typeinfo.type_def
self._writing_type_defs.append(type_def)
cpdef inline list get_writing_type_defs(self):
"""Get all type definitions that need to be written."""
return self._writing_type_defs
-
+
cpdef inline reset_write(self):
"""Reset write state."""
self._writing_type_defs.clear()
self._c_type_map.clear()
-
+
cpdef inline add_read_typeinfo(self, type_info):
"""Add a type info read from peer."""
self._read_type_infos.append(type_info)
-
+
cpdef inline read_shared_typeinfo(self, Buffer buffer):
"""Read a type info from buffer."""
cdef type_id = buffer.read_varuint32()
if IsTypeShareMeta(type_id & 0xFF):
return self._read_type_infos[buffer.read_varuint32()]
return self.type_resolver.get_typeinfo_by_id(type_id)
-
+
cpdef inline reset_read(self):
"""Reset read state."""
self._read_type_infos.clear()
-
+
cpdef inline reset(self):
"""Reset both read and write state."""
self.reset_write()
self.reset_read()
-
+
def __str__(self):
return self.__repr__()
-
+
def __repr__(self):
return (f"MetaContext("
f"read_infos={self._read_type_infos}, "
@@ -930,13 +936,13 @@ cdef class Fory:
if self.serialization_context.scoped_meta_share_enabled:
type_defs_offset_pos = buffer.writer_index
buffer.write_int32(-1) # Reserve 4 bytes for type definitions
offset
-
+
cdef int32_t start_offset
if self.language == Language.PYTHON:
self.serialize_ref(buffer, obj)
else:
self.xserialize_ref(buffer, obj)
-
+
# Write type definitions at the end, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
meta_context = self.serialization_context.meta_context
@@ -945,7 +951,7 @@ cdef class Fory:
current_pos = buffer.writer_index
buffer.put_int32(type_defs_offset_pos, current_pos -
type_defs_offset_pos - 4)
self.type_resolver.write_type_defs(buffer)
-
+
if buffer is not self.buffer:
return buffer
else:
@@ -1076,7 +1082,7 @@ cdef class Fory:
"buffers should be null when the serialized stream is "
"produced with buffer_callback null."
)
-
+
# Read type definitions at the start, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
relative_type_defs_offset = buffer.read_int32()
@@ -1089,7 +1095,7 @@ cdef class Fory:
self.type_resolver.read_type_defs(buffer)
# Jump back to continue with object deserialization
buffer.reader_index = current_reader_index
-
+
if not is_target_x_lang:
return self.deserialize_ref(buffer)
return self.xdeserialize_ref(buffer)
diff --git a/python/pyfory/_struct.py b/python/pyfory/_struct.py
index 032740f7f..89322c45d 100644
--- a/python/pyfory/_struct.py
+++ b/python/pyfory/_struct.py
@@ -65,7 +65,7 @@ basic_types = {
}
-class ComplexTypeVisitor(TypeVisitor):
+class StructFieldSerializerVisitor(TypeVisitor):
def __init__(
self,
fory,
@@ -88,6 +88,8 @@ class ComplexTypeVisitor(TypeVisitor):
return MapSerializer(self.fory, dict, key_serializer, value_serializer)
def visit_customized(self, field_name, type_, types_path=None):
+ if issubclass(type_, enum.Enum):
+ return self.fory.type_resolver.get_serializer(type_)
return None
def visit_other(self, field_name, type_, types_path=None):
@@ -210,7 +212,10 @@ class StructHashVisitor(TypeVisitor):
assert not isinstance(serializer, (PickleSerializer,))
id_ = typeinfo.type_id
assert id_ is not None, serializer
- id_ = abs(id_)
+ if TypeId.is_namespaced_type(typeinfo.type_id):
+ namespace_str = typeinfo.decode_namespace()
+ typename_str = typeinfo.decode_typename()
+ id_ = compute_string_hash(namespace_str + typename_str)
self._hash = self._compute_field_hash(self._hash, id_)
@staticmethod
@@ -254,7 +259,7 @@ class StructTypeIdVisitor(TypeVisitor):
from pyfory.serializer import PickleSerializer # Local import
if is_subclass(type_, enum.Enum):
- return self.fory.type_resolver.get_typeinfo(type_).type_id
+ return [self.fory.type_resolver.get_typeinfo(type_).type_id]
if type_ not in basic_types and not is_py_array_type(type_):
return None, None
typeinfo = self.fory.type_resolver.get_typeinfo(type_)
@@ -262,6 +267,28 @@ class StructTypeIdVisitor(TypeVisitor):
return [typeinfo.type_id]
+class StructTypeVisitor(TypeVisitor):
+ def __init__(self, cls):
+ self.cls = cls
+
+ def visit_list(self, field_name, elem_type, types_path=None):
+ # Infer type recursively for type such as List[Dict[str, str]]
+ elem_types = infer_field("item", elem_type, self,
types_path=types_path)
+ return typing.List, elem_types
+
+ def visit_dict(self, field_name, key_type, value_type, types_path=None):
+ # Infer type recursively for type such as Dict[str, Dict[str, str]]
+ key_types = infer_field("key", key_type, self, types_path=types_path)
+ value_types = infer_field("value", value_type, self,
types_path=types_path)
+ return typing.Dict, key_types, value_types
+
+ def visit_customized(self, field_name, type_, types_path=None):
+ return [type_]
+
+ def visit_other(self, field_name, type_, types_path=None):
+ return [type_]
+
+
def get_field_names(clz, type_hints=None):
if hasattr(clz, "__dict__"):
# Regular object with __dict__
diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py
index 00f2710aa..4edbbfcf7 100644
--- a/python/pyfory/meta/typedef.py
+++ b/python/pyfory/meta/typedef.py
@@ -15,12 +15,14 @@
# specific language governing permissions and limitations
# under the License.
-from typing import List
+import enum
import typing
+from typing import List
from pyfory.type import TypeId
from pyfory._util import Buffer
from pyfory.type import infer_field, is_polymorphic_type
from pyfory.meta.metastring import Encoding
+from pyfory.type import infer_field_types
# Constants from the specification
@@ -56,7 +58,8 @@ class TypeDef:
self.is_compressed = is_compressed
def create_fields_serializer(self, resolver):
- serializers = [field_info.field_type.create_serializer(resolver) for
field_info in self.fields]
+ field_types = infer_field_types(self.cls)
+ serializers = [field_info.field_type.create_serializer(resolver,
field_types.get(field_info.name, None)) for field_info in self.fields]
return serializers
def get_field_names(self):
@@ -143,9 +146,19 @@ class FieldType:
is_monomorphic = not is_polymorphic_type(xtype_id)
return FieldType(xtype_id, is_monomorphic, is_nullable,
is_tracking_ref)
- def create_serializer(self, resolver):
+ def create_serializer(self, resolver, type_):
if self.type_id in [TypeId.EXT, TypeId.STRUCT, TypeId.NAMED_STRUCT,
TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT, TypeId.UNKNOWN]:
return None
+ if isinstance(type_, list):
+ type_ = type_[0]
+ if isinstance(type_, type) and issubclass(type_, enum.Enum):
+ typeinfo = resolver.get_typeinfo(type_, create=False)
+ if typeinfo is not None and typeinfo.serializer is not None:
+ return typeinfo.serializer
+ else:
+ from pyfory.serializer import NonExistEnumSerializer
+
+ return NonExistEnumSerializer(resolver.fory)
return resolver.get_typeinfo_by_id(self.type_id).serializer
def __repr__(self):
@@ -164,13 +177,15 @@ class CollectionFieldType(FieldType):
super().__init__(type_id, is_monomorphic, is_nullable, is_tracking_ref)
self.element_type = element_type
- def create_serializer(self, resolver):
+ def create_serializer(self, resolver, type_):
from pyfory.serializer import ListSerializer, SetSerializer
+ elem_type = type_[1] if len(type_) >= 2 else None
+ elem_serializer = self.element_type.create_serializer(resolver,
elem_type)
if self.type_id == TypeId.LIST:
- return ListSerializer(resolver.fory, list,
self.element_type.create_serializer(resolver))
+ return ListSerializer(resolver.fory, list, elem_serializer)
elif self.type_id == TypeId.SET:
- return SetSerializer(resolver.fory, set,
self.element_type.create_serializer(resolver))
+ return SetSerializer(resolver.fory, set, elem_serializer)
else:
raise ValueError(f"Unknown collection type: {self.type_id}")
@@ -189,9 +204,14 @@ class MapFieldType(FieldType):
self.key_type = key_type
self.value_type = value_type
- def create_serializer(self, resolver):
- key_serializer = self.key_type.create_serializer(resolver)
- value_serializer = self.value_type.create_serializer(resolver)
+ def create_serializer(self, resolver, type_):
+ key_type, value_type = None, None
+ if len(type_) >= 2:
+ key_type = type_[1]
+ if len(type_) >= 3:
+ value_type = type_[2]
+ key_serializer = self.key_type.create_serializer(resolver, key_type)
+ value_serializer = self.value_type.create_serializer(resolver,
value_type)
from pyfory.serializer import MapSerializer
return MapSerializer(resolver.fory, dict, key_serializer,
value_serializer)
@@ -207,7 +227,7 @@ class DynamicFieldType(FieldType):
def __init__(self, type_id: int, is_monomorphic: bool, is_nullable: bool,
is_tracking_ref: bool):
super().__init__(type_id, is_monomorphic, is_nullable, is_tracking_ref)
- def create_serializer(self, resolver):
+ def create_serializer(self, resolver, type_):
return None
def __repr__(self):
@@ -229,8 +249,8 @@ def build_field_infos(type_resolver, cls):
field_type = build_field_type(type_resolver, field_name,
field_type_hint, visitor)
field_info = FieldInfo(field_name, field_type, cls.__name__)
field_infos.append(field_info)
-
- serializers = [field_info.field_type.create_serializer(type_resolver) for
field_info in field_infos]
+ field_types = infer_field_types(cls)
+ serializers = [field_info.field_type.create_serializer(type_resolver,
field_types.get(field_info.name, None)) for field_info in field_infos]
field_names, serializers = _sort_fields(type_resolver, field_names,
serializers)
field_infos_map = {field_info.name: field_info for field_info in
field_infos}
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index c19bfc777..27045840e 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -17,6 +17,7 @@
import array
import builtins
+import dataclasses
import itertools
import marshal
import logging
@@ -293,8 +294,8 @@ _ENABLE_FORY_PYTHON_JIT =
os.environ.get("ENABLE_FORY_PYTHON_JIT", "True").lower
# Moved from L32 to here, after all Serializer base classes and specific
serializers
# like ListSerializer, MapSerializer, PickleSerializer are defined or imported
-# and before DataClassSerializer which uses ComplexTypeVisitor from _struct.
-from pyfory._struct import _get_hash, _sort_fields, ComplexTypeVisitor
+# and before DataClassSerializer which uses StructFieldSerializerVisitor from
_struct.
+from pyfory._struct import _get_hash, _sort_fields,
StructFieldSerializerVisitor
class DataClassSerializer(Serializer):
@@ -309,7 +310,7 @@ class DataClassSerializer(Serializer):
if self._xlang:
self._serializers = serializers or [None] * len(self._field_names)
if serializers is None:
- visitor = ComplexTypeVisitor(fory)
+ visitor = StructFieldSerializerVisitor(fory)
for index, key in enumerate(self._field_names):
serializer = infer_field(key, self._type_hints[key],
visitor, types_path=[])
self._serializers[index] = serializer
@@ -594,6 +595,29 @@ class DataClassSerializer(Serializer):
return obj
+class DataClassStubSerializer(DataClassSerializer):
+ def __init__(self, fory, clz: type, xlang: bool = False):
+ Serializer.__init__(self, fory, clz)
+ self.xlang = xlang
+
+ def write(self, buffer, value):
+ self._replace().write(buffer, value)
+
+ def read(self, buffer):
+ return self._replace().read(buffer)
+
+ def xwrite(self, buffer, value):
+ self._replace().xwrite(buffer, value)
+
+ def xread(self, buffer):
+ return self._replace().xread(buffer)
+
+ def _replace(self):
+ typeinfo = self.fory.type_resolver.get_typeinfo(self.type_)
+ typeinfo.serializer = DataClassSerializer(self.fory, self.type_,
self.xlang)
+ return typeinfo.serializer
+
+
# Use numpy array or python array module.
typecode_dict = (
{
@@ -1262,3 +1286,33 @@ class ComplexObjectSerializer(DataClassSerializer):
stacklevel=2,
)
return DataClassSerializer(fory, clz, xlang=True)
+
+
[email protected]
+class NonExistEnum:
+ value: int = -1
+ name: str = ""
+
+
+class NonExistEnumSerializer(Serializer):
+ def __init__(self, fory):
+ super().__init__(fory, NonExistEnum)
+ self.need_to_write_ref = False
+
+ @classmethod
+ def support_subclass(cls) -> bool:
+ return True
+
+ def write(self, buffer, value):
+ buffer.write_string(value.name)
+
+ def read(self, buffer):
+ name = buffer.read_string()
+ return NonExistEnum(name=name)
+
+ def xwrite(self, buffer, value):
+ buffer.write_varuint32(value.value)
+
+ def xread(self, buffer):
+ value = buffer.read_varuint32()
+ return NonExistEnum(value=value)
diff --git a/python/pyfory/tests/test_cross_language.py
b/python/pyfory/tests/test_cross_language.py
index 5740caeb4..673fd47da 100644
--- a/python/pyfory/tests/test_cross_language.py
+++ b/python/pyfory/tests/test_cross_language.py
@@ -514,7 +514,8 @@ class EnumFieldStruct:
@cross_language_test
def test_enum_field(data_file_path):
- fory = pyfory.Fory(language=pyfory.Language.XLANG, ref_tracking=False)
+ compatible = "compatible" in data_file_path
+ fory = pyfory.Fory(language=pyfory.Language.XLANG, ref_tracking=False,
compatible=compatible)
fory.register_type(EnumTestClass, namespace="test",
typename="EnumTestClass")
fory.register_type(EnumFieldStruct, namespace="test",
typename="EnumFieldStruct")
obj = EnumFieldStruct(f1=EnumTestClass.FOO, f2=EnumTestClass.BAR, f3="abc")
@@ -529,7 +530,7 @@ def test_struct_hash(data_file_path):
read_hash = pyfory.Buffer(data_bytes).read_int32()
fory = pyfory.Fory(language=pyfory.Language.XLANG, ref_tracking=True)
fory.register_type(ComplexObject1, typename="ComplexObject1")
- serializer = fory.type_resolver.get_serializer(ComplexObject1)
+ serializer = fory.type_resolver.get_serializer(ComplexObject1)._replace()
from pyfory._struct import _get_hash
v = _get_hash(fory, serializer._field_names, serializer._type_hints)
diff --git a/python/pyfory/tests/test_struct.py
b/python/pyfory/tests/test_struct.py
index 576f28369..481382409 100644
--- a/python/pyfory/tests/test_struct.py
+++ b/python/pyfory/tests/test_struct.py
@@ -123,6 +123,19 @@ class DataClassObject:
f_any: Any
f_complex: ComplexObject = None
+ @classmethod
+ def create(cls):
+ return cls(
+ f_int=42,
+ f_float=3.14159,
+ f_str="test_codegen",
+ f_bool=True,
+ f_list=[1, 2, 3],
+ f_dict={"key": 1.5},
+ f_any="any_data",
+ f_complex=None,
+ )
+
def test_data_class_serializer_xlang():
fory = Fory(language=Language.XLANG, ref_tracking=True)
@@ -190,6 +203,8 @@ def test_data_class_serializer_xlang_codegen():
fory.register_type(ComplexObject, typename="example.ComplexObject")
fory.register_type(DataClassObject, typename="example.TestDataClassObject")
+ # trigger lazy serializer replace
+ fory.serialize(DataClassObject.create())
# Get the serializer that was created during registration
serializer = fory.type_resolver.get_serializer(DataClassObject)
@@ -305,6 +320,8 @@ def
test_data_class_serializer_xlang_codegen_generated_code():
fory.register_type(ComplexObject, typename="example.ComplexObject")
fory.register_type(DataClassObject, typename="example.TestDataClassObject")
+ # trigger lazy serializer replace
+ fory.serialize(DataClassObject.create())
# Get the serializer that was created during registration
serializer = fory.type_resolver.get_serializer(DataClassObject)
@@ -341,6 +358,8 @@ def test_data_class_serializer_xlang_vs_non_xlang():
fory_xlang.register_type(ComplexObject, typename="example.ComplexObject")
fory_xlang.register_type(DataClassObject,
typename="example.TestDataClassObject")
+ # trigger lazy serializer replace
+ fory_xlang.serialize(DataClassObject.create())
# For Python mode, we can create the serializer directly since it doesn't
require registration
serializer_xlang = fory_xlang.type_resolver.get_serializer(DataClassObject)
serializer_python = DataClassSerializer(fory_python, DataClassObject,
xlang=False)
diff --git a/python/pyfory/type.py b/python/pyfory/type.py
index 13be4dadc..e596c4678 100644
--- a/python/pyfory/type.py
+++ b/python/pyfory/type.py
@@ -410,6 +410,14 @@ class TypeVisitor(ABC):
pass
+def infer_field_types(type_):
+ type_hints = typing.get_type_hints(type_)
+ from pyfory._struct import StructTypeVisitor
+
+ visitor = StructTypeVisitor(type_)
+ return {name: infer_field(name, type_, visitor) for name, type_ in
sorted(type_hints.items())}
+
+
def infer_field(field_name, type_, visitor: TypeVisitor, types_path=None):
types_path = list(types_path or [])
types_path.append(type_)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]