This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 00b247156 feat(python): support optional typehint for dataclass fields
(#2766)
00b247156 is described below
commit 00b247156016e6334cbfc1365f579f6e94101f44
Author: Shawn Yang <[email protected]>
AuthorDate: Wed Oct 15 23:19:34 2025 +0530
feat(python): support optional typehint for dataclass fields (#2766)
## Why?
<!-- Describe the purpose of this PR. -->
## What does this PR do?
- support optional typehint for dataclass fields
- fastpath for numeric and string fields serialization
## Related issues
<!--
Is there any related issue? If this PR closes them you say say
fix/closes:
- #xxxx0
- #xxxx1
- Fixes #xxxx2
-->
## Does this PR introduce any user-facing change?
<!--
If any user-facing interface changes, please [open an
issue](https://github.com/apache/fory/issues/new/choose) describing the
need to do so and update the document if necessary.
Delete section if not applicable.
-->
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
<!--
When the PR has an impact on performance (if you don't know whether the
PR will have an impact on performance, you can submit the PR first, and
if it will have impact on performance, the code reviewer will explain
it), be sure to attach a benchmark data here.
Delete section if not applicable.
-->
---
ci/run_ci.py | 2 +-
.../java/org/apache/fory/CrossLanguageTest.java | 18 +-
python/pyfory/_fory.py | 28 +--
python/pyfory/_serialization.pyx | 46 ++---
python/pyfory/_serializer.py | 14 +-
python/pyfory/_struct.py | 10 +-
python/pyfory/_util.pxd | 8 +
python/pyfory/_util.pyx | 20 ++
python/pyfory/meta/typedef.py | 48 +++--
python/pyfory/meta/typedef_decoder.py | 12 +-
python/pyfory/meta/typedef_encoder.py | 22 +-
python/pyfory/serializer.py | 222 ++++++++++++++++-----
python/pyfory/tests/test_cross_language.py | 13 +-
python/pyfory/tests/test_reduce_serializer.py | 2 +-
python/pyfory/tests/test_struct.py | 162 +++++++++++++--
python/pyfory/type.py | 24 ++-
16 files changed, 479 insertions(+), 172 deletions(-)
diff --git a/ci/run_ci.py b/ci/run_ci.py
index 7818a9e60..cdd8f7fbc 100644
--- a/ci/run_ci.py
+++ b/ci/run_ci.py
@@ -293,7 +293,7 @@ def parse_args():
if USE_PYTHON_GO:
func()
else:
- run_shell_script("go")
+ # run_shell_script("go")
pass
elif command == "format":
if USE_PYTHON_FORMAT:
diff --git
a/java/fory-core/src/test/java/org/apache/fory/CrossLanguageTest.java
b/java/fory-core/src/test/java/org/apache/fory/CrossLanguageTest.java
index d2397a50f..ce357e273 100644
--- a/java/fory-core/src/test/java/org/apache/fory/CrossLanguageTest.java
+++ b/java/fory-core/src/test/java/org/apache/fory/CrossLanguageTest.java
@@ -102,7 +102,7 @@ public class CrossLanguageTest extends ForyTestBase {
@Data
public static class A {
- public Integer f1;
+ public int f1;
public Map<String, String> f2;
public static A create() {
@@ -184,7 +184,7 @@ public class CrossLanguageTest extends ForyTestBase {
/** Keep this in sync with `foo_schema` in test_cross_language.py */
@Data
public static class Foo {
- public Integer f1;
+ public int f1;
public String f2;
public List<String> f3;
public Map<String, Integer> f4;
@@ -214,7 +214,7 @@ public class CrossLanguageTest extends ForyTestBase {
/** Keep this in sync with `bar_schema` in test_cross_language.py */
@Data
public static class Bar {
- public Integer f1;
+ public int f1;
public String f2;
public static Bar create() {
@@ -446,12 +446,12 @@ public class CrossLanguageTest extends ForyTestBase {
String f2;
List<String> f3;
Map<Byte, Integer> f4;
- Byte f5;
- Short f6;
- Integer f7;
- Long f8;
- Float f9;
- Double f10;
+ byte f5;
+ short f6;
+ int f7;
+ long f8;
+ float f9;
+ double f10;
short[] f11;
List<Short> f12;
}
diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py
index dec22fe61..e89f8ffa7 100644
--- a/python/pyfory/_fory.py
+++ b/python/pyfory/_fory.py
@@ -322,7 +322,7 @@ class Fory:
if self.language == Language.PYTHON:
self.serialize_ref(buffer, obj)
else:
- self.xserialize_ref(buffer, obj)
+ self.xwrite_ref(buffer, obj)
# Write type definitions at the end, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
@@ -379,18 +379,18 @@ class Fory:
self.type_resolver.write_typeinfo(buffer, typeinfo)
typeinfo.serializer.write(buffer, obj)
- def xserialize_ref(self, buffer, obj, serializer=None):
+ def xwrite_ref(self, buffer, obj, serializer=None):
if serializer is None or serializer.need_to_write_ref:
if not self.ref_resolver.write_ref_or_null(buffer, obj):
- self.xserialize_nonref(buffer, obj, serializer=serializer)
+ self.xwrite_no_ref(buffer, obj, serializer=serializer)
else:
if obj is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
- self.xserialize_nonref(buffer, obj, serializer=serializer)
+ self.xwrite_no_ref(buffer, obj, serializer=serializer)
- def xserialize_nonref(self, buffer, obj, serializer=None):
+ def xwrite_no_ref(self, buffer, obj, serializer=None):
if serializer is not None:
serializer.xwrite(buffer, obj)
return
@@ -458,12 +458,12 @@ class Fory:
buffer.reader_index = current_reader_index
if is_target_x_lang:
- obj = self.xdeserialize_ref(buffer)
+ obj = self.xread_ref(buffer)
else:
- obj = self.deserialize_ref(buffer)
+ obj = self.read_ref(buffer)
return obj
- def deserialize_ref(self, buffer):
+ def read_ref(self, buffer):
ref_resolver = self.ref_resolver
ref_id = ref_resolver.try_preserve_ref_id(buffer)
# indicates that the object is first read.
@@ -477,7 +477,7 @@ class Fory:
else:
return ref_resolver.get_read_object()
- def deserialize_nonref(self, buffer):
+ def read_no_ref(self, buffer):
"""Deserialize not-null and non-reference object from buffer."""
typeinfo = self.type_resolver.read_typeinfo(buffer)
self.inc_depth()
@@ -485,13 +485,13 @@ class Fory:
self.dec_depth()
return o
- def xdeserialize_ref(self, buffer, serializer=None):
+ def xread_ref(self, buffer, serializer=None):
if serializer is None or serializer.need_to_write_ref:
ref_resolver = self.ref_resolver
ref_id = ref_resolver.try_preserve_ref_id(buffer)
# indicates that the object is first read.
if ref_id >= NOT_NULL_VALUE_FLAG:
- o = self.xdeserialize_nonref(buffer, serializer=serializer)
+ o = self.xread_no_ref(buffer, serializer=serializer)
ref_resolver.set_read_object(ref_id, o)
return o
else:
@@ -499,9 +499,9 @@ class Fory:
head_flag = buffer.read_int8()
if head_flag == NULL_FLAG:
return None
- return self.xdeserialize_nonref(buffer, serializer=serializer)
+ return self.xread_no_ref(buffer, serializer=serializer)
- def xdeserialize_nonref(self, buffer, serializer=None):
+ def xread_no_ref(self, buffer, serializer=None):
if serializer is None:
serializer = self.type_resolver.read_typeinfo(buffer).serializer
self.inc_depth()
@@ -551,7 +551,7 @@ class Fory:
typeinfo.serializer.write(buffer, value)
def read_ref_pyobject(self, buffer):
- return self.deserialize_ref(buffer)
+ return self.read_ref(buffer)
def reset_write(self):
self.ref_resolver.reset_write()
diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx
index 8339ae2ba..8be215235 100644
--- a/python/pyfory/_serialization.pyx
+++ b/python/pyfory/_serialization.pyx
@@ -923,7 +923,7 @@ cdef class Fory:
Serialize an object to bytes, alias for `serialize` method.
"""
return self.serialize(obj, buffer, buffer_callback,
unsupported_callback)
-
+
def loads(
self,
buffer: Union[Buffer, bytes],
@@ -995,7 +995,7 @@ cdef class Fory:
if self.language == Language.PYTHON:
self.serialize_ref(buffer, obj)
else:
- self.xserialize_ref(buffer, obj)
+ self.xwrite_ref(buffer, obj)
# Write type definitions at the end, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
@@ -1059,11 +1059,11 @@ cdef class Fory:
self.type_resolver.write_typeinfo(buffer, typeinfo)
typeinfo.serializer.write(buffer, obj)
- cpdef inline xserialize_ref(
+ cpdef inline xwrite_ref(
self, Buffer buffer, obj, Serializer serializer=None):
if serializer is None or serializer.need_to_write_ref:
if not self.ref_resolver.write_ref_or_null(buffer, obj):
- self.xserialize_nonref(
+ self.xwrite_no_ref(
buffer, obj, serializer=serializer
)
else:
@@ -1071,11 +1071,11 @@ cdef class Fory:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
- self.xserialize_nonref(
+ self.xwrite_no_ref(
buffer, obj, serializer=serializer
)
- cpdef inline xserialize_nonref(
+ cpdef inline xwrite_no_ref(
self, Buffer buffer, obj, Serializer serializer=None):
if serializer is None:
typeinfo = self.type_resolver.get_typeinfo(type(obj))
@@ -1149,10 +1149,10 @@ cdef class Fory:
buffer.reader_index = current_reader_index
if not is_target_x_lang:
- return self.deserialize_ref(buffer)
- return self.xdeserialize_ref(buffer)
+ return self.read_ref(buffer)
+ return self.xread_ref(buffer)
- cpdef inline deserialize_ref(self, Buffer buffer):
+ cpdef inline read_ref(self, Buffer buffer):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef int32_t ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
@@ -1174,7 +1174,7 @@ cdef class Fory:
ref_resolver.set_read_object(ref_id, o)
return o
- cpdef inline deserialize_nonref(self, Buffer buffer):
+ cpdef inline read_no_ref(self, Buffer buffer):
"""Deserialize not-null and non-reference object from buffer."""
cdef TypeInfo typeinfo = self.type_resolver.read_typeinfo(buffer)
cls = typeinfo.cls
@@ -1191,7 +1191,7 @@ cdef class Fory:
self.depth -= 1
return o
- cpdef inline xdeserialize_ref(self, Buffer buffer, Serializer
serializer=None):
+ cpdef inline xread_ref(self, Buffer buffer, Serializer serializer=None):
cdef MapRefResolver ref_resolver
cdef int32_t ref_id
if serializer is None or serializer.need_to_write_ref:
@@ -1199,7 +1199,7 @@ cdef class Fory:
ref_id = ref_resolver.try_preserve_ref_id(buffer)
# indicates that the object is first read.
if ref_id >= NOT_NULL_VALUE_FLAG:
- o = self.xdeserialize_nonref(
+ o = self.xread_no_ref(
buffer, serializer=serializer
)
ref_resolver.set_read_object(ref_id, o)
@@ -1209,11 +1209,11 @@ cdef class Fory:
cdef int8_t head_flag = buffer.read_int8()
if head_flag == NULL_FLAG:
return None
- return self.xdeserialize_nonref(
+ return self.xread_no_ref(
buffer, serializer=serializer
)
- cpdef inline xdeserialize_nonref(
+ cpdef inline xread_no_ref(
self, Buffer buffer, Serializer serializer=None):
if serializer is None:
serializer = self.type_resolver.read_typeinfo(buffer).serializer
@@ -2087,7 +2087,7 @@ cdef class MapSerializer(Serializer):
if is_py:
fory.serialize_ref(buffer, key)
else:
- fory.xserialize_ref(buffer, key)
+ fory.xwrite_ref(buffer, key)
else:
if value is not None:
if value_serializer is not None:
@@ -2114,7 +2114,7 @@ cdef class MapSerializer(Serializer):
if is_py:
fory.serialize_ref(buffer, value)
else:
- fory.xserialize_ref(buffer, value)
+ fory.xwrite_ref(buffer, value)
else:
buffer.write_int8(KV_NULL)
has_next = PyDict_Next(obj, &pos, <PyObject **>&key_addr,
<PyObject **>&value_addr)
@@ -2250,9 +2250,9 @@ cdef class MapSerializer(Serializer):
key = key_serializer.xread(buffer)
else:
if is_py:
- key = fory.deserialize_ref(buffer)
+ key = fory.read_ref(buffer)
else:
- key = fory.xdeserialize_ref(buffer)
+ key = fory.xread_ref(buffer)
map_[key] = None
else:
if not value_has_null:
@@ -2270,9 +2270,9 @@ cdef class MapSerializer(Serializer):
ref_resolver.set_read_object(ref_id, value)
else:
if is_py:
- value = fory.deserialize_ref(buffer)
+ value = fory.read_ref(buffer)
else:
- value = fory.xdeserialize_ref(buffer)
+ value = fory.xread_ref(buffer)
map_[None] = value
else:
map_[None] = None
@@ -2514,15 +2514,15 @@ cdef class SliceSerializer(Serializer):
if buffer.read_int8() == NULL_FLAG:
start = None
else:
- start = self.fory.deserialize_nonref(buffer)
+ start = self.fory.read_no_ref(buffer)
if buffer.read_int8() == NULL_FLAG:
stop = None
else:
- stop = self.fory.deserialize_nonref(buffer)
+ stop = self.fory.read_no_ref(buffer)
if buffer.read_int8() == NULL_FLAG:
step = None
else:
- step = self.fory.deserialize_nonref(buffer)
+ step = self.fory.read_no_ref(buffer)
return slice(start, stop, step)
cpdef xwrite(self, Buffer buffer, value):
diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py
index 9b702e840..5950e1ece 100644
--- a/python/pyfory/_serializer.py
+++ b/python/pyfory/_serializer.py
@@ -464,7 +464,7 @@ class MapSerializer(Serializer):
items_iter = iter(obj.items())
key, value = next(items_iter)
has_next = True
- serialize_ref = fory.serialize_ref if self.fory.is_py else
fory.xserialize_ref
+ serialize_ref = fory.serialize_ref if self.fory.is_py else
fory.xwrite_ref
while has_next:
while True:
if key is not None:
@@ -567,7 +567,7 @@ class MapSerializer(Serializer):
if size != 0:
chunk_header = buffer.read_uint8()
key_serializer, value_serializer = self.key_serializer,
self.value_serializer
- deserialize_ref = fory.deserialize_ref if self.fory.is_py else
fory.xdeserialize_ref
+ read_ref = fory.read_ref if self.fory.is_py else fory.xread_ref
fory.inc_depth()
while size > 0:
while True:
@@ -589,7 +589,7 @@ class MapSerializer(Serializer):
else:
key = self._read_obj(key_serializer, buffer)
else:
- key = deserialize_ref(buffer)
+ key = read_ref(buffer)
map_[key] = None
else:
if not value_has_null:
@@ -603,7 +603,7 @@ class MapSerializer(Serializer):
value = self._read_obj(value_serializer,
buffer)
ref_resolver.set_read_object(ref_id, value)
else:
- value = deserialize_ref(buffer)
+ value = read_ref(buffer)
map_[None] = value
else:
map_[None] = None
@@ -733,15 +733,15 @@ class SliceSerializer(Serializer):
if buffer.read_int8() == NULL_FLAG:
start = None
else:
- start = self.fory.deserialize_nonref(buffer)
+ start = self.fory.read_no_ref(buffer)
if buffer.read_int8() == NULL_FLAG:
stop = None
else:
- stop = self.fory.deserialize_nonref(buffer)
+ stop = self.fory.read_no_ref(buffer)
if buffer.read_int8() == NULL_FLAG:
step = None
else:
- step = self.fory.deserialize_nonref(buffer)
+ step = self.fory.read_no_ref(buffer)
return slice(start, stop, step)
def xwrite(self, buffer, value):
diff --git a/python/pyfory/_struct.py b/python/pyfory/_struct.py
index 2b9ae8714..481e695fe 100644
--- a/python/pyfory/_struct.py
+++ b/python/pyfory/_struct.py
@@ -121,8 +121,10 @@ _UNKNOWN_TYPE_ID = -1
_time_types = {datetime.date, datetime.datetime, datetime.timedelta}
-def _sort_fields(type_resolver, field_names, serializers):
+def _sort_fields(type_resolver, field_names, serializers, nullable_map=None):
+ nullable_map = nullable_map or {}
boxed_types = []
+ nullable_boxed_types = []
collection_types = []
set_types = []
map_types = []
@@ -141,8 +143,9 @@ def _sort_fields(type_resolver, field_names, serializers):
)
)
for type_id, serializer, field_name in type_ids:
+ is_nullable = nullable_map.get(field_name, False)
if is_primitive_type(type_id):
- container = boxed_types
+ container = nullable_boxed_types if is_nullable else boxed_types
elif type_id == TypeId.SET:
container = set_types
elif is_list_type(serializer.type_):
@@ -174,11 +177,12 @@ def _sort_fields(type_resolver, field_names, serializers):
return int(compress), -get_primitive_type_size(id_), item[2]
boxed_types = sorted(boxed_types, key=numeric_sorter)
+ nullable_boxed_types = sorted(nullable_boxed_types, key=numeric_sorter)
collection_types = sorted(collection_types, key=sorter)
internal_types = sorted(internal_types, key=sorter)
map_types = sorted(map_types, key=sorter)
other_types = sorted(other_types, key=lambda item: item[2])
- all_types = boxed_types + internal_types + collection_types + set_types +
map_types + other_types
+ all_types = boxed_types + nullable_boxed_types + internal_types +
collection_types + set_types + map_types + other_types
return [t[2] for t in all_types], [t[1] for t in all_types]
diff --git a/python/pyfory/_util.pxd b/python/pyfory/_util.pxd
index f705800a8..5108a6f65 100644
--- a/python/pyfory/_util.pxd
+++ b/python/pyfory/_util.pxd
@@ -108,8 +108,12 @@ cdef class Buffer:
cpdef inline write_float(self, float value)
+ cpdef inline write_float32(self, float value)
+
cpdef inline write_double(self, double value)
+ cpdef inline write_float64(self, double value)
+
cpdef inline skip(self, int32_t length)
cpdef inline c_bool read_bool(self)
@@ -128,8 +132,12 @@ cdef class Buffer:
cpdef inline float read_float(self)
+ cpdef inline float read_float32(self)
+
cpdef inline double read_double(self)
+ cpdef inline double read_float64(self)
+
cpdef inline write_varint64(self, int64_t v)
cpdef inline write_varuint64(self, int64_t v)
diff --git a/python/pyfory/_util.pyx b/python/pyfory/_util.pyx
index 5a7e68478..fa4eef19b 100644
--- a/python/pyfory/_util.pyx
+++ b/python/pyfory/_util.pyx
@@ -206,11 +206,21 @@ cdef class Buffer:
self.c_buffer.get().UnsafePut(self.writer_index, value)
self.writer_index += <int32_t>4
+ cpdef inline write_float32(self, float value):
+ self.grow(<int32_t>4)
+ self.c_buffer.get().UnsafePut(self.writer_index, value)
+ self.writer_index += <int32_t>4
+
cpdef inline write_double(self, double value):
self.grow(<int32_t>8)
self.c_buffer.get().UnsafePut(self.writer_index, value)
self.writer_index += <int32_t>8
+ cpdef inline write_float64(self, double value):
+ self.grow(<int32_t>8)
+ self.c_buffer.get().UnsafePut(self.writer_index, value)
+ self.writer_index += <int32_t>8
+
cpdef put_buffer(self, uint32_t offset, v, int32_t src_index, int32_t
length):
if length == 0: # access an emtpy buffer may raise out-of-bound
exception.
return
@@ -351,11 +361,21 @@ cdef class Buffer:
self.reader_index += <int32_t>4
return value
+ cpdef inline float read_float32(self):
+ value = self.get_float(self.reader_index)
+ self.reader_index += <int32_t>4
+ return value
+
cpdef inline double read_double(self):
value = self.get_double(self.reader_index)
self.reader_index += <int32_t>8
return value
+ cpdef inline double read_float64(self):
+ value = self.get_double(self.reader_index)
+ self.reader_index += <int32_t>8
+ return value
+
cpdef inline bytes read(self, int32_t length):
return self.read_bytes(length)
diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py
index 9a797abee..252a5535d 100644
--- a/python/pyfory/meta/typedef.py
+++ b/python/pyfory/meta/typedef.py
@@ -18,7 +18,7 @@
import enum
import typing
from typing import List
-from pyfory.type import TypeId
+from pyfory.type import TypeId, is_primitive_type
from pyfory._util import Buffer
from pyfory.type import infer_field, is_polymorphic_type
from pyfory.meta.metastring import Encoding
@@ -28,7 +28,8 @@ from pyfory.type import infer_field_types
# Constants from the specification
SMALL_NUM_FIELDS_THRESHOLD = 0b11111
REGISTER_BY_NAME_FLAG = 0b100000
-FIELD_NAME_SIZE_THRESHOLD = 0b1111
+FIELD_NAME_SIZE_THRESHOLD = 0b1111 # 4-bit threshold for field names
+BIG_NAME_THRESHOLD = 0b111111 # 6-bit threshold for namespace/typename
COMPRESS_META_FLAG = 0b1 << 13
HAS_FIELDS_META_FLAG = 0b1 << 12
META_SIZE_MASKS = 0xFFF
@@ -69,8 +70,14 @@ class TypeDef:
from pyfory.serializer import DataClassSerializer
fory = resolver.fory
+ nullable_fields = {f.name: f.field_type.is_nullable for f in
self.fields}
return DataClassSerializer(
- fory, self.cls, xlang=not fory.is_py,
field_names=self.get_field_names(),
serializers=self.create_fields_serializer(resolver)
+ fory,
+ self.cls,
+ xlang=not fory.is_py,
+ field_names=self.get_field_names(),
+ serializers=self.create_fields_serializer(resolver),
+ nullable_fields=nullable_fields,
)
def __repr__(self):
@@ -180,7 +187,7 @@ class CollectionFieldType(FieldType):
def create_serializer(self, resolver, type_):
from pyfory.serializer import ListSerializer, SetSerializer
- elem_type = type_[1] if len(type_) >= 2 else None
+ elem_type = type_[1] if type_ and len(type_) >= 2 else None
elem_serializer = self.element_type.create_serializer(resolver,
elem_type)
if self.type_id == TypeId.LIST:
return ListSerializer(resolver.fory, list, elem_serializer)
@@ -206,9 +213,9 @@ class MapFieldType(FieldType):
def create_serializer(self, resolver, type_):
key_type, value_type = None, None
- if len(type_) >= 2:
+ if type_ and len(type_) >= 2:
key_type = type_[1]
- if len(type_) >= 3:
+ if type_ and len(type_) >= 3:
value_type = type_[2]
key_serializer = self.key_type.create_serializer(resolver, key_type)
value_serializer = self.value_type.create_serializer(resolver,
value_type)
@@ -237,22 +244,27 @@ class DynamicFieldType(FieldType):
def build_field_infos(type_resolver, cls):
"""Build field information for the class."""
from pyfory._struct import _sort_fields, StructTypeIdVisitor,
get_field_names
+ from pyfory.type import unwrap_optional
field_names = get_field_names(cls)
type_hints = typing.get_type_hints(cls)
field_infos = []
+ nullable_map = {}
visitor = StructTypeIdVisitor(type_resolver.fory, cls)
for field_name in field_names:
field_type_hint = type_hints.get(field_name, typing.Any)
- field_type = build_field_type(type_resolver, field_name,
field_type_hint, visitor)
+ unwrapped_type, is_nullable = unwrap_optional(field_type_hint)
+ is_nullable = is_nullable or not is_primitive_type(unwrapped_type)
+ nullable_map[field_name] = is_nullable
+ field_type = build_field_type(type_resolver, field_name,
unwrapped_type, visitor, is_nullable)
field_info = FieldInfo(field_name, field_type, cls.__name__)
field_infos.append(field_info)
field_types = infer_field_types(cls)
serializers = [field_info.field_type.create_serializer(type_resolver,
field_types.get(field_info.name, None)) for field_info in field_infos]
- field_names, serializers = _sort_fields(type_resolver, field_names,
serializers)
+ field_names, serializers = _sort_fields(type_resolver, field_names,
serializers, nullable_map)
field_infos_map = {field_info.name: field_info for field_info in
field_infos}
new_field_infos = []
for field_name in field_names:
@@ -261,16 +273,16 @@ def build_field_infos(type_resolver, cls):
return new_field_infos
-def build_field_type(type_resolver, field_name: str, type_hint, visitor):
+def build_field_type(type_resolver, field_name: str, type_hint, visitor,
is_nullable=False):
"""Build field type from type hint."""
type_ids = infer_field(field_name, type_hint, visitor)
try:
- return build_field_type_from_type_ids(type_resolver, field_name,
type_ids, visitor)
+ return build_field_type_from_type_ids(type_resolver, field_name,
type_ids, visitor, is_nullable)
except Exception as e:
raise TypeError(f"Error building field type for field: {field_name}
with type hint: {type_hint} in class: {visitor.cls}") from e
-def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids,
visitor):
+def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids,
visitor, is_nullable=False):
tracking_ref = type_resolver.fory.ref_tracking
type_id = type_ids[0]
if type_id is None:
@@ -279,15 +291,15 @@ def build_field_type_from_type_ids(type_resolver,
field_name: str, type_ids, vis
type_id = type_id & 0xFF
morphic = not is_polymorphic_type(type_id)
if type_id in [TypeId.SET, TypeId.LIST]:
- elem_type = build_field_type_from_type_ids(type_resolver, field_name,
type_ids[1], visitor)
- return CollectionFieldType(type_id, morphic, True, tracking_ref,
elem_type)
+ elem_type = build_field_type_from_type_ids(type_resolver, field_name,
type_ids[1], visitor, is_nullable=False)
+ return CollectionFieldType(type_id, morphic, is_nullable,
tracking_ref, elem_type)
elif type_id == TypeId.MAP:
- key_type = build_field_type_from_type_ids(type_resolver, field_name,
type_ids[1], visitor)
- value_type = build_field_type_from_type_ids(type_resolver, field_name,
type_ids[2], visitor)
- return MapFieldType(type_id, morphic, True, tracking_ref, key_type,
value_type)
+ key_type = build_field_type_from_type_ids(type_resolver, field_name,
type_ids[1], visitor, is_nullable=False)
+ value_type = build_field_type_from_type_ids(type_resolver, field_name,
type_ids[2], visitor, is_nullable=False)
+ return MapFieldType(type_id, morphic, is_nullable, tracking_ref,
key_type, value_type)
elif type_id in [TypeId.UNKNOWN, TypeId.EXT, TypeId.STRUCT,
TypeId.NAMED_STRUCT, TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT]:
- return DynamicFieldType(type_id, False, True, tracking_ref)
+ return DynamicFieldType(type_id, False, is_nullable, tracking_ref)
else:
if type_id <= 0 or type_id >= TypeId.BOUND:
raise TypeError(f"Unknown type: {type_id} for field: {field_name}")
- return FieldType(type_id, morphic, True, tracking_ref)
+ return FieldType(type_id, morphic, is_nullable, tracking_ref)
diff --git a/python/pyfory/meta/typedef_decoder.py
b/python/pyfory/meta/typedef_decoder.py
index 9446e3c45..8b2f4b794 100644
--- a/python/pyfory/meta/typedef_decoder.py
+++ b/python/pyfory/meta/typedef_decoder.py
@@ -28,6 +28,7 @@ from pyfory.meta.typedef import (
SMALL_NUM_FIELDS_THRESHOLD,
REGISTER_BY_NAME_FLAG,
FIELD_NAME_SIZE_THRESHOLD,
+ BIG_NAME_THRESHOLD,
COMPRESS_META_FLAG,
HAS_FIELDS_META_FLAG,
META_SIZE_MASKS,
@@ -149,7 +150,7 @@ def read_typename(buffer: Buffer) -> str:
def read_meta_string(buffer: Buffer, decoder: MetaStringDecoder, encodings:
List[Encoding]) -> str:
- """Read a meta string from the buffer."""
+ """Read a big meta string (namespace/typename) from the buffer using 6-bit
size field."""
# Read encoding and length combined in first byte
header = buffer.read_uint8()
@@ -161,8 +162,8 @@ def read_meta_string(buffer: Buffer, decoder:
MetaStringDecoder, encodings: List
# Read length - same logic as encoder
length = 0
- if size_value >= FIELD_NAME_SIZE_THRESHOLD:
- length = size_value - FIELD_NAME_SIZE_THRESHOLD +
buffer.read_varuint32()
+ if size_value >= BIG_NAME_THRESHOLD:
+ length = size_value - BIG_NAME_THRESHOLD + buffer.read_varuint32()
else:
length = size_value
@@ -202,6 +203,7 @@ def read_field_info(buffer: Buffer, resolver,
defined_class: str) -> FieldInfo:
xtype_id = buffer.read_varuint32()
field_type = FieldType.xread_with_type(buffer, resolver, xtype_id,
is_nullable, is_tracking_ref)
- # Read field name
- field_name = FIELD_NAME_DECODER.decode(buffer.read_bytes(field_name_size),
encoding)
+ # Read field name - it comes AFTER the type info in the encoding
+ field_name_bytes = buffer.read_bytes(field_name_size)
+ field_name = FIELD_NAME_DECODER.decode(field_name_bytes, encoding)
return FieldInfo(field_name, field_type, defined_class)
diff --git a/python/pyfory/meta/typedef_encoder.py
b/python/pyfory/meta/typedef_encoder.py
index 29b5cc843..954c389c8 100644
--- a/python/pyfory/meta/typedef_encoder.py
+++ b/python/pyfory/meta/typedef_encoder.py
@@ -24,6 +24,7 @@ from pyfory.meta.typedef import (
SMALL_NUM_FIELDS_THRESHOLD,
REGISTER_BY_NAME_FLAG,
FIELD_NAME_SIZE_THRESHOLD,
+ BIG_NAME_THRESHOLD,
COMPRESS_META_FLAG,
HAS_FIELDS_META_FLAG,
META_SIZE_MASKS,
@@ -66,18 +67,21 @@ def encode_typedef(type_resolver, cls):
buffer = Buffer.allocate(64)
- # Write placeholder for header
- buffer.write_uint8(0)
-
# Write meta header
header = len(field_infos)
if len(field_infos) >= SMALL_NUM_FIELDS_THRESHOLD:
header = SMALL_NUM_FIELDS_THRESHOLD
+ if type_resolver.is_registered_by_name(cls):
+ header |= REGISTER_BY_NAME_FLAG
+ buffer.write_uint8(header)
buffer.write_varuint32(len(field_infos) - SMALL_NUM_FIELDS_THRESHOLD)
+ else:
+ if type_resolver.is_registered_by_name(cls):
+ header |= REGISTER_BY_NAME_FLAG
+ buffer.write_uint8(header)
# Write type info
if type_resolver.is_registered_by_name(cls):
- header |= REGISTER_BY_NAME_FLAG
namespace, typename = type_resolver.get_registered_name(cls)
write_namespace(buffer, namespace)
write_typename(buffer, typename)
@@ -87,8 +91,6 @@ def encode_typedef(type_resolver, cls):
assert type_resolver.is_registered_by_id(cls), "Class must be
registered by name or id"
type_id = type_resolver.get_registered_id(cls)
buffer.write_varuint32(type_id)
- # Update header byte
- buffer.put_uint8(0, header)
# Write fields info
write_fields_info(type_resolver, buffer, field_infos)
@@ -162,15 +164,15 @@ def write_typename(buffer: Buffer, typename: str):
def write_meta_string(buffer: Buffer, meta_string, encoding_value: int):
- """Write a meta string to the buffer."""
+ """Write a big meta string (namespace/typename) to the buffer using 6-bit
size field."""
# Write encoding and length combined in first byte
length = len(meta_string.encoded_data)
- if length >= FIELD_NAME_SIZE_THRESHOLD:
+ if length >= BIG_NAME_THRESHOLD:
# Use threshold value and write additional length
- header = (FIELD_NAME_SIZE_THRESHOLD << 2) | encoding_value
+ header = (BIG_NAME_THRESHOLD << 2) | encoding_value
buffer.write_uint8(header)
- buffer.write_varuint32(length - FIELD_NAME_SIZE_THRESHOLD)
+ buffer.write_varuint32(length - BIG_NAME_THRESHOLD)
else:
# Combine length and encoding in single byte
header = (length << 2) | encoding_value
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index a1732d2b4..7b66fa5bb 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -27,7 +27,7 @@ import os
import pickle
import types
import typing
-from typing import List
+from typing import List, Dict
import warnings
from pyfory.buffer import Buffer
@@ -40,6 +40,8 @@ from pyfory.error import TypeNotCompatibleError
from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
from pyfory import Language
+from pyfory.type import is_primitive_type
+
try:
import numpy as np
except ImportError:
@@ -229,7 +231,7 @@ class TypeSerializer(Serializer):
# Read base classes
num_bases = buffer.read_varuint32()
- bases = tuple([fory.deserialize_ref(buffer) for _ in range(num_bases)])
+ bases = tuple([fory.read_ref(buffer) for _ in range(num_bases)])
# Create the class using type() constructor
cls = type(name, bases, {})
# `class_dict` may reference to `cls`, which is a circular reference
@@ -238,12 +240,12 @@ class TypeSerializer(Serializer):
# classmethods
for i in range(buffer.read_varuint32()):
attr_name = buffer.read_string()
- func = fory.deserialize_ref(buffer)
+ func = fory.read_ref(buffer)
method = types.MethodType(func, cls)
setattr(cls, attr_name, method)
# Read class dictionary
# Fory's normal deserialization will handle methods via
FunctionSerializer
- class_dict = fory.deserialize_ref(buffer)
+ class_dict = fory.read_ref(buffer)
for k, v in class_dict.items():
setattr(cls, k, v)
@@ -275,7 +277,7 @@ class MappingProxySerializer(Serializer):
self.fory.serialize_ref(buffer, dict(value))
def read(self, buffer):
- return types.MappingProxyType(self.fory.deserialize_ref(buffer))
+ return types.MappingProxyType(self.fory.read_ref(buffer))
class PandasRangeIndexSerializer(Serializer):
@@ -325,17 +327,17 @@ class PandasRangeIndexSerializer(Serializer):
if buffer.read_int8() == NULL_FLAG:
start = None
else:
- start = self.fory.deserialize_nonref(buffer)
+ start = self.fory.read_no_ref(buffer)
if buffer.read_int8() == NULL_FLAG:
stop = None
else:
- stop = self.fory.deserialize_nonref(buffer)
+ stop = self.fory.read_no_ref(buffer)
if buffer.read_int8() == NULL_FLAG:
step = None
else:
- step = self.fory.deserialize_nonref(buffer)
- dtype = self.fory.deserialize_ref(buffer)
- name = self.fory.deserialize_ref(buffer)
+ step = self.fory.read_no_ref(buffer)
+ dtype = self.fory.read_ref(buffer)
+ name = self.fory.read_ref(buffer)
return self.type_(start, stop, step, dtype=dtype, name=name)
def xwrite(self, buffer, value):
@@ -365,22 +367,32 @@ class DataClassSerializer(Serializer):
xlang: bool = False,
field_names: List[str] = None,
serializers: List[Serializer] = None,
+ nullable_fields: Dict[str, bool] = None,
):
super().__init__(fory, clz)
self._xlang = xlang
- # This will get superclass type hints too.
+ from pyfory.type import unwrap_optional
+
self._type_hints = typing.get_type_hints(clz)
self._field_names = field_names or self._get_field_names(clz)
self._has_slots = hasattr(clz, "__slots__")
+ self._nullable_fields = nullable_fields or {}
+ if self._field_names and not self._nullable_fields:
+ for field_name in self._field_names:
+ if field_name in self._type_hints:
+ unwrapped_type, is_nullable =
unwrap_optional(self._type_hints[field_name])
+ is_nullable = is_nullable or not
is_primitive_type(unwrapped_type)
+ self._nullable_fields[field_name] = is_nullable
if self._xlang:
self._serializers = serializers or [None] * len(self._field_names)
if serializers is None:
visitor = StructFieldSerializerVisitor(fory)
for index, key in enumerate(self._field_names):
- serializer = infer_field(key, self._type_hints[key],
visitor, types_path=[])
+ unwrapped_type, _ = unwrap_optional(self._type_hints[key])
+ serializer = infer_field(key, unwrapped_type, visitor,
types_path=[])
self._serializers[index] = serializer
- self._field_names, self._serializers =
_sort_fields(fory.type_resolver, self._field_names, self._serializers)
+ self._field_names, self._serializers =
_sort_fields(fory.type_resolver, self._field_names, self._serializers,
self._nullable_fields)
self._hash = 0 # Will be computed on first xwrite/xread
self._generated_xwrite_method = self._gen_xwrite_method()
self._generated_xread_method = self._gen_xread_method()
@@ -510,27 +522,64 @@ class DataClassSerializer(Serializer):
context[fory] = self.fory
context[get_hash_func] = _get_hash
context["_field_names"] = self._field_names
- context["_type_hints"] = self._type_hints
+ from pyfory.type import unwrap_optional
+
+ unwrapped_hints = {}
+ for field_name, hint in self._type_hints.items():
+ unwrapped, _ = unwrap_optional(hint)
+ unwrapped_hints[field_name] = unwrapped
+ context["_type_hints"] = unwrapped_hints
context["_serializers"] = self._serializers
stmts = [
f'"""xwrite method for {self.type_}"""',
]
if not self.fory.compatible:
- # Compute hash at generation time since we're in xlang mode
if self._hash == 0:
- self._hash = _get_hash(self.fory, self._field_names,
self._type_hints)
+ self._hash = _get_hash(self.fory, self._field_names,
unwrapped_hints)
stmts.append(f"{buffer}.write_int32({self._hash})")
if not self._has_slots:
stmts.append(f"{value_dict} = {value}.__dict__")
for index, field_name in enumerate(self._field_names):
field_value = f"field_value{next(counter)}"
serializer_var = f"serializer{index}"
- context[serializer_var] = self._serializers[index]
+ serializer = self._serializers[index]
+ context[serializer_var] = serializer
if not self._has_slots:
stmts.append(f"{field_value} = {value_dict}['{field_name}']")
else:
stmts.append(f"{field_value} = {value}.{field_name}")
- stmts.append(f"{fory}.xserialize_ref({buffer}, {field_value},
serializer={serializer_var})")
+ is_nullable = self._nullable_fields.get(field_name, False)
+ if is_nullable:
+ if isinstance(serializer, StringSerializer):
+ stmts.extend(
+ [
+ f"if {field_value} is None:",
+ f" {buffer}.write_int8({NULL_FLAG})",
+ "else:",
+ f" {buffer}.write_int8({NOT_NULL_VALUE_FLAG})",
+ f" {buffer}.write_string({field_value})",
+ ]
+ )
+ else:
+ stmts.append(f"{fory}.xwrite_ref({buffer}, {field_value},
serializer={serializer_var})")
+ else:
+ if isinstance(serializer, BooleanSerializer):
+ stmt = f"{buffer}.write_bool({field_value})"
+ elif isinstance(serializer, ByteSerializer):
+ stmt = f"{buffer}.write_int8({field_value})"
+ elif isinstance(serializer, Int16Serializer):
+ stmt = f"{buffer}.write_int16({field_value})"
+ elif isinstance(serializer, Int32Serializer):
+ stmt = f"{buffer}.write_varint32({field_value})"
+ elif isinstance(serializer, Int64Serializer):
+ stmt = f"{buffer}.write_varint64({field_value})"
+ elif isinstance(serializer, Float32Serializer):
+ stmt = f"{buffer}.write_float32({field_value})"
+ elif isinstance(serializer, Float64Serializer):
+ stmt = f"{buffer}.write_float64({field_value})"
+ else:
+ stmt = f"{fory}.xwrite_no_ref({buffer}, {field_value},
serializer={serializer_var})"
+ stmts.append(stmt)
self._xwrite_method_code, func = compile_function(
f"xwrite_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"),
[buffer, value],
@@ -555,16 +604,21 @@ class DataClassSerializer(Serializer):
context[ref_resolver] = self.fory.ref_resolver
context[get_hash_func] = _get_hash
context["_field_names"] = self._field_names
- context["_type_hints"] = self._type_hints
+ from pyfory.type import unwrap_optional
+
+ unwrapped_hints = {}
+ for field_name, hint in self._type_hints.items():
+ unwrapped, _ = unwrap_optional(hint)
+ unwrapped_hints[field_name] = unwrapped
+ context["_type_hints"] = unwrapped_hints
context["_serializers"] = self._serializers
current_class_field_names = set(self._get_field_names(self.type_))
stmts = [
f'"""xread method for {self.type_}"""',
]
if not self.fory.compatible:
- # Compute hash at generation time since we're in xlang mode
if self._hash == 0:
- self._hash = _get_hash(self.fory, self._field_names,
self._type_hints)
+ self._hash = _get_hash(self.fory, self._field_names,
unwrapped_hints)
stmts.extend(
[
f"read_hash = {buffer}.read_int32()",
@@ -585,9 +639,40 @@ class DataClassSerializer(Serializer):
for index, field_name in enumerate(self._field_names):
serializer_var = f"serializer{index}"
- context[serializer_var] = self._serializers[index]
+ serializer = self._serializers[index]
+ context[serializer_var] = serializer
field_value = f"field_value{index}"
- stmts.append(f"{field_value} = {fory}.xdeserialize_ref({buffer},
serializer={serializer_var})")
+ is_nullable = self._nullable_fields.get(field_name, False)
+ if is_nullable:
+ if isinstance(serializer, StringSerializer):
+ stmts.extend(
+ [
+ f"if {buffer}.read_int8() >=
{NOT_NULL_VALUE_FLAG}:",
+ f" {field_value} = {buffer}.read_string()",
+ "else:",
+ f" {field_value} = None",
+ ]
+ )
+ else:
+ stmts.append(f"{field_value} = {fory}.xread_ref({buffer},
serializer={serializer_var})")
+ else:
+ if isinstance(serializer, BooleanSerializer):
+ stmt = f"{field_value} = {buffer}.read_bool()"
+ elif isinstance(serializer, ByteSerializer):
+ stmt = f"{field_value} = {buffer}.read_int8()"
+ elif isinstance(serializer, Int16Serializer):
+ stmt = f"{field_value} = {buffer}.read_int16()"
+ elif isinstance(serializer, Int32Serializer):
+ stmt = f"{field_value} = {buffer}.read_varint32()"
+ elif isinstance(serializer, Int64Serializer):
+ stmt = f"{field_value} = {buffer}.read_varint64()"
+ elif isinstance(serializer, Float32Serializer):
+ stmt = f"{field_value} = {buffer}.read_float32()"
+ elif isinstance(serializer, Float64Serializer):
+ stmt = f"{field_value} = {buffer}.read_float64()"
+ else:
+ stmt = f"{field_value} = {fory}.xread_no_ref({buffer},
serializer={serializer_var})"
+ stmts.append(stmt)
if field_name not in current_class_field_names:
stmts.append(f"# {field_name} is not in {self.type_}")
continue
@@ -619,7 +704,7 @@ class DataClassSerializer(Serializer):
obj = self.type_.__new__(self.type_)
self.fory.ref_resolver.reference(obj)
for field_name in self._field_names:
- field_value = self.fory.deserialize_ref(buffer)
+ field_value = self.fory.read_ref(buffer)
setattr(
obj,
field_name,
@@ -631,33 +716,62 @@ class DataClassSerializer(Serializer):
if not self._xlang:
raise TypeError("xwrite can only be called when
DataClassSerializer is in xlang mode")
if self._hash == 0:
- self._hash = _get_hash(self.fory, self._field_names,
self._type_hints)
- buffer.write_int32(self._hash)
+ from pyfory.type import unwrap_optional
+
+ unwrapped_hints = {}
+ for field_name, hint in self._type_hints.items():
+ unwrapped, _ = unwrap_optional(hint)
+ unwrapped_hints[field_name] = unwrapped
+ self._hash = _get_hash(self.fory, self._field_names,
unwrapped_hints)
+ if not self.fory.compatible:
+ buffer.write_int32(self._hash)
for index, field_name in enumerate(self._field_names):
field_value = getattr(value, field_name)
serializer = self._serializers[index]
- self.fory.xserialize_ref(buffer, field_value,
serializer=serializer)
+ is_nullable = self._nullable_fields.get(field_name, False)
+ if is_nullable and field_value is None:
+ buffer.write_int8(-3)
+ else:
+ self.fory.xwrite_ref(buffer, field_value,
serializer=serializer)
def xread(self, buffer):
if not self._xlang:
raise TypeError("xread can only be called when DataClassSerializer
is in xlang mode")
if self._hash == 0:
- self._hash = _get_hash(self.fory, self._field_names,
self._type_hints)
- hash_ = buffer.read_int32()
- if hash_ != self._hash:
- raise TypeNotCompatibleError(
- f"Hash {hash_} is not consistent with {self._hash} for type
{self.type_}",
- )
+ from pyfory.type import unwrap_optional
+
+ unwrapped_hints = {}
+ for field_name, hint in self._type_hints.items():
+ unwrapped, _ = unwrap_optional(hint)
+ unwrapped_hints[field_name] = unwrapped
+ self._hash = _get_hash(self.fory, self._field_names,
unwrapped_hints)
+ if not self.fory.compatible:
+ hash_ = buffer.read_int32()
+ if hash_ != self._hash:
+ raise TypeNotCompatibleError(
+ f"Hash {hash_} is not consistent with {self._hash} for
type {self.type_}",
+ )
obj = self.type_.__new__(self.type_)
self.fory.ref_resolver.reference(obj)
+ current_class_field_names = set(self._get_field_names(self.type_))
for index, field_name in enumerate(self._field_names):
serializer = self._serializers[index]
- field_value = self.fory.xdeserialize_ref(buffer,
serializer=serializer)
- setattr(
- obj,
- field_name,
- field_value,
- )
+ is_nullable = self._nullable_fields.get(field_name, False)
+ if is_nullable:
+ ref_id = buffer.read_int8()
+ if ref_id == -3:
+ field_value = None
+ else:
+ buffer.reader_index -= 1
+ field_value = self.fory.xread_ref(buffer,
serializer=serializer)
+ else:
+ field_value = self.fory.xread_ref(buffer,
serializer=serializer)
+ if field_name in current_class_field_names:
+ setattr(
+ obj,
+ field_name,
+ field_value,
+ )
return obj
@@ -902,12 +1016,12 @@ class NDArraySerializer(Serializer):
def read(self, buffer):
fory = self.fory
- dtype = fory.deserialize_ref(buffer)
+ dtype = fory.read_ref(buffer)
ndim = buffer.read_varuint32()
shape = tuple(buffer.read_varuint32() for _ in range(ndim))
if dtype.kind == "O":
length = buffer.read_varint32()
- items = [fory.deserialize_ref(buffer) for _ in range(length)]
+ items = [fory.read_ref(buffer) for _ in range(length)]
return np.array(items, dtype=object)
fory_buf = fory.read_buffer_object(buffer)
if isinstance(fory_buf, memoryview):
@@ -1037,9 +1151,9 @@ class
StatefulSerializer(CrossLanguageCompatibleSerializer):
self.fory.serialize_ref(buffer, state)
def read(self, buffer):
- args = self.fory.deserialize_ref(buffer)
- kwargs = self.fory.deserialize_ref(buffer)
- state = self.fory.deserialize_ref(buffer)
+ args = self.fory.read_ref(buffer)
+ kwargs = self.fory.read_ref(buffer)
+ state = self.fory.read_ref(buffer)
if args or kwargs:
# Case 1: __getnewargs__ was used. Re-create by calling __init__.
@@ -1126,7 +1240,7 @@ class ReduceSerializer(CrossLanguageCompatibleSerializer):
reduce_data = [None] * 6
fory = self.fory
for i in range(reduce_data_num_items):
- reduce_data[i] = fory.deserialize_ref(buffer)
+ reduce_data[i] = fory.read_ref(buffer)
if reduce_data[0] == "global":
# Case 1: Global name
@@ -1193,7 +1307,7 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
The code object is serialized with marshal, and all other components
(defaults, globals, closure cells, attrs) go through Fory’s own
- serialize_ref/deserialize_ref pipeline to ensure proper type registration
+ serialize_ref/read_ref pipeline to ensure proper type registration
and reference tracking.
"""
@@ -1317,7 +1431,7 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
func_type_id = buffer.read_int8()
if func_type_id == 0:
# Handle bound methods
- self_obj = self.fory.deserialize_ref(buffer)
+ self_obj = self.fory.read_ref(buffer)
method_name = buffer.read_string()
return getattr(self_obj, method_name)
@@ -1347,7 +1461,7 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
# Deserialize each default value
default_values = []
for _ in range(num_defaults):
- default_values.append(self.fory.deserialize_ref(buffer))
+ default_values.append(self.fory.read_ref(buffer))
defaults = tuple(default_values)
# Handle closure
@@ -1359,7 +1473,7 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
closure_values = []
if has_closure:
for _ in range(num_freevars):
- closure_values.append(self.fory.deserialize_ref(buffer))
+ closure_values.append(self.fory.read_ref(buffer))
# Create closure cells
closure = tuple(types.CellType(value) for value in closure_values)
@@ -1371,7 +1485,7 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
freevars.append(buffer.read_string())
# Handle globals
- globals_dict = self.fory.deserialize_ref(buffer)
+ globals_dict = self.fory.read_ref(buffer)
# Create a globals dictionary with module's globals as the base
func_globals = {}
@@ -1397,7 +1511,7 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
func.__qualname__ = qualname
# Deserialize and set additional attributes
- attrs = self.fory.deserialize_ref(buffer)
+ attrs = self.fory.read_ref(buffer)
for attr_name, attr_value in attrs.items():
setattr(func, attr_name, attr_value)
@@ -1438,7 +1552,7 @@ class NativeFuncMethodSerializer(Serializer):
mod = importlib.import_module(module)
return getattr(mod, name)
else:
- obj = self.fory.deserialize_ref(buffer)
+ obj = self.fory.read_ref(buffer)
return getattr(obj, name)
@@ -1458,7 +1572,7 @@ class MethodSerializer(Serializer):
buffer.write_string(method_name)
def read(self, buffer):
- instance = self.fory.deserialize_ref(buffer)
+ instance = self.fory.read_ref(buffer)
method_name = buffer.read_string()
return getattr(instance, method_name)
@@ -1505,7 +1619,7 @@ class ObjectSerializer(Serializer):
num_fields = buffer.read_varuint32()
for _ in range(num_fields):
field_name = buffer.read_string()
- field_value = self.fory.deserialize_ref(buffer)
+ field_value = self.fory.read_ref(buffer)
setattr(obj, field_name, field_value)
return obj
diff --git a/python/pyfory/tests/test_cross_language.py
b/python/pyfory/tests/test_cross_language.py
index 33cd2f18a..3cb9c9daa 100644
--- a/python/pyfory/tests/test_cross_language.py
+++ b/python/pyfory/tests/test_cross_language.py
@@ -597,16 +597,16 @@ class
ComplexObject1Serializer(pyfory.serializer.Serializer):
return self.xread(buffer)
def xwrite(self, buffer, value):
- self.fory.xserialize_ref(buffer, value.f1)
- self.fory.xserialize_ref(buffer, value.f2)
- self.fory.xserialize_ref(buffer, value.f3)
+ self.fory.xwrite_ref(buffer, value.f1)
+ self.fory.xwrite_ref(buffer, value.f2)
+ self.fory.xwrite_ref(buffer, value.f3)
def xread(self, buffer):
obj = ComplexObject1(*([None] *
len(typing.get_type_hints(ComplexObject1).keys())))
self.fory.ref_resolver.reference(obj)
- obj.f1 = self.fory.xdeserialize_ref(buffer)
- obj.f2 = self.fory.xdeserialize_ref(buffer)
- obj.f3 = self.fory.xdeserialize_ref(buffer)
+ obj.f1 = self.fory.xread_ref(buffer)
+ obj.f2 = self.fory.xread_ref(buffer)
+ obj.f3 = self.fory.xread_ref(buffer)
return obj
@@ -817,6 +817,7 @@ def test_schema_evolution(data_file_path):
# Serialize back
new_serialized = fory.serialize(obj)
debug_print(f"Re-serialized data length: {len(new_serialized)}")
+ assert fory.deserialize(new_serialized) == obj
# Write back for Java to verify
with open(data_file_path, "wb") as f:
diff --git a/python/pyfory/tests/test_reduce_serializer.py
b/python/pyfory/tests/test_reduce_serializer.py
index 9f6d7c3ed..c7703c61b 100644
--- a/python/pyfory/tests/test_reduce_serializer.py
+++ b/python/pyfory/tests/test_reduce_serializer.py
@@ -303,5 +303,5 @@ def test_cross_language_compatibility():
assert deserialized == obj
# The serialized data should use Fory's native format, not pickle
- # This is verified by the fact that we're using
serialize_ref/deserialize_ref
+ # This is verified by the fact that we're using serialize_ref/read_ref
# in the ReduceSerializer implementation
diff --git a/python/pyfory/tests/test_struct.py
b/python/pyfory/tests/test_struct.py
index 2e522964a..cf2536527 100644
--- a/python/pyfory/tests/test_struct.py
+++ b/python/pyfory/tests/test_struct.py
@@ -17,7 +17,7 @@
from dataclasses import dataclass
import datetime
-from typing import Dict, Any, List, Set
+from typing import Dict, Any, List, Set, Optional
import os
import pytest
@@ -159,23 +159,7 @@ def test_sort_fields():
fory = Fory(xlang=True, ref=True)
serializer = DataClassSerializer(fory, TestClass, xlang=True)
- assert serializer._field_names == [
- "f13",
- "f5",
- "f11",
- "f7",
- "f12",
- "f1",
- "f4",
- "f15",
- "f6",
- "f10",
- "f2",
- "f14",
- "f3",
- "f9",
- "f8",
- ]
+ assert serializer._field_names == ["f13", "f5", "f11", "f12", "f1", "f7",
"f4", "f15", "f6", "f10", "f2", "f14", "f3", "f9", "f8"]
def test_data_class_serializer_xlang():
@@ -376,12 +360,12 @@ def
test_data_class_serializer_xlang_codegen_generated_code():
# Check that xwrite code contains expected elements
assert "def xwrite_" in xwrite_code
assert "buffer.write_int32" in xwrite_code # Hash writing
- assert "fory.xserialize_ref" in xwrite_code # Field serialization
+ assert "fory.xwrite_ref" in xwrite_code # Field serialization
# Check that xread code contains expected elements
assert "def xread_" in xread_code
assert "buffer.read_int32" in xread_code # Hash reading
- assert "fory.xdeserialize_ref" in xread_code # Field deserialization
+ assert "fory.xread_ref" in xread_code # Field deserialization
assert "TypeNotCompatibleError" in xread_code # Hash validation
# Check that field names are referenced in the code
@@ -419,3 +403,141 @@ def test_data_class_serializer_xlang_vs_non_xlang():
# They should have different method implementations
assert serializer_xlang._generated_xwrite_method !=
serializer_python._generated_write_method
assert serializer_xlang._generated_xread_method !=
serializer_python._generated_read_method
+
+
+@dataclass
+class OptionalFieldsObject:
+ f1: Optional[int] = None
+ f2: Optional[str] = None
+ f3: Optional[List[int]] = None
+ f4: int = 0
+ f5: str = ""
+
+
[email protected]("compatible", [False, True])
+def test_optional_fields(compatible):
+ fory = Fory(xlang=True, ref=True, compatible=compatible)
+ fory.register_type(OptionalFieldsObject,
typename="example.OptionalFieldsObject")
+
+ obj_with_none = OptionalFieldsObject(f1=None, f2=None, f3=None, f4=42,
f5="test")
+ result = ser_de(fory, obj_with_none)
+ assert result.f1 is None
+ assert result.f2 is None
+ assert result.f3 is None
+ assert result.f4 == 42
+ assert result.f5 == "test"
+
+ obj_with_values = OptionalFieldsObject(f1=100, f2="hello", f3=[1, 2, 3],
f4=42, f5="test")
+ result = ser_de(fory, obj_with_values)
+ assert result.f1 == 100
+ assert result.f2 == "hello"
+ assert result.f3 == [1, 2, 3]
+ assert result.f4 == 42
+ assert result.f5 == "test"
+
+ obj_mixed = OptionalFieldsObject(f1=100, f2=None, f3=[1, 2, 3], f4=42,
f5="test")
+ result = ser_de(fory, obj_mixed)
+ assert result.f1 == 100
+ assert result.f2 is None
+ assert result.f3 == [1, 2, 3]
+ assert result.f4 == 42
+ assert result.f5 == "test"
+
+
+@dataclass
+class NestedOptionalObject:
+ f1: Optional[ComplexObject] = None
+ f2: Optional[Dict[str, int]] = None
+ f3: str = ""
+
+
[email protected]("compatible", [False, True])
+def test_nested_optional_fields(compatible):
+ fory = Fory(xlang=True, ref=True, compatible=compatible)
+ fory.register_type(ComplexObject, typename="example.ComplexObject")
+ fory.register_type(NestedOptionalObject,
typename="example.NestedOptionalObject")
+
+ obj_with_none = NestedOptionalObject(f1=None, f2=None, f3="test")
+ result = ser_de(fory, obj_with_none)
+ assert result.f1 is None
+ assert result.f2 is None
+ assert result.f3 == "test"
+
+ complex_obj = ComplexObject(f1="nested", f5=100, f8=3.14)
+ obj_with_values = NestedOptionalObject(f1=complex_obj, f2={"a": 1, "b":
2}, f3="test")
+ result = ser_de(fory, obj_with_values)
+ assert result.f1.f1 == "nested"
+ assert result.f1.f5 == 100
+ assert result.f2 == {"a": 1, "b": 2}
+ assert result.f3 == "test"
+
+
+@dataclass
+class OptionalV1:
+ f1: Optional[int] = None
+ f2: str = ""
+ f3: Optional[List[int]] = None
+
+
+@dataclass
+class OptionalV2:
+ f1: Optional[int] = None
+ f2: str = ""
+ f3: Optional[List[int]] = None
+ f4: Optional[str] = None
+
+
+@dataclass
+class OptionalV3:
+ f1: Optional[int] = None
+ f2: str = ""
+
+
+def test_optional_compatible_mode_evolution():
+ fory_v1 = Fory(xlang=True, ref=True, compatible=True)
+ fory_v2 = Fory(xlang=True, ref=True, compatible=True)
+ fory_v3 = Fory(xlang=True, ref=True, compatible=True)
+
+ fory_v1.register_type(OptionalV1, typename="example.OptionalVersioned")
+ fory_v2.register_type(OptionalV2, typename="example.OptionalVersioned")
+ fory_v3.register_type(OptionalV3, typename="example.OptionalVersioned")
+
+ v1_obj = OptionalV1(f1=100, f2="test", f3=[1, 2, 3])
+ v1_binary = fory_v1.serialize(v1_obj)
+
+ v2_result = fory_v2.deserialize(v1_binary)
+ assert v2_result.f1 == 100
+ assert v2_result.f2 == "test"
+ assert v2_result.f3 == [1, 2, 3]
+ assert v2_result.f4 is None
+
+ v1_obj_with_none = OptionalV1(f1=None, f2="test", f3=None)
+ v1_binary_with_none = fory_v1.serialize(v1_obj_with_none)
+
+ v2_result_with_none = fory_v2.deserialize(v1_binary_with_none)
+ assert v2_result_with_none.f1 is None
+ assert v2_result_with_none.f2 == "test"
+ assert v2_result_with_none.f3 is None
+ assert v2_result_with_none.f4 is None
+
+ v2_obj = OptionalV2(f1=200, f2="test2", f3=[4, 5], f4="extra")
+ v2_binary = fory_v2.serialize(v2_obj)
+
+ v3_result = fory_v3.deserialize(v2_binary)
+ assert v3_result.f1 == 200
+ assert v3_result.f2 == "test2"
+
+ v2_obj_partial_none = OptionalV2(f1=None, f2="test2", f3=None, f4=None)
+ v2_binary_partial_none = fory_v2.serialize(v2_obj_partial_none)
+
+ v3_result_partial_none = fory_v3.deserialize(v2_binary_partial_none)
+ assert v3_result_partial_none.f1 is None
+ assert v3_result_partial_none.f2 == "test2"
+
+ v3_obj = OptionalV3(f1=300, f2="test3")
+ v3_binary = fory_v3.serialize(v3_obj)
+
+ v1_result = fory_v1.deserialize(v3_binary)
+ assert v1_result.f1 == 300
+ assert v1_result.f2 == "test3"
+ assert v1_result.f3 is None
diff --git a/python/pyfory/type.py b/python/pyfory/type.py
index dc25dd6b6..ecd505d13 100644
--- a/python/pyfory/type.py
+++ b/python/pyfory/type.py
@@ -420,7 +420,29 @@ def infer_field_types(type_):
from pyfory._struct import StructTypeVisitor
visitor = StructTypeVisitor(type_)
- return {name: infer_field(name, type_, visitor) for name, type_ in
sorted(type_hints.items())}
+ result = {}
+ for name, hint in sorted(type_hints.items()):
+ unwrapped, _ = unwrap_optional(hint)
+ result[name] = infer_field(name, unwrapped, visitor)
+ return result
+
+
+def is_optional_type(type_):
+ origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else
getattr(type_, "__origin__", None)
+ if origin is typing.Union:
+ args = typing.get_args(type_) if hasattr(typing, "get_args") else
getattr(type_, "__args__", ())
+ return type(None) in args
+ return False
+
+
+def unwrap_optional(type_):
+ if not is_optional_type(type_):
+ return type_, False
+ args = typing.get_args(type_) if hasattr(typing, "get_args") else
getattr(type_, "__args__", ())
+ non_none_types = [arg for arg in args if arg is not type(None)]
+ if len(non_none_types) == 1:
+ return non_none_types[0], True
+ return typing.Union[tuple(non_none_types)], True
def infer_field(field_name, type_, visitor: TypeVisitor, types_path=None):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]