This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fury.git


The following commit(s) were added to refs/heads/main by this push:
     new fb2172b7 feat(python): Implement collection serialization protocol 
(#1942)
fb2172b7 is described below

commit fb2172b7f88e3c64fb54bd4dfe1a46c7eaad8c61
Author: penguin_wwy <[email protected]>
AuthorDate: Mon Nov 18 10:35:57 2024 +0800

    feat(python): Implement collection serialization protocol (#1942)
    
    ## What does this PR do?
    
    Implement a new format for collection serialization in pyfury.
    
    ## Related issues
    
    ## Does this PR introduce any user-facing change?
    
    - [ ] Does this PR introduce any public API change?
    - [ ] Does this PR introduce any binary protocol compatibility change?
    
    ## Benchmark
    
    ```
    fury_tuple: Mean +- std dev: [base] 259 us +- 6 us -> [collection] 256 us 
+- 5 us: 1.01x faster
    fury_large_tuple: Mean +- std dev: [base] 92.7 ms +- 5.5 ms -> [collection] 
63.7 ms +- 4.8 ms: 1.46x faster
    fury_list: Mean +- std dev: [base] 277 us +- 6 us -> [collection] 267 us +- 
3 us: 1.04x faster
    fury_large_list: Mean +- std dev: [base] 92.8 ms +- 5.3 ms -> [collection] 
66.5 ms +- 3.0 ms: 1.40x faster
    
    Geometric mean: 1.21x faster
    ```
---
 python/pyfury/_serialization.pyx | 301 ++++++++++++++++++++++++++++++++-------
 1 file changed, 253 insertions(+), 48 deletions(-)

diff --git a/python/pyfury/_serialization.pyx b/python/pyfury/_serialization.pyx
index 1e2280ac..0175f615 100644
--- a/python/pyfury/_serialization.pyx
+++ b/python/pyfury/_serialization.pyx
@@ -69,6 +69,14 @@ logger = logging.getLogger(__name__)
 ENABLE_FURY_CYTHON_SERIALIZATION = os.environ.get(
     "ENABLE_FURY_CYTHON_SERIALIZATION", "True").lower() in ("true", "1")
 
+cdef extern from *:
+    """
+    #define int2obj(obj_addr) ((PyObject *)(obj_addr))
+    #define obj2int(obj_ref) (Py_INCREF(obj_ref), ((int64_t)(obj_ref)))
+    """
+    object int2obj(int64_t obj_addr)
+    int64_t obj2int(object obj_ref)
+
 
 cdef int8_t NULL_FLAG = -3
 # This flag indicates that object is a not-null value.
@@ -1630,6 +1638,18 @@ cdef class 
BytesSerializer(CrossLanguageCompatibleSerializer):
         return fury_buf.to_pybytes()
 
 
+"""
+Collection serialization format:
+https://fury.apache.org/docs/specification/fury_xlang_serialization_spec/#list
+Has the following changes:
+* None has an independent NonType type, so COLLECTION_NOT_SAME_TYPE can also 
cover the concept of being nullable.
+* No flag is needed to indicate that the element type is not the declared type.
+"""
+cdef int8_t COLLECTION_DEFAULT_FLAG = 0b0
+cdef int8_t COLLECTION_TRACKING_REF = 0b1
+cdef int8_t COLLECTION_NOT_SAME_TYPE = 0b1000
+
+
 cdef class CollectionSerializer(Serializer):
     cdef ClassResolver class_resolver
     cdef MapRefResolver ref_resolver
@@ -1644,29 +1664,143 @@ cdef class CollectionSerializer(Serializer):
     cpdef int16_t get_xtype_id(self):
         return -FuryType.LIST.value
 
+    cdef pair[int8_t, int64_t] write_header(self, Buffer buffer, value):
+        cdef int8_t collect_flag = COLLECTION_DEFAULT_FLAG
+        elem_type = type(next(iter(value)))
+        for s in value:
+            if type(s) is not elem_type:
+                collect_flag |= COLLECTION_NOT_SAME_TYPE
+                break
+        if self.fury.ref_tracking:
+            collect_flag |= COLLECTION_TRACKING_REF
+        buffer.write_varint64((len(value) << 4) | collect_flag)
+        return pair[int8_t, int64_t](collect_flag, obj2int(elem_type))
+
     cpdef write(self, Buffer buffer, value):
-        buffer.write_varint32(len(value))
+        if len(value) == 0:
+            buffer.write_varint64(0)
+            return
+        cdef pair[int8_t, int64_t] header_pair = self.write_header(buffer, 
value)
+        cdef int8_t collect_flag = header_pair.first
+        cdef int64_t elem_type_ptr = header_pair.second
+        cdef elem_type = <type>int2obj(elem_type_ptr)
         cdef MapRefResolver ref_resolver = self.ref_resolver
         cdef ClassResolver class_resolver = self.class_resolver
+        if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
+            if elem_type is str:
+                self._write_string(buffer, value)
+            elif elem_type is int:
+                self._write_int(buffer, value)
+            elif elem_type is bool:
+                self._write_bool(buffer, value)
+            elif elem_type is float:
+                self._write_float(buffer, value)
+            else:
+                if (collect_flag & COLLECTION_TRACKING_REF) == 0:
+                    self._write_same_type_no_ref(buffer, value, elem_type)
+                else:
+                    self._write_same_type_ref(buffer, value, elem_type)
+        else:
+            for s in value:
+                cls = type(s)
+                if cls is str:
+                    buffer.write_int16(NOT_NULL_STRING_FLAG)
+                    buffer.write_string(s)
+                elif cls is int:
+                    buffer.write_int16(NOT_NULL_PYINT_FLAG)
+                    buffer.write_varint64(s)
+                elif cls is bool:
+                    buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
+                    buffer.write_bool(s)
+                elif cls is float:
+                    buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
+                    buffer.write_double(s)
+                else:
+                    if not ref_resolver.write_ref_or_null(buffer, s):
+                        classinfo = class_resolver.get_or_create_classinfo(cls)
+                        class_resolver.write_classinfo(buffer, classinfo)
+                        classinfo.serializer.write(buffer, s)
+
+    cdef inline _write_string(self, Buffer buffer, value):
+        buffer.write_int16(NOT_NULL_STRING_FLAG)
         for s in value:
-            cls = type(s)
-            if cls is str:
-                buffer.write_int16(NOT_NULL_STRING_FLAG)
-                buffer.write_string(s)
-            elif cls is int:
-                buffer.write_int16(NOT_NULL_PYINT_FLAG)
-                buffer.write_varint64(s)
-            elif cls is bool:
-                buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
-                buffer.write_bool(s)
-            elif cls is float:
-                buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
-                buffer.write_double(s)
+            buffer.write_string(s)
+
+    cdef inline _read_string(self, Buffer buffer, int64_t len_, object 
collection_):
+        assert buffer.read_int16() == NOT_NULL_STRING_FLAG
+        for i in range(len_):
+            self._add_element(collection_, i, buffer.read_string())
+
+    cdef inline _write_int(self, Buffer buffer, value):
+        buffer.write_int16(NOT_NULL_PYINT_FLAG)
+        for s in value:
+            buffer.write_varint64(s)
+
+    cdef inline _read_int(self, Buffer buffer, int64_t len_, object 
collection_):
+        assert buffer.read_int16() == NOT_NULL_PYINT_FLAG
+        for i in range(len_):
+            self._add_element(collection_, i, buffer.read_varint64())
+
+    cdef inline _write_bool(self, Buffer buffer, value):
+        buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
+        for s in value:
+            buffer.write_bool(s)
+
+    cdef inline _read_bool(self, Buffer buffer, int64_t len_, object 
collection_):
+        assert buffer.read_int16() == NOT_NULL_PYBOOL_FLAG
+        for i in range(len_):
+            self._add_element(collection_, i, buffer.read_bool())
+
+    cdef inline _write_float(self, Buffer buffer, value):
+        buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
+        for s in value:
+            buffer.write_double(s)
+
+    cdef inline _read_float(self, Buffer buffer, int64_t len_, object 
collection_):
+        assert buffer.read_int16() == NOT_NULL_PYFLOAT_FLAG
+        for i in range(len_):
+            self._add_element(collection_, i, buffer.read_double())
+
+    cpdef _write_same_type_no_ref(self, Buffer buffer, value, elem_type):
+        cdef MapRefResolver ref_resolver = self.ref_resolver
+        cdef ClassResolver class_resolver = self.class_resolver
+        classinfo = class_resolver.get_or_create_classinfo(elem_type)
+        class_resolver.write_classinfo(buffer, classinfo)
+        for s in value:
+            classinfo.serializer.write(buffer, s)
+
+    cpdef _read_same_type_no_ref(self, Buffer buffer, int64_t len_, object 
collection_):
+        cdef MapRefResolver ref_resolver = self.ref_resolver
+        cdef ClassResolver class_resolver = self.class_resolver
+        classinfo = class_resolver.read_classinfo(buffer)
+        for i in range(len_):
+            obj = classinfo.serializer.read(buffer)
+            self._add_element(collection_, i, obj)
+
+    cpdef _write_same_type_ref(self, Buffer buffer, value, elem_type):
+        cdef MapRefResolver ref_resolver = self.ref_resolver
+        cdef ClassResolver class_resolver = self.class_resolver
+        classinfo = class_resolver.get_or_create_classinfo(elem_type)
+        class_resolver.write_classinfo(buffer, classinfo)
+        for s in value:
+            if not ref_resolver.write_ref_or_null(buffer, s):
+                classinfo.serializer.write(buffer, s)
+
+    cpdef _read_same_type_ref(self, Buffer buffer, int64_t len_, object 
collection_):
+        cdef MapRefResolver ref_resolver = self.ref_resolver
+        cdef ClassResolver class_resolver = self.class_resolver
+        classinfo = class_resolver.read_classinfo(buffer)
+        for i in range(len_):
+            ref_id = ref_resolver.try_preserve_ref_id(buffer)
+            if ref_id < NOT_NULL_VALUE_FLAG:
+                obj = ref_resolver.get_read_object()
             else:
-                if not ref_resolver.write_ref_or_null(buffer, s):
-                    classinfo = class_resolver.get_or_create_classinfo(cls)
-                    class_resolver.write_classinfo(buffer, classinfo)
-                    classinfo.serializer.write(buffer, s)
+                obj = classinfo.serializer.read(buffer)
+                ref_resolver.set_read_object(ref_id, obj)
+            self._add_element(collection_, i, obj)
+
+    cpdef _add_element(self, object collection_, int64_t index, object 
element):
+        raise NotImplementedError
 
     cpdef xwrite(self, Buffer buffer, value):
         cdef int32_t len_ = 0
@@ -1690,15 +1824,39 @@ cdef class ListSerializer(CollectionSerializer):
     cpdef read(self, Buffer buffer):
         cdef MapRefResolver ref_resolver = self.fury.ref_resolver
         cdef ClassResolver class_resolver = self.fury.class_resolver
-        cdef int32_t len_ = buffer.read_varint32()
+        cdef int64_t len_and_flag = buffer.read_varint64()
+        cdef int64_t len_ = len_and_flag >> 4
+        cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
         cdef list list_ = PyList_New(len_)
         ref_resolver.reference(list_)
-        for i in range(len_):
-            elem = get_next_elenment(buffer, ref_resolver, class_resolver)
-            Py_INCREF(elem)
-            PyList_SET_ITEM(list_, i, elem)
+        if len_ == 0:
+            return list_
+        if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
+            type_flag = buffer.get_int16(buffer.reader_index)
+            if type_flag == NOT_NULL_STRING_FLAG:
+                self._read_string(buffer, len_, list_)
+            elif type_flag == NOT_NULL_PYINT_FLAG:
+                self._read_int(buffer, len_, list_)
+            elif type_flag == NOT_NULL_PYBOOL_FLAG:
+                self._read_bool(buffer, len_, list_)
+            elif type_flag == NOT_NULL_PYFLOAT_FLAG:
+                self._read_float(buffer, len_, list_)
+            else:
+                if (collect_flag & COLLECTION_TRACKING_REF) == 0:
+                    self._read_same_type_no_ref(buffer, len_, list_)
+                else:
+                    self._read_same_type_ref(buffer, len_, list_)
+        else:
+            for i in range(len_):
+                elem = get_next_elenment(buffer, ref_resolver, class_resolver)
+                Py_INCREF(elem)
+                PyList_SET_ITEM(list_, i, elem)
         return list_
 
+    cpdef _add_element(self, object collection_, int64_t index, object 
element):
+        Py_INCREF(element)
+        PyList_SET_ITEM(collection_, index, element)
+
     cpdef xread(self, Buffer buffer):
         cdef int32_t len_ = buffer.read_varint32()
         cdef list collection_ = PyList_New(len_)
@@ -1746,14 +1904,38 @@ cdef class TupleSerializer(CollectionSerializer):
     cpdef inline read(self, Buffer buffer):
         cdef MapRefResolver ref_resolver = self.fury.ref_resolver
         cdef ClassResolver class_resolver = self.fury.class_resolver
-        cdef int32_t len_ = buffer.read_varint32()
+        cdef int64_t len_and_flag = buffer.read_varint64()
+        cdef int64_t len_ = len_and_flag >> 4
+        cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
         cdef tuple tuple_ = PyTuple_New(len_)
-        for i in range(len_):
-            elem = get_next_elenment(buffer, ref_resolver, class_resolver)
-            Py_INCREF(elem)
-            PyTuple_SET_ITEM(tuple_, i, elem)
+        if len_ == 0:
+            return tuple_
+        if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
+            type_flag = buffer.get_int16(buffer.reader_index)
+            if type_flag == NOT_NULL_STRING_FLAG:
+                self._read_string(buffer, len_, tuple_)
+            elif type_flag == NOT_NULL_PYINT_FLAG:
+                self._read_int(buffer, len_, tuple_)
+            elif type_flag == NOT_NULL_PYBOOL_FLAG:
+                self._read_bool(buffer, len_, tuple_)
+            elif type_flag == NOT_NULL_PYFLOAT_FLAG:
+                self._read_float(buffer, len_, tuple_)
+            else:
+                if (collect_flag & COLLECTION_TRACKING_REF) == 0:
+                    self._read_same_type_no_ref(buffer, len_, tuple_)
+                else:
+                    self._read_same_type_ref(buffer, len_, tuple_)
+        else:
+            for i in range(len_):
+                elem = get_next_elenment(buffer, ref_resolver, class_resolver)
+                Py_INCREF(elem)
+                PyTuple_SET_ITEM(tuple_, i, elem)
         return tuple_
 
+    cpdef inline _add_element(self, object collection_, int64_t index, object 
element):
+        Py_INCREF(element)
+        PyTuple_SET_ITEM(collection_, index, element)
+
     cpdef inline xread(self, Buffer buffer):
         cdef int32_t len_ = buffer.read_varint32()
         cdef tuple tuple_ = PyTuple_New(len_)
@@ -1785,31 +1967,54 @@ cdef class SetSerializer(CollectionSerializer):
         cdef ClassResolver class_resolver = self.fury.class_resolver
         cdef set instance = set()
         ref_resolver.reference(instance)
-        cdef int32_t len_ = buffer.read_varint32()
+        cdef int64_t len_and_flag = buffer.read_varint64()
+        cdef int64_t len_ = len_and_flag >> 4
+        cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
         cdef int32_t ref_id
         cdef ClassInfo classinfo
-        for i in range(len_):
-            ref_id = ref_resolver.try_preserve_ref_id(buffer)
-            if ref_id < NOT_NULL_VALUE_FLAG:
-                instance.add(ref_resolver.get_read_object())
-                continue
-            # indicates that the object is first read.
-            classinfo = class_resolver.read_classinfo(buffer)
-            cls = classinfo.cls
-            if cls is str:
-                instance.add(buffer.read_string())
-            elif cls is int:
-                instance.add(buffer.read_varint64())
-            elif cls is bool:
-                instance.add(buffer.read_bool())
-            elif cls is float:
-                instance.add(buffer.read_double())
+        if len_ == 0:
+            return instance
+        if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
+            type_flag = buffer.get_int16(buffer.reader_index)
+            if type_flag == NOT_NULL_STRING_FLAG:
+                self._read_string(buffer, len_, instance)
+            elif type_flag == NOT_NULL_PYINT_FLAG:
+                self._read_int(buffer, len_, instance)
+            elif type_flag == NOT_NULL_PYBOOL_FLAG:
+                self._read_bool(buffer, len_, instance)
+            elif type_flag == NOT_NULL_PYFLOAT_FLAG:
+                self._read_float(buffer, len_, instance)
             else:
-                o = classinfo.serializer.read(buffer)
-                ref_resolver.set_read_object(ref_id, o)
-                instance.add(o)
+                if (collect_flag & COLLECTION_TRACKING_REF) == 0:
+                    self._read_same_type_no_ref(buffer, len_, instance)
+                else:
+                    self._read_same_type_ref(buffer, len_, instance)
+        else:
+            for i in range(len_):
+                ref_id = ref_resolver.try_preserve_ref_id(buffer)
+                if ref_id < NOT_NULL_VALUE_FLAG:
+                    instance.add(ref_resolver.get_read_object())
+                    continue
+                # indicates that the object is first read.
+                classinfo = class_resolver.read_classinfo(buffer)
+                cls = classinfo.cls
+                if cls is str:
+                    instance.add(buffer.read_string())
+                elif cls is int:
+                    instance.add(buffer.read_varint64())
+                elif cls is bool:
+                    instance.add(buffer.read_bool())
+                elif cls is float:
+                    instance.add(buffer.read_double())
+                else:
+                    o = classinfo.serializer.read(buffer)
+                    ref_resolver.set_read_object(ref_id, o)
+                    instance.add(o)
         return instance
 
+    cpdef inline _add_element(self, object collection_, int64_t index, object 
element):
+        collection_.add(element)
+
     cpdef inline xread(self, Buffer buffer):
         cdef int32_t len_ = buffer.read_varint32()
         cdef set instance = set()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to