This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new a44f24ace feat(python): fast cython struct serializer (#3443)
a44f24ace is described below
commit a44f24aceaee3c8a5b834b0658eec499b3db4f1d
Author: Shawn Yang <[email protected]>
AuthorDate: Tue Mar 3 00:43:33 2026 +0800
feat(python): fast cython struct serializer (#3443)
## Why?
`pyfory` dataclass serialization hot paths is still not fast enough
because slow python methods call cost
## What does this PR do?
- Adds a Cython `DataClassSerializer` implementation
(`python/pyfory/struct.pxi`) and wires it into `serialization.pyx`,
while keeping Python fallback wiring in `struct.py`.
- Expands C++ primitive fastpath helpers to support more primitive type
IDs (signed/unsigned/fixed/var/tagged numerics, float32, bool coercion)
and exposes them for struct field fastpath use.
- Carries per-field ref metadata from `TypeDef` into serializer creation
(`python/pyfory/meta/typedef.py`) and updates struct serializer type
detection in `registry.py` for both Python and Cython serializer
classes.
- Optimizes map type-info lookup in `MapSerializer` with `FlatIntMap`
caches keyed by class pointer.
- Updates Python benchmarks to focus on `struct` and `slots_struct`
cases and improves msgpack dataclass conversion/restore handling for
benchmark parity.
- Refreshes tests in `test_struct.py`, `test_ref_tracking.py`, and
`xlang_test_main.py` for bool coercion, numeric/ref edge cases, schema
evolution defaults, and enum evolution behavior.
- Adds an extra Python CI job that runs with
`ENABLE_FORY_CYTHON_SERIALIZATION=1`.
## Related issues
#1017
#3441
## Does this PR introduce any user-facing change?
- Improves dataclass serialization behavior/performance in Python Cython
mode and refines schema-evolution defaults for missing required fields.
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
```
================================================================================
SPEEDUP (Fory vs Pickle)
================================================================================
Benchmark Fory Pickle Speedup
--------------------------------------------------------------------------------
struct 2.32 us 4.75 us 2.05x
================================================================================
SPEEDUP (Fory vs Msgpack)
================================================================================
Benchmark Fory Msgpack Speedup
--------------------------------------------------------------------------------
struct 2.32 us 12.27 us 5.28x
```
---
.github/workflows/ci.yml | 5 +
AGENTS.md | 3 +-
benchmarks/python/fory_benchmark.py | 108 +++-
python/pyfory/collection.pxi | 23 +-
python/pyfory/cpp/pyfory.cc | 197 +++++-
python/pyfory/cpp/pyfory.h | 3 +
python/pyfory/includes/libserialization.pxd | 2 +
python/pyfory/includes/libutil.pxd | 14 +
python/pyfory/meta/typedef.py | 3 +
python/pyfory/registry.py | 38 +-
python/pyfory/serialization.pyx | 7 +-
python/pyfory/struct.pxi | 463 +++++++++++++++
python/pyfory/struct.py | 887 ++++++++--------------------
python/pyfory/tests/test_meta_share.py | 2 +-
python/pyfory/tests/test_ref_tracking.py | 136 +++--
python/pyfory/tests/test_struct.py | 317 ++++++----
python/pyfory/tests/xlang_test_main.py | 4 +-
17 files changed, 1344 insertions(+), 868 deletions(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index a5e6f47ec..168db47a5 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -741,6 +741,11 @@ jobs:
- name: Run Python CI
shell: bash
run: python ./ci/run_ci.py python
+ - name: Run Python CI (ENABLE_FORY_CYTHON_SERIALIZATION=1)
+ shell: bash
+ env:
+ ENABLE_FORY_CYTHON_SERIALIZATION: "1"
+ run: python ./ci/run_ci.py python
python_xlang:
name: Python Xlang Test
diff --git a/AGENTS.md b/AGENTS.md
index 27900c2a9..d0d1a01e3 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -558,7 +558,7 @@ Fory python has two implementations for the protocol:
- **Python mode**: Pure python implementation based on `xlang serialization
format`, used for debugging and testing only. This mode can be enabled by
setting `ENABLE_FORY_CYTHON_SERIALIZATION=0` environment variable.
- **Cython mode**: Cython based implementation based on `xlang serialization
format`, which is used by default and has better performance than pure python.
This mode can be enabled by setting `ENABLE_FORY_CYTHON_SERIALIZATION=1`
environment variable.
- **Python mode** and **Cython mode** reused some code from each other to
reduce code duplication.
-- **Debug Struct Serialization**: set `ENABLE_FORY_PYTHON_JIT=0` when debug
struct fields serialization error, this mode is more easy to debug and add
logs. Even struct serialization itself has no bug, by enable this mode and
adding debug logs, we can narrow the bug scope more easily.
+- **Debug Struct Serialization**: set `ENABLE_FORY_DEBUG_OUTPUT=1` to enable
detailed struct serialization/deserialization logs while debugging protocol
behavior.
Code structure:
@@ -621,6 +621,7 @@ Fory rust provides macro-based serialization and
deserialization. Fory rust cons
### Performance Guidelines
- **Performance First**: Never introduce code that reduces performance without
explicit justification
+- **Benchmark Required After Perf Optimizations**: For every code change
expected to improve performance, run the relevant benchmark immediately after
applying the change and report the measured results (command + before/after
numbers) in your response/PR.
- **Zero-Copy**: Leverage zero-copy techniques when possible
- **JIT Compilation**: Consider JIT compilation opportunities
- **Memory Layout**: Optimize for cache-friendly memory access patterns
diff --git a/benchmarks/python/fory_benchmark.py
b/benchmarks/python/fory_benchmark.py
index 7ea9ffeb1..4301fc565 100644
--- a/benchmarks/python/fory_benchmark.py
+++ b/benchmarks/python/fory_benchmark.py
@@ -31,7 +31,7 @@ Benchmark Options:
--benchmarks BENCHMARK_LIST
Comma-separated list of benchmarks to run. Default: all
Available: dict, large_dict, dict_group, tuple, large_tuple,
- large_float_tuple, large_boolean_tuple, list, large_list,
complex
+ large_float_tuple, large_boolean_tuple, list, large_list,
struct, slots_struct
--serializers SERIALIZER_LIST
Comma-separated list of serializers to benchmark. Default: all
@@ -67,7 +67,7 @@ Examples:
python fory_benchmark.py --operation deserialize
# Run specific benchmarks with both serializers
- python fory_benchmark.py --benchmarks dict,large_dict,complex
+ python fory_benchmark.py --benchmarks dict,large_dict,struct,slots_struct
# Compare only Fory performance
python fory_benchmark.py --serializers fory
@@ -90,7 +90,7 @@ Examples:
import argparse
import array
-from dataclasses import dataclass, is_dataclass
+from dataclasses import dataclass, fields, is_dataclass
import datetime
import pickle
import random
@@ -196,7 +196,7 @@ DICT_GROUP = [mutate_dict(DICT, random_source) for _ in
range(3)]
@dataclass
-class ComplexObject1:
+class Struct1:
f1: Any = None
f2: str = None
f3: List[str] = None
@@ -212,13 +212,44 @@ class ComplexObject1:
@dataclass
-class ComplexObject2:
+class Struct2:
f1: Any
f2: Dict[pyfory.int8, pyfory.int32]
-COMPLEX_OBJECT = ComplexObject1(
- f1=ComplexObject2(f1=True, f2={-1: 2}),
[email protected]
+@dataclass
+class SlotsStruct:
+ f1: Any = None
+ f2: str = None
+ f3: List[str] = None
+ f4: Dict[pyfory.int8, pyfory.int32] = None
+ f5: pyfory.int8 = None
+ f6: pyfory.int16 = None
+ f7: pyfory.int32 = None
+ f8: pyfory.int64 = None
+ f9: pyfory.float32 = None
+ f10: pyfory.float64 = None
+ f11: pyfory.int16_array = None
+ f12: List[pyfory.int16] = None
+
+
+STRUCT_OBJECT = Struct1(
+ f1=Struct2(f1=True, f2={-1: 2}),
+ f2="abc",
+ f3=["abc", "abc"],
+ f4={1: 2},
+ f5=2**7 - 1,
+ f6=2**15 - 1,
+ f7=2**31 - 1,
+ f8=2**63 - 1,
+ f9=1.0 / 2,
+ f10=1 / 3.0,
+ f11=array.array("h", [-1, 4]),
+ f12=[-1, 4],
+)
+SLOTS_STRUCT_OBJECT = SlotsStruct(
+ f1=Struct2(f1=True, f2={-1: 2}),
f2="abc",
f3=["abc", "abc"],
f4={1: 2},
@@ -238,8 +269,9 @@ fory_without_ref = pyfory.Fory(ref=False)
# Register all custom types on both instances
for fory_instance in (fory_with_ref, fory_without_ref):
- fory_instance.register_type(ComplexObject1)
- fory_instance.register_type(ComplexObject2)
+ fory_instance.register_type(Struct1)
+ fory_instance.register_type(Struct2)
+ fory_instance.register_type(SlotsStruct)
def fory_roundtrip(ref, obj):
@@ -276,21 +308,40 @@ def msgpack_roundtrip(obj):
msgpack.loads(binary, raw=False, strict_map_key=False)
+def msgpack_roundtrip_dataclass(obj):
+ payload = make_msgpack_compatible(obj)
+ binary = msgpack.dumps(payload, use_bin_type=True)
+ restored = msgpack.loads(binary, raw=False, strict_map_key=False)
+ _restore_dataclass_from_template(restored, obj)
+
+
def msgpack_serialize(obj):
msgpack.dumps(obj, use_bin_type=True)
+def msgpack_serialize_dataclass(obj):
+ payload = make_msgpack_compatible(obj)
+ msgpack.dumps(payload, use_bin_type=True)
+
+
def msgpack_deserialize(binary):
msgpack.loads(binary, raw=False, strict_map_key=False)
+def msgpack_deserialize_dataclass(binary, dataclass_template):
+ restored = msgpack.loads(binary, raw=False, strict_map_key=False)
+ _restore_dataclass_from_template(restored, dataclass_template)
+
+
def make_msgpack_compatible(obj):
if isinstance(obj, datetime.date):
return obj.isoformat()
if isinstance(obj, array.array):
return obj.tolist()
if is_dataclass(obj):
- return {k: make_msgpack_compatible(v) for k, v in vars(obj).items()}
+ return {
+ f.name: make_msgpack_compatible(getattr(obj, f.name)) for f in
fields(obj)
+ }
if isinstance(obj, dict):
return {
make_msgpack_compatible(k): make_msgpack_compatible(v)
@@ -303,6 +354,25 @@ def make_msgpack_compatible(obj):
return obj
+def _restore_dataclass_from_template(value, template):
+ if not is_dataclass(template):
+ return value
+ if not isinstance(value, dict):
+ return value
+
+ kwargs = {}
+ for field in template.__dataclass_fields__.values():
+ field_value = value.get(field.name)
+ template_value = getattr(template, field.name, None)
+ if is_dataclass(template_value):
+ kwargs[field.name] = _restore_dataclass_from_template(
+ field_value, template_value
+ )
+ else:
+ kwargs[field.name] = field_value
+ return type(template)(**kwargs)
+
+
def build_fory_benchmark_case(operation: str, ref: bool, obj):
if operation == "serialize":
return fory_serialize, (ref, obj)
@@ -322,9 +392,18 @@ def build_pickle_benchmark_case(operation: str, obj):
def build_msgpack_benchmark_case(operation: str, obj):
if operation == "serialize":
+ if is_dataclass(obj):
+ return msgpack_serialize_dataclass, (obj,)
return msgpack_serialize, (obj,)
if operation == "deserialize":
+ if is_dataclass(obj):
+ return msgpack_deserialize_dataclass, (
+ msgpack.dumps(make_msgpack_compatible(obj), use_bin_type=True),
+ obj,
+ )
return msgpack_deserialize, (msgpack.dumps(obj, use_bin_type=True),)
+ if is_dataclass(obj):
+ return msgpack_roundtrip_dataclass, (obj,)
return msgpack_roundtrip, (obj,)
@@ -356,7 +435,7 @@ def benchmark_args():
default="all",
help="Comma-separated list of benchmarks to run. Available: dict,
large_dict, "
"dict_group, tuple, large_tuple, large_float_tuple,
large_boolean_tuple, "
- "list, large_list, complex. Default: all",
+ "list, large_list, struct, slots_struct. Default: all",
)
parser.add_argument(
"--serializers",
@@ -450,7 +529,8 @@ def micro_benchmark():
"large_boolean_tuple": LARGE_BOOLEAN_TUPLE,
"list": LIST,
"large_list": LARGE_LIST,
- "complex": COMPLEX_OBJECT,
+ "struct": STRUCT_OBJECT,
+ "slots_struct": SLOTS_STRUCT_OBJECT,
}
# Determine which benchmarks to run
@@ -487,7 +567,9 @@ def micro_benchmark():
msgpack_data = {}
if "msgpack" in selected_serializers:
msgpack_data = {
- benchmark_name: make_msgpack_compatible(data)
+ benchmark_name: (
+ data if is_dataclass(data) else make_msgpack_compatible(data)
+ )
for benchmark_name, data in benchmark_data.items()
}
diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi
index 649ba294f..48bd3efca 100644
--- a/python/pyfory/collection.pxi
+++ b/python/pyfory/collection.pxi
@@ -693,6 +693,7 @@ cdef int32_t NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF
=KEY_HAS_NULL | VALUE_DECL_TY
cdef int32_t NULL_VALUE_KEY_DECL_TYPE = VALUE_HAS_NULL | KEY_DECL_TYPE
# Value is null, key type is declared type, and ref tracking for key is
enabled.
cdef int32_t NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = VALUE_HAS_NULL |
KEY_DECL_TYPE | TRACKING_KEY_REF
+ctypedef PyObject* PyObjectPtr
@cython.final
@@ -703,6 +704,8 @@ cdef class MapSerializer(Serializer):
cdef Serializer value_serializer
cdef int8_t key_tracking_ref
cdef int8_t value_tracking_ref
+ cdef FlatIntMap[uint64_t, PyObjectPtr] _key_typeinfo_cache
+ cdef FlatIntMap[uint64_t, PyObjectPtr] _value_typeinfo_cache
def __init__(
self,
@@ -718,6 +721,8 @@ cdef class MapSerializer(Serializer):
self.ref_resolver = fory.ref_resolver
self.key_serializer = key_serializer
self.value_serializer = value_serializer
+ self._key_typeinfo_cache = FlatIntMap[uint64_t, PyObjectPtr](4)
+ self._value_typeinfo_cache = FlatIntMap[uint64_t, PyObjectPtr](4)
self.key_tracking_ref = 0
self.value_tracking_ref = 0
if key_serializer is not None:
@@ -743,6 +748,8 @@ cdef class MapSerializer(Serializer):
cdef Serializer key_serializer = self.key_serializer
cdef Serializer value_serializer = self.value_serializer
cdef type key_cls, value_cls, key_serializer_type,
value_serializer_type
+ cdef uint64_t key_cls_addr, value_cls_addr
+ cdef PyObjectPtr key_typeinfo_ptr, value_typeinfo_ptr
cdef TypeInfo key_type_info, value_type_info
cdef int32_t chunk_size_offset, chunk_header, chunk_size
cdef c_bool key_write_ref, value_write_ref
@@ -799,13 +806,25 @@ cdef class MapSerializer(Serializer):
if key_serializer is not None:
chunk_header |= KEY_DECL_TYPE
else:
- key_type_info = self.type_resolver.get_type_info(key_cls)
+ key_cls_addr = <uint64_t><uintptr_t><PyObject *> key_cls
+ key_typeinfo_ptr = self._key_typeinfo_cache[key_cls_addr]
+ if key_typeinfo_ptr == NULL:
+ key_type_info = self.type_resolver.get_type_info(key_cls)
+ self._key_typeinfo_cache[key_cls_addr] = <PyObject *>
key_type_info
+ else:
+ key_type_info = <TypeInfo> key_typeinfo_ptr
type_resolver.write_type_info(buffer, key_type_info)
key_serializer = key_type_info.serializer
if value_serializer is not None:
chunk_header |= VALUE_DECL_TYPE
else:
- value_type_info = self.type_resolver.get_type_info(value_cls)
+ value_cls_addr = <uint64_t><uintptr_t><PyObject *> value_cls
+ value_typeinfo_ptr = self._value_typeinfo_cache[value_cls_addr]
+ if value_typeinfo_ptr == NULL:
+ value_type_info =
self.type_resolver.get_type_info(value_cls)
+ self._value_typeinfo_cache[value_cls_addr] = <PyObject *>
value_type_info
+ else:
+ value_type_info = <TypeInfo> value_typeinfo_ptr
type_resolver.write_type_info(buffer, value_type_info)
value_serializer = value_type_info.serializer
if self.key_serializer is not None:
diff --git a/python/pyfory/cpp/pyfory.cc b/python/pyfory/cpp/pyfory.cc
index f5f15d426..62251b7f6 100644
--- a/python/pyfory/cpp/pyfory.cc
+++ b/python/pyfory/cpp/pyfory.cc
@@ -485,6 +485,24 @@ static bool py_long_to_integral_range(PyObject *value,
const char *type_name,
return true;
}
+template <typename T>
+static bool py_long_to_unsigned_range(PyObject *value, const char *type_name,
+ T *out) {
+ const unsigned long long converted = PyLong_AsUnsignedLongLong(value);
+ if (converted == static_cast<unsigned long long>(-1) &&
+ PyErr_Occurred() != nullptr) {
+ return false;
+ }
+ constexpr unsigned long long k_max =
+ static_cast<unsigned long long>(std::numeric_limits<T>::max());
+ if (converted > k_max) {
+ PyErr_Format(PyExc_OverflowError, "integer out of range for %s",
type_name);
+ return false;
+ }
+ *out = static_cast<T>(converted);
+ return true;
+}
+
static int write_python_string(Buffer *buffer, PyObject *value) {
if (FORY_PREDICT_FALSE(!PyUnicode_Check(value))) {
PyErr_Format(PyExc_TypeError, "expected str, got %.200s",
@@ -608,26 +626,82 @@ static int write_primitive_item(Buffer *buffer, PyObject
*value,
switch (static_cast<TypeId>(type_id)) {
case TypeId::STRING:
return write_python_string(buffer, value);
- case TypeId::VARINT64: {
+ case TypeId::VARINT64:
+ case TypeId::INT64: {
int64_t v = 0;
if (FORY_PREDICT_FALSE(!py_long_to_int64(value, &v))) {
return -1;
}
- buffer->write_var_int64(v);
+ if (static_cast<TypeId>(type_id) == TypeId::INT64) {
+ buffer->write_int64(v);
+ } else {
+ buffer->write_var_int64(v);
+ }
return 0;
}
- case TypeId::VARINT32: {
+ case TypeId::VARINT32:
+ case TypeId::INT32: {
int32_t v = 0;
if (FORY_PREDICT_FALSE(
!py_long_to_integral_range<int32_t>(value, "int32", &v))) {
return -1;
}
- buffer->write_var_int32(v);
+ if (static_cast<TypeId>(type_id) == TypeId::INT32) {
+ buffer->write_int32(v);
+ } else {
+ buffer->write_var_int32(v);
+ }
+ return 0;
+ }
+ case TypeId::VAR_UINT64:
+ case TypeId::UINT64:
+ case TypeId::TAGGED_UINT64: {
+ uint64_t v = 0;
+ if (FORY_PREDICT_FALSE(
+ !py_long_to_unsigned_range<uint64_t>(value, "uint64", &v))) {
+ return -1;
+ }
+ if (static_cast<TypeId>(type_id) == TypeId::VAR_UINT64) {
+ buffer->write_var_uint64(v);
+ } else if (static_cast<TypeId>(type_id) == TypeId::TAGGED_UINT64) {
+ buffer->write_tagged_uint64(v);
+ } else {
+ buffer->write_int64(static_cast<int64_t>(v));
+ }
+ return 0;
+ }
+ case TypeId::VAR_UINT32:
+ case TypeId::UINT32: {
+ uint32_t v = 0;
+ if (FORY_PREDICT_FALSE(
+ !py_long_to_unsigned_range<uint32_t>(value, "uint32", &v))) {
+ return -1;
+ }
+ if (static_cast<TypeId>(type_id) == TypeId::VAR_UINT32) {
+ buffer->write_var_uint32(v);
+ } else {
+ buffer->write_uint32(v);
+ }
return 0;
}
case TypeId::BOOL:
- buffer->write_int8(value == Py_True ? 1 : 0);
+ if (value == Py_True) {
+ buffer->write_int8(1);
+ return 0;
+ }
+ if (value == Py_False) {
+ buffer->write_int8(0);
+ return 0;
+ }
+ {
+ const int truthy = PyObject_IsTrue(value);
+ if (FORY_PREDICT_FALSE(truthy < 0)) {
+ return -1;
+ }
+ buffer->write_int8(truthy ? 1 : 0);
+ }
return 0;
+ case TypeId::FLOAT32:
case TypeId::FLOAT64: {
double v;
if (PyFloat_CheckExact(value)) {
@@ -638,7 +712,20 @@ static int write_primitive_item(Buffer *buffer, PyObject
*value,
return -1;
}
}
- buffer->write_double(v);
+ if (static_cast<TypeId>(type_id) == TypeId::FLOAT32) {
+ buffer->write_float(static_cast<float>(v));
+ } else {
+ buffer->write_double(v);
+ }
+ return 0;
+ }
+ case TypeId::UINT8: {
+ uint8_t v = 0;
+ if (FORY_PREDICT_FALSE(
+ !py_long_to_unsigned_range<uint8_t>(value, "uint8", &v))) {
+ return -1;
+ }
+ buffer->write_uint8(v);
return 0;
}
case TypeId::INT8: {
@@ -650,6 +737,15 @@ static int write_primitive_item(Buffer *buffer, PyObject
*value,
buffer->write_int8(v);
return 0;
}
+ case TypeId::UINT16: {
+ uint16_t v = 0;
+ if (FORY_PREDICT_FALSE(
+ !py_long_to_unsigned_range<uint16_t>(value, "uint16", &v))) {
+ return -1;
+ }
+ buffer->write_uint16(v);
+ return 0;
+ }
case TypeId::INT16: {
int16_t v = 0;
if (FORY_PREDICT_FALSE(
@@ -659,13 +755,13 @@ static int write_primitive_item(Buffer *buffer, PyObject
*value,
buffer->write_int16(v);
return 0;
}
- case TypeId::INT32: {
- int32_t v = 0;
+ case TypeId::TAGGED_INT64: {
+ int64_t v = 0;
if (FORY_PREDICT_FALSE(
- !py_long_to_integral_range<int32_t>(value, "int32", &v))) {
+ !py_long_to_integral_range<int64_t>(value, "int64", &v))) {
return -1;
}
- buffer->write_int32(v);
+ buffer->write_tagged_int64(v);
return 0;
}
default:
@@ -858,22 +954,56 @@ static PyObject *read_primitive_item(Buffer *buffer,
uint8_t type_id) {
switch (static_cast<TypeId>(type_id)) {
case TypeId::STRING:
return read_python_string(buffer);
- case TypeId::VARINT64: {
- const int64_t v = buffer->read_var_int64(error);
+ case TypeId::VARINT64:
+ case TypeId::INT64: {
+ const int64_t v = static_cast<TypeId>(type_id) == TypeId::INT64
+ ? buffer->read_int64(error)
+ : buffer->read_var_int64(error);
if (FORY_PREDICT_FALSE(!error.ok())) {
set_buffer_error(error);
return nullptr;
}
return PyLong_FromLongLong(v);
}
- case TypeId::VARINT32: {
- const int32_t v = buffer->read_var_int32(error);
+ case TypeId::VARINT32:
+ case TypeId::INT32: {
+ const int32_t v = static_cast<TypeId>(type_id) == TypeId::INT32
+ ? buffer->read_int32(error)
+ : buffer->read_var_int32(error);
if (FORY_PREDICT_FALSE(!error.ok())) {
set_buffer_error(error);
return nullptr;
}
return PyLong_FromLong(v);
}
+ case TypeId::VAR_UINT64:
+ case TypeId::UINT64:
+ case TypeId::TAGGED_UINT64: {
+ uint64_t v = 0;
+ if (static_cast<TypeId>(type_id) == TypeId::VAR_UINT64) {
+ v = buffer->read_var_uint64(error);
+ } else if (static_cast<TypeId>(type_id) == TypeId::TAGGED_UINT64) {
+ v = buffer->read_tagged_uint64(error);
+ } else {
+ v = buffer->read_uint64(error);
+ }
+ if (FORY_PREDICT_FALSE(!error.ok())) {
+ set_buffer_error(error);
+ return nullptr;
+ }
+ return PyLong_FromUnsignedLongLong(v);
+ }
+ case TypeId::VAR_UINT32:
+ case TypeId::UINT32: {
+ const uint32_t v = static_cast<TypeId>(type_id) == TypeId::UINT32
+ ? buffer->read_uint32(error)
+ : buffer->read_var_uint32(error);
+ if (FORY_PREDICT_FALSE(!error.ok())) {
+ set_buffer_error(error);
+ return nullptr;
+ }
+ return PyLong_FromUnsignedLong(v);
+ }
case TypeId::BOOL: {
const uint8_t v = buffer->read_uint8(error);
if (FORY_PREDICT_FALSE(!error.ok())) {
@@ -882,6 +1012,14 @@ static PyObject *read_primitive_item(Buffer *buffer,
uint8_t type_id) {
}
return PyBool_FromLong(v != 0);
}
+ case TypeId::FLOAT32: {
+ const float v = buffer->read_float(error);
+ if (FORY_PREDICT_FALSE(!error.ok())) {
+ set_buffer_error(error);
+ return nullptr;
+ }
+ return PyFloat_FromDouble(static_cast<double>(v));
+ }
case TypeId::FLOAT64: {
const double v = buffer->read_double(error);
if (FORY_PREDICT_FALSE(!error.ok())) {
@@ -890,6 +1028,14 @@ static PyObject *read_primitive_item(Buffer *buffer,
uint8_t type_id) {
}
return PyFloat_FromDouble(v);
}
+ case TypeId::UINT8: {
+ const uint8_t v = buffer->read_uint8(error);
+ if (FORY_PREDICT_FALSE(!error.ok())) {
+ set_buffer_error(error);
+ return nullptr;
+ }
+ return PyLong_FromUnsignedLong(v);
+ }
case TypeId::INT8: {
const int8_t v = buffer->read_int8(error);
if (FORY_PREDICT_FALSE(!error.ok())) {
@@ -898,6 +1044,14 @@ static PyObject *read_primitive_item(Buffer *buffer,
uint8_t type_id) {
}
return PyLong_FromLong(v);
}
+ case TypeId::UINT16: {
+ const uint16_t v = buffer->read_uint16(error);
+ if (FORY_PREDICT_FALSE(!error.ok())) {
+ set_buffer_error(error);
+ return nullptr;
+ }
+ return PyLong_FromUnsignedLong(v);
+ }
case TypeId::INT16: {
const int16_t v = buffer->read_int16(error);
if (FORY_PREDICT_FALSE(!error.ok())) {
@@ -906,13 +1060,13 @@ static PyObject *read_primitive_item(Buffer *buffer,
uint8_t type_id) {
}
return PyLong_FromLong(v);
}
- case TypeId::INT32: {
- const int32_t v = buffer->read_int32(error);
+ case TypeId::TAGGED_INT64: {
+ const int64_t v = buffer->read_tagged_int64(error);
if (FORY_PREDICT_FALSE(!error.ok())) {
set_buffer_error(error);
return nullptr;
}
- return PyLong_FromLong(v);
+ return PyLong_FromLongLong(v);
}
default:
PyErr_Format(PyExc_ValueError, "unsupported primitive fastpath type id:
%u",
@@ -1235,6 +1389,15 @@ int Fory_PyPrimitiveCollectionReadFromBuffer(PyObject
*collection,
return 0;
}
+int Fory_PyWriteBasicFieldToBuffer(PyObject *value, Buffer *buffer,
+ uint8_t type_id) {
+ return write_primitive_item(buffer, value, type_id);
+}
+
+PyObject *Fory_PyReadBasicFieldFromBuffer(Buffer *buffer, uint8_t type_id) {
+ return read_primitive_item(buffer, type_id);
+}
+
int Fory_PyCreateBufferFromStream(PyObject *stream, uint32_t buffer_size,
Buffer **out, std::string *error_message) {
if (stream == nullptr) {
diff --git a/python/pyfory/cpp/pyfory.h b/python/pyfory/cpp/pyfory.h
index 4c96257a6..1003733cc 100644
--- a/python/pyfory/cpp/pyfory.h
+++ b/python/pyfory/cpp/pyfory.h
@@ -65,6 +65,9 @@ int Fory_PyPrimitiveCollectionWriteToBuffer(PyObject
*collection,
int Fory_PyPrimitiveCollectionReadFromBuffer(PyObject *collection,
Buffer *buffer, Py_ssize_t size,
uint8_t type_id);
+int Fory_PyWriteBasicFieldToBuffer(PyObject *value, Buffer *buffer,
+ uint8_t type_id);
+PyObject *Fory_PyReadBasicFieldFromBuffer(Buffer *buffer, uint8_t type_id);
int Fory_PyCreateBufferFromStream(PyObject *stream, uint32_t buffer_size,
Buffer **out, std::string *error_message);
} // namespace fory
diff --git a/python/pyfory/includes/libserialization.pxd
b/python/pyfory/includes/libserialization.pxd
index a6fa0b8e3..40e0b3248 100644
--- a/python/pyfory/includes/libserialization.pxd
+++ b/python/pyfory/includes/libserialization.pxd
@@ -97,3 +97,5 @@ cdef extern from "fory/python/pyfory.h" namespace "fory":
cdef c_bool Fory_CanUsePrimitiveCollectionFastpath(uint8_t type_id)
int Fory_PyPrimitiveCollectionWriteToBuffer(object collection, CBuffer
*buffer, uint8_t type_id) except -1
int Fory_PyPrimitiveCollectionReadFromBuffer(object collection, CBuffer
*buffer, int64_t size, uint8_t type_id) except -1
+ int Fory_PyWriteBasicFieldToBuffer(object value, CBuffer *buffer, uint8_t
type_id) except -1
+ object Fory_PyReadBasicFieldFromBuffer(CBuffer *buffer, uint8_t type_id)
diff --git a/python/pyfory/includes/libutil.pxd
b/python/pyfory/includes/libutil.pxd
index 50373e19c..aeaa60a7d 100644
--- a/python/pyfory/includes/libutil.pxd
+++ b/python/pyfory/includes/libutil.pxd
@@ -227,3 +227,17 @@ cdef extern from "fory/util/bit_util.h" namespace
"fory::util" nogil:
cdef extern from "fory/util/string_util.h" namespace "fory" nogil:
c_bool utf16_has_surrogate_pairs(uint16_t* data, size_t size)
+
+
+cdef extern from "fory/util/flat_int_map.h" namespace "fory::util" nogil:
+ cdef cppclass FlatIntMap[K, V]:
+ cppclass Entry:
+ K key
+ V value
+
+ FlatIntMap() except +
+ FlatIntMap(size_t initial_capacity) except +
+ V& operator[](K key)
+ Entry* find(K key)
+ size_t size() const
+ void clear()
diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py
index 50b48f260..69680eaf1 100644
--- a/python/pyfory/meta/typedef.py
+++ b/python/pyfory/meta/typedef.py
@@ -168,9 +168,11 @@ class TypeDef:
nullable_fields[resolved_name] = field_info.field_type.is_nullable
dynamic_fields = {}
+ ref_fields = {}
for i, field_info in enumerate(self.fields):
resolved_name = field_names[i]
type_id = field_info.field_type.type_id
+ ref_fields[resolved_name] = field_info.field_type.is_tracking_ref
if is_polymorphic_type(type_id):
dynamic_fields[resolved_name] = True
@@ -181,6 +183,7 @@ class TypeDef:
serializers=self.create_fields_serializer(resolver, field_names),
nullable_fields=nullable_fields,
dynamic_fields=dynamic_fields,
+ ref_fields=ref_fields,
)
def __repr__(self):
diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py
index 6328dd4a5..dcae0f42c 100644
--- a/python/pyfory/registry.py
+++ b/python/pyfory/registry.py
@@ -126,6 +126,27 @@ logger = logging.getLogger(__name__)
namespace_decoder = MetaStringDecoder(".", "_")
typename_decoder = MetaStringDecoder("$", "_")
+_NO_REF_NUMERIC_TYPE_IDS = frozenset(
+ {
+ TypeId.INT8,
+ TypeId.INT16,
+ TypeId.INT32,
+ TypeId.VARINT32,
+ TypeId.INT64,
+ TypeId.VARINT64,
+ TypeId.TAGGED_INT64,
+ TypeId.UINT8,
+ TypeId.UINT16,
+ TypeId.UINT32,
+ TypeId.VAR_UINT32,
+ TypeId.UINT64,
+ TypeId.VAR_UINT64,
+ TypeId.TAGGED_UINT64,
+ TypeId.FLOAT32,
+ TypeId.FLOAT64,
+ }
+)
+
if ENABLE_FORY_CYTHON_SERIALIZATION:
from pyfory.serialization import TypeInfo
else:
@@ -538,6 +559,8 @@ class TypeResolver:
if should_create_serializer:
serializer = self._create_serializer(cls)
+ if serializer is not None and type_id in _NO_REF_NUMERIC_TYPE_IDS:
+ serializer.need_to_write_ref = False
if typename is None:
typeinfo = TypeInfo(cls, type_id, user_type_id, serializer, None,
None, dynamic_type)
@@ -636,9 +659,18 @@ class TypeResolver:
elif self._internal_py_serializer_map.get(type(serializer)) is not
None:
type_id =
self._internal_py_serializer_map.get(type(serializer))[1]
if not self.require_registration:
- from pyfory.struct import DataClassSerializer
-
- if isinstance(serializer, DataClassSerializer):
+ from pyfory import struct as struct_module
+
+ data_class_types = tuple(
+ cls
+ for cls in (
+ getattr(struct_module, "DataClassSerializer", None),
+ getattr(struct_module, "DataClassStubSerializer",
None),
+ getattr(struct_module, "PythonDataClassSerializer",
None),
+ )
+ if cls is not None
+ )
+ if data_class_types and isinstance(serializer,
data_class_types):
type_id = TypeId.NAMED_STRUCT
if type_id is None:
raise TypeUnregisteredError(f"{cls} must be registered using
`fory.register_type` API")
diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx
index b53a5d459..9b8c73098 100644
--- a/python/pyfory/serialization.pyx
+++ b/python/pyfory/serialization.pyx
@@ -39,7 +39,9 @@ from pyfory.includes.libserialization cimport \
Fory_IsInternalTypeId,
Fory_CanUsePrimitiveCollectionFastpath,
Fory_PyPrimitiveCollectionWriteToBuffer,
- Fory_PyPrimitiveCollectionReadFromBuffer)
+ Fory_PyPrimitiveCollectionReadFromBuffer,
+ Fory_PyWriteBasicFieldToBuffer,
+ Fory_PyReadBasicFieldFromBuffer)
from libc.stdint cimport int8_t, int16_t, int32_t, int64_t, uint64_t
from libc.stdint cimport *
@@ -56,7 +58,7 @@ from libcpp cimport bool as c_bool
from libcpp.utility cimport pair
from cython.operator cimport dereference as deref
from pyfory.includes.libabsl cimport flat_hash_map
-from pyfory.includes.libutil cimport CBuffer
+from pyfory.includes.libutil cimport CBuffer, FlatIntMap
from pyfory.meta.metastring import MetaStringDecoder
try:
@@ -1869,3 +1871,4 @@ cdef class SliceSerializer(Serializer):
include "primitive.pxi"
include "collection.pxi"
+include "struct.pxi"
diff --git a/python/pyfory/struct.pxi b/python/pyfory/struct.pxi
new file mode 100644
index 000000000..63a3871cc
--- /dev/null
+++ b/python/pyfory/struct.pxi
@@ -0,0 +1,463 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import dataclasses
+import typing
+
+from cpython.unicode cimport PyUnicode_InternFromString
+
+
+cdef uint8_t _BASIC_FIELD_UNSUPPORTED = 0xFF
+
+
+cdef struct FieldRuntimeInfo:
+ uint8_t basic_type_id
+ uint8_t is_nullable
+ uint8_t track_ref
+ uint8_t is_dynamic
+ uint8_t field_exists
+ PyObject *field_name
+ PyObject *serializer
+
+
[email protected]
+cdef class DataClassSerializer(Serializer):
+ cdef public object _type_hints
+ cdef public bint _has_slots
+ cdef public bint _fields_from_typedef
+ cdef public object _field_names
+ cdef public object _serializers
+ cdef public object _nullable_fields
+ cdef public object _ref_fields
+ cdef public object _dynamic_fields
+ cdef public object _field_infos
+ cdef public object _field_metas
+ cdef public object _unwrapped_hints
+ cdef public int32_t _hash
+ cdef public tuple _field_name_interned
+ cdef tuple _serializer_owner
+ cdef public object _default_values_factory
+ cdef object _missing_field_defaults
+ cdef vector[FieldRuntimeInfo] _field_runtime_infos
+
+ def __init__(
+ self,
+ fory,
+ clz: type,
+ field_names: list = None,
+ serializers: list = None,
+ nullable_fields: dict = None,
+ dynamic_fields: dict = None,
+ ref_fields: dict = None,
+ ):
+ super().__init__(fory, clz)
+
+ from pyfory.lib.mmh3 import hash_buffer
+ from pyfory.struct import (
+ _extract_field_infos,
+ build_default_values_factory,
+ compute_struct_fingerprint,
+ compute_struct_meta,
+ StructFieldSerializerVisitor,
+ )
+ from pyfory.type_util import get_type_hints, unwrap_optional,
infer_field
+ from pyfory.types import TypeId, is_primitive_type
+
+ self._type_hints = get_type_hints(clz)
+ self._has_slots = hasattr(clz, "__slots__")
+
+ self._fields_from_typedef = field_names is not None and serializers is
not None
+ if self._fields_from_typedef:
+ self._field_names = list(field_names)
+ self._serializers = list(serializers)
+ self._nullable_fields = dict(nullable_fields) if nullable_fields
is not None else {}
+ self._ref_fields = dict(ref_fields) if ref_fields is not None else
{}
+ self._dynamic_fields = dict(dynamic_fields) if dynamic_fields is
not None else {}
+ self._field_infos = []
+ self._field_metas = {}
+ else:
+ self._field_infos, self._field_metas = _extract_field_infos(fory,
clz, self._type_hints)
+
+ if self._field_infos:
+ self._field_names = [fi.name for fi in self._field_infos]
+ self._serializers = [fi.serializer for fi in self._field_infos]
+ self._nullable_fields = {fi.name: fi.nullable for fi in
self._field_infos}
+ self._ref_fields = {fi.name: fi.runtime_ref_tracking for fi in
self._field_infos}
+ self._dynamic_fields = {fi.name: fi.dynamic for fi in
self._field_infos}
+ else:
+ self._field_names = self._get_field_names(clz)
+ self._nullable_fields = dict(nullable_fields) if
nullable_fields is not None else {}
+ self._ref_fields = {}
+ self._dynamic_fields = {}
+
+ if self._field_names and not self._nullable_fields:
+ for field_name in self._field_names:
+ if field_name in self._type_hints:
+ unwrapped_type, is_optional =
unwrap_optional(self._type_hints[field_name])
+ self._nullable_fields[field_name] = is_optional or
not is_primitive_type(unwrapped_type)
+
+ if serializers is None:
+ self._serializers = [None] * len(self._field_names)
+ visitor = StructFieldSerializerVisitor(fory)
+ for index, key in enumerate(self._field_names):
+ unwrapped_type, _ =
unwrap_optional(self._type_hints.get(key, typing.Any))
+ self._serializers[index] = infer_field(key,
unwrapped_type, visitor, types_path=[])
+ else:
+ self._serializers = list(serializers)
+
+ self._unwrapped_hints = self._compute_unwrapped_hints()
+
+ if self._fields_from_typedef:
+ hash_str = compute_struct_fingerprint(
+ fory.type_resolver,
+ self._field_names,
+ self._serializers,
+ self._nullable_fields,
+ self._field_infos,
+ )
+ hash_bytes = hash_str.encode("utf-8")
+ if len(hash_bytes) == 0:
+ self._hash = 47
+ else:
+ full_hash = hash_buffer(hash_bytes, seed=47)[0]
+ type_hash_32 = full_hash & 0xFFFFFFFF
+ if full_hash & 0x80000000:
+ type_hash_32 -= 0x100000000
+ self._hash = type_hash_32
+ else:
+ self._hash, self._field_names, self._serializers =
compute_struct_meta(
+ fory.type_resolver,
+ self._field_names,
+ self._serializers,
+ self._nullable_fields,
+ self._field_infos,
+ )
+
+ self._field_name_interned = tuple(self._intern_field_name(name) for
name in self._field_names)
+ self._serializer_owner = tuple(self._serializers)
+ if dataclasses.is_dataclass(clz):
+ self._default_values_factory =
build_default_values_factory(self.fory, self._type_hints,
dataclasses.fields(clz))
+ else:
+ self._default_values_factory = {}
+ self._build_fastpath_metadata()
+ self._build_missing_field_defaults()
+
+ cdef object _intern_field_name(self, str name):
+ cdef bytes encoded = name.encode("utf-8")
+ cdef const char *ptr = encoded
+ cdef object interned = PyUnicode_InternFromString(ptr)
+ if interned is None:
+ raise MemoryError("failed to intern field name")
+ return interned
+
+ cdef list _get_field_names(self, object clz):
+ if hasattr(clz, "__dict__"):
+ if dataclasses.is_dataclass(clz):
+ return [field.name for field in dataclasses.fields(clz)]
+ return sorted(self._type_hints.keys())
+ if hasattr(clz, "__slots__"):
+ slots = clz.__slots__
+ if type(slots) is str:
+ return [slots]
+ return sorted(slots)
+ return []
+
+ cdef dict _compute_unwrapped_hints(self):
+ from pyfory.type_util import unwrap_optional
+
+ return {field_name: unwrap_optional(hint)[0] for field_name, hint in
self._type_hints.items()}
+
+ cdef inline uint8_t _resolve_basic_type_id(self, Serializer serializer,
bint is_dynamic):
+ cdef uint8_t type_id
+ if is_dynamic or serializer is None:
+ return _BASIC_FIELD_UNSUPPORTED
+ type_id =
<uint8_t>self.fory.type_resolver.get_type_info(serializer.type_).type_id
+ if type_id == <uint8_t>TypeId.BOOL:
+ return type_id
+ if type_id == <uint8_t>TypeId.INT8:
+ return type_id
+ if type_id == <uint8_t>TypeId.INT16:
+ return type_id
+ if type_id == <uint8_t>TypeId.INT32:
+ return type_id
+ if type_id == <uint8_t>TypeId.VARINT32:
+ return type_id
+ if type_id == <uint8_t>TypeId.INT64:
+ return type_id
+ if type_id == <uint8_t>TypeId.VARINT64:
+ return type_id
+ if type_id == <uint8_t>TypeId.TAGGED_INT64:
+ return type_id
+ if type_id == <uint8_t>TypeId.UINT8:
+ return type_id
+ if type_id == <uint8_t>TypeId.UINT16:
+ return type_id
+ if type_id == <uint8_t>TypeId.UINT32:
+ return type_id
+ if type_id == <uint8_t>TypeId.VAR_UINT32:
+ return type_id
+ if type_id == <uint8_t>TypeId.UINT64:
+ return type_id
+ if type_id == <uint8_t>TypeId.VAR_UINT64:
+ return type_id
+ if type_id == <uint8_t>TypeId.TAGGED_UINT64:
+ return type_id
+ if type_id == <uint8_t>TypeId.FLOAT32:
+ return type_id
+ if type_id == <uint8_t>TypeId.FLOAT64:
+ return type_id
+ if type_id == <uint8_t>TypeId.STRING:
+ return type_id
+ return _BASIC_FIELD_UNSUPPORTED
+
+ cdef void _build_fastpath_metadata(self):
+ cdef Py_ssize_t i
+ cdef object field_name
+ cdef object serializer
+ cdef set current_fields
+ cdef bint is_dynamic
+ cdef bint is_nullable
+ cdef bint is_tracking_ref
+ cdef FieldRuntimeInfo runtime_info
+
+ self._field_runtime_infos.clear()
+
+ current_fields = set(self._get_field_names(self.type_))
+ self._field_runtime_infos.reserve(len(self._field_names))
+
+ for i in range(len(self._field_names)):
+ field_name = self._field_names[i]
+ serializer = self._serializer_owner[i]
+ is_nullable = bool(self._nullable_fields.get(field_name, False))
+ is_tracking_ref = bool(self._ref_fields.get(field_name, False))
+ is_dynamic = bool(self._dynamic_fields.get(field_name, False))
+
+ runtime_info.basic_type_id =
self._resolve_basic_type_id(serializer, is_dynamic)
+ runtime_info.is_nullable = 1 if is_nullable else 0
+ runtime_info.track_ref = 1 if is_tracking_ref else 0
+ runtime_info.is_dynamic = 1 if is_dynamic else 0
+ runtime_info.field_exists = 1 if field_name in current_fields else 0
+ runtime_info.field_name = <PyObject *> self._field_name_interned[i]
+ runtime_info.serializer = <PyObject *> serializer
+ self._field_runtime_infos.push_back(runtime_info)
+
+ cdef void _build_missing_field_defaults(self):
+ cdef object read_field_names
+ cdef object current_class_field_names
+ cdef object missing_fields
+ cdef list defaults
+ cdef object field_name
+ cdef object default_factory
+
+ self._missing_field_defaults = ()
+ if not self.fory.compatible or not self._default_values_factory:
+ return
+
+ read_field_names = set(self._field_names)
+ current_class_field_names = set(self._get_field_names(self.type_))
+ missing_fields = current_class_field_names - read_field_names
+ if not missing_fields:
+ return
+
+ defaults = []
+ for field_name, default_factory in
self._default_values_factory.items():
+ if field_name not in missing_fields:
+ continue
+ defaults.append((self._intern_field_name(field_name),
default_factory))
+ self._missing_field_defaults = tuple(defaults)
+
+ cpdef inline write(self, Buffer buffer, value):
+ if not self.fory.compatible:
+ buffer.write_int32(self._hash)
+ if self._has_slots:
+ self._write_slots(buffer, value)
+ else:
+ self._write_dict(buffer, value)
+
+ cdef inline void _write_dict(self, Buffer buffer, object value):
+ cdef dict value_dict = value.__dict__
+ cdef Py_ssize_t i
+ cdef Py_ssize_t field_count = self._field_runtime_infos.size()
+ cdef object field_value
+ cdef object field_name
+ cdef FieldRuntimeInfo *field_info
+
+ if self.fory.compatible:
+ for i in range(field_count):
+ field_info = &self._field_runtime_infos[i]
+ field_name = <object> field_info.field_name
+ field_value = value_dict.get(field_name)
+ self._write_field_value(buffer, field_info, field_value)
+ else:
+ for i in range(field_count):
+ field_info = &self._field_runtime_infos[i]
+ field_name = <object> field_info.field_name
+ field_value = value_dict[field_name]
+ self._write_field_value(buffer, field_info, field_value)
+
+ cdef inline void _write_slots(self, Buffer buffer, object value):
+ cdef Py_ssize_t i
+ cdef Py_ssize_t field_count = self._field_runtime_infos.size()
+ cdef object field_name
+ cdef object field_value
+ cdef FieldRuntimeInfo *field_info
+
+ if self.fory.compatible:
+ for i in range(field_count):
+ field_info = &self._field_runtime_infos[i]
+ field_name = <object> field_info.field_name
+ field_value = getattr(value, field_name, None)
+ self._write_field_value(buffer, field_info, field_value)
+ else:
+ for i in range(field_count):
+ field_info = &self._field_runtime_infos[i]
+ field_name = <object> field_info.field_name
+ field_value = getattr(value, field_name)
+ self._write_field_value(buffer, field_info, field_value)
+
+ cdef inline void _write_field_value(self, Buffer buffer, FieldRuntimeInfo
*field_info, object field_value):
+ cdef uint8_t type_id = field_info.basic_type_id
+ cdef bint is_nullable = field_info.is_nullable != 0
+ cdef bint is_tracking_ref = field_info.track_ref != 0
+ cdef bint is_dynamic = field_info.is_dynamic != 0
+ cdef Serializer serializer
+
+ if type_id != _BASIC_FIELD_UNSUPPORTED:
+ if is_nullable:
+ if field_value is None:
+ buffer.write_int8(NULL_FLAG)
+ else:
+ buffer.write_int8(NOT_NULL_VALUE_FLAG)
+ Fory_PyWriteBasicFieldToBuffer(field_value,
&buffer.c_buffer, type_id)
+ else:
+ Fory_PyWriteBasicFieldToBuffer(field_value, &buffer.c_buffer,
type_id)
+ return
+
+ serializer = <object> field_info.serializer
+ if is_tracking_ref:
+ if is_dynamic:
+ self.fory.write_ref(buffer, field_value)
+ else:
+ self.fory.write_ref(buffer, field_value, serializer=serializer)
+ else:
+ if is_nullable:
+ if field_value is None:
+ buffer.write_int8(NULL_FLAG)
+ return
+ buffer.write_int8(NOT_NULL_VALUE_FLAG)
+ if is_dynamic:
+ self.fory.write_no_ref(buffer, field_value)
+ else:
+ self.fory.write_no_ref(buffer, field_value,
serializer=serializer)
+
+ cpdef inline read(self, Buffer buffer):
+ cdef object obj
+
+ if not self.fory.strict:
+ self.fory.policy.authorize_instantiation(self.type_)
+
+ if not self.fory.compatible:
+ read_hash = buffer.read_int32()
+ if read_hash != self._hash:
+ from pyfory.error import TypeNotCompatibleError
+
+ raise TypeNotCompatibleError(f"Hash {read_hash} is not
consistent with {self._hash} for type {self.type_}")
+
+ obj = self.type_.__new__(self.type_)
+ self.fory.ref_resolver.reference(obj)
+
+ if self._has_slots:
+ self._read_slots(buffer, obj)
+ else:
+ self._read_dict(buffer, obj)
+
+ if self._missing_field_defaults:
+ if self._has_slots:
+ self._apply_missing_defaults_slots(obj)
+ else:
+ self._apply_missing_defaults_dict(obj.__dict__)
+ return obj
+
+ cdef inline void _read_dict(self, Buffer buffer, object obj):
+ cdef dict obj_dict = obj.__dict__
+ cdef Py_ssize_t i
+ cdef Py_ssize_t field_count = self._field_runtime_infos.size()
+ cdef object field_value
+ cdef object field_name
+ cdef FieldRuntimeInfo *field_info
+
+ for i in range(field_count):
+ field_info = &self._field_runtime_infos[i]
+ field_value = self._read_field_value(buffer, field_info)
+ if field_info.field_exists == 0:
+ continue
+ field_name = <object> field_info.field_name
+ obj_dict[field_name] = field_value
+
+ cdef inline void _read_slots(self, Buffer buffer, object obj):
+ cdef Py_ssize_t i
+ cdef Py_ssize_t field_count = self._field_runtime_infos.size()
+ cdef object field_value
+ cdef object field_name
+ cdef FieldRuntimeInfo *field_info
+
+ for i in range(field_count):
+ field_info = &self._field_runtime_infos[i]
+ field_value = self._read_field_value(buffer, field_info)
+ if field_info.field_exists == 0:
+ continue
+ field_name = <object> field_info.field_name
+ setattr(obj, field_name, field_value)
+
+ cdef inline object _read_field_value(self, Buffer buffer, FieldRuntimeInfo
*field_info):
+ cdef uint8_t type_id = field_info.basic_type_id
+ cdef bint is_nullable = field_info.is_nullable != 0
+ cdef bint is_tracking_ref = field_info.track_ref != 0
+ cdef bint is_dynamic = field_info.is_dynamic != 0
+ cdef Serializer serializer
+
+ if type_id != _BASIC_FIELD_UNSUPPORTED:
+ if is_nullable and buffer.read_int8() == NULL_FLAG:
+ return None
+ return Fory_PyReadBasicFieldFromBuffer(&buffer.c_buffer, type_id)
+
+ serializer = <object> field_info.serializer
+ if is_tracking_ref:
+ if is_dynamic:
+ return self.fory.read_ref(buffer)
+ return self.fory.read_ref(buffer, serializer=serializer)
+
+ if is_nullable and buffer.read_int8() == NULL_FLAG:
+ return None
+
+ if is_dynamic:
+ return self.fory.read_no_ref(buffer)
+ return self.fory.read_no_ref(buffer, serializer=serializer)
+
+ cdef inline void _apply_missing_defaults_dict(self, dict obj_dict):
+ cdef object field_name
+ cdef object default_factory
+
+ for field_name, default_factory in self._missing_field_defaults:
+ obj_dict[field_name] = default_factory()
+
+ cdef inline void _apply_missing_defaults_slots(self, object obj):
+ cdef object field_name
+ cdef object default_factory
+
+ for field_name, default_factory in self._missing_field_defaults:
+ setattr(obj, field_name, default_factory())
diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py
index e29cecb7b..7923dd673 100644
--- a/python/pyfory/struct.py
+++ b/python/pyfory/struct.py
@@ -21,9 +21,9 @@ import dataclasses
import datetime
import enum
import inspect
-import itertools
import logging
import os
+import sys
import typing
from typing import List, Dict
@@ -62,11 +62,6 @@ from pyfory.type_util import (
unwrap_optional,
)
from pyfory.serialization import Buffer
-from pyfory.codegen import (
- gen_write_nullable_basic_stmts,
- gen_read_nullable_basic_stmts,
- compile_function,
-)
from pyfory.error import TypeNotCompatibleError
from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
from pyfory.field import (
@@ -89,6 +84,30 @@ from pyfory import (
logger = logging.getLogger(__name__)
+_MISSING_DEFAULT_INT_TYPES = {
+ int,
+ int8,
+ int16,
+ int32,
+ fixed_int32,
+ int64,
+ fixed_int64,
+ tagged_int64,
+ uint8,
+ uint16,
+ uint32,
+ fixed_uint32,
+ uint64,
+ fixed_uint64,
+ tagged_uint64,
+}
+
+_MISSING_DEFAULT_FLOAT_TYPES = {
+ float,
+ float32,
+ float64,
+}
+
@dataclasses.dataclass
class FieldInfo:
@@ -264,20 +283,74 @@ def _extract_field_infos(
return field_infos, field_metas
-_jit_context = locals()
-
-
-_ENABLE_FORY_PYTHON_JIT = os.environ.get("ENABLE_FORY_PYTHON_JIT",
"True").lower() in (
- "true",
- "1",
-)
-_ENABLE_FORY_DEBUG_OUTPUT = os.environ.get("ENABLE_FORY_DEBUG_OUTPUT",
"False").lower() in (
- "true",
- "1",
-)
+def resolve_missing_field_default(
+ dc_field: dataclasses.Field,
+ fory,
+ type_hints: dict[str, typing.Any],
+) -> typing.Callable[[], typing.Any]:
+ type_hint = type_hints.get(dc_field.name, typing.Any)
+ unwrapped_type, is_optional = unwrap_optional(type_hint)
+ meta = extract_field_meta(dc_field)
+ effective_nullable = (meta.nullable if meta is not None else
fory.field_nullable) or is_optional
+
+ if dc_field.default is not dataclasses.MISSING:
+ default_value = dc_field.default
+ if default_value is None and not effective_nullable and
is_subclass(unwrapped_type, enum.Enum):
+ members = tuple(unwrapped_type)
+ if members:
+ default_value = members[0]
+ return lambda value=default_value: value
+
+ if dc_field.default_factory is not dataclasses.MISSING:
+ return dc_field.default_factory
+
+ if not effective_nullable:
+ origin = typing.get_origin(unwrapped_type) if hasattr(typing,
"get_origin") else getattr(unwrapped_type, "__origin__", None)
+ origin = origin or unwrapped_type
+ if is_subclass(unwrapped_type, enum.Enum):
+ members = tuple(unwrapped_type)
+ if members:
+ default_value = members[0]
+ return lambda value=default_value: value
+ if origin is list or origin == typing.List:
+ return lambda: []
+ if origin is set or origin == typing.Set:
+ return lambda: set()
+ if origin is dict or origin == typing.Dict:
+ return lambda: {}
+ if unwrapped_type is bool:
+ return lambda: False
+ if unwrapped_type in _MISSING_DEFAULT_INT_TYPES:
+ return lambda: 0
+ if unwrapped_type in _MISSING_DEFAULT_FLOAT_TYPES:
+ return lambda: 0.0
+ if unwrapped_type is str:
+ return lambda: ""
+ if unwrapped_type is bytes:
+ return lambda: b""
+ return lambda: None
+
+
+def _resolve_missing_field_default(dc_field, fory, type_hints):
+ return resolve_missing_field_default(dc_field, fory, type_hints)
+
+
+def build_default_values_factory(fory, type_hints, dc_fields=()):
+ return {dc_field.name: _resolve_missing_field_default(dc_field, fory,
type_hints) for dc_field in dc_fields}
class DataClassSerializer(Serializer):
+ _BASIC_SERIALIZERS = (
+ BooleanSerializer,
+ ByteSerializer,
+ Int16Serializer,
+ Int32Serializer,
+ Int64Serializer,
+ Float32Serializer,
+ Float64Serializer,
+ StringSerializer,
+ )
+
def __init__(
self,
fory,
@@ -286,685 +359,187 @@ class DataClassSerializer(Serializer):
serializers: List[Serializer] = None,
nullable_fields: Dict[str, bool] = None,
dynamic_fields: Dict[str, bool] = None,
+ ref_fields: Dict[str, bool] = None,
):
super().__init__(fory, clz)
self._type_hints = get_type_hints(clz)
self._has_slots = hasattr(clz, "__slots__")
- # When field_names is explicitly passed (from
TypeDef.create_serializer during schema evolution),
- # use those fields instead of extracting from the class. This is
critical for schema evolution
- # where the sender's schema (in TypeDef) differs from the receiver's
registered class.
- # Track whether field order comes from wire (TypeDef) - don't re-sort
these
self._fields_from_typedef = field_names is not None and serializers is
not None
if self._fields_from_typedef:
- # Use the passed-in field_names and serializers from TypeDef
- self._field_names = field_names
- self._serializers = serializers
+ self._field_names = list(field_names)
+ self._serializers = list(serializers)
self._nullable_fields = nullable_fields or {}
- self._ref_fields = {}
+ self._ref_fields = ref_fields or {}
self._dynamic_fields = dynamic_fields or {}
self._field_infos = []
self._field_metas = {}
else:
- # Extract field infos using new pyfory.field() metadata
self._field_infos, self._field_metas = _extract_field_infos(fory,
clz, self._type_hints)
-
if self._field_infos:
- # Use new field info based approach
self._field_names = [fi.name for fi in self._field_infos]
self._serializers = [fi.serializer for fi in self._field_infos]
self._nullable_fields = {fi.name: fi.nullable for fi in
self._field_infos}
self._ref_fields = {fi.name: fi.runtime_ref_tracking for fi in
self._field_infos}
self._dynamic_fields = {fi.name: fi.dynamic for fi in
self._field_infos}
else:
- # Fallback for non-dataclass types
self._field_names = field_names or self._get_field_names(clz)
self._nullable_fields = nullable_fields or {}
self._ref_fields = {}
- self._dynamic_fields = {} # Empty dict, will use mode defaults
-
+ self._dynamic_fields = {}
if self._field_names and not self._nullable_fields:
for field_name in self._field_names:
if field_name in self._type_hints:
unwrapped_type, is_optional =
unwrap_optional(self._type_hints[field_name])
- is_nullable = is_optional or not
is_primitive_type(unwrapped_type)
- self._nullable_fields[field_name] = is_nullable
-
+ self._nullable_fields[field_name] = is_optional or
not is_primitive_type(unwrapped_type)
self._serializers = serializers or [None] *
len(self._field_names)
if serializers is None:
visitor = StructFieldSerializerVisitor(fory)
for index, key in enumerate(self._field_names):
unwrapped_type, _ =
unwrap_optional(self._type_hints.get(key, typing.Any))
- serializer = infer_field(key, unwrapped_type, visitor,
types_path=[])
- self._serializers[index] = serializer
+ self._serializers[index] = infer_field(key,
unwrapped_type, visitor, types_path=[])
- # Cache unwrapped type hints
self._unwrapped_hints = self._compute_unwrapped_hints()
-
- # Compute struct hash and field order.
- # If fields come from TypeDef (wire schema), preserve field order and
only compute hash.
if self._fields_from_typedef:
hash_str = compute_struct_fingerprint(fory.type_resolver,
self._field_names, self._serializers, self._nullable_fields, self._field_infos)
hash_bytes = hash_str.encode("utf-8")
if len(hash_bytes) == 0:
self._hash = 47
else:
- from pyfory.lib.mmh3 import hash_buffer
-
full_hash = hash_buffer(hash_bytes, seed=47)[0]
type_hash_32 = full_hash & 0xFFFFFFFF
if full_hash & 0x80000000:
- type_hash_32 = type_hash_32 - 0x100000000
+ type_hash_32 -= 0x100000000
self._hash = type_hash_32
else:
self._hash, self._field_names, self._serializers =
compute_struct_meta(
fory.type_resolver, self._field_names, self._serializers,
self._nullable_fields, self._field_infos
)
- self._generated_write_method = self._gen_generated_write_method()
- self._generated_read_method = self._gen_generated_read_method()
- if _ENABLE_FORY_PYTHON_JIT:
- self.write = self._generated_write_method
- self.read = self._generated_read_method
+ self._field_name_interned = {name: sys.intern(name) for name in
self._field_names}
+ self._current_class_field_names =
set(self._get_field_names(self.type_))
+ self._default_values_factory = (
+ build_default_values_factory(self.fory, self._type_hints,
dataclasses.fields(self.type_)) if dataclasses.is_dataclass(self.type_) else {}
+ )
+ self._missing_field_defaults = self._build_missing_field_defaults()
+ self._basic_field_flags = [
+ (not self._dynamic_fields.get(field_name, False)) and
isinstance(self._serializers[index], self._BASIC_SERIALIZERS)
+ for index, field_name in enumerate(self._field_names)
+ ]
def _get_field_names(self, clz):
if hasattr(clz, "__dict__"):
- # Regular object with __dict__
- # For dataclasses, preserve field definition order
- # In compatible mode, stable field ordering is critical for schema
evolution
if dataclasses.is_dataclass(clz):
- # Use dataclasses.fields() to get fields in definition order
return [field.name for field in dataclasses.fields(clz)]
- # For non-dataclass objects, sort by key names for consistency
return sorted(self._type_hints.keys())
- elif hasattr(clz, "__slots__"):
- # Object with __slots__
- return sorted(clz.__slots__)
+ if hasattr(clz, "__slots__"):
+ slots = clz.__slots__
+ if isinstance(slots, str):
+ return [slots]
+ return sorted(slots)
return []
def _compute_unwrapped_hints(self):
- """Compute unwrapped type hints once and cache."""
- from pyfory.type_util import unwrap_optional
-
return {field_name: unwrap_optional(hint)[0] for field_name, hint in
self._type_hints.items()}
- def _write_header(self, buffer):
- """Write serialization header (hash or field count based on compatible
mode)."""
- if not self.fory.compatible:
- buffer.write_int32(self._hash)
- else:
- buffer.write_var_uint32(len(self._field_names))
-
- def _read_header(self, buffer):
- """Read serialization header and return number of fields written.
-
- Returns:
- int: Number of fields that were written
-
- Raises:
- TypeNotCompatibleError: If hash doesn't match in non-compatible
mode
- """
- if not self.fory.compatible:
- hash_ = buffer.read_int32()
- expected_hash = self._hash
- if hash_ != expected_hash:
- raise TypeNotCompatibleError(f"Hash {hash_} is not consistent
with {expected_hash} for type {self.type_}")
- return len(self._field_names)
- else:
- return buffer.read_var_uint32()
-
- def _get_write_stmt_for_codegen(self, serializer, buffer, field_value):
- """Generate write statement for code generation based on serializer
type."""
- if isinstance(serializer, BooleanSerializer):
- return f"{buffer}.write_bool({field_value})"
- elif isinstance(serializer, ByteSerializer):
- return f"{buffer}.write_int8({field_value})"
- elif isinstance(serializer, Int16Serializer):
- return f"{buffer}.write_int16({field_value})"
- elif isinstance(serializer, Int32Serializer):
- return f"{buffer}.write_varint32({field_value})"
- elif isinstance(serializer, Int64Serializer):
- return f"{buffer}.write_varint64({field_value})"
- elif isinstance(serializer, Float32Serializer):
- return f"{buffer}.write_float32({field_value})"
- elif isinstance(serializer, Float64Serializer):
- return f"{buffer}.write_float64({field_value})"
- elif isinstance(serializer, StringSerializer):
- return f"{buffer}.write_string({field_value})"
- else:
- return None # Complex type, needs ref handling
-
- def _get_read_stmt_for_codegen(self, serializer, buffer, field_value):
- """Generate read statement for code generation based on serializer
type."""
- if isinstance(serializer, BooleanSerializer):
- return f"{field_value} = {buffer}.read_bool()"
- elif isinstance(serializer, ByteSerializer):
- return f"{field_value} = {buffer}.read_int8()"
- elif isinstance(serializer, Int16Serializer):
- return f"{field_value} = {buffer}.read_int16()"
- elif isinstance(serializer, Int32Serializer):
- return f"{field_value} = {buffer}.read_varint32()"
- elif isinstance(serializer, Int64Serializer):
- return f"{field_value} = {buffer}.read_varint64()"
- elif isinstance(serializer, Float32Serializer):
- return f"{field_value} = {buffer}.read_float32()"
- elif isinstance(serializer, Float64Serializer):
- return f"{field_value} = {buffer}.read_float64()"
- elif isinstance(serializer, StringSerializer):
- return f"{field_value} = {buffer}.read_string()"
- else:
- return None # Complex type, needs ref handling
-
- def _write_non_nullable_field(self, buffer, field_value, serializer,
typeinfo=None):
- """Write a non-nullable field value at runtime."""
- if isinstance(serializer, BooleanSerializer):
- buffer.write_bool(field_value)
- elif isinstance(serializer, ByteSerializer):
- buffer.write_int8(field_value)
- elif isinstance(serializer, Int16Serializer):
- buffer.write_int16(field_value)
- elif isinstance(serializer, Int32Serializer):
- buffer.write_varint32(field_value)
- elif isinstance(serializer, Int64Serializer):
- buffer.write_varint64(field_value)
- elif isinstance(serializer, Float32Serializer):
- buffer.write_float32(field_value)
- elif isinstance(serializer, Float64Serializer):
- buffer.write_float64(field_value)
- elif isinstance(serializer, StringSerializer):
- buffer.write_string(field_value)
- else:
- self.fory.write_ref_pyobject(buffer, field_value,
typeinfo=typeinfo)
-
- def _read_non_nullable_field(self, buffer, serializer):
- """Read a non-nullable field value at runtime."""
- if isinstance(serializer, BooleanSerializer):
- return buffer.read_bool()
- elif isinstance(serializer, ByteSerializer):
- return buffer.read_int8()
- elif isinstance(serializer, Int16Serializer):
- return buffer.read_int16()
- elif isinstance(serializer, Int32Serializer):
- return buffer.read_varint32()
- elif isinstance(serializer, Int64Serializer):
- return buffer.read_varint64()
- elif isinstance(serializer, Float32Serializer):
- return buffer.read_float32()
- elif isinstance(serializer, Float64Serializer):
- return buffer.read_float64()
- elif isinstance(serializer, StringSerializer):
- return buffer.read_string()
- else:
- return self.fory.read_ref_pyobject(buffer)
-
- def _write_nullable_field(self, buffer, field_value, serializer,
typeinfo=None):
- """Write a nullable field value at runtime."""
- if field_value is None:
- buffer.write_int8(NULL_FLAG)
- else:
- buffer.write_int8(NOT_NULL_VALUE_FLAG)
- if isinstance(serializer, StringSerializer):
- buffer.write_string(field_value)
- else:
- self.fory.write_ref_pyobject(buffer, field_value,
typeinfo=typeinfo)
-
- def _read_nullable_field(self, buffer, serializer):
- """Read a nullable field value at runtime."""
- flag = buffer.read_int8()
- if flag == NULL_FLAG:
- return None
- else:
- if isinstance(serializer, StringSerializer):
- return buffer.read_string()
- else:
- return self.fory.read_ref_pyobject(buffer)
-
- def _gen_write_method(self):
- context = {}
- counter = itertools.count(0)
- buffer, fory, value, value_dict = "buffer", "fory", "value",
"value_dict"
- context[fory] = self.fory
- context["_serializers"] = self._serializers
-
- stmts = [
- f'"""write method for {self.type_}"""',
- ]
-
- # Write hash only in non-compatible mode; in compatible mode, write
field count
- if not self.fory.compatible:
- stmts.append(f"{buffer}.write_int32({self._hash})")
- else:
-
stmts.append(f"{buffer}.write_var_uint32({len(self._field_names)})")
-
- if not self._has_slots:
- stmts.append(f"{value_dict} = {value}.__dict__")
-
- # Write field values in order
- for index, field_name in enumerate(self._field_names):
- field_value = f"field_value{next(counter)}"
- serializer_var = f"serializer{index}"
- serializer = self._serializers[index]
- context[serializer_var] = serializer
-
- if not self._has_slots:
- stmts.append(f"{field_value} = {value_dict}['{field_name}']")
- else:
- stmts.append(f"{field_value} = {value}.{field_name}")
+ def _build_missing_field_defaults(self):
+ if not self.fory.compatible or not self._default_values_factory:
+ return []
+ missing_fields = self._current_class_field_names -
set(self._field_names)
+ if not missing_fields:
+ return []
+ return [(field_name, default_factory) for field_name, default_factory
in self._default_values_factory.items() if field_name in missing_fields]
- is_nullable = self._nullable_fields.get(field_name, False)
- is_dynamic = self._dynamic_fields.get(field_name, False)
- # For dynamic=False, get typeinfo for declared type to use its
serializer
- typeinfo_var = f"typeinfo{index}"
- if not is_dynamic and serializer is not None:
- context[typeinfo_var] =
self.fory.type_resolver.get_type_info(serializer.type_)
+ def _write_field_value(self, buffer, serializer, field_value, is_nullable,
is_dynamic, is_basic, is_tracking_ref):
+ if is_basic:
if is_nullable:
- # Use gen_write_nullable_basic_stmts for nullable basic types
- if isinstance(serializer, BooleanSerializer):
- stmts.extend(gen_write_nullable_basic_stmts(buffer,
field_value, bool))
- elif isinstance(serializer, ByteSerializer):
- stmts.extend(gen_write_nullable_basic_stmts(buffer,
field_value, "int8"))
- elif isinstance(serializer, Int16Serializer):
- stmts.extend(gen_write_nullable_basic_stmts(buffer,
field_value, "int16"))
- elif isinstance(serializer, Int32Serializer):
- stmts.extend(gen_write_nullable_basic_stmts(buffer,
field_value, "int32"))
- elif isinstance(serializer, Int64Serializer):
- stmts.extend(gen_write_nullable_basic_stmts(buffer,
field_value, "int64"))
- elif isinstance(serializer, Float32Serializer):
- stmts.extend(gen_write_nullable_basic_stmts(buffer,
field_value, "float32"))
- elif isinstance(serializer, Float64Serializer):
- stmts.extend(gen_write_nullable_basic_stmts(buffer,
field_value, "float64"))
- elif isinstance(serializer, StringSerializer):
- stmts.extend(gen_write_nullable_basic_stmts(buffer,
field_value, str))
+ if field_value is None:
+ buffer.write_int8(NULL_FLAG)
else:
- # For complex types, use write_ref_pyobject
- # dynamic=True or serializer is None: pass None to use
actual type
- # dynamic=False: pass typeinfo to use declared type
- typeinfo_arg = "None" if is_dynamic or serializer is None
else typeinfo_var
- stmts.append(f"{fory}.write_ref_pyobject({buffer},
{field_value}, typeinfo={typeinfo_arg})")
+ buffer.write_int8(NOT_NULL_VALUE_FLAG)
+ serializer.write(buffer, field_value)
else:
- stmt = self._get_write_stmt_for_codegen(serializer, buffer,
field_value)
- if stmt is None:
- # dynamic=True or serializer is None: pass None to use
actual type
- # dynamic=False: pass typeinfo to use declared type
- typeinfo_arg = "None" if is_dynamic or serializer is None
else typeinfo_var
- stmt = f"{fory}.write_ref_pyobject({buffer},
{field_value}, typeinfo={typeinfo_arg})"
- stmts.append(stmt)
-
- self._write_method_code, func = compile_function(
-
f"write_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"),
- [buffer, value],
- stmts,
- context,
- )
- return func
-
- def _gen_read_method(self):
- context = dict(_jit_context)
- buffer, fory, obj_class, obj, obj_dict = (
- "buffer",
- "fory",
- "obj_class",
- "obj",
- "obj_dict",
- )
- ref_resolver = "ref_resolver"
- context[fory] = self.fory
- context[obj_class] = self.type_
- context[ref_resolver] = self.fory.ref_resolver
- context["_serializers"] = self._serializers
- current_class_field_names = set(self._get_field_names(self.type_))
-
- stmts = [
- f'"""read method for {self.type_}"""',
- ]
- if not self.fory.strict:
- context["checker"] = self.fory.policy
- stmts.append(f"checker.authorize_instantiation({obj_class})")
-
- # Read hash only in non-compatible mode; in compatible mode, read
field count
- if not self.fory.compatible:
- stmts.extend(
- [
- f"read_hash = {buffer}.read_int32()",
- f"if read_hash != {self._hash}:",
- f""" raise TypeNotCompatibleError(
- f"Hash {{read_hash}} is not consistent with {self._hash} for type
{self.type_}")""",
- ]
- )
+ serializer.write(buffer, field_value)
+ return
+ if is_tracking_ref:
+ self.fory.write_ref(buffer, field_value, serializer=None if
is_dynamic else serializer)
+ return
+ if is_nullable:
+ if field_value is None:
+ buffer.write_int8(NULL_FLAG)
+ return
+ buffer.write_int8(NOT_NULL_VALUE_FLAG)
+ if is_dynamic:
+ self.fory.write_no_ref(buffer, field_value)
else:
- stmts.append(f"num_fields_written = {buffer}.read_var_uint32()")
-
- stmts.extend(
- [
- f"{obj} = {obj_class}.__new__({obj_class})",
- f"{ref_resolver}.reference({obj})",
- ]
- )
-
- if not self._has_slots:
- stmts.append(f"{obj_dict} = {obj}.__dict__")
-
- # Read field values in order
- for index, field_name in enumerate(self._field_names):
- serializer_var = f"serializer{index}"
- serializer = self._serializers[index]
- context[serializer_var] = serializer
- field_value = f"field_value{index}"
- is_nullable = self._nullable_fields.get(field_name, False)
-
- # Build field reading statements
- field_stmts = []
-
- if is_nullable:
- # Use gen_read_nullable_basic_stmts for nullable basic types
- if isinstance(serializer, BooleanSerializer):
- field_stmts.extend(gen_read_nullable_basic_stmts(buffer,
bool, lambda v: f"{field_value} = {v}"))
- elif isinstance(serializer, ByteSerializer):
- field_stmts.extend(gen_read_nullable_basic_stmts(buffer,
"int8", lambda v: f"{field_value} = {v}"))
- elif isinstance(serializer, Int16Serializer):
- field_stmts.extend(gen_read_nullable_basic_stmts(buffer,
"int16", lambda v: f"{field_value} = {v}"))
- elif isinstance(serializer, Int32Serializer):
- field_stmts.extend(gen_read_nullable_basic_stmts(buffer,
"int32", lambda v: f"{field_value} = {v}"))
- elif isinstance(serializer, Int64Serializer):
- field_stmts.extend(gen_read_nullable_basic_stmts(buffer,
"int64", lambda v: f"{field_value} = {v}"))
- elif isinstance(serializer, Float32Serializer):
- field_stmts.extend(gen_read_nullable_basic_stmts(buffer,
"float32", lambda v: f"{field_value} = {v}"))
- elif isinstance(serializer, Float64Serializer):
- field_stmts.extend(gen_read_nullable_basic_stmts(buffer,
"float64", lambda v: f"{field_value} = {v}"))
- elif isinstance(serializer, StringSerializer):
- field_stmts.extend(gen_read_nullable_basic_stmts(buffer,
str, lambda v: f"{field_value} = {v}"))
- else:
- # For complex types, use read_ref_pyobject
- field_stmts.append(f"{field_value} =
{fory}.read_ref_pyobject({buffer})")
- else:
- stmt = self._get_read_stmt_for_codegen(serializer, buffer,
field_value)
- if stmt is None:
- stmt = f"{field_value} =
{fory}.read_ref_pyobject({buffer})"
- field_stmts.append(stmt)
-
- # Set field value if it exists in current class
- if field_name not in current_class_field_names:
- field_stmts.append(f"# {field_name} is not in {self.type_}")
- else:
- if not self._has_slots:
- field_stmts.append(f"{obj_dict}['{field_name}'] =
{field_value}")
- else:
- field_stmts.append(f"{obj}.{field_name} = {field_value}")
-
- # In compatible mode, wrap field reading in a check
- if self.fory.compatible:
- stmts.append(f"if {index} < num_fields_written:")
- # Indent all field statements
- from pyfory.codegen import ident_lines
-
- field_stmts = ident_lines(field_stmts)
- stmts.extend(field_stmts)
- else:
- stmts.extend(field_stmts)
-
- stmts.append(f"return {obj}")
- self._read_method_code, func = compile_function(
-
f"read_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"),
- [buffer],
- stmts,
- context,
- )
- return func
-
- def _gen_generated_write_method(self):
- """Generate JIT-compiled write method.
-
- Per xlang spec, struct format is:
- - Schema consistent mode: |4-byte hash|field values|
- - Schema evolution mode (compatible): |field values| (no field count
prefix!)
- The field count is in TypeDef meta written at the end, not in object
data.
- """
- context = {}
- counter = itertools.count(0)
- buffer, fory, value, value_dict = "buffer", "fory", "value",
"value_dict"
- context[fory] = self.fory
- context["_serializers"] = self._serializers
- stmts = [
- f'"""write method for {self.type_}"""',
- ]
- if not self.fory.compatible:
- stmts.append(f"{buffer}.write_int32({self._hash})")
- if not self._has_slots:
- stmts.append(f"{value_dict} = {value}.__dict__")
- for index, field_name in enumerate(self._field_names):
- field_value = f"field_value{next(counter)}"
- serializer_var = f"serializer{index}"
- serializer = self._serializers[index]
- context[serializer_var] = serializer
- is_nullable = self._nullable_fields.get(field_name, False)
- # For schema evolution: use safe access with None default to handle
- # cases where the field might not exist on the object (missing
from remote schema)
- # In compatible mode, always use safe access even for non-nullable
fields
- if not self._has_slots:
- if is_nullable or self.fory.compatible:
- stmts.append(f"{field_value} =
{value_dict}.get('{field_name}')")
- else:
- stmts.append(f"{field_value} =
{value_dict}['{field_name}']")
- else:
- if is_nullable or self.fory.compatible:
- stmts.append(f"{field_value} = getattr({value},
'{field_name}', None)")
- else:
- stmts.append(f"{field_value} = {value}.{field_name}")
- is_dynamic = self._dynamic_fields.get(field_name, False)
- if is_nullable:
- if isinstance(serializer, StringSerializer):
- stmts.extend(
- [
- f"if {field_value} is None:",
- f" {buffer}.write_int8({NULL_FLAG})",
- "else:",
- f" {buffer}.write_int8({NOT_NULL_VALUE_FLAG})",
- f" {buffer}.write_string({field_value})",
- ]
- )
- else:
- # dynamic=True: don't pass serializer, write actual type
info
- # dynamic=False: pass serializer, use declared type
- serializer_arg = "None" if is_dynamic else serializer_var
- stmts.append(f"{fory}.write_ref({buffer}, {field_value},
serializer={serializer_arg})")
- else:
- stmt = self._get_write_stmt_for_codegen(serializer, buffer,
field_value)
- if stmt is None:
- # For non-nullable complex types, use write_no_ref
- # dynamic=True: don't pass serializer, write actual type
info
- if is_dynamic:
- stmt = f"{fory}.write_no_ref({buffer}, {field_value})"
- else:
- stmt = f"{fory}.write_no_ref({buffer}, {field_value},
serializer={serializer_var})"
- # In compatible mode, handle None for non-nullable fields
(schema evolution)
- # Write zero/default value when field is None due to missing
from remote schema
- if self.fory.compatible:
- from pyfory.serializer import EnumSerializer
-
- if isinstance(serializer, EnumSerializer):
- # For enums, write ordinal 0 when None
- stmts.extend(
- [
- f"if {field_value} is None:",
- f" {buffer}.write_var_uint32(0)",
- "else:",
- f" {stmt}",
- ]
- )
- else:
- stmts.append(stmt)
- else:
- stmts.append(stmt)
- self._generated_write_method_code, func = compile_function(
-
f"write_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"),
- [buffer, value],
- stmts,
- context,
- )
- return func
-
- def _gen_generated_read_method(self):
- """Generate JIT-compiled read method.
-
- Per xlang spec, struct format is:
- - Schema consistent mode: |4-byte hash|field values|
- - Schema evolution mode (compatible): |field values| (no field count
prefix!)
- The field count is in TypeDef meta written at the end, not in object
data.
- """
- context = dict(_jit_context)
- buffer, fory, obj_class, obj, obj_dict = (
- "buffer",
- "fory",
- "obj_class",
- "obj",
- "obj_dict",
- )
- ref_resolver = "ref_resolver"
- context[fory] = self.fory
- context[obj_class] = self.type_
- context[ref_resolver] = self.fory.ref_resolver
- context["_serializers"] = self._serializers
- current_class_field_names = set(self._get_field_names(self.type_))
- stmts = [
- f'"""read method for {self.type_}"""',
- ]
- if not self.fory.strict:
- context["checker"] = self.fory.policy
- stmts.append(f"checker.authorize_instantiation({obj_class})")
- if not self.fory.compatible:
- stmts.extend(
- [
- f"read_hash = {buffer}.read_int32()",
- f"if read_hash != {self._hash}:",
- f""" raise TypeNotCompatibleError(
- f"Hash {{read_hash}} is not consistent with {self._hash} for
type {self.type_}")""",
- ]
- )
- stmts.extend(
- [
- f"{obj} = {obj_class}.__new__({obj_class})",
- f"{ref_resolver}.reference({obj})",
- ]
- )
-
- if not self._has_slots:
- stmts.append(f"{obj_dict} = {obj}.__dict__")
-
- for index, field_name in enumerate(self._field_names):
- serializer_var = f"serializer{index}"
- serializer = self._serializers[index]
- context[serializer_var] = serializer
- field_value = f"field_value{index}"
- is_nullable = self._nullable_fields.get(field_name, False)
-
- is_dynamic = self._dynamic_fields.get(field_name, False)
- if is_nullable:
- if isinstance(serializer, StringSerializer):
- stmts.extend(
- [
- f"if {buffer}.read_int8() >=
{NOT_NULL_VALUE_FLAG}:",
- f" {field_value} = {buffer}.read_string()",
- "else:",
- f" {field_value} = None",
- ]
- )
- else:
- # dynamic=True: don't pass serializer, read type info from
buffer
- # dynamic=False: pass serializer, use declared type
- serializer_arg = "None" if is_dynamic else serializer_var
- stmts.append(f"{field_value} = {fory}.read_ref({buffer},
serializer={serializer_arg})")
- else:
- stmt = self._get_read_stmt_for_codegen(serializer, buffer,
field_value)
- if stmt is None:
- # For non-nullable complex types, use read_no_ref
- # dynamic=True: don't pass serializer, read type info from
buffer
- if is_dynamic:
- stmt = f"{field_value} = {fory}.read_no_ref({buffer})"
- else:
- stmt = f"{field_value} = {fory}.read_no_ref({buffer},
serializer={serializer_var})"
- stmts.append(stmt)
-
- if field_name not in current_class_field_names:
- stmts.append(f"# {field_name} is not in {self.type_}")
- elif not self._has_slots:
- stmts.append(f"{obj_dict}['{field_name}'] = {field_value}")
- else:
- stmts.append(f"{obj}.{field_name} = {field_value}")
-
- # For schema evolution: initialize missing fields with default values
- # This handles cases where the sender's schema has fewer fields than
the receiver's
- if self.fory.compatible:
- read_field_names = set(self._field_names)
- missing_fields = current_class_field_names - read_field_names
- if missing_fields and dataclasses.is_dataclass(self.type_):
- for dc_field in dataclasses.fields(self.type_):
- if dc_field.name in missing_fields:
- if dc_field.default is not dataclasses.MISSING:
- default_val = repr(dc_field.default)
- if not self._has_slots:
- stmts.append(f"{obj_dict}['{dc_field.name}'] =
{default_val}")
- else:
- stmts.append(f"{obj}.{dc_field.name} =
{default_val}")
- elif dc_field.default_factory is not
dataclasses.MISSING:
- factory_var = f"_default_factory_{dc_field.name}"
- context[factory_var] = dc_field.default_factory
- if not self._has_slots:
- stmts.append(f"{obj_dict}['{dc_field.name}'] =
{factory_var}()")
- else:
- stmts.append(f"{obj}.{dc_field.name} =
{factory_var}()")
- # else: field has no default, leave it unset
-
- stmts.append(f"return {obj}")
- self._generated_read_method_code, func = compile_function(
-
f"read_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"),
- [buffer],
- stmts,
- context,
- )
- return func
+ self.fory.write_no_ref(buffer, field_value, serializer=serializer)
+
+ def _read_field_value(self, buffer, serializer, is_nullable, is_dynamic,
is_basic, is_tracking_ref):
+ if is_nullable and is_basic:
+ if buffer.read_int8() == NULL_FLAG:
+ return None
+ return serializer.read(buffer)
+ if is_basic:
+ return serializer.read(buffer)
+ if is_tracking_ref:
+ return self.fory.read_ref(buffer, serializer=None if is_dynamic
else serializer)
+ if is_nullable and buffer.read_int8() == NULL_FLAG:
+ return None
+ if is_dynamic:
+ return self.fory.read_no_ref(buffer)
+ return self.fory.read_no_ref(buffer, serializer=serializer)
def write(self, buffer: Buffer, value):
- """Write dataclass instance to buffer.
-
- Struct format:
- - Schema consistent mode: |4-byte hash|field values|
- - Schema evolution mode (compatible): |field values| (no field count
prefix!)
- The field count is in TypeDef meta written at the end, not in object
data.
- """
if not self.fory.compatible:
buffer.write_int32(self._hash)
- for index, field_name in enumerate(self._field_names):
- field_value = getattr(value, field_name)
- serializer = self._serializers[index]
- is_nullable = self._nullable_fields.get(field_name, False)
- is_dynamic = self._dynamic_fields.get(field_name, False)
- if _ENABLE_FORY_DEBUG_OUTPUT:
- print(
- f"write field '{field_name}': {field_value!r},
writer_index={buffer.get_writer_index()}, "
- f"nullable={is_nullable}, dynamic={is_dynamic},
serializer={serializer}"
- )
- if is_nullable:
- if field_value is None:
- buffer.write_int8(-3)
- else:
- # dynamic=True: don't pass serializer, write actual type
info
- # dynamic=False: pass serializer, use declared type
- self.fory.write_ref(buffer, field_value, serializer=None
if is_dynamic else serializer)
+ value_dict = value.__dict__ if not self._has_slots else None
+ if value_dict is not None:
+ if self.fory.compatible:
+ for index, field_name in enumerate(self._field_names):
+ interned_name = self._field_name_interned[field_name]
+ field_value = value_dict.get(interned_name)
+ serializer = self._serializers[index]
+ is_nullable = self._nullable_fields.get(field_name, False)
+ is_dynamic = self._dynamic_fields.get(field_name, False)
+ is_tracking_ref = self._ref_fields.get(field_name, False)
+ is_basic = self._basic_field_flags[index]
+ self._write_field_value(buffer, serializer, field_value,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
else:
- if is_dynamic:
- self.fory.write_no_ref(buffer, field_value)
- else:
- self.fory.write_no_ref(buffer, field_value,
serializer=serializer)
+ for index, field_name in enumerate(self._field_names):
+ interned_name = self._field_name_interned[field_name]
+ field_value = value_dict[interned_name]
+ serializer = self._serializers[index]
+ is_nullable = self._nullable_fields.get(field_name, False)
+ is_dynamic = self._dynamic_fields.get(field_name, False)
+ is_tracking_ref = self._ref_fields.get(field_name, False)
+ is_basic = self._basic_field_flags[index]
+ self._write_field_value(buffer, serializer, field_value,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
+ else:
+ if self.fory.compatible:
+ for index, field_name in enumerate(self._field_names):
+ interned_name = self._field_name_interned[field_name]
+ field_value = getattr(value, interned_name, None)
+ serializer = self._serializers[index]
+ is_nullable = self._nullable_fields.get(field_name, False)
+ is_dynamic = self._dynamic_fields.get(field_name, False)
+ is_tracking_ref = self._ref_fields.get(field_name, False)
+ is_basic = self._basic_field_flags[index]
+ self._write_field_value(buffer, serializer, field_value,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
+ else:
+ for index, field_name in enumerate(self._field_names):
+ interned_name = self._field_name_interned[field_name]
+ field_value = getattr(value, interned_name)
+ serializer = self._serializers[index]
+ is_nullable = self._nullable_fields.get(field_name, False)
+ is_dynamic = self._dynamic_fields.get(field_name, False)
+ is_tracking_ref = self._ref_fields.get(field_name, False)
+ is_basic = self._basic_field_flags[index]
+ self._write_field_value(buffer, serializer, field_value,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
def read(self, buffer):
- """Read dataclass instance from buffer.
-
- Struct format:
- - Schema consistent mode: |4-byte hash|field values|
- - Schema evolution mode (compatible): |field values| (no field count
prefix!)
- The field count is in TypeDef meta written at the end, not in object
data.
- """
+ if not self.fory.strict:
+ self.fory.policy.authorize_instantiation(self.type_)
if not self.fory.compatible:
hash_ = buffer.read_int32()
if hash_ != self._hash:
@@ -973,46 +548,29 @@ class DataClassSerializer(Serializer):
)
obj = self.type_.__new__(self.type_)
self.fory.ref_resolver.reference(obj)
- current_class_field_names = set(self._get_field_names(self.type_))
- read_field_names = set()
+ obj_dict = obj.__dict__ if not self._has_slots else None
for index, field_name in enumerate(self._field_names):
serializer = self._serializers[index]
is_nullable = self._nullable_fields.get(field_name, False)
is_dynamic = self._dynamic_fields.get(field_name, False)
- if _ENABLE_FORY_DEBUG_OUTPUT:
- print(
- f"read field '{field_name}':
reader_index={buffer.get_reader_index()}, "
- f"nullable={is_nullable}, dynamic={is_dynamic},
serializer={serializer}"
- )
- if is_nullable:
- ref_id = buffer.read_int8()
- if ref_id == -3:
- field_value = None
- else:
- buffer.set_reader_index(buffer.get_reader_index() - 1)
- # dynamic=True: don't pass serializer, read type info from
buffer
- # dynamic=False: pass serializer, use declared type
- field_value = self.fory.read_ref(buffer, serializer=None
if is_dynamic else serializer)
+ is_tracking_ref = self._ref_fields.get(field_name, False)
+ is_basic = self._basic_field_flags[index]
+ field_value = self._read_field_value(buffer, serializer,
is_nullable, is_dynamic, is_basic, is_tracking_ref)
+ if field_name not in self._current_class_field_names:
+ continue
+ interned_name = self._field_name_interned[field_name]
+ if obj_dict is not None:
+ obj_dict[interned_name] = field_value
else:
- if is_dynamic:
- field_value = self.fory.read_no_ref(buffer)
+ setattr(obj, interned_name, field_value)
+
+ if self._missing_field_defaults:
+ for field_name, default_factory in self._missing_field_defaults:
+ value = default_factory()
+ if obj_dict is not None:
+ obj_dict[field_name] = value
else:
- field_value = self.fory.read_no_ref(buffer,
serializer=serializer)
- if field_name in current_class_field_names:
- setattr(obj, field_name, field_value)
- read_field_names.add(field_name)
- # For schema evolution: initialize missing fields with default values
- # This handles cases where the sender's schema has fewer fields than
the receiver's
- if self.fory.compatible:
- missing_fields = current_class_field_names - read_field_names
- if missing_fields and dataclasses.is_dataclass(self.type_):
- for dc_field in dataclasses.fields(self.type_):
- if dc_field.name in missing_fields:
- if dc_field.default is not dataclasses.MISSING:
- setattr(obj, dc_field.name, dc_field.default)
- elif dc_field.default_factory is not
dataclasses.MISSING:
- setattr(obj, dc_field.name,
dc_field.default_factory())
- # else: field has no default, leave it unset (will be
None for nullable)
+ setattr(obj, field_name, value)
return obj
@@ -1032,6 +590,21 @@ class DataClassStubSerializer(DataClassSerializer):
return typeinfo.serializer
+PythonDataClassSerializer = DataClassSerializer
+try:
+ from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION
+except ImportError:
+ ENABLE_FORY_CYTHON_SERIALIZATION = False
+
+if ENABLE_FORY_CYTHON_SERIALIZATION:
+ try:
+ from pyfory.serialization import DataClassSerializer as
_CythonDataClassSerializer
+
+ DataClassSerializer = _CythonDataClassSerializer
+ except ImportError:
+ DataClassSerializer = PythonDataClassSerializer
+
+
basic_types = {
bool,
# Signed integers
diff --git a/python/pyfory/tests/test_meta_share.py
b/python/pyfory/tests/test_meta_share.py
index 7e2c9630c..67f68d721 100644
--- a/python/pyfory/tests/test_meta_share.py
+++ b/python/pyfory/tests/test_meta_share.py
@@ -199,7 +199,7 @@ class TestMetaShareMode:
assert deserialized.name == obj.name
assert deserialized.age == obj.age
assert deserialized.active == obj.active
- assert not hasattr(deserialized, "email")
+ assert deserialized.email == ""
def test_schema_evolution_fewer_fields(self):
# Serialize with original schema
diff --git a/python/pyfory/tests/test_ref_tracking.py
b/python/pyfory/tests/test_ref_tracking.py
index beecc9b2f..66e695034 100644
--- a/python/pyfory/tests/test_ref_tracking.py
+++ b/python/pyfory/tests/test_ref_tracking.py
@@ -16,11 +16,7 @@
# under the License.
from dataclasses import dataclass
-import os
-import subprocess
-import sys
-import textwrap
-from typing import Any
+from typing import Any, List
import pytest
@@ -54,6 +50,40 @@ class RefNode:
self_ref: Any = pyfory.field(default=None, ref=True, nullable=True)
+@dataclass
+class RefOverrideDisabled:
+ left: Any = pyfory.field(default=None, ref=False, nullable=True)
+ right: Any = pyfory.field(default=None, ref=False, nullable=True)
+
+
+@dataclass
+class RefOverrideEnabled:
+ left: Any = pyfory.field(default=None, ref=True, nullable=True)
+ right: Any = pyfory.field(default=None, ref=True, nullable=True)
+
+
+@dataclass
+class FixedUint64Pair:
+ a: pyfory.fixed_uint64 = None
+ b: pyfory.fixed_uint64 = None
+
+
+@dataclass
+class Holder:
+ values: List[pyfory.int64]
+
+
+class EvilIndex:
+ def __init__(self):
+ self.owner = None
+
+ def __index__(self):
+ # Reallocate list storage and inject invalid element types.
+ self.owner.clear()
+ self.owner.extend([bytearray(16)] * 1024)
+ return 7
+
+
@pytest.mark.parametrize("xlang", [False, True])
def test_collection_list_mixed_type_shared_reference(xlang):
fory = pyfory.Fory(xlang=xlang, ref=True, strict=False)
@@ -137,6 +167,29 @@ def
test_struct_shared_fields_and_cross_container_alias_python_mode():
assert restored.mapping["alias"] is restored.left
[email protected]("xlang", [False, True])
+def test_struct_field_ref_override_controls_alias_preservation(xlang):
+ fory = pyfory.Fory(xlang=xlang, ref=True, strict=False)
+ if xlang:
+ fory.register_type(RefOverrideDisabled,
typename="example.RefOverrideDisabled")
+ fory.register_type(RefOverrideEnabled,
typename="example.RefOverrideEnabled")
+ else:
+ fory.register(RefOverrideDisabled)
+ fory.register(RefOverrideEnabled)
+
+ shared = {"v": [1, 2, 3]}
+
+ disabled = _roundtrip(fory, RefOverrideDisabled(shared, shared))
+ assert disabled.left == shared
+ assert disabled.right == shared
+ assert disabled.left is not disabled.right
+
+ enabled = _roundtrip(fory, RefOverrideEnabled(shared, shared))
+ assert enabled.left == shared
+ assert enabled.right == shared
+ assert enabled.left is enabled.right
+
+
def test_struct_self_cycle_and_nested_alias_python_mode():
fory = pyfory.Fory(xlang=False, ref=True, strict=False)
fory.register(RefNode)
@@ -211,51 +264,28 @@ def
test_invalid_collection_element_ref_id_raises_value_error():
fory.deserialize(payload)
-def test_primitive_list_fastpath_mutation_no_crash_subprocess():
- py_root = os.path.abspath(os.path.join(os.path.dirname(pyfory.__file__),
".."))
- env = os.environ.copy()
- env["ENABLE_FORY_CYTHON_SERIALIZATION"] = "1"
- env["PYTHONPATH"] = py_root
- code = textwrap.dedent(
- """
- from dataclasses import dataclass
- from typing import List
- import pyfory
-
- @dataclass
- class Holder:
- values: List[pyfory.int64]
-
- class EvilIndex:
- def __init__(self):
- self.owner = None
- def __index__(self):
- # Reallocate list storage and inject invalid element types.
- self.owner.clear()
- self.owner.extend([bytearray(16)] * 1024)
- return 7
-
- fory = pyfory.Fory(xlang=False, ref=True, strict=False)
- fory.register(Holder)
- for _ in range(10):
- lst = [EvilIndex() for _ in range(64)]
- for e in lst:
- e.owner = lst
- try:
- fory.serialize(Holder(values=lst))
- except TypeError:
- continue
- print("UNEXPECTED:SUCCESS")
- raise SystemExit(2)
- print("OK:TYPEERROR")
- """
- )
- proc = subprocess.run(
- [sys.executable, "-c", code],
- env=env,
- capture_output=True,
- text=True,
- check=False,
- )
- assert proc.returncode == 0, f"subprocess failed rc={proc.returncode},
stderr={proc.stderr}"
- assert "OK:TYPEERROR" in proc.stdout, proc.stdout
[email protected]("xlang", [False, True])
+def test_optional_fixed_uint64_roundtrip(xlang):
+ value = 1234567890123456789
+ fory = pyfory.Fory(xlang=xlang, ref=True, strict=False)
+ if xlang:
+ fory.register_type(FixedUint64Pair, typename="example.FixedUint64Pair")
+ else:
+ fory.register(FixedUint64Pair)
+
+ serializer = fory.type_resolver.get_serializer(pyfory.fixed_uint64)
+ assert serializer.need_to_write_ref is False
+ restored = _roundtrip(fory, FixedUint64Pair(value, value))
+ assert restored.a == value
+ assert restored.b == value
+
+
+def test_primitive_list_fastpath_mutation_typeerror():
+ fory = pyfory.Fory(xlang=False, ref=True, strict=False)
+ fory.register(Holder)
+ for _ in range(10):
+ lst = [EvilIndex() for _ in range(64)]
+ for element in lst:
+ element.owner = lst
+ with pytest.raises(TypeError):
+ fory.serialize(Holder(values=lst))
diff --git a/python/pyfory/tests/test_struct.py
b/python/pyfory/tests/test_struct.py
index 08df9f5a9..bfd06d22a 100644
--- a/python/pyfory/tests/test_struct.py
+++ b/python/pyfory/tests/test_struct.py
@@ -15,18 +15,19 @@
# specific language governing permissions and limitations
# under the License.
+import dataclasses
from dataclasses import dataclass
import datetime
+import enum
from typing import Dict, Any, List, Set, Optional
-import os
import pytest
import typing
import pyfory
from pyfory import Fory
from pyfory.error import TypeUnregisteredError
-from pyfory.struct import DataClassSerializer
+from pyfory.struct import DataClassSerializer, build_default_values_factory
from pyfory.types import TypeId
@@ -139,6 +140,11 @@ class DataClassObject:
)
+@dataclass
+class BoolCoercionObject:
+ b: bool
+
+
def test_sort_fields():
@dataclass
class TestClass:
@@ -174,6 +180,57 @@ def test_sort_fields():
assert serializer._field_names == ["f13", "f5", "f11", "f7", "f12", "f1",
"f4", "f15", "f6", "f10", "f2", "f14", "f3", "f9", "f8"]
[email protected](
+ "value, expected",
+ [
+ (1, True),
+ (0, False),
+ ],
+)
+def test_bool_field_coercion(value, expected):
+ fory = Fory(xlang=False, ref=True, strict=False)
+ result = ser_de(fory, BoolCoercionObject(value))
+ assert result.b is expected
+
+
+def test_bool_field_coercion_numpy_bool():
+ np = pytest.importorskip("numpy")
+ fory = Fory(xlang=False, ref=True, strict=False)
+
+ result_true = ser_de(fory, BoolCoercionObject(np.bool_(True)))
+ assert result_true.b is True
+
+ result_false = ser_de(fory, BoolCoercionObject(np.bool_(False)))
+ assert result_false.b is False
+
+
[email protected](
+ "numeric_type",
+ [
+ pyfory.int8,
+ pyfory.int16,
+ pyfory.int32,
+ pyfory.fixed_int32,
+ pyfory.int64,
+ pyfory.fixed_int64,
+ pyfory.tagged_int64,
+ pyfory.uint8,
+ pyfory.uint16,
+ pyfory.uint32,
+ pyfory.fixed_uint32,
+ pyfory.uint64,
+ pyfory.fixed_uint64,
+ pyfory.tagged_uint64,
+ pyfory.float32,
+ pyfory.float64,
+ ],
+)
+def test_numeric_serializer_need_to_write_ref_disabled(numeric_type):
+ fory = Fory(xlang=False, ref=True, strict=False)
+ serializer = fory.type_resolver.get_serializer(numeric_type)
+ assert serializer.need_to_write_ref is False
+
+
def test_data_class_serializer_xlang():
fory = Fory(xlang=True, ref=True)
fory.register_type(ComplexObject, typename="example.ComplexObject")
@@ -250,8 +307,8 @@ def test_struct_evolving_override():
assert fixed_info.type_id == TypeId.NAMED_STRUCT
-def test_data_class_serializer_xlang_codegen():
- """Test that DataClassSerializer generates write/read methods correctly in
xlang mode."""
+def test_data_class_serializer_xlang_serializer():
+ """Test DataClassSerializer round-trip behavior in xlang mode."""
fory = Fory(xlang=True, ref=True)
# Register types first
@@ -263,12 +320,6 @@ def test_data_class_serializer_xlang_codegen():
# Get the serializer that was created during registration
serializer = fory.type_resolver.get_serializer(DataClassObject)
- # Check that the generated methods exist
- assert hasattr(serializer, "_generated_write_method"), "Generated write
method should exist"
- assert hasattr(serializer, "_generated_read_method"), "Generated read
method should exist"
- assert hasattr(serializer, "_generated_write_method_code"), "Generated
write method code should exist"
- assert hasattr(serializer, "_generated_read_method_code"), "Generated read
method code should exist"
-
# Serializer API is unified: no mode-specific serializer attribute.
assert not hasattr(serializer, "_xlang")
assert hasattr(serializer, "_serializers")
@@ -287,7 +338,6 @@ def test_data_class_serializer_xlang_codegen():
)
# Test serialization and deserialization using the normal fory flow
- # This will use the generated methods internally
binary = fory.serialize(test_obj)
deserialized_obj = fory.deserialize(binary)
@@ -302,108 +352,6 @@ def test_data_class_serializer_xlang_codegen():
assert deserialized_obj.f_complex == test_obj.f_complex
-def test_data_class_serializer_xlang_codegen_with_jit():
- """Test that DataClassSerializer JIT compilation works correctly when
enabled."""
- # Save the original environment variable
- original_jit_setting = os.environ.get("ENABLE_FORY_PYTHON_JIT")
-
- try:
- # Enable JIT
- os.environ["ENABLE_FORY_PYTHON_JIT"] = "True"
-
- # Import after setting environment variable to ensure it takes effect
- import importlib
- import pyfory.serializer
-
- importlib.reload(pyfory.serializer)
-
- fory = Fory(xlang=True, ref=True)
-
- # Register types first
- fory.register_type(ComplexObject, typename="example.ComplexObject")
- fory.register_type(DataClassObject,
typename="example.TestDataClassObject")
-
- # Get the serializer that was created during registration
- serializer = fory.type_resolver.get_serializer(DataClassObject)
-
- # Check that JIT methods are assigned when JIT is enabled
- # The methods should be the generated functions, not the original
instance methods
- assert callable(serializer.write)
- assert callable(serializer.read)
-
- # Test that the JIT-compiled methods work through normal serialization
- test_obj = DataClassObject(
- f_int=123,
- f_float=45.67,
- f_str="jit_test",
- f_bool=False,
- f_list=[10, 20, 30],
- f_dict={"jit": 2.5},
- f_any={"nested": "data"},
- f_complex=None,
- )
-
- # Use normal serialization flow which will use the JIT-compiled
methods internally
- binary = fory.serialize(test_obj)
- deserialized_obj = fory.deserialize(binary)
-
- assert deserialized_obj.f_int == test_obj.f_int
- assert deserialized_obj.f_float == test_obj.f_float
- assert deserialized_obj.f_str == test_obj.f_str
- assert deserialized_obj.f_bool == test_obj.f_bool
- assert deserialized_obj.f_list == test_obj.f_list
- assert deserialized_obj.f_dict == test_obj.f_dict
- assert deserialized_obj.f_any == test_obj.f_any
- assert deserialized_obj.f_complex == test_obj.f_complex
-
- finally:
- # Restore original environment variable
- if original_jit_setting is None:
- os.environ.pop("ENABLE_FORY_PYTHON_JIT", None)
- else:
- os.environ["ENABLE_FORY_PYTHON_JIT"] = original_jit_setting
-
- # Reload to restore the original state
- importlib.reload(pyfory.serializer)
-
-
-def test_data_class_serializer_xlang_codegen_generated_code():
- """Test that the generated code contains expected elements."""
- fory = Fory(xlang=True, ref=True)
-
- # Register types first
- fory.register_type(ComplexObject, typename="example.ComplexObject")
- fory.register_type(DataClassObject, typename="example.TestDataClassObject")
-
- # trigger lazy serializer replace
- fory.serialize(DataClassObject.create())
- # Get the serializer that was created during registration
- serializer = fory.type_resolver.get_serializer(DataClassObject)
-
- # Check that generated code exists and contains expected elements
- write_code = serializer._generated_write_method_code
- read_code = serializer._generated_read_method_code
-
- assert isinstance(write_code, str)
- assert isinstance(read_code, str)
-
- # Check that write code contains expected elements
- assert "def write_" in write_code
- assert "buffer.write_int32" in write_code # Hash writing
- assert "fory.write_ref" in write_code # Field serialization
-
- # Check that read code contains expected elements
- assert "def read_" in read_code
- assert "buffer.read_int32" in read_code # Hash reading
- assert "fory.read_ref" in read_code # Field deserialization
- assert "TypeNotCompatibleError" in read_code # Hash validation
-
- # Check that field names are referenced in the code
- for field_name in serializer._field_names:
- # Field names should appear in the generated code
- assert field_name in write_code or field_name in read_code
-
-
def test_data_class_serializer_xlang_vs_non_xlang():
"""Test that xlang and non-xlang modes use the same dataclass serializer
behavior."""
fory_xlang = Fory(xlang=True, ref=True)
@@ -421,10 +369,6 @@ def test_data_class_serializer_xlang_vs_non_xlang():
assert not hasattr(serializer_xlang, "_xlang")
assert not hasattr(serializer_python, "_xlang")
- assert hasattr(serializer_xlang, "_generated_write_method")
- assert hasattr(serializer_xlang, "_generated_read_method")
- assert hasattr(serializer_python, "_generated_write_method")
- assert hasattr(serializer_python, "_generated_read_method")
# Unified serializer metadata should be mode-independent.
assert serializer_xlang._field_names == serializer_python._field_names
@@ -433,6 +377,73 @@ def test_data_class_serializer_xlang_vs_non_xlang():
assert serializer_xlang._hash == serializer_python._hash
+class MissingDefaultEnum(enum.Enum):
+ A = 1
+ B = 2
+
+
+@dataclass
+class MissingDefaultFactoryFields:
+ required: int
+ required_float: float
+ required_str: str
+ required_bytes: bytes
+ required_list: List[int]
+ required_set: Set[int]
+ required_dict: Dict[str, int]
+ plain_default: int = 7
+ list_default: List[int] = dataclasses.field(default_factory=list)
+ enum_default_none: MissingDefaultEnum = None
+
+
+def test_build_default_values_factory():
+ fory = Fory(xlang=False, ref=True, strict=False)
+ type_hints = typing.get_type_hints(MissingDefaultFactoryFields)
+ default_factories = build_default_values_factory(
+ fory,
+ type_hints,
+ dataclasses.fields(MissingDefaultFactoryFields),
+ )
+
+ assert callable(default_factories["required"])
+ assert callable(default_factories["required_float"])
+ assert callable(default_factories["required_str"])
+ assert callable(default_factories["required_bytes"])
+ assert callable(default_factories["required_list"])
+ assert callable(default_factories["required_set"])
+ assert callable(default_factories["required_dict"])
+ assert callable(default_factories["plain_default"])
+ assert callable(default_factories["list_default"])
+ assert callable(default_factories["enum_default_none"])
+
+ assert default_factories["required"]() == 0
+ assert default_factories["required_float"]() == 0.0
+ assert default_factories["required_str"]() == ""
+ assert default_factories["required_bytes"]() == b""
+ list_required_one = default_factories["required_list"]()
+ list_required_two = default_factories["required_list"]()
+ assert list_required_one == []
+ assert list_required_two == []
+ assert list_required_one is not list_required_two
+ set_required_one = default_factories["required_set"]()
+ set_required_two = default_factories["required_set"]()
+ assert set_required_one == set()
+ assert set_required_two == set()
+ assert set_required_one is not set_required_two
+ dict_required_one = default_factories["required_dict"]()
+ dict_required_two = default_factories["required_dict"]()
+ assert dict_required_one == {}
+ assert dict_required_two == {}
+ assert dict_required_one is not dict_required_two
+ assert default_factories["plain_default"]() == 7
+ assert default_factories["enum_default_none"]() is MissingDefaultEnum.A
+ list_one = default_factories["list_default"]()
+ list_two = default_factories["list_default"]()
+ assert list_one == []
+ assert list_two == []
+ assert list_one is not list_two
+
+
@dataclass
class OptionalFieldsObject:
f1: Optional[int] = None
@@ -546,6 +557,34 @@ class CompatibleV3:
f2: str = ""
+@dataclass
+class CompatibleRequiredFieldV1:
+ f1: int
+
+
+@dataclass
+class CompatibleRequiredFieldV2:
+ f1: int
+ f2: int
+
+
+@dataclass
+class CompatibleRequiredDefaultsV1:
+ f1: int
+
+
+@dataclass
+class CompatibleRequiredDefaultsV2:
+ f1: int
+ f_int: int
+ f_float: float
+ f_str: str
+ f_bytes: bytes
+ f_list: List[int]
+ f_set: Set[int]
+ f_dict: Dict[str, int]
+
+
@pytest.mark.parametrize("xlang", [False, True])
def test_compatible_mode_add_field(xlang):
"""Test that adding a field with default value works in compatible mode."""
@@ -614,6 +653,50 @@ def test_compatible_mode_bidirectional(xlang):
assert v1_result.f3 == 2.71
[email protected]("xlang", [False, True])
+def
test_compatible_mode_add_required_field_without_default_uses_zero_value(xlang):
+ fory_v1 = Fory(xlang=xlang, ref=True, compatible=True, strict=False)
+ fory_v2 = Fory(xlang=xlang, ref=True, compatible=True, strict=False)
+
+ fory_v1.register_type(CompatibleRequiredFieldV1,
typename="example.CompatibleRequiredField")
+ fory_v2.register_type(CompatibleRequiredFieldV2,
typename="example.CompatibleRequiredField")
+
+ v1_binary = fory_v1.serialize(CompatibleRequiredFieldV1(f1=321))
+ v2_result = fory_v2.deserialize(v1_binary)
+
+ assert v2_result.f1 == 321
+ assert hasattr(v2_result, "f2")
+ assert v2_result.f2 == 0
+
+ serializer_v2 =
fory_v2.type_resolver.get_serializer(CompatibleRequiredFieldV2)
+ assert hasattr(serializer_v2, "_default_values_factory")
+ assert callable(serializer_v2._default_values_factory["f2"])
+ assert serializer_v2._default_values_factory["f2"]() == 0
+ assert ser_de(fory_v2, v2_result) == v2_result
+
+
[email protected]("xlang", [False, True])
+def test_compatible_mode_add_required_fields_use_type_defaults(xlang):
+ fory_v1 = Fory(xlang=xlang, ref=True, compatible=True, strict=False)
+ fory_v2 = Fory(xlang=xlang, ref=True, compatible=True, strict=False)
+
+ fory_v1.register_type(CompatibleRequiredDefaultsV1,
typename="example.CompatibleRequiredDefaults")
+ fory_v2.register_type(CompatibleRequiredDefaultsV2,
typename="example.CompatibleRequiredDefaults")
+
+ v1_binary = fory_v1.serialize(CompatibleRequiredDefaultsV1(f1=11))
+ v2_result = fory_v2.deserialize(v1_binary)
+
+ assert v2_result.f1 == 11
+ assert v2_result.f_int == 0
+ assert v2_result.f_float == 0.0
+ assert v2_result.f_str == ""
+ assert v2_result.f_bytes == b""
+ assert v2_result.f_list == []
+ assert v2_result.f_set == set()
+ assert v2_result.f_dict == {}
+ assert ser_de(fory_v2, v2_result) == v2_result
+
+
@dataclass
class CompatibleWithOptional:
f1: Optional[int] = None
diff --git a/python/pyfory/tests/xlang_test_main.py
b/python/pyfory/tests/xlang_test_main.py
index 24a2a6369..f23810df2 100644
--- a/python/pyfory/tests/xlang_test_main.py
+++ b/python/pyfory/tests/xlang_test_main.py
@@ -824,9 +824,9 @@ def test_enum_schema_evolution_compatible_reverse():
debug_print(f"Deserialized as TwoEnumFieldStruct: {obj}")
assert isinstance(obj, TwoEnumFieldStruct), f"Expected TwoEnumFieldStruct,
got {type(obj)}"
assert obj.f1 == TestEnum.VALUE_C, f"Expected f1=VALUE_C, got f1={obj.f1}"
- # f2 should be None (missing field due to schema evolution)
+ # f2 is missing from source schema; non-nullable enum should use zero
value.
f2_value = getattr(obj, "f2", None)
- assert f2_value is None, f"Expected f2=None, got f2={f2_value}"
+ assert f2_value == TestEnum.VALUE_A, f"Expected f2=VALUE_A, got
f2={f2_value}"
new_bytes = fory.serialize(obj)
with open(data_file, "wb") as f:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]