This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch releases-0.12
in repository https://gitbox.apache.org/repos/asf/fory.git

commit f693964ad72f26990ac36b986897c4dfa725c3c5
Author: chaokunyang <[email protected]>
AuthorDate: Wed Sep 17 23:24:08 2025 +0800

    limit pyfory depth
---
 python/pyfory/_fory.py           | 28 +++++++++++++++++--
 python/pyfory/_serialization.pyx | 50 ++++++++++++++++++++++++++++++++--
 python/pyfory/_serializer.py     | 59 ++++++++++++++--------------------------
 3 files changed, 94 insertions(+), 43 deletions(-)

diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py
index 6bc87ba1a..76ef34e72 100644
--- a/python/pyfory/_fory.py
+++ b/python/pyfory/_fory.py
@@ -112,6 +112,8 @@ class Fory:
         "_unsupported_callback",
         "_unsupported_objects",
         "_peer_language",
+        "max_depth",
+        "depth",
     )
     serialization_context: "SerializationContext"
 
@@ -120,6 +122,7 @@ class Fory:
         language=Language.PYTHON,
         ref_tracking: bool = False,
         require_type_registration: bool = True,
+        max_depth: int = 50,
     ):
         """
         :param require_type_registration:
@@ -130,6 +133,10 @@ class Fory:
           Do not disable type registration if you can't ensure your 
environment are
           *indeed secure*. We are not responsible for security risks if
           you disable this option.
+        :param max_depth:
+         The maximum depth of the deserialization data.
+         If the depth exceeds the maximum depth, an exception will be raised.
+         The default value is 50.
         """
         self.language = language
         self.is_py = language == Language.PYTHON
@@ -163,6 +170,8 @@ class Fory:
         self._unsupported_callback = None
         self._unsupported_objects = None
         self._peer_language = None
+        self.max_depth = max_depth
+        self.depth = 0
 
     def register(
         self,
@@ -381,7 +390,11 @@ class Fory:
         # indicates that the object is first read.
         if ref_id >= NOT_NULL_VALUE_FLAG:
             typeinfo = self.type_resolver.read_typeinfo(buffer)
+            self.depth += 1
+            if self.depth > self.max_depth:
+                self.throw_depth_limit_exceeded_exception()
             o = typeinfo.serializer.read(buffer)
+            self.depth -= 1
             ref_resolver.set_read_object(ref_id, o)
             return o
         else:
@@ -390,7 +403,12 @@ class Fory:
     def deserialize_nonref(self, buffer):
         """Deserialize not-null and non-reference object from buffer."""
         typeinfo = self.type_resolver.read_typeinfo(buffer)
-        return typeinfo.serializer.read(buffer)
+        self.depth += 1
+        if self.depth > self.max_depth:
+            self.throw_depth_limit_exceeded_exception()
+        o = typeinfo.serializer.read(buffer)
+        self.depth -= 1
+        return o
 
     def xdeserialize_ref(self, buffer, serializer=None):
         if serializer is None or serializer.need_to_write_ref:
@@ -411,7 +429,12 @@ class Fory:
     def xdeserialize_nonref(self, buffer, serializer=None):
         if serializer is None:
             serializer = self.type_resolver.read_typeinfo(buffer).serializer
-        return serializer.xread(buffer)
+        self.depth += 1
+        if self.depth > self.max_depth:
+            self.throw_depth_limit_exceeded_exception()
+        o = serializer.xread(buffer)
+        self.depth -= 1
+        return o
 
     def write_buffer_object(self, buffer, buffer_object: BufferObject):
         if self._buffer_callback is None or 
self._buffer_callback(buffer_object):
@@ -477,6 +500,7 @@ class Fory:
         self._unsupported_callback = None
 
     def reset_read(self):
+        self.depth = 0
         self.ref_resolver.reset_read()
         self.type_resolver.reset_read()
         self.serialization_context.reset()
diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx
index d98a248d5..833fbed24 100644
--- a/python/pyfory/_serialization.pyx
+++ b/python/pyfory/_serialization.pyx
@@ -599,12 +599,15 @@ cdef class Fory:
     cdef object _unsupported_callback
     cdef object _unsupported_objects  # iterator
     cdef object _peer_language
+    cdef int32_t max_depth
+    cdef int32_t depth
 
     def __init__(
             self,
             language=Language.PYTHON,
             ref_tracking: bool = False,
             require_type_registration: bool = True,
+            max_depth: int = 50,
     ):
         """
        :param require_type_registration:
@@ -615,6 +618,10 @@ cdef class Fory:
          Do not disable type registration if you can't ensure your environment 
are
          *indeed secure*. We are not responsible for security risks if
          you disable this option.
+       :param max_depth:
+        The maximum depth of the deserialization data.
+        If the depth exceeds the maximum depth, an exception will be raised.
+        The default value is 50.
        """
         self.language = language
         if _ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration:
@@ -646,6 +653,8 @@ cdef class Fory:
         self._unsupported_callback = None
         self._unsupported_objects = None
         self._peer_language = None
+        self.max_depth = max_depth
+        self.depth = 0
 
     def register_serializer(self, cls: Union[type, TypeVar], Serializer 
serializer):
         self.type_resolver.register_serializer(cls, serializer)
@@ -881,7 +890,9 @@ cdef class Fory:
             return buffer.read_bool()
         elif cls is float:
             return buffer.read_double()
+        self.inc_depth()
         o = typeinfo.serializer.read(buffer)
+        self.depth -= 1
         ref_resolver.set_read_object(ref_id, o)
         return o
 
@@ -897,7 +908,10 @@ cdef class Fory:
             return buffer.read_bool()
         elif cls is float:
             return buffer.read_double()
-        return typeinfo.serializer.read(buffer)
+        self.inc_depth()
+        o = typeinfo.serializer.read(buffer)
+        self.depth -= 1
+        return o
 
     cpdef inline xdeserialize_ref(self, Buffer buffer, Serializer 
serializer=None):
         cdef MapRefResolver ref_resolver
@@ -925,7 +939,25 @@ cdef class Fory:
             self, Buffer buffer, Serializer serializer=None):
         if serializer is None:
             serializer = self.type_resolver.read_typeinfo(buffer).serializer
-        return serializer.xread(buffer)
+        self.depth += 1
+        self.inc_depth()
+        o = serializer.xread(buffer)
+        self.depth -= 1
+        return o
+
+    cpdef inline inc_depth(self):
+        self.depth += 1
+        if self.depth > self.max_depth:
+            self.throw_depth_limit_exceeded_exception()
+
+    cpdef inline dec_depth(self):
+        self.depth -= 1
+
+    cpdef inline throw_depth_limit_exceeded_exception(self):
+        raise Exception(
+            f"Read depth exceed max depth: {self.depth}, the deserialization 
data may be malicious. If it's not malicious, "
+            "please increase max read depth by Fory(..., max_depth=...)"
+        )
 
     cpdef inline write_buffer_object(self, Buffer buffer, buffer_object):
         if self._buffer_callback is not None and 
self._buffer_callback(buffer_object):
@@ -997,6 +1029,7 @@ cdef class Fory:
         self._unsupported_callback = None
 
     cpdef inline reset_read(self):
+        self.depth = 0
         self.ref_resolver.reset_read()
         self.type_resolver.reset_read()
         self.metastring_resolver.reset_read()
@@ -1447,6 +1480,7 @@ cdef class CollectionSerializer(Serializer):
     cpdef _read_same_type_no_ref(self, Buffer buffer, int64_t len_, object 
collection_, TypeInfo typeinfo):
         cdef MapRefResolver ref_resolver = self.ref_resolver
         cdef TypeResolver type_resolver = self.type_resolver
+        self.fory.inc_depth()
         if self.is_py:
             for i in range(len_):
                 obj = typeinfo.serializer.read(buffer)
@@ -1455,6 +1489,7 @@ cdef class CollectionSerializer(Serializer):
             for i in range(len_):
                 obj = typeinfo.serializer.xread(buffer)
                 self._add_element(collection_, i, obj)
+        self.fory.dec_depth()
 
     cpdef _write_same_type_ref(self, Buffer buffer, value, TypeInfo typeinfo):
         cdef MapRefResolver ref_resolver = self.ref_resolver
@@ -1472,6 +1507,7 @@ cdef class CollectionSerializer(Serializer):
         cdef MapRefResolver ref_resolver = self.ref_resolver
         cdef TypeResolver type_resolver = self.type_resolver
         cdef c_bool is_py = self.is_py
+        self.fory.inc_depth()
         for i in range(len_):
             ref_id = ref_resolver.try_preserve_ref_id(buffer)
             if ref_id < NOT_NULL_VALUE_FLAG:
@@ -1483,6 +1519,7 @@ cdef class CollectionSerializer(Serializer):
                     obj = typeinfo.serializer.xread(buffer)
                 ref_resolver.set_read_object(ref_id, obj)
             self._add_element(collection_, i, obj)
+        self.fory.dec_depth()
 
     cpdef _add_element(self, object collection_, int64_t index, object 
element):
         raise NotImplementedError
@@ -1527,10 +1564,12 @@ cdef class ListSerializer(CollectionSerializer):
             else:
                 self._read_same_type_ref(buffer, len_, list_, typeinfo)
         else:
+            self.fory.inc_depth()
             for i in range(len_):
                 elem = get_next_element(buffer, ref_resolver, type_resolver, 
is_py)
                 Py_INCREF(elem)
                 PyList_SET_ITEM(list_, i, elem)
+            self.fory.dec_depth()
         return list_
 
     cpdef _add_element(self, object collection_, int64_t index, object 
element):
@@ -1670,6 +1709,7 @@ cdef class SetSerializer(CollectionSerializer):
             else:
                 self._read_same_type_ref(buffer, len_, instance, typeinfo)
         else:
+            self.fory.inc_depth()
             for i in range(len_):
                 ref_id = ref_resolver.try_preserve_ref_id(buffer)
                 if ref_id < NOT_NULL_VALUE_FLAG:
@@ -1693,6 +1733,7 @@ cdef class SetSerializer(CollectionSerializer):
                         o = typeinfo.serializer.xread(buffer)
                     ref_resolver.set_read_object(ref_id, o)
                     instance.add(o)
+            self.fory.dec_depth()
         return instance
 
     cpdef inline _add_element(self, object collection_, int64_t index, object 
element):
@@ -1927,6 +1968,7 @@ cdef class MapSerializer(Serializer):
         cdef type key_serializer_type, value_serializer_type
         cdef int32_t chunk_size
         cdef c_bool is_py = self.is_py
+        self.fory.inc_depth()
         while size > 0:
             while True:
                 key_has_null = (chunk_header & KEY_HAS_NULL) != 0
@@ -1982,6 +2024,7 @@ cdef class MapSerializer(Serializer):
                         map_[None] = None
                 size -= 1
                 if size == 0:
+                    self.fory.dec_depth()
                     return map_
                 else:
                     chunk_header = buffer.read_uint8()
@@ -2055,6 +2098,7 @@ cdef class MapSerializer(Serializer):
                 size -= 1
             if size != 0:
                 chunk_header = buffer.read_uint8()
+        self.fory.dec_depth()
         return map_
 
     cpdef inline xwrite(self, Buffer buffer, o):
@@ -2121,6 +2165,7 @@ cdef class SubMapSerializer(Serializer):
         cdef int32_t ref_id
         cdef TypeInfo key_typeinfo
         cdef TypeInfo value_typeinfo
+        self.fory.inc_depth()
         for i in range(len_):
             ref_id = ref_resolver.try_preserve_ref_id(buffer)
             if ref_id < NOT_NULL_VALUE_FLAG:
@@ -2150,6 +2195,7 @@ cdef class SubMapSerializer(Serializer):
                     value = value_typeinfo.serializer.read(buffer)
                     ref_resolver.set_read_object(ref_id, value)
             map_[key] = value
+        self.fory.dec_depth()
         return map_
 
 
diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py
index dacdae670..77691c6a8 100644
--- a/python/pyfory/_serializer.py
+++ b/python/pyfory/_serializer.py
@@ -53,15 +53,11 @@ KV_NULL = KEY_HAS_NULL | VALUE_HAS_NULL
 # Key is null, value type is declared type, and ref tracking for value is 
disabled.
 NULL_KEY_VALUE_DECL_TYPE = KEY_HAS_NULL | VALUE_DECL_TYPE
 # Key is null, value type is declared type, and ref tracking for value is 
enabled.
-NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF = (
-    KEY_HAS_NULL | VALUE_DECL_TYPE | TRACKING_VALUE_REF
-)
+NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF = KEY_HAS_NULL | VALUE_DECL_TYPE | 
TRACKING_VALUE_REF
 # Value is null, key type is declared type, and ref tracking for key is 
disabled.
 NULL_VALUE_KEY_DECL_TYPE = VALUE_HAS_NULL | KEY_DECL_TYPE
 # Value is null, key type is declared type, and ref tracking for key is 
enabled.
-NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = (
-    VALUE_HAS_NULL | KEY_DECL_TYPE | TRACKING_VALUE_REF
-)
+NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = VALUE_HAS_NULL | KEY_DECL_TYPE | 
TRACKING_VALUE_REF
 
 
 class Serializer(ABC):
@@ -182,11 +178,7 @@ _base_date = datetime.date(1970, 1, 1)
 class DateSerializer(CrossLanguageCompatibleSerializer):
     def write(self, buffer, value: datetime.date):
         if not isinstance(value, datetime.date):
-            raise TypeError(
-                "{} should be {} instead of {}".format(
-                    value, datetime.date, type(value)
-                )
-            )
+            raise TypeError("{} should be {} instead of {}".format(value, 
datetime.date, type(value)))
         days = (value - _base_date).days
         buffer.write_int32(days)
 
@@ -208,9 +200,7 @@ class 
TimestampSerializer(CrossLanguageCompatibleSerializer):
 
     def write(self, buffer, value: datetime.datetime):
         if not isinstance(value, datetime.datetime):
-            raise TypeError(
-                "{} should be {} instead of {}".format(value, datetime, 
type(value))
-            )
+            raise TypeError("{} should be {} instead of {}".format(value, 
datetime, type(value)))
         # TimestampType represent micro seconds
         buffer.write_int64(self._get_timestamp(value))
 
@@ -287,10 +277,7 @@ class CollectionSerializer(Serializer):
                     collect_flag |= COLLECTION_TRACKING_REF
         buffer.write_varuint32(len(value))
         buffer.write_int8(collect_flag)
-        if (
-            not has_different_type
-            and (collect_flag & COLLECTION_NOT_DECL_ELEMENT_TYPE) != 0
-        ):
+        if not has_different_type and (collect_flag & 
COLLECTION_NOT_DECL_ELEMENT_TYPE) != 0:
             self.type_resolver.write_typeinfo(buffer, elem_typeinfo)
         return collect_flag, elem_typeinfo
 
@@ -361,14 +348,17 @@ class CollectionSerializer(Serializer):
         raise NotImplementedError
 
     def _read_same_type_no_ref(self, buffer, len_, collection_, typeinfo):
+        self.fory.inc_depth()
         if self.is_py:
             for _ in range(len_):
                 self._add_element(collection_, 
typeinfo.serializer.read(buffer))
         else:
             for _ in range(len_):
                 self._add_element(collection_, 
typeinfo.serializer.xread(buffer))
+        self.fory.dec_depth()
 
     def _read_same_type_ref(self, buffer, len_, collection_, typeinfo):
+        self.fory.inc_depth()
         for _ in range(len_):
             ref_id = self.ref_resolver.try_preserve_ref_id(buffer)
             if ref_id < NOT_NULL_VALUE_FLAG:
@@ -380,15 +370,16 @@ class CollectionSerializer(Serializer):
                     obj = typeinfo.serializer.xread(buffer)
                 self.ref_resolver.set_read_object(ref_id, obj)
             self._add_element(collection_, obj)
+        self.fory.dec_depth()
 
     def _read_different_types(self, buffer, len_, collection_):
+        self.fory.inc_depth()
         for _ in range(len_):
             self._add_element(
                 collection_,
-                get_next_element(
-                    buffer, self.ref_resolver, self.type_resolver, self.is_py
-                ),
+                get_next_element(buffer, self.ref_resolver, 
self.type_resolver, self.is_py),
             )
+        self.fory.dec_depth()
 
     def xwrite(self, buffer, value):
         self.write(buffer, value)
@@ -532,12 +523,8 @@ class MapSerializer(Serializer):
                 type_resolver.write_typeinfo(buffer, value_typeinfo)
                 value_serializer = value_typeinfo.serializer
 
-            key_write_ref = (
-                key_serializer.need_to_write_ref if key_serializer else False
-            )
-            value_write_ref = (
-                value_serializer.need_to_write_ref if value_serializer else 
False
-            )
+            key_write_ref = key_serializer.need_to_write_ref if key_serializer 
else False
+            value_write_ref = value_serializer.need_to_write_ref if 
value_serializer else False
             if key_write_ref:
                 chunk_header |= TRACKING_KEY_REF
             if value_write_ref:
@@ -547,18 +534,11 @@ class MapSerializer(Serializer):
             chunk_size = 0
 
             while chunk_size < MAX_CHUNK_SIZE:
-                if (
-                    key is None
-                    or value is None
-                    or type(key) is not key_cls
-                    or type(value) is not value_cls
-                ):
+                if key is None or value is None or type(key) is not key_cls or 
type(value) is not value_cls:
                     break
                 if not key_write_ref or not 
ref_resolver.write_ref_or_null(buffer, key):
                     self._write_obj(key_serializer, buffer, key)
-                if not value_write_ref or not ref_resolver.write_ref_or_null(
-                    buffer, value
-                ):
+                if not value_write_ref or not 
ref_resolver.write_ref_or_null(buffer, value):
                     value_serializer.write(buffer, value)
 
                 chunk_size += 1
@@ -583,9 +563,8 @@ class MapSerializer(Serializer):
         if size != 0:
             chunk_header = buffer.read_uint8()
         key_serializer, value_serializer = self.key_serializer, 
self.value_serializer
-        deserialize_ref = (
-            fory.deserialize_ref if self.fory.is_py else fory.xdeserialize_ref
-        )
+        deserialize_ref = fory.deserialize_ref if self.fory.is_py else 
fory.xdeserialize_ref
+        fory.inc_depth()
         while size > 0:
             while True:
                 key_has_null = (chunk_header & KEY_HAS_NULL) != 0
@@ -626,6 +605,7 @@ class MapSerializer(Serializer):
                         map_[None] = None
                 size -= 1
                 if size == 0:
+                    fory.dec_depth()
                     return map_
                 else:
                     chunk_header = buffer.read_uint8()
@@ -662,6 +642,7 @@ class MapSerializer(Serializer):
                 size -= 1
             if size != 0:
                 chunk_header = buffer.read_uint8()
+        fory.dec_depth()
         return map_
 
     def _write_obj(self, serializer, buffer, obj):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to