This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 314186095 feat(python): meta share mode for pyfory compatible
serialization (#2593)
314186095 is described below
commit 3141860955fed098811a21be0424243ae412d08e
Author: Shawn Yang <[email protected]>
AuthorDate: Wed Sep 10 01:01:00 2025 +0800
feat(python): meta share mode for pyfory compatible serialization (#2593)
## Why?
implement meta share mode for pyfory, so the struct can add/delete
fields and have inconsistent schema bettween serialization and
deserialization
## What does this PR do?
<!-- Describe the details of this PR. -->
## Related issues
#2509
https://github.com/apache/fory/issues/1938
https://github.com/apache/fory/issues/2160
https://github.com/apache/fory/pull/2278
## Does this PR introduce any user-facing change?
<!--
If any user-facing interface changes, please [open an
issue](https://github.com/apache/fory/issues/new/choose) describing the
need to do so and update the document if necessary.
Delete section if not applicable.
-->
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
<!--
When the PR has an impact on performance (if you don't know whether the
PR will have an impact on performance, you can submit the PR first, and
if it will have impact on performance, the code reviewer will explain
it), be sure to attach a benchmark data here.
Delete section if not applicable.
-->
---
ci/format.sh | 10 +-
python/pyfory/_fory.py | 77 ++++----
python/pyfory/_registry.py | 130 +++++++++++-
python/pyfory/_serialization.pyx | 242 +++++++++++++++++++----
python/pyfory/_serializer.py | 50 ++---
python/pyfory/_struct.py | 3 +-
python/pyfory/format/__init__.py | 3 +-
python/pyfory/format/tests/test_encoder.py | 8 +-
python/pyfory/meta/typedef.py | 43 +++-
python/pyfory/meta/typedef_decoder.py | 41 ++--
python/pyfory/meta/typedef_encoder.py | 30 +--
python/pyfory/serializer.py | 59 ++++--
python/pyfory/tests/benchmark.py | 20 +-
python/pyfory/tests/record.py | 4 +-
python/pyfory/tests/test_buffer.py | 5 +-
python/pyfory/tests/test_codegen.py | 4 +-
python/pyfory/tests/test_meta_share.py | 284 +++++++++++++++++++++++++++
python/pyfory/tests/test_metastring.py | 4 +-
python/pyfory/tests/test_typedef_encoding.py | 5 +-
python/pyfory/type.py | 34 ++--
20 files changed, 830 insertions(+), 226 deletions(-)
diff --git a/ci/format.sh b/ci/format.sh
index b80b26a9d..fed6ab9df 100755
--- a/ci/format.sh
+++ b/ci/format.sh
@@ -125,11 +125,11 @@ format_files() {
format_all_scripts() {
echo "$(date)" "Ruff format...."
- git ls-files -- '*.py' '*.pyx' '*.pxd' '*.pxi' "${GIT_LS_EXCLUDES[@]}" |
xargs -P 10 \
+ git ls-files -- '*.py' "${GIT_LS_EXCLUDES[@]}" | xargs -P 10 \
ruff format
echo "$(date)" "Ruff check...."
- git ls-files -- '*.py' '*.pyx' '*.pxd' '*.pxi' "${GIT_LS_EXCLUDES[@]}" |
xargs \
+ git ls-files -- '*.py' "${GIT_LS_EXCLUDES[@]}" | xargs \
ruff check --fix
}
@@ -193,10 +193,10 @@ format_changed() {
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
- if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" --
'*.py' '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then
- git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' '*.pyx'
'*.pxd' '*.pxi' | xargs -P 5 \
+ if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" --
'*.py' &>/dev/null; then
+ git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs
-P 5 \
ruff format
- git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' '*.pyx'
'*.pxd' '*.pxi' | xargs -P 5 \
+ git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs
-P 5 \
ruff check --fix
fi
diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py
index 6bc87ba1a..b05b37ffc 100644
--- a/python/pyfory/_fory.py
+++ b/python/pyfory/_fory.py
@@ -98,6 +98,7 @@ class Fory:
__slots__ = (
"language",
"is_py",
+ "compatible",
"ref_tracking",
"ref_resolver",
"type_resolver",
@@ -113,13 +114,13 @@ class Fory:
"_unsupported_objects",
"_peer_language",
)
- serialization_context: "SerializationContext"
def __init__(
self,
language=Language.PYTHON,
ref_tracking: bool = False,
require_type_registration: bool = True,
+ compatible: bool = False,
):
"""
:param require_type_registration:
@@ -130,10 +131,14 @@ class Fory:
Do not disable type registration if you can't ensure your
environment are
*indeed secure*. We are not responsible for security risks if
you disable this option.
+ :param compatible:
+ Whether to enable compatible mode for cross-language serialization.
+ When enabled, type forward/backward compatibility for struct fields
will be enabled.
"""
self.language = language
self.is_py = language == Language.PYTHON
self.require_type_registration = _ENABLE_TYPE_REGISTRATION_FORCIBLY or
require_type_registration
+ self.compatible = compatible
self.ref_tracking = ref_tracking
if self.ref_tracking:
self.ref_resolver = MapRefResolver()
@@ -143,9 +148,11 @@ class Fory:
from pyfory._registry import TypeResolver
self.metastring_resolver = MetaStringResolver()
- self.type_resolver = TypeResolver(self)
+ self.type_resolver = TypeResolver(self, meta_share=compatible)
self.type_resolver.initialize()
- self.serialization_context = SerializationContext()
+ from pyfory._serialization import SerializationContext
+
+ self.serialization_context =
SerializationContext(scoped_meta_share_enabled=compatible)
self.buffer = Buffer.allocate(32)
if not require_type_registration:
warnings.warn(
@@ -255,10 +262,26 @@ class Fory:
set_bit(buffer, mask_index, 3)
else:
clear_bit(buffer, mask_index, 3)
+ # Reserve space for type definitions offset, similar to Java
implementation
+ type_defs_offset_pos = None
+ if self.serialization_context.scoped_meta_share_enabled:
+ type_defs_offset_pos = buffer.writer_index
+ buffer.write_int32(-1) # Reserve 4 bytes for type definitions
offset
+
if self.language == Language.PYTHON:
self.serialize_ref(buffer, obj)
else:
self.xserialize_ref(buffer, obj)
+
+ # Write type definitions at the end, similar to Java implementation
+ if self.serialization_context.scoped_meta_share_enabled:
+ meta_context = self.serialization_context.meta_context
+ if meta_context is not None and
len(meta_context.get_writing_type_defs()) > 0:
+ # Update the offset to point to current position
+ current_pos = buffer.writer_index
+ buffer.put_int32(type_defs_offset_pos, current_pos -
type_defs_offset_pos - 4)
+ self.type_resolver.write_type_defs(buffer)
+
self.reset_write()
if buffer is not self.buffer:
return buffer
@@ -369,6 +392,20 @@ class Fory:
self._buffers = iter(buffers)
else:
assert buffers is None, "buffers should be null when the
serialized stream is produced with buffer_callback null."
+
+ # Read type definitions at the start, similar to Java implementation
+ if self.serialization_context.scoped_meta_share_enabled:
+ relative_type_defs_offset = buffer.read_int32()
+ if relative_type_defs_offset != -1:
+ # Save current reader position
+ current_reader_index = buffer.reader_index
+ # Jump to type definitions
+ buffer.reader_index = current_reader_index +
relative_type_defs_offset
+ # Read type definitions
+ self.type_resolver.read_type_defs(buffer)
+ # Jump back to continue with object deserialization
+ buffer.reader_index = current_reader_index
+
if is_target_x_lang:
obj = self.xdeserialize_ref(buffer)
else:
@@ -470,7 +507,7 @@ class Fory:
def reset_write(self):
self.ref_resolver.reset_write()
self.type_resolver.reset_write()
- self.serialization_context.reset()
+ self.serialization_context.reset_write()
self.metastring_resolver.reset_write()
self.pickler.clear_memo()
self._buffer_callback = None
@@ -479,7 +516,7 @@ class Fory:
def reset_read(self):
self.ref_resolver.reset_read()
self.type_resolver.reset_read()
- self.serialization_context.reset()
+ self.serialization_context.reset_read()
self.metastring_resolver.reset_write()
self.unpickler = None
self._buffers = None
@@ -490,36 +527,6 @@ class Fory:
self.reset_read()
-class SerializationContext:
- """
- A context is used to add some context-related information, so that the
- serializers can setup relation between serializing different objects.
- The context will be reset after finished serializing/deserializing the
- object tree.
- """
-
- __slots__ = ("objects",)
-
- def __init__(self):
- self.objects = dict()
-
- def add(self, key, obj):
- self.objects[id(key)] = obj
-
- def __contains__(self, key):
- return id(key) in self.objects
-
- def __getitem__(self, key):
- return self.objects[id(key)]
-
- def get(self, key):
- return self.objects.get(id(key))
-
- def reset(self):
- if len(self.objects) > 0:
- self.objects.clear()
-
-
_ENABLE_TYPE_REGISTRATION_FORCIBLY =
os.getenv("ENABLE_TYPE_REGISTRATION_FORCIBLY", "0") in {
"1",
"true",
diff --git a/python/pyfory/_registry.py b/python/pyfory/_registry.py
index 2333ff782..8b421f036 100644
--- a/python/pyfory/_registry.py
+++ b/python/pyfory/_registry.py
@@ -76,12 +76,16 @@ from pyfory.type import (
Float32Type,
Float64Type,
load_class,
+ is_struct_type,
)
from pyfory._fory import (
DYNAMIC_TYPE_ID,
# preserve 0 as flag for type id not set in TypeInfo`
NO_TYPE_ID,
)
+from pyfory.meta.typedef import TypeDef
+from pyfory.meta.typedef_decoder import decode_typedef, skip_typedef
+from pyfory.meta.typedef_encoder import encode_typedef
try:
import numpy as np
@@ -104,6 +108,7 @@ else:
"namespace_bytes",
"typename_bytes",
"dynamic_type",
+ "type_def",
)
def __init__(
@@ -114,6 +119,7 @@ else:
namespace_bytes=None,
typename_bytes=None,
dynamic_type: bool = False,
+ type_def: TypeDef = None,
):
self.cls = cls
self.type_id = type_id
@@ -121,6 +127,7 @@ else:
self.namespace_bytes = namespace_bytes
self.typename_bytes = typename_bytes
self.dynamic_type = dynamic_type
+ self.type_def = None
def __repr__(self):
return f"TypeInfo(cls={self.cls}, type_id={self.type_id},
serializer={self.serializer})"
@@ -160,9 +167,11 @@ class TypeResolver:
"metastring_resolver",
"language",
"_type_id_to_typeinfo",
+ "_meta_shared_typeinfo",
+ "meta_share",
)
- def __init__(self, fory):
+ def __init__(self, fory, meta_share=False):
self.fory = fory
self.metastring_resolver = fory.metastring_resolver
self.language = fory.language
@@ -182,9 +191,12 @@ class TypeResolver:
self._named_type_to_typeinfo = dict()
self.namespace_encoder = MetaStringEncoder(".", "_")
self.namespace_decoder = MetaStringDecoder(".", "_")
+ # Cache for TypeDef and TypeInfo tuples (similar to Java's
classIdToDef)
+ self._meta_shared_typeinfo = {}
self.typename_encoder = MetaStringEncoder("$", "_")
self.typename_decoder = MetaStringDecoder("$", "_")
self.meta_compressor = DeflaterMetaCompressor()
+ self.meta_share = meta_share
def initialize(self):
self._initialize_xlang()
@@ -356,7 +368,7 @@ class TypeResolver:
serializer = FunctionSerializer(self.fory, cls)
type_id = TypeId.NAMED_EXT if type_id is None else ((type_id
<< 8) + TypeId.EXT)
else:
- serializer = DataClassSerializer(self.fory, cls, xlang=True)
+ serializer = None
type_id = TypeId.NAMED_STRUCT if type_id is None else
((type_id << 8) + TypeId.STRUCT)
elif not internal:
type_id = TypeId.NAMED_EXT if type_id is None else ((type_id << 8)
+ TypeId.EXT)
@@ -460,7 +472,7 @@ class TypeResolver:
type_info = self._types_info.get(cls)
if type_info is not None:
if type_info.serializer is None:
- type_info.serializer = self._create_serializer(cls)
+ self._set_typeinfo(type_info)
return type_info
elif not create:
return None
@@ -491,6 +503,20 @@ class TypeResolver:
serializer=serializer,
)
+ def _set_typeinfo(self, typeinfo):
+ type_id = typeinfo.type_id & 0xFF
+ if is_struct_type(type_id):
+ if self.meta_share:
+ type_def = encode_typedef(self, typeinfo.cls)
+ typeinfo.serializer = type_def.create_serializer(self)
+ typeinfo.type_def = type_def
+ else:
+ typeinfo.serializer = DataClassSerializer(self.fory,
typeinfo.cls, xlang=not self.fory.is_py)
+ else:
+ typeinfo.serializer = self._create_serializer(typeinfo.cls)
+
+ return typeinfo
+
def _create_serializer(self, cls):
for clz in cls.__mro__:
type_info = self._types_info.get(clz)
@@ -502,7 +528,11 @@ class TypeResolver:
# Use FunctionSerializer for function types (including lambdas)
serializer = FunctionSerializer(self.fory, cls)
elif dataclasses.is_dataclass(cls):
- serializer = DataClassSerializer(self.fory, cls)
+ if not self.meta_share:
+ serializer = DataClassSerializer(self.fory, cls, xlang=not
self.fory.is_py)
+ else:
+ # lazy create serializer to handle nested struct fields.
+ serializer = None
elif issubclass(cls, enum.Enum):
serializer = EnumSerializer(self.fory, cls)
elif (hasattr(cls, "__reduce__") and cls.__reduce__ is not
object.__reduce__) or (
@@ -536,6 +566,28 @@ class TypeResolver:
serializer = PickleSerializer(self.fory, cls)
return serializer
+ def is_registered_by_name(self, cls):
+ typeinfo = self._types_info.get(cls)
+ if typeinfo is None:
+ return False
+ return TypeId.is_namespaced_type(typeinfo.type_id & 0xFF)
+
+ def is_registered_by_id(self, cls):
+ typeinfo = self._types_info.get(cls)
+ if typeinfo is None:
+ return False
+ return not TypeId.is_namespaced_type(typeinfo.type_id & 0xFF)
+
+ def get_registered_name(self, cls):
+ typeinfo = self._types_info.get(cls)
+ assert typeinfo is not None, f"{cls} not registered"
+ return typeinfo.decode_namespace(), typeinfo.decode_typename()
+
+ def get_registered_id(self, cls):
+ typeinfo = self._types_info.get(cls)
+ assert typeinfo is not None, f"{cls} not registered"
+ return typeinfo.type_id
+
def _load_metabytes_to_typeinfo(self, ns_metabytes, type_metabytes):
typeinfo = self._ns_type_to_typeinfo.get((ns_metabytes,
type_metabytes))
if typeinfo is not None:
@@ -557,12 +609,22 @@ class TypeResolver:
return
type_id = typeinfo.type_id
internal_type_id = type_id & 0xFF
+
+ # Check if meta share is enabled first
+ if self.meta_share:
+ self.write_shared_type_meta(buffer, typeinfo)
+ return
+
buffer.write_varuint32(type_id)
if TypeId.is_namespaced_type(internal_type_id):
self.metastring_resolver.write_meta_string_bytes(buffer,
typeinfo.namespace_bytes)
self.metastring_resolver.write_meta_string_bytes(buffer,
typeinfo.typename_bytes)
def read_typeinfo(self, buffer):
+ # Check if meta share is enabled first
+ if self.meta_share:
+ return self.read_shared_type_meta(buffer)
+
type_id = buffer.read_varuint32()
internal_type_id = type_id & 0xFF
if TypeId.is_namespaced_type(internal_type_id):
@@ -595,6 +657,66 @@ class TypeResolver:
def get_meta_compressor(self):
return self.meta_compressor
+ def write_shared_type_meta(self, buffer, typeinfo):
+ """Write shared type meta information."""
+ assert typeinfo.type_def is not None, "Type info must be set when meta
share is enabled"
+ meta_context = self.fory.serialization_context.meta_context
+ meta_context.write_typeinfo(buffer, typeinfo)
+
+ def read_shared_type_meta(self, buffer):
+ """Read shared type meta information."""
+ meta_context = self.fory.serialization_context.meta_context
+ assert meta_context is not None, "Meta context must be set when meta
share is enabled"
+ type_id = buffer.read_varuint32()
+ typeinfo = meta_context.get_read_type_info(type_id)
+ assert typeinfo is not None, f"Type info not found for ID {type_id}"
+ return typeinfo
+
+ def write_type_defs(self, buffer):
+ """Write all type definitions that need to be sent."""
+ meta_context = self.fory.serialization_context.meta_context
+ if meta_context is None:
+ return
+ writing_type_defs = meta_context.get_writing_type_defs()
+ buffer.write_varuint32(len(writing_type_defs))
+ for type_def in writing_type_defs:
+ # Just copy the encoded bytes directly
+ buffer.write_bytes(type_def.encoded)
+
+ def read_type_defs(self, buffer):
+ """Read all type definitions from the buffer."""
+ meta_context = self.fory.serialization_context.meta_context
+ if meta_context is None:
+ return
+
+ num_type_defs = buffer.read_varuint32()
+ for i in range(num_type_defs):
+ # Read the header (first 8 bytes) to get the type ID
+ header = buffer.read_int64()
+ # Check if we already have this TypeDef cached
+ type_info = self._meta_shared_typeinfo.get(header)
+ if type_info is not None:
+ # Skip the rest of the TypeDef binary for faster performance
+ skip_typedef(buffer, header)
+ else:
+ # Read the TypeDef and create TypeInfo
+ type_def = decode_typedef(buffer, self, header=header)
+ type_info = self._build_type_info_from_typedef(type_def)
+ # Cache the tuple for future use
+ self._meta_shared_typeinfo[header] = type_info
+ meta_context.add_read_type_info(type_info)
+
+ def _build_type_info_from_typedef(self, type_def):
+ """Build TypeInfo from TypeDef using TypeDef's create_serializer
method."""
+ # Create serializer using TypeDef's create_serializer method
+ serializer = type_def.create_serializer(self)
+ ns_metastr = self.namespace_encoder.encode(type_def.namespace or "")
+ ns_meta_bytes = self.metastring_resolver.get_metastr_bytes(ns_metastr)
+ type_metastr = self.typename_encoder.encode(type_def.typename)
+ type_meta_bytes =
self.metastring_resolver.get_metastr_bytes(type_metastr)
+ typeinfo = TypeInfo(type_def.cls, type_def.type_id, serializer,
ns_meta_bytes, type_meta_bytes, False, type_def)
+ return typeinfo
+
def reset(self):
pass
diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx
index 315333a70..458eb640f 100644
--- a/python/pyfory/_serialization.pyx
+++ b/python/pyfory/_serialization.pyx
@@ -399,6 +399,7 @@ cdef class TypeInfo:
cdef public MetaStringBytes namespace_bytes
cdef public MetaStringBytes typename_bytes
cdef public c_bool dynamic_type
+ cdef public object type_def
def __init__(
self,
@@ -408,6 +409,7 @@ cdef class TypeInfo:
namespace_bytes: MetaStringBytes = None,
typename_bytes: MetaStringBytes = None,
dynamic_type: bool = False,
+ type_def: object = None
):
self.cls = cls
self.type_id = type_id
@@ -415,6 +417,7 @@ cdef class TypeInfo:
self.namespace_bytes = namespace_bytes
self.typename_bytes = typename_bytes
self.dynamic_type = dynamic_type
+ self.type_def = type_def
def __repr__(self):
return f"TypeInfo(cls={self.cls}, type_id={self.type_id}, " \
@@ -443,12 +446,16 @@ cdef class TypeResolver:
# hash -> TypeInfo
flat_hash_map[pair[int64_t, int64_t], PyObject *]
_c_meta_hash_to_typeinfo
MetaStringResolver meta_string_resolver
+ c_bool meta_share
+ SerializationContext serialization_context
- def __init__(self, fory):
+ def __init__(self, fory, meta_share=False):
self.fory = fory
self.metastring_resolver = fory.metastring_resolver
+ self.meta_share = meta_share
from pyfory._registry import TypeResolver
- self._resolver = TypeResolver(fory)
+ self._resolver = TypeResolver(fory, meta_share=meta_share)
+ self.serialization_context = fory.serialization_context
def initialize(self):
self._resolver.initialize()
@@ -518,7 +525,7 @@ cdef class TypeResolver:
if type_info.serializer is not None:
return type_info
else:
- type_info.serializer = self._resolver._create_serializer(cls)
+ type_info.serializer =
self._resolver.get_typeinfo(cls).serializer
return type_info
elif not create:
return None
@@ -527,6 +534,18 @@ cdef class TypeResolver:
self._c_types_info[<uintptr_t> <PyObject *> cls] = <PyObject *>
type_info
self._populate_typeinfo(type_info)
return type_info
+
+ def is_registered_by_name(self, cls):
+ return self._resolver.is_registered_by_name(cls)
+
+ def is_registered_by_id(self, cls):
+ return self._resolver.is_registered_by_id(cls)
+
+ def get_registered_name(self, cls):
+ return self._resolver.get_registered_name(cls)
+
+ def get_registered_id(self, cls):
+ return self._resolver.get_registered_id(cls)
cdef inline TypeInfo _load_bytes_to_typeinfo(
self, int32_t type_id, MetaStringBytes ns_metabytes,
MetaStringBytes type_metabytes):
@@ -546,12 +565,20 @@ cdef class TypeResolver:
cdef:
int32_t type_id = typeinfo.type_id
int32_t internal_type_id = type_id & 0xFF
+
+ if self.meta_share:
+ self.write_shared_type_meta(buffer, typeinfo)
+ return
+
buffer.write_varuint32(type_id)
if IsNamespacedType(internal_type_id):
self.metastring_resolver.write_meta_string_bytes(buffer,
typeinfo.namespace_bytes)
self.metastring_resolver.write_meta_string_bytes(buffer,
typeinfo.typename_bytes)
cpdef inline TypeInfo read_typeinfo(self, Buffer buffer):
+ if self.meta_share:
+ return self.read_shared_type_meta(buffer)
+
cdef:
int32_t type_id = buffer.read_varuint32()
if type_id < 0:
@@ -580,6 +607,29 @@ cdef class TypeResolver:
def get_meta_compressor(self):
return self._resolver.get_meta_compressor()
+ cpdef write_shared_type_meta(self, Buffer buffer, TypeInfo typeinfo):
+ """Write shared type meta information."""
+ meta_context = self.serialization_context.meta_context
+ assert meta_context is not None, "Meta context must be set when meta
share is enabled"
+ meta_context.write_typeinfo(buffer, typeinfo)
+
+ cpdef TypeInfo read_shared_type_meta(self, Buffer buffer):
+ """Read shared type meta information."""
+ meta_context = self.serialization_context.meta_context
+ assert meta_context is not None, "Meta context must be set when meta
share is enabled"
+ type_id = buffer.read_varuint32()
+ typeinfo = meta_context.get_read_type_info(type_id)
+ assert typeinfo is not None, f"Type info not found for ID {type_id}"
+ return typeinfo
+
+ cpdef write_type_defs(self, Buffer buffer):
+ """Write all type definitions that need to be sent."""
+ self._resolver.write_type_defs(buffer)
+
+ cpdef read_type_defs(self, Buffer buffer):
+ """Read all type definitions from the buffer."""
+ self._resolver.read_type_defs(buffer)
+
cpdef inline reset(self):
pass
@@ -590,12 +640,124 @@ cdef class TypeResolver:
pass
[email protected]
+cdef class MetaContext:
+ """
+ Context for sharing type meta across multiple serialization. Type name,
field name and field
+ type will be shared between different serialization.
+
+ This is the Cython-optimized equivalent of Java's MetaContext class.
+ """
+ cdef:
+ # Types which have sent definitions to peer
+ # Maps type objects to their assigned IDs
+ flat_hash_map[uint64_t, int32_t] _c_type_map
+
+ # Counter for assigning new IDs
+ list _writing_type_defs
+ list _read_type_infos
+
+ def __cinit__(self):
+ self._writing_type_defs = []
+ self._read_type_infos = []
+
+ cpdef inline int32_t write_typeinfo(self, Buffer buffer, typeinfo):
+ """Add a type definition to the writing queue."""
+ type_cls = typeinfo.cls
+ cdef uint64_t type_addr = <uint64_t> <PyObject *> type_cls
+ cdef flat_hash_map[uint64_t, int32_t].iterator it =
self._c_type_map.find(type_addr)
+ if it != self._c_type_map.end():
+ buffer.write_varuint32(deref(it).second)
+
+ cdef index = self._c_type_map.size()
+ buffer.write_varuint32(index)
+ self._c_type_map[type_addr] = index
+ type_def = typeinfo.type_def
+ self._writing_type_defs.append(type_def)
+
+ cpdef inline list get_writing_type_defs(self):
+ """Get all type definitions that need to be written."""
+ return self._writing_type_defs
+
+ cpdef inline reset_write(self):
+ """Reset write state."""
+ self._writing_type_defs.clear()
+ self._c_type_map.clear()
+
+ cpdef inline add_read_type_info(self, type_info):
+ """Add a type info read from peer."""
+ self._read_type_infos.append(type_info)
+
+ cpdef inline get_read_type_info(self, int32_t index):
+ """Get a type info by index."""
+ return self._read_type_infos[index]
+
+ cpdef inline reset_read(self):
+ """Reset read state."""
+ self._read_type_infos.clear()
+
+ cpdef inline reset(self):
+ """Reset both read and write state."""
+ self.reset_write()
+ self.reset_read()
+
+ def __repr__(self):
+ return (f"MetaContext("
+ f"read_defs={len(self._read_type_defs)}, "
+ f"read_infos={len(self._read_type_infos)}, "
+ f"writing_defs={len(self._writing_type_defs)})")
+
+
[email protected]
+cdef class SerializationContext:
+ cdef dict objects
+ cdef readonly bint scoped_meta_share_enabled
+ cdef public object meta_context
+
+ def __init__(self, scoped_meta_share_enabled: bool = False):
+ self.objects = dict()
+ self.scoped_meta_share_enabled = scoped_meta_share_enabled
+ if scoped_meta_share_enabled:
+ self.meta_context = MetaContext()
+ else:
+ self.meta_context = None
+
+ def add(self, key, obj):
+ self.objects[id(key)] = obj
+
+ def __contains__(self, key):
+ return id(key) in self.objects
+
+ def __getitem__(self, key):
+ return self.objects[id(key)]
+
+ def get(self, key):
+ return self.objects.get(id(key))
+
+ cpdef reset(self):
+ if len(self.objects) > 0:
+ self.objects.clear()
+
+ cpdef reset_write(self):
+ if len(self.objects) > 0:
+ self.objects.clear()
+ if self.scoped_meta_share_enabled and self.meta_context is not None:
+ self.meta_context.reset_write()
+
+ cpdef reset_read(self):
+ if len(self.objects) > 0:
+ self.objects.clear()
+ if self.scoped_meta_share_enabled and self.meta_context is not None:
+ self.meta_context.reset_read()
+
+
@cython.final
cdef class Fory:
cdef readonly object language
cdef readonly c_bool ref_tracking
cdef readonly c_bool require_type_registration
cdef readonly c_bool is_py
+ cdef readonly c_bool compatible
cdef readonly MapRefResolver ref_resolver
cdef readonly TypeResolver type_resolver
cdef readonly MetaStringResolver metastring_resolver
@@ -614,6 +776,7 @@ cdef class Fory:
language=Language.PYTHON,
ref_tracking: bool = False,
require_type_registration: bool = True,
+ compatible: bool = False,
):
"""
:param require_type_registration:
@@ -621,22 +784,26 @@ cdef class Fory:
If disabled, unknown insecure types can be deserialized, which can be
insecure and cause remote code execution attack if the types
`__new__`/`__init__`/`__eq__`/`__hash__` method contain malicious
code.
- Do not disable type registration if you can't ensure your environment
are
- *indeed secure*. We are not responsible for security risks if
- you disable this option.
- """
+ Do not disable type registration if you can't ensure your
environment are
+ *indeed secure*. We are not responsible for security risks if
+ you disable this option.
+ :param compatible:
+ Whether to enable compatible mode for cross-language serialization.
+ When enabled, type forward/backward compatibility for struct fields
will be enabled.
+ """
self.language = language
if _ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration:
self.require_type_registration = True
else:
self.require_type_registration = False
+ self.compatible = compatible
self.ref_tracking = ref_tracking
self.ref_resolver = MapRefResolver(ref_tracking)
self.is_py = self.language == Language.PYTHON
self.metastring_resolver = MetaStringResolver()
- self.type_resolver = TypeResolver(self)
+ self.serialization_context =
SerializationContext(scoped_meta_share_enabled=compatible)
+ self.type_resolver = TypeResolver(self, meta_share=compatible)
self.type_resolver.initialize()
- self.serialization_context = SerializationContext()
self.buffer = Buffer.allocate(32)
if not require_type_registration:
warnings.warn(
@@ -735,11 +902,27 @@ cdef class Fory:
set_bit(buffer, mask_index, 3)
else:
clear_bit(buffer, mask_index, 3)
+ # Reserve space for type definitions offset, similar to Java
implementation
+ cdef int32_t type_defs_offset_pos = -1
+ if self.serialization_context.scoped_meta_share_enabled:
+ type_defs_offset_pos = buffer.writer_index
+ buffer.write_int32(-1) # Reserve 4 bytes for type definitions
offset
+
cdef int32_t start_offset
if self.language == Language.PYTHON:
self.serialize_ref(buffer, obj)
else:
self.xserialize_ref(buffer, obj)
+
+ # Write type definitions at the end, similar to Java implementation
+ if self.serialization_context.scoped_meta_share_enabled:
+ meta_context = self.serialization_context.meta_context
+ if meta_context is not None and
len(meta_context.get_writing_type_defs()) > 0:
+ # Update the offset to point to current position
+ current_pos = buffer.writer_index
+ buffer.put_int32(type_defs_offset_pos, current_pos -
type_defs_offset_pos - 4)
+ self.type_resolver.write_type_defs(buffer)
+
if buffer is not self.buffer:
return buffer
else:
@@ -870,6 +1053,20 @@ cdef class Fory:
"buffers should be null when the serialized stream is "
"produced with buffer_callback null."
)
+
+ # Read type definitions at the start, similar to Java implementation
+ if self.serialization_context.scoped_meta_share_enabled:
+ relative_type_defs_offset = buffer.read_int32()
+ if relative_type_defs_offset != -1:
+ # Save current reader position
+ current_reader_index = buffer.reader_index
+ # Jump to type definitions
+ buffer.reader_index = current_reader_index +
relative_type_defs_offset
+ # Read type definitions
+ self.type_resolver.read_type_defs(buffer)
+ # Jump back to continue with object deserialization
+ buffer.reader_index = current_reader_index
+
if not is_target_x_lang:
return self.deserialize_ref(buffer)
return self.xdeserialize_ref(buffer)
@@ -1001,7 +1198,7 @@ cdef class Fory:
self.ref_resolver.reset_write()
self.type_resolver.reset_write()
self.metastring_resolver.reset_write()
- self.serialization_context.reset()
+ self.serialization_context.reset_write()
self.pickler.clear_memo()
self._unsupported_callback = None
@@ -1009,7 +1206,7 @@ cdef class Fory:
self.ref_resolver.reset_read()
self.type_resolver.reset_read()
self.metastring_resolver.reset_read()
- self.serialization_context.reset()
+ self.serialization_context.reset_read()
self._buffers = None
self.unpickler = None
self._unsupported_objects = None
@@ -1071,29 +1268,6 @@ cpdef inline read_nullable_pystr(Buffer buffer):
return None
[email protected]
-cdef class SerializationContext:
- cdef dict objects
-
- def __init__(self):
- self.objects = dict()
-
- def add(self, key, obj):
- self.objects[id(key)] = obj
-
- def __contains__(self, key):
- return id(key) in self.objects
-
- def __getitem__(self, key):
- return self.objects[id(key)]
-
- def get(self, key):
- return self.objects.get(id(key))
-
- def reset(self):
- if len(self.objects) > 0:
- self.objects.clear()
-
cdef class Serializer:
cdef readonly Fory fory
cdef readonly object type_
diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py
index dacdae670..a12b8c096 100644
--- a/python/pyfory/_serializer.py
+++ b/python/pyfory/_serializer.py
@@ -53,15 +53,11 @@ KV_NULL = KEY_HAS_NULL | VALUE_HAS_NULL
# Key is null, value type is declared type, and ref tracking for value is
disabled.
NULL_KEY_VALUE_DECL_TYPE = KEY_HAS_NULL | VALUE_DECL_TYPE
# Key is null, value type is declared type, and ref tracking for value is
enabled.
-NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF = (
- KEY_HAS_NULL | VALUE_DECL_TYPE | TRACKING_VALUE_REF
-)
+NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF = KEY_HAS_NULL | VALUE_DECL_TYPE |
TRACKING_VALUE_REF
# Value is null, key type is declared type, and ref tracking for key is
disabled.
NULL_VALUE_KEY_DECL_TYPE = VALUE_HAS_NULL | KEY_DECL_TYPE
# Value is null, key type is declared type, and ref tracking for key is
enabled.
-NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = (
- VALUE_HAS_NULL | KEY_DECL_TYPE | TRACKING_VALUE_REF
-)
+NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = VALUE_HAS_NULL | KEY_DECL_TYPE |
TRACKING_VALUE_REF
class Serializer(ABC):
@@ -182,11 +178,7 @@ _base_date = datetime.date(1970, 1, 1)
class DateSerializer(CrossLanguageCompatibleSerializer):
def write(self, buffer, value: datetime.date):
if not isinstance(value, datetime.date):
- raise TypeError(
- "{} should be {} instead of {}".format(
- value, datetime.date, type(value)
- )
- )
+ raise TypeError("{} should be {} instead of {}".format(value,
datetime.date, type(value)))
days = (value - _base_date).days
buffer.write_int32(days)
@@ -208,9 +200,7 @@ class
TimestampSerializer(CrossLanguageCompatibleSerializer):
def write(self, buffer, value: datetime.datetime):
if not isinstance(value, datetime.datetime):
- raise TypeError(
- "{} should be {} instead of {}".format(value, datetime,
type(value))
- )
+ raise TypeError("{} should be {} instead of {}".format(value,
datetime, type(value)))
# TimestampType represent micro seconds
buffer.write_int64(self._get_timestamp(value))
@@ -287,10 +277,7 @@ class CollectionSerializer(Serializer):
collect_flag |= COLLECTION_TRACKING_REF
buffer.write_varuint32(len(value))
buffer.write_int8(collect_flag)
- if (
- not has_different_type
- and (collect_flag & COLLECTION_NOT_DECL_ELEMENT_TYPE) != 0
- ):
+ if not has_different_type and (collect_flag &
COLLECTION_NOT_DECL_ELEMENT_TYPE) != 0:
self.type_resolver.write_typeinfo(buffer, elem_typeinfo)
return collect_flag, elem_typeinfo
@@ -385,9 +372,7 @@ class CollectionSerializer(Serializer):
for _ in range(len_):
self._add_element(
collection_,
- get_next_element(
- buffer, self.ref_resolver, self.type_resolver, self.is_py
- ),
+ get_next_element(buffer, self.ref_resolver,
self.type_resolver, self.is_py),
)
def xwrite(self, buffer, value):
@@ -532,12 +517,8 @@ class MapSerializer(Serializer):
type_resolver.write_typeinfo(buffer, value_typeinfo)
value_serializer = value_typeinfo.serializer
- key_write_ref = (
- key_serializer.need_to_write_ref if key_serializer else False
- )
- value_write_ref = (
- value_serializer.need_to_write_ref if value_serializer else
False
- )
+ key_write_ref = key_serializer.need_to_write_ref if key_serializer
else False
+ value_write_ref = value_serializer.need_to_write_ref if
value_serializer else False
if key_write_ref:
chunk_header |= TRACKING_KEY_REF
if value_write_ref:
@@ -547,18 +528,11 @@ class MapSerializer(Serializer):
chunk_size = 0
while chunk_size < MAX_CHUNK_SIZE:
- if (
- key is None
- or value is None
- or type(key) is not key_cls
- or type(value) is not value_cls
- ):
+ if key is None or value is None or type(key) is not key_cls or
type(value) is not value_cls:
break
if not key_write_ref or not
ref_resolver.write_ref_or_null(buffer, key):
self._write_obj(key_serializer, buffer, key)
- if not value_write_ref or not ref_resolver.write_ref_or_null(
- buffer, value
- ):
+ if not value_write_ref or not
ref_resolver.write_ref_or_null(buffer, value):
value_serializer.write(buffer, value)
chunk_size += 1
@@ -583,9 +557,7 @@ class MapSerializer(Serializer):
if size != 0:
chunk_header = buffer.read_uint8()
key_serializer, value_serializer = self.key_serializer,
self.value_serializer
- deserialize_ref = (
- fory.deserialize_ref if self.fory.is_py else fory.xdeserialize_ref
- )
+ deserialize_ref = fory.deserialize_ref if self.fory.is_py else
fory.xdeserialize_ref
while size > 0:
while True:
key_has_null = (chunk_header & KEY_HAS_NULL) != 0
diff --git a/python/pyfory/_struct.py b/python/pyfory/_struct.py
index 4e251c1c8..ef1cdc68a 100644
--- a/python/pyfory/_struct.py
+++ b/python/pyfory/_struct.py
@@ -243,7 +243,8 @@ class StructTypeIdVisitor(TypeVisitor):
return TypeId.MAP, key_ids, value_ids
def visit_customized(self, field_name, type_, types_path=None):
- return None, None
+ typeinfo = self.fory.type_resolver.get_typeinfo(type_)
+ return [typeinfo.type_id]
def visit_other(self, field_name, type_, types_path=None):
from pyfory.serializer import PickleSerializer # Local import
diff --git a/python/pyfory/format/__init__.py b/python/pyfory/format/__init__.py
index 3bc70502f..f6fd1d8f5 100644
--- a/python/pyfory/format/__init__.py
+++ b/python/pyfory/format/__init__.py
@@ -41,8 +41,7 @@ try:
)
except (ImportError, AttributeError) as e:
warnings.warn(
- f"Fory format initialization failed, please ensure pyarrow is
installed "
- f"with version which fory is compiled with: {e}",
+ f"Fory format initialization failed, please ensure pyarrow is
installed with version which fory is compiled with: {e}",
RuntimeWarning,
stacklevel=2,
)
diff --git a/python/pyfory/format/tests/test_encoder.py
b/python/pyfory/format/tests/test_encoder.py
index ac9dbdfdc..b4b01fc1c 100644
--- a/python/pyfory/format/tests/test_encoder.py
+++ b/python/pyfory/format/tests/test_encoder.py
@@ -63,9 +63,7 @@ def test_encoder_with_schema():
@require_pyarrow
def test_dict():
dict_ = {"f1": 1, "f2": "str"}
- encoder = pyfory.create_row_encoder(
- pa.schema([("f1", pa.int32()), ("f2", pa.utf8())])
- )
+ encoder = pyfory.create_row_encoder(pa.schema([("f1", pa.int32()), ("f2",
pa.utf8())]))
row = encoder.to_row(dict_)
new_obj = encoder.from_row(row)
assert new_obj.f1 == dict_["f1"]
@@ -74,9 +72,7 @@ def test_dict():
@require_pyarrow
def test_ints():
- cls = pyfory.record_class_factory(
- "TestNumeric", ["f" + str(i) for i in range(1, 9)]
- )
+ cls = pyfory.record_class_factory("TestNumeric", ["f" + str(i) for i in
range(1, 9)])
schema = pa.schema(
[
("f1", pa.int64()),
diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py
index dd786e990..1f054c156 100644
--- a/python/pyfory/meta/typedef.py
+++ b/python/pyfory/meta/typedef.py
@@ -19,9 +19,7 @@ from typing import List
import typing
from pyfory.type import TypeId
from pyfory._util import Buffer
-from pyfory.serializer import MapSerializer, ListSerializer, SetSerializer
-from pyfory._struct import _sort_fields, StructTypeIdVisitor, get_field_names
-from pyfory.type import TypeId, infer_field, is_primitive_type,
is_polymorphic_type
+from pyfory.type import infer_field, is_primitive_type, is_polymorphic_type,
is_struct_type
from pyfory.meta.metastring import Encoding
@@ -43,8 +41,12 @@ FIELD_NAME_ENCODINGS = [Encoding.UTF_8,
Encoding.LOWER_UPPER_DIGIT_SPECIAL, Enco
class TypeDef:
- def __init__(self, name: str, type_id: int, fields: List["FieldInfo"],
encoded: bytes = None, is_compressed: bool = False):
- self.name = name
+ def __init__(
+ self, namespace: str, typename: str, cls: type, type_id: int, fields:
List["FieldInfo"], encoded: bytes = None, is_compressed: bool = False
+ ):
+ self.namespace = namespace
+ self.typename = typename
+ self.cls = cls
self.type_id = type_id
self.fields = fields
self.encoded = encoded
@@ -54,8 +56,19 @@ class TypeDef:
serializers = [field_info.field_type.create_serializer(resolver) for
field_info in self.fields]
return serializers
+ def get_field_names(self):
+ return [field_info.name for field_info in self.fields]
+
+ def create_serializer(self, resolver):
+ from pyfory.serializer import DataClassSerializer
+
+ fory = resolver.fory
+ return DataClassSerializer(
+ fory, self.cls, xlang=not fory.is_py,
field_names=self.get_field_names(),
serializers=self.create_fields_serializer(resolver)
+ )
+
def __repr__(self):
- return f"TypeDef(name={self.name}, type_id={self.type_id},
fields={self.fields}, is_compressed={self.is_compressed})"
+ return f"TypeDef(namespace={self.namespace}, typename={self.typename},
cls={self.cls}, type_id={self.type_id}, fields={self.fields},
is_compressed={self.is_compressed})"
class FieldInfo:
@@ -121,7 +134,11 @@ class FieldType:
elif xtype_id == TypeId.UNKNOWN:
return DynamicFieldType(xtype_id, False, is_nullable,
is_tracking_ref)
else:
- return FieldType(xtype_id, False, is_nullable, is_tracking_ref)
+ # For primitive types, determine if they are monomorphic based on
the type
+ from pyfory.type import is_polymorphic_type
+
+ is_monomorphic = not is_polymorphic_type(xtype_id)
+ return FieldType(xtype_id, is_monomorphic, is_nullable,
is_tracking_ref)
def create_serializer(self, resolver):
if self.type_id in [TypeId.EXT, TypeId.STRUCT, TypeId.NAMED_STRUCT,
TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT, TypeId.UNKNOWN]:
@@ -145,6 +162,8 @@ class CollectionFieldType(FieldType):
self.element_type = element_type
def create_serializer(self, resolver):
+ from pyfory.serializer import ListSerializer, SetSerializer
+
if self.type_id == TypeId.LIST:
return ListSerializer(resolver.fory, list,
self.element_type.create_serializer(resolver))
elif self.type_id == TypeId.SET:
@@ -170,6 +189,8 @@ class MapFieldType(FieldType):
def create_serializer(self, resolver):
key_serializer = self.key_type.create_serializer(resolver)
value_serializer = self.value_type.create_serializer(resolver)
+ from pyfory.serializer import MapSerializer
+
return MapSerializer(resolver.fory, dict, key_serializer,
value_serializer)
def __repr__(self):
@@ -192,6 +213,8 @@ class DynamicFieldType(FieldType):
def build_field_infos(type_resolver, cls):
"""Build field information for the class."""
+ from pyfory._struct import _sort_fields, StructTypeIdVisitor,
get_field_names
+
field_names = get_field_names(cls)
type_hints = typing.get_type_hints(cls)
@@ -205,6 +228,7 @@ def build_field_infos(type_resolver, cls):
field_infos.append(field_info)
serializers = [field_info.field_type.create_serializer(type_resolver) for
field_info in field_infos]
+
field_names, serializers = _sort_fields(type_resolver, field_names,
serializers)
field_infos_map = {field_info.name: field_info for field_info in
field_infos}
new_field_infos = []
@@ -217,12 +241,15 @@ def build_field_infos(type_resolver, cls):
def build_field_type(type_resolver, field_name: str, type_hint, visitor):
"""Build field type from type hint."""
type_ids = infer_field(field_name, type_hint, visitor)
+ print(f"=??????????=> {field_name, type_hint, visitor, type_ids}")
return build_field_type_from_type_ids(type_resolver, field_name, type_ids,
visitor)
def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids,
visitor):
tracking_ref = type_resolver.fory.ref_tracking
type_id = type_ids[0]
+ if type_id is not None and type_id >= 0:
+ type_id = type_id & 0xFF
morphic = not is_polymorphic_type(type_id)
if type_id in [TypeId.SET, TypeId.LIST]:
elem_type = build_field_type_from_type_ids(type_resolver, field_name,
type_ids[1], visitor)
@@ -234,7 +261,7 @@ def build_field_type_from_type_ids(type_resolver,
field_name: str, type_ids, vis
elif type_id in [TypeId.UNKNOWN, TypeId.EXT, TypeId.STRUCT,
TypeId.NAMED_STRUCT, TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT]:
return DynamicFieldType(type_id, False, True, tracking_ref)
else:
- assert is_primitive_type(type_id) or type_id in [TypeId.STRING,
TypeId.ENUM, TypeId.NAMED_ENUM], (
+ assert is_primitive_type(type_id) or type_id in [TypeId.STRING,
TypeId.ENUM, TypeId.NAMED_ENUM] or is_struct_type(type_id), (
f"Unknown type: {type_id} for field: {field_name}"
)
return FieldType(type_id, morphic, True, tracking_ref)
diff --git a/python/pyfory/meta/typedef_decoder.py
b/python/pyfory/meta/typedef_decoder.py
index 18f6cdbb4..2cf5633c7 100644
--- a/python/pyfory/meta/typedef_decoder.py
+++ b/python/pyfory/meta/typedef_decoder.py
@@ -25,18 +25,15 @@ from typing import List
from pyfory._util import Buffer
from pyfory.meta.typedef import TypeDef, FieldInfo, FieldType
from pyfory.meta.typedef import (
- FieldInfo,
- TypeDef,
SMALL_NUM_FIELDS_THRESHOLD,
REGISTER_BY_NAME_FLAG,
FIELD_NAME_SIZE_THRESHOLD,
COMPRESS_META_FLAG,
HAS_FIELDS_META_FLAG,
META_SIZE_MASKS,
- NUM_HASH_BITS,
FIELD_NAME_ENCODINGS,
)
-from pyfory.type import TypeId
+from pyfory.type import TypeId, record_class_factory
from pyfory.meta.metastring import MetaStringDecoder, Encoding
@@ -46,7 +43,20 @@ TYPENAME_DECODER = MetaStringDecoder("$", "_")
FIELD_NAME_DECODER = MetaStringDecoder("$", "_")
-def decode_typedef(buffer: Buffer, resolver) -> TypeDef:
+def skip_typedef(buffer: Buffer, header) -> None:
+ """
+ Skip a TypeDef from the buffer.
+ """
+ # Extract components from header
+ meta_size = header & META_SIZE_MASKS
+ # If meta size is at maximum, read additional size
+ if meta_size == META_SIZE_MASKS:
+ meta_size += buffer.read_varuint32()
+ # Read meta data
+ buffer.read_bytes(meta_size)
+
+
+def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef:
"""
Decode a TypeDef from the buffer.
@@ -58,7 +68,8 @@ def decode_typedef(buffer: Buffer, resolver) -> TypeDef:
The decoded TypeDef.
"""
# Read global binary header
- header = buffer.read_int64()
+ if header is None:
+ header = buffer.read_int64()
# Extract components from header
meta_size = header & META_SIZE_MASKS
@@ -90,11 +101,11 @@ def decode_typedef(buffer: Buffer, resolver) -> TypeDef:
# Check if registered by name
is_registered_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0
+ type_cls = None
# Read type info
if is_registered_by_name:
namespace = read_namespace(meta_buffer)
typename = read_typename(meta_buffer)
- name = namespace + "." + typename if namespace else typename
# Look up the type_id from namespace and typename
type_info = resolver.get_typeinfo_by_name(namespace, typename)
if type_info:
@@ -105,15 +116,23 @@ def decode_typedef(buffer: Buffer, resolver) -> TypeDef:
else:
type_id = meta_buffer.read_varuint32()
type_info = resolver.get_typeinfo_by_id(type_id)
- name = type_info.cls.__name__
-
+ if type_info is not None:
+ type_cls = type_info.cls
+ namespace = type_info.decode_namespace()
+ typename = type_info.decode_typename()
+ else:
+ namespace = "fory"
+ typename = f"Nonexistent{type_id}"
+ name = namespace + "." + typename if namespace else typename
# Read fields info if present
field_infos = []
if has_fields_meta:
field_infos = read_fields_info(meta_buffer, resolver, name, num_fields)
+ if type_cls is None:
+ type_cls = record_class_factory(name, [field_info.name for field_info
in field_infos])
# Create TypeDef object
- return TypeDef(name, type_id, field_infos, meta_data, is_compressed)
+ return TypeDef(namespace, typename, type_cls, type_id, field_infos,
meta_data, is_compressed)
def read_namespace(buffer: Buffer) -> str:
@@ -174,7 +193,7 @@ def read_field_info(buffer: Buffer, resolver,
defined_class: str) -> FieldInfo:
field_name_size += 1
encoding = FIELD_NAME_ENCODINGS[field_name_encoding]
is_nullable = (header & 0b10) != 0
- is_tracking_ref = header & 0b1
+ is_tracking_ref = (header & 0b1) != 0
# Read field type info (without flags since they're in the header)
xtype_id = buffer.read_varuint32()
diff --git a/python/pyfory/meta/typedef_encoder.py
b/python/pyfory/meta/typedef_encoder.py
index f652fc40f..7d8b5fdb3 100644
--- a/python/pyfory/meta/typedef_encoder.py
+++ b/python/pyfory/meta/typedef_encoder.py
@@ -33,7 +33,6 @@ from pyfory.meta.typedef import (
from pyfory.meta.metastring import MetaStringEncoder
from pyfory._util import Buffer
-from pyfory.type import TypeId
from pyfory.lib.mmh3 import hash_buffer
@@ -75,18 +74,17 @@ def encode_typedef(type_resolver, cls):
buffer.write_varuint32(len(field_infos) - SMALL_NUM_FIELDS_THRESHOLD)
# Write type info
- type_info = type_resolver.get_typeinfo(cls)
- assert type_info.type_id > 0
-
- if not TypeId.is_namespaced_type(type_info.type_id):
- buffer.write_varuint32(type_info.type_id)
- else:
+ if type_resolver.is_registered_by_name(cls):
header |= REGISTER_BY_NAME_FLAG
- namespace = type_info.decode_namespace()
- typename = type_info.decode_typename()
+ namespace, typename = type_resolver.get_registered_name(cls)
write_namespace(buffer, namespace)
write_typename(buffer, typename)
-
+ # Use the actual type_id from the resolver, not a generic one
+ type_id = type_resolver.get_registered_id(cls)
+ else:
+ assert type_resolver.is_registered_by_id(cls), "Class must be
registered by name or id"
+ type_id = type_resolver.get_registered_id(cls)
+ buffer.write_varuint32(type_id)
# Update header byte
buffer.put_uint8(0, header)
@@ -103,7 +101,15 @@ def encode_typedef(type_resolver, cls):
binary = compressed_binary
# Prepend header
binary = prepend_header(binary, is_compressed, len(field_infos) > 0)
- return TypeDef(cls.__name__, type_info.type_id, field_infos, binary,
is_compressed)
+ # Extract namespace and typename
+ if type_resolver.is_registered_by_name(cls):
+ namespace, typename = type_resolver.get_registered_name(cls)
+ else:
+ splits = cls.__name__.rsplit(".", 1)
+ if len(splits) == 1:
+ splits.insert(0, "")
+ namespace, typename = splits
+ return TypeDef(namespace, typename, cls, type_id, field_infos, binary,
is_compressed)
def prepend_header(buffer: bytes, is_compressed: bool, has_fields_meta: bool):
@@ -125,7 +131,7 @@ def prepend_header(buffer: bytes, is_compressed: bool,
has_fields_meta: bool):
result.write_varuint32(meta_size - META_SIZE_MASKS)
result.write_bytes(buffer)
- return result
+ return result.to_bytes()
def write_namespace(buffer: Buffer, namespace: str):
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index 0018232d9..c19bfc777 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -24,6 +24,7 @@ import os
import pickle
import types
import typing
+from typing import List
import warnings
from weakref import WeakValueDictionary
@@ -297,21 +298,22 @@ from pyfory._struct import _get_hash, _sort_fields,
ComplexTypeVisitor
class DataClassSerializer(Serializer):
- def __init__(self, fory, clz: type, xlang: bool = False):
+ def __init__(self, fory, clz: type, xlang: bool = False, field_names:
List[str] = None, serializers: List[Serializer] = None):
super().__init__(fory, clz)
self._xlang = xlang
# This will get superclass type hints too.
self._type_hints = typing.get_type_hints(clz)
- self._field_names = self._get_field_names(clz)
+ self._field_names = field_names or self._get_field_names(clz)
self._has_slots = hasattr(clz, "__slots__")
if self._xlang:
- self._serializers = [None] * len(self._field_names)
- visitor = ComplexTypeVisitor(fory)
- for index, key in enumerate(self._field_names):
- serializer = infer_field(key, self._type_hints[key], visitor,
types_path=[])
- self._serializers[index] = serializer
- self._field_names, self._serializers =
_sort_fields(fory.type_resolver, self._field_names, self._serializers)
+ self._serializers = serializers or [None] * len(self._field_names)
+ if serializers is None:
+ visitor = ComplexTypeVisitor(fory)
+ for index, key in enumerate(self._field_names):
+ serializer = infer_field(key, self._type_hints[key],
visitor, types_path=[])
+ self._serializers[index] = serializer
+ self._field_names, self._serializers =
_sort_fields(fory.type_resolver, self._field_names, self._serializers)
self._hash = 0 # Will be computed on first xwrite/xread
self._generated_xwrite_method = self._gen_xwrite_method()
self._generated_xread_method = self._gen_xread_method()
@@ -443,13 +445,14 @@ class DataClassSerializer(Serializer):
context["_field_names"] = self._field_names
context["_type_hints"] = self._type_hints
context["_serializers"] = self._serializers
- # Compute hash at generation time since we're in xlang mode
- if self._hash == 0:
- self._hash = _get_hash(self.fory, self._field_names,
self._type_hints)
stmts = [
f'"""xwrite method for {self.type_}"""',
- f"{buffer}.write_int32({self._hash})",
]
+ if not self.fory.compatible:
+ # Compute hash at generation time since we're in xlang mode
+ if self._hash == 0:
+ self._hash = _get_hash(self.fory, self._field_names,
self._type_hints)
+ stmts.append(f"{buffer}.write_int32({self._hash})")
if not self._has_slots:
stmts.append(f"{value_dict} = {value}.__dict__")
for index, field_name in enumerate(self._field_names):
@@ -487,18 +490,29 @@ class DataClassSerializer(Serializer):
context["_field_names"] = self._field_names
context["_type_hints"] = self._type_hints
context["_serializers"] = self._serializers
- # Compute hash at generation time since we're in xlang mode
- if self._hash == 0:
- self._hash = _get_hash(self.fory, self._field_names,
self._type_hints)
+ current_class_field_names = set(self._get_field_names(self.type_))
stmts = [
f'"""xread method for {self.type_}"""',
- f"read_hash = {buffer}.read_int32()",
- f"if read_hash != {self._hash}:",
- f""" raise TypeNotCompatibleError(
- f"Hash {{read_hash}} is not consistent with {self._hash} for type
{self.type_}")""",
- f"{obj} = {obj_class}.__new__({obj_class})",
- f"{ref_resolver}.reference({obj})",
]
+ if not self.fory.compatible:
+ # Compute hash at generation time since we're in xlang mode
+ if self._hash == 0:
+ self._hash = _get_hash(self.fory, self._field_names,
self._type_hints)
+ stmts.extend(
+ [
+ f"read_hash = {buffer}.read_int32()",
+ f"if read_hash != {self._hash}:",
+ f""" raise TypeNotCompatibleError(
+ f"Hash {{read_hash}} is not consistent with {self._hash} for
type {self.type_}")""",
+ ]
+ )
+ stmts.extend(
+ [
+ f"{obj} = {obj_class}.__new__({obj_class})",
+ f"{ref_resolver}.reference({obj})",
+ ]
+ )
+
if not self._has_slots:
stmts.append(f"{obj_dict} = {obj}.__dict__")
@@ -507,6 +521,9 @@ class DataClassSerializer(Serializer):
context[serializer_var] = self._serializers[index]
field_value = f"field_value{index}"
stmts.append(f"{field_value} = {fory}.xdeserialize_ref({buffer},
serializer={serializer_var})")
+ if field_name not in current_class_field_names:
+ stmts.append(f"# {field_name} is not in {self.type_}")
+ continue
if not self._has_slots:
stmts.append(f"{obj_dict}['{field_name}'] = {field_value}")
else:
diff --git a/python/pyfory/tests/benchmark.py b/python/pyfory/tests/benchmark.py
index ab6abe1b7..75c883296 100644
--- a/python/pyfory/tests/benchmark.py
+++ b/python/pyfory/tests/benchmark.py
@@ -33,13 +33,9 @@ def test_encode():
assert foo == encoder.from_row(row)
t1 = timeit.timeit(lambda: encoder.to_row(foo), number=iter_nums)
- print(
- "encoder take {0} for {1} times, avg: {2}".format(t1, iter_nums, t1 /
iter_nums)
- )
+ print("encoder take {0} for {1} times, avg: {2}".format(t1, iter_nums, t1
/ iter_nums))
t2 = timeit.timeit(lambda: pickle.dumps(foo), number=iter_nums)
- print(
- "pickle take {0} for {1} times, avg: {2}".format(t2, iter_nums, t2 /
iter_nums)
- )
+ print("pickle take {0} for {1} times, avg: {2}".format(t2, iter_nums, t2 /
iter_nums))
@pytest.mark.skip(reason="take too long")
@@ -51,18 +47,10 @@ def test_decode():
row = encoder.to_row(foo)
assert foo == encoder.from_row(row)
t1 = timeit.timeit(lambda: encoder.from_row(row), number=iter_nums)
- print(
- "encoder take {0} for {1} times, avg: {2}, size {3}".format(
- t1, iter_nums, t1 / iter_nums, row.size_bytes()
- )
- )
+ print("encoder take {0} for {1} times, avg: {2}, size {3}".format(t1,
iter_nums, t1 / iter_nums, row.size_bytes()))
pickled_data = pickle.dumps(foo)
t2 = timeit.timeit(lambda: pickle.loads(pickled_data), number=iter_nums)
- print(
- "pickle take {0} for {1} times, avg: {2}, size {3}".format(
- t2, iter_nums, t2 / iter_nums, len(pickled_data)
- )
- )
+ print("pickle take {0} for {1} times, avg: {2}, size {3}".format(t2,
iter_nums, t2 / iter_nums, len(pickled_data)))
if __name__ == "__main__":
diff --git a/python/pyfory/tests/record.py b/python/pyfory/tests/record.py
index 2f56a9ad8..31ebd66a8 100644
--- a/python/pyfory/tests/record.py
+++ b/python/pyfory/tests/record.py
@@ -117,9 +117,7 @@ def foo_schema():
("f4", pa.map_(pa.string(), pa.int32())),
("f5", pa.list_(pa.int32())),
("f6", pa.int32()),
- pa.field(
- "f7", bar_struct, metadata={"cls":
fory.get_qualified_classname(Bar)}
- ),
+ pa.field("f7", bar_struct, metadata={"cls":
fory.get_qualified_classname(Bar)}),
],
metadata={"cls": fory.get_qualified_classname(Foo)},
)
diff --git a/python/pyfory/tests/test_buffer.py
b/python/pyfory/tests/test_buffer.py
index 3ba9c388e..cefd6abf5 100644
--- a/python/pyfory/tests/test_buffer.py
+++ b/python/pyfory/tests/test_buffer.py
@@ -217,10 +217,7 @@ def check_varuint64(buf: Buffer, value: int,
bytes_written: int):
assert buf.writer_index == buf.reader_index
assert value == varint
# test slow read branch in `read_varint64`
- assert (
- buf.slice(reader_index, buf.reader_index -
reader_index).read_varuint64()
- == value
- )
+ assert buf.slice(reader_index, buf.reader_index -
reader_index).read_varuint64() == value
def test_write_buffer():
diff --git a/python/pyfory/tests/test_codegen.py
b/python/pyfory/tests/test_codegen.py
index 3b2243b29..b73d2465e 100644
--- a/python/pyfory/tests/test_codegen.py
+++ b/python/pyfory/tests/test_codegen.py
@@ -43,8 +43,6 @@ def test_debug_compiled():
def test_compile_function():
- code, func = codegen.compile_function(
- "test_compile_function", ["x"], ["print(1)", "print(2)", "return x"],
{}
- )
+ code, func = codegen.compile_function("test_compile_function", ["x"],
["print(1)", "print(2)", "return x"], {})
print(code)
assert func(100) == 100
diff --git a/python/pyfory/tests/test_meta_share.py
b/python/pyfory/tests/test_meta_share.py
new file mode 100644
index 000000000..d405b24dc
--- /dev/null
+++ b/python/pyfory/tests/test_meta_share.py
@@ -0,0 +1,284 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import dataclasses
+from typing import List, Dict
+from pyfory import Fory, Language
+import pyfory
+
+
[email protected]
+class SimpleDataClass:
+ name: str
+ age: int
+ active: bool
+
+
[email protected]
+class SimpleNestedDataClass:
+ value: int
+ name: str
+
+
[email protected]
+class ExtendedDataClass:
+ name: str
+ age: int
+ active: bool
+ email: str # Additional field
+
+
[email protected]
+class ReducedDataClass:
+ name: str
+ age: int
+ # Missing 'active' field
+
+
[email protected]
+class NestedStructClass:
+ name: str
+ nested: SimpleNestedDataClass
+
+
[email protected]
+class NestedStructClassInconsistent:
+ name: str
+ nested: ExtendedDataClass # Different nested type
+
+
[email protected]
+class ListFieldsClass:
+ name: str
+ int_list: List[pyfory.Int32Type]
+ str_list: List[str]
+
+
[email protected]
+class ListFieldsClassInconsistent:
+ name: str
+ int_list: List[str] # Changed from Int32Type to str
+ str_list: List[pyfory.Int32Type] # Changed from str to Int32Type
+
+
[email protected]
+class DictFieldsClass:
+ name: str
+ int_dict: Dict[str, pyfory.Int32Type]
+ str_dict: Dict[str, str]
+
+
[email protected]
+class DictFieldsClassInconsistent:
+ name: str
+ int_dict: Dict[str, str] # Changed from Int32Type to str
+ str_dict: Dict[str, pyfory.Int32Type] # Changed from str to Int32Type
+
+
+class TestMetaShareMode:
+ def setup_method(self):
+ """Setup method to register dataclasses for each test."""
+ pass
+
+ def test_meta_share_enabled(self):
+ """Test that meta share mode can be enabled."""
+ fory = Fory(language=Language.XLANG, compatible=True)
+ assert fory.serialization_context.scoped_meta_share_enabled
+ assert fory.serialization_context.meta_context is not None
+
+ def test_meta_share_disabled(self):
+ """Test that meta share mode can be disabled."""
+ fory = Fory(language=Language.XLANG, compatible=False)
+ assert not fory.serialization_context.scoped_meta_share_enabled
+ assert fory.serialization_context.meta_context is None
+
+ def test_simple_dataclass_serialization(self):
+ """Test serialization of simple dataclass with meta share."""
+ fory = Fory(language=Language.XLANG, compatible=True)
+
+ # Register the dataclass
+ fory.register_type(SimpleDataClass)
+
+ obj = SimpleDataClass(name="test", age=25, active=True)
+ buffer = fory.serialize(obj)
+
+ # Deserialize
+ deserialized = fory.deserialize(buffer)
+ assert deserialized.name == obj.name
+ assert deserialized.age == obj.age
+ assert deserialized.active == obj.active
+
+ def test_multiple_objects_same_type(self):
+ """Test that multiple objects of same type reuse type definition."""
+ fory = Fory(language=Language.XLANG, compatible=True)
+
+ # Register the dataclass
+ fory.register_type(SimpleDataClass)
+
+ obj1 = SimpleDataClass(name="test1", age=25, active=True)
+ obj2 = SimpleDataClass(name="test2", age=30, active=False)
+
+ # Serialize both objects
+ buffer1 = fory.serialize(obj1)
+ buffer2 = fory.serialize(obj2)
+
+ # Create a new fory instance with the same meta context for
deserialization
+ fory2 = Fory(language=Language.XLANG, compatible=True)
+ fory2.register_type(SimpleDataClass)
+ # Copy the meta context from the first fory instance
+ fory2.serialization_context.meta_context =
fory.serialization_context.meta_context
+
+ # Deserialize both
+ deserialized1 = fory2.deserialize(buffer1)
+ deserialized2 = fory2.deserialize(buffer2)
+
+ assert deserialized1.name == obj1.name
+ assert deserialized2.name == obj2.name
+ assert deserialized1.age == obj1.age
+ assert deserialized2.age == obj2.age
+
+ def test_simple_nested_dataclass_serialization(self):
+ """Test serialization of simple nested dataclass with meta share."""
+ fory = Fory(language=Language.XLANG, compatible=True)
+
+ # Register the dataclass
+ fory.register_type(SimpleNestedDataClass)
+
+ obj = SimpleNestedDataClass(value=42, name="test")
+
+ buffer = fory.serialize(obj)
+ deserialized = fory.deserialize(buffer)
+
+ assert deserialized.value == obj.value
+ assert deserialized.name == obj.name
+
+ def test_serialization_without_meta_share(self):
+ """Test that serialization works without meta share mode."""
+ fory = Fory(language=Language.XLANG, compatible=False)
+
+ # Register the dataclass
+ fory.register_type(SimpleDataClass)
+
+ obj = SimpleDataClass(name="test", age=25, active=True)
+ buffer = fory.serialize(obj)
+ deserialized = fory.deserialize(buffer)
+
+ assert deserialized.name == obj.name
+ assert deserialized.age == obj.age
+ assert deserialized.active == obj.active
+
+ def test_schema_evolution_more_fields(self):
+ # Serialize with original schema
+ fory1 = Fory(language=Language.XLANG, compatible=True)
+ fory1.register_type(SimpleDataClass)
+
+ obj = SimpleDataClass(name="test", age=25, active=True)
+ buffer = fory1.serialize(obj)
+
+ # Deserialize with extended schema (more fields)
+ fory2 = Fory(language=Language.XLANG, compatible=True)
+ fory2.register_type(ExtendedDataClass)
+ deserialized = fory2.deserialize(buffer)
+
+ # Current behavior: deserialized object is of the new registered type
+ assert isinstance(deserialized, ExtendedDataClass)
+ assert deserialized.name == obj.name
+ assert deserialized.age == obj.age
+ assert deserialized.active == obj.active
+ assert not hasattr(deserialized, "email")
+
+ def test_schema_evolution_fewer_fields(self):
+ # Serialize with original schema
+ fory1 = Fory(language=Language.XLANG, compatible=True)
+ fory1.register_type(SimpleDataClass)
+ obj = SimpleDataClass(name="test", age=25, active=True)
+ buffer = fory1.serialize(obj)
+
+ # Deserialize with reduced schema (fewer fields)
+ fory2 = Fory(language=Language.XLANG, compatible=True)
+ fory2.register_type(ReducedDataClass)
+ deserialized = fory2.deserialize(buffer)
+
+ assert isinstance(deserialized, ReducedDataClass)
+ assert deserialized.name == obj.name
+ assert deserialized.age == obj.age
+ # The missing field should not be present
+ assert not hasattr(deserialized, "active")
+
+ def test_schema_inconsistent_nested_struct(self):
+ """Test schema inconsistency with nested struct types."""
+ # Serialize with original schema
+ fory1 = Fory(language=Language.XLANG, compatible=True)
+ fory1.register_type(NestedStructClass)
+ fory1.register_type(SimpleNestedDataClass)
+
+ obj = NestedStructClass(name="test",
nested=SimpleNestedDataClass(value=42, name="nested_test"))
+ buffer = fory1.serialize(obj)
+
+ # Deserialize with inconsistent schema (different nested type)
+ fory2 = Fory(language=Language.XLANG, compatible=True)
+ fory2.register_type(NestedStructClassInconsistent)
+ fory2.register_type(ExtendedDataClass)
+
+ # This should handle the schema inconsistency gracefully
+ deserialized = fory2.deserialize(buffer)
+ assert isinstance(deserialized, NestedStructClassInconsistent)
+ assert deserialized.name == obj.name
+ # The nested field type has changed, so we expect different behavior
+ assert hasattr(deserialized, "nested")
+
+ def test_schema_inconsistent_list_fields(self):
+ """Test schema inconsistency with List field types."""
+ # Serialize with original schema
+ fory1 = Fory(language=Language.XLANG, compatible=True)
+ fory1.register_type(ListFieldsClass)
+
+ obj = ListFieldsClass(name="test", int_list=[1, 2, 3], str_list=["a",
"b", "c"])
+ buffer = fory1.serialize(obj)
+
+ # Deserialize with inconsistent schema (swapped List types)
+ fory2 = Fory(language=Language.XLANG, compatible=True)
+ fory2.register_type(ListFieldsClassInconsistent)
+
+ # This should handle the schema inconsistency gracefully
+ deserialized = fory2.deserialize(buffer)
+ assert isinstance(deserialized, ListFieldsClassInconsistent)
+ assert deserialized.name == obj.name
+ # The field types have been swapped, so we expect different behavior
+ assert hasattr(deserialized, "int_list")
+ assert hasattr(deserialized, "str_list")
+
+ def test_schema_inconsistent_dict_fields(self):
+ """Test schema inconsistency with Dict field types."""
+ # Serialize with original schema
+ fory1 = Fory(language=Language.XLANG, compatible=True)
+ fory1.register_type(DictFieldsClass)
+
+ obj = DictFieldsClass(name="test", int_dict={"key1": 1, "key2": 2},
str_dict={"key1": "value1", "key2": "value2"})
+ buffer = fory1.serialize(obj)
+
+ # Deserialize with inconsistent schema (swapped Dict value types)
+ fory2 = Fory(language=Language.XLANG, compatible=True)
+ fory2.register_type(DictFieldsClassInconsistent)
+
+ # This should handle the schema inconsistency gracefully
+ deserialized = fory2.deserialize(buffer)
+ assert isinstance(deserialized, DictFieldsClassInconsistent)
+ assert deserialized.name == obj.name
+ # The field value types have been swapped, so we expect different
behavior
+ assert hasattr(deserialized, "int_dict")
+ assert hasattr(deserialized, "str_dict")
diff --git a/python/pyfory/tests/test_metastring.py
b/python/pyfory/tests/test_metastring.py
index f21e09585..d470de299 100644
--- a/python/pyfory/tests/test_metastring.py
+++ b/python/pyfory/tests/test_metastring.py
@@ -196,7 +196,5 @@ def test_non_ascii_encoding_and_non_utf8():
non_ascii_string = "こんにちは" # Non-ASCII string
- with pytest.raises(
- ValueError, match="Unsupported character for LOWER_SPECIAL encoding: こ"
- ):
+ with pytest.raises(ValueError, match="Unsupported character for
LOWER_SPECIAL encoding: こ"):
encoder.encode_with_encoding(non_ascii_string, Encoding.LOWER_SPECIAL)
diff --git a/python/pyfory/tests/test_typedef_encoding.py
b/python/pyfory/tests/test_typedef_encoding.py
index 70ad35182..b53fa0986 100644
--- a/python/pyfory/tests/test_typedef_encoding.py
+++ b/python/pyfory/tests/test_typedef_encoding.py
@@ -75,9 +75,10 @@ def test_typedef_creation():
FieldInfo("age", FieldType(TypeId.INT32, True, True, False),
"TestTypeDef"),
]
- typedef = TypeDef("TestTypeDef", TypeId.STRUCT, fields, b"encoded_data",
False)
+ typedef = TypeDef("", "TestTypeDef", None, TypeId.STRUCT, fields,
b"encoded_data", False)
- assert typedef.name == "TestTypeDef"
+ assert typedef.namespace == ""
+ assert typedef.typename == "TestTypeDef"
assert typedef.type_id == TypeId.STRUCT
assert len(typedef.fields) == 2
assert typedef.encoded == b"encoded_data"
diff --git a/python/pyfory/type.py b/python/pyfory/type.py
index 7018f504f..1ff93e509 100644
--- a/python/pyfory/type.py
+++ b/python/pyfory/type.py
@@ -129,6 +129,7 @@ class TypeId:
Fory type for cross-language serialization.
See `org.apache.fory.types.Type`
"""
+
UNKNOWN = -1
# null value
NA = 0
@@ -356,7 +357,7 @@ def is_map_type(type_):
return issubclass(type_, typing.Dict)
except TypeError:
return False
-
+
_polymorphic_type_ids = {
TypeId.STRUCT,
@@ -368,11 +369,22 @@ _polymorphic_type_ids = {
TypeId.UNKNOWN,
}
+_struct_type_ids = {
+ TypeId.STRUCT,
+ TypeId.COMPATIBLE_STRUCT,
+ TypeId.NAMED_STRUCT,
+ TypeId.NAMED_COMPATIBLE_STRUCT,
+}
+
def is_polymorphic_type(type_id: int) -> bool:
return type_id in _polymorphic_type_ids
+def is_struct_type(type_id: int) -> bool:
+ return type_id in _struct_type_ids
+
+
def is_subclass(from_type, to_type):
try:
return issubclass(from_type, to_type)
@@ -401,30 +413,18 @@ class TypeVisitor(ABC):
def infer_field(field_name, type_, visitor: TypeVisitor, types_path=None):
types_path = list(types_path or [])
types_path.append(type_)
- origin = (
- typing.get_origin(type_)
- if hasattr(typing, "get_origin")
- else getattr(type_, "__origin__", type_)
- )
+ origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else
getattr(type_, "__origin__", type_)
origin = origin or type_
- args = (
- typing.get_args(type_)
- if hasattr(typing, "get_args")
- else getattr(type_, "__args__", ())
- )
+ args = typing.get_args(type_) if hasattr(typing, "get_args") else
getattr(type_, "__args__", ())
if args:
if origin is list or origin == typing.List:
elem_type = args[0]
return visitor.visit_list(field_name, elem_type,
types_path=types_path)
elif origin is dict or origin == typing.Dict:
key_type, value_type = args
- return visitor.visit_dict(
- field_name, key_type, value_type, types_path=types_path
- )
+ return visitor.visit_dict(field_name, key_type, value_type,
types_path=types_path)
else:
- raise TypeError(
- f"Collection types should be {list, dict} instead of {type_}"
- )
+ raise TypeError(f"Collection types should be {list, dict} instead
of {type_}")
else:
if is_function(origin) or not hasattr(origin, "__annotations__"):
return visitor.visit_other(field_name, type_,
types_path=types_path)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]