This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git


The following commit(s) were added to refs/heads/main by this push:
     new 1dc733a64 feat(python): add deserialization policy for more 
fine-grained control and audit deserialization behaviour (#2811)
1dc733a64 is described below

commit 1dc733a647e119e699004b9bf0727211d9f48e9e
Author: Shawn Yang <[email protected]>
AuthorDate: Wed Oct 22 23:14:55 2025 +0800

    feat(python): add deserialization policy for more fine-grained control and 
audit deserialization behaviour (#2811)
    
    ## Why?
    
    <!-- Describe the purpose of this PR. -->
    
    ## What does this PR do?
    
    add deserialization policy for more fine-grained control and audit
    deserialization behaviour
    
    ## Related issues
    
    <!--
    Is there any related issue? If this PR closes them you say say
    fix/closes:
    
    - #xxxx0
    - #xxxx1
    - Fixes #xxxx2
    -->
    
    ## Does this PR introduce any user-facing change?
    
    <!--
    If any user-facing interface changes, please [open an
    issue](https://github.com/apache/fory/issues/new/choose) describing the
    need to do so and update the document if necessary.
    
    Delete section if not applicable.
    -->
    
    - [ ] Does this PR introduce any public API change?
    - [ ] Does this PR introduce any binary protocol compatibility change?
    
    ## Benchmark
    
    <!--
    When the PR has an impact on performance (if you don't know whether the
    PR will have an impact on performance, you can submit the PR first, and
    if it will have impact on performance, the code reviewer will explain
    it), be sure to attach a benchmark data here.
    
    Delete section if not applicable.
    -->
---
 python/README.md                   |  63 +++-
 python/pyfory/__init__.py          |   1 +
 python/pyfory/_fory.py             |   8 +
 python/pyfory/_serialization.pyx   |  99 +------
 python/pyfory/policy.py            | 580 +++++++++++++++++++++++++++++++++++++
 python/pyfory/serializer.py        |  83 ++++--
 python/pyfory/tests/test_policy.py | 261 +++++++++++++++++
 7 files changed, 977 insertions(+), 118 deletions(-)

diff --git a/python/README.md b/python/README.md
index f0bd62a32..3426a4545 100644
--- a/python/README.md
+++ b/python/README.md
@@ -133,7 +133,7 @@ print(result)  # Person(name='Bob', age=25, ...)
 - **For circular references**: Set `ref=True` to enable reference tracking
 - **For functions/classes**: Set `strict=False` to allow deserialization of 
dynamic types
 
-**⚠️ Security Warning**: When `strict=False`, Fory will deserialize arbitrary 
types, which can pose security risks if data comes from untrusted sources. Only 
use `strict=False` in controlled environments where you trust the data source 
completely.
+**⚠️ Security Warning**: When `strict=False`, Fory will deserialize arbitrary 
types, which can pose security risks if data comes from untrusted sources. Only 
use `strict=False` in controlled environments where you trust the data source 
completely. If you do need to use `strict=False`, please configure a 
`DeserializationPolicy` when creating fory using `policy=your_policy` to 
controlling deserialization behavior.
 
 #### Common Usage
 
@@ -1141,6 +1141,67 @@ else:
         fory.register(model_class, type_id=100 + idx)
 ```
 
+### DeserializationPolicy
+
+When `strict=False` is necessary (e.g., deserializing functions/lambdas), use 
`DeserializationPolicy` to implement fine-grained security controls during 
deserialization. This provides protection similar to 
`pickle.Unpickler.find_class()` but with more comprehensive hooks.
+
+**Why use DeserializationPolicy?**
+
+- Block dangerous classes/modules (e.g., `subprocess.Popen`)
+- Intercept and validate `__reduce__` callables before invocation
+- Sanitize sensitive data during `__setstate__`
+- Replace or reject deserialized objects based on custom rules
+
+**Example: Blocking Dangerous Classes**
+
+```python
+import pyfory
+from pyfory import DeserializationPolicy
+
+dangerous_modules = {'subprocess', 'os', '__builtin__'}
+
+class SafeDeserializationPolicy(DeserializationPolicy):
+    """Block potentially dangerous classes during deserialization."""
+
+    def validate_class(self, cls, is_local, **kwargs):
+        # Block dangerous modules
+        if cls.__module__ in dangerous_modules:
+            raise ValueError(f"Blocked dangerous class: 
{cls.__module__}.{cls.__name__}")
+        return None
+
+    def intercept_reduce_call(self, callable_obj, args, **kwargs):
+        # Block specific callable invocations during __reduce__
+        if getattr(callable_obj, '__name__', "") == 'Popen':
+            raise ValueError("Blocked attempt to invoke subprocess.Popen")
+        return None
+
+    def intercept_setstate(self, obj, state, **kwargs):
+        # Sanitize sensitive data
+        if isinstance(state, dict) and 'password' in state:
+            state['password'] = '***REDACTED***'
+        return None
+
+# Create Fory with custom security policy
+policy = SafeDeserializationPolicy()
+fory = pyfory.Fory(xlang=False, ref=True, strict=False, policy=policy)
+
+# Now deserialization is protected by your custom policy
+data = fory.serialize(my_object)
+result = fory.deserialize(data)  # Policy hooks will be invoked
+```
+
+**Available Policy Hooks:**
+
+- `validate_class(cls, is_local)` - Validate/block class types during 
deserialization
+- `validate_module(module, is_local)` - Validate/block module imports
+- `validate_function(func, is_local)` - Validate/block function references
+- `intercept_reduce_call(callable_obj, args)` - Intercept `__reduce__` 
invocations
+- `inspect_reduced_object(obj)` - Inspect/replace objects created via 
`__reduce__`
+- `intercept_setstate(obj, state)` - Sanitize state before `__setstate__`
+- `authorize_instantiation(cls, args, kwargs)` - Control class instantiation
+
+**See also:** `pyfory/policy.py` contains detailed documentation and examples 
for each hook.
+
 ## 🐛 Troubleshooting
 
 ### Common Issues
diff --git a/python/pyfory/__init__.py b/python/pyfory/__init__.py
index bf070bea4..0ff316d60 100644
--- a/python/pyfory/__init__.py
+++ b/python/pyfory/__init__.py
@@ -54,6 +54,7 @@ from pyfory.type import (  # noqa: F401 # pylint: 
disable=unused-import
     Float64ArrayType,
     dataslots,
 )
+from pyfory.policy import DeserializationPolicy  # noqa: F401 # pylint: 
disable=unused-import
 from pyfory._util import Buffer  # noqa: F401 # pylint: disable=unused-import
 
 import warnings
diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py
index f283bc4ea..9398a3b78 100644
--- a/python/pyfory/_fory.py
+++ b/python/pyfory/_fory.py
@@ -30,6 +30,7 @@ from pyfory.resolver import (
 )
 from pyfory.util import is_little_endian, set_bit, get_bit, clear_bit
 from pyfory.type import TypeId
+from pyfory.policy import DeserializationPolicy, DEFAULT_POLICY
 
 try:
     import numpy as np
@@ -122,6 +123,7 @@ class Fory:
         "max_depth",
         "depth",
         "field_nullable",
+        "policy",
     )
 
     def __init__(
@@ -131,6 +133,7 @@ class Fory:
         strict: bool = True,
         compatible: bool = False,
         max_depth: int = 50,
+        policy: DeserializationPolicy = None,
         field_nullable: bool = False,
         **kwargs,
     ):
@@ -157,6 +160,10 @@ class Fory:
           Do not disable strict mode if you can't ensure your environment are
           *indeed secure*. We are not responsible for security risks if
           you disable this option.
+        :param policy:
+         A custom type policy for deserialization security check.
+         If not None, it will be used to check whether a type can be 
deserialized
+         instead of the default type policy.
         :param compatible:
          Whether to enable compatible mode for cross-language serialization.
          When enabled, type forward/backward compatibility for dataclass 
fields will be enabled.
@@ -182,6 +189,7 @@ class Fory:
         if kwargs.get("require_type_registration") is not None:
             strict = kwargs.get("require_type_registration")
         self.strict = _ENABLE_TYPE_REGISTRATION_FORCIBLY or strict
+        self.policy = policy or DEFAULT_POLICY
         self.compatible = compatible
         self.field_nullable = field_nullable if self.is_py else False
         from pyfory._serialization import MetaStringResolver, 
SerializationContext
diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx
index 188c48320..9f246ff08 100644
--- a/python/pyfory/_serialization.pyx
+++ b/python/pyfory/_serialization.pyx
@@ -34,6 +34,7 @@ from pyfory._fory import _ENABLE_TYPE_REGISTRATION_FORCIBLY
 from pyfory.lib import mmh3
 from pyfory.meta.metastring import Encoding
 from pyfory.type import is_primitive_type
+from pyfory.policy import DeserializationPolicy, DEFAULT_POLICY
 from pyfory.util import is_little_endian
 from pyfory.includes.libserialization cimport \
     (TypeId, IsNamespacedType, IsTypeShareMeta, 
Fory_PyBooleanSequenceWriteToBuffer, Fory_PyFloatSequenceWriteToBuffer)
@@ -805,6 +806,7 @@ cdef class Fory:
     cdef readonly c_bool is_py
     cdef readonly c_bool compatible
     cdef readonly c_bool field_nullable
+    cdef readonly object policy
     cdef readonly MapRefResolver ref_resolver
     cdef readonly TypeResolver type_resolver
     cdef readonly MetaStringResolver metastring_resolver
@@ -823,6 +825,7 @@ cdef class Fory:
             xlang: bool = False,
             ref: bool = False,
             strict: bool = True,
+            policy: DeserializationPolicy = None,
             compatible: bool = False,
             max_depth: int = 50,
             field_nullable: bool = False,
@@ -851,6 +854,10 @@ cdef class Fory:
           Do not disable strict mode if you can't ensure your environment are
           *indeed secure*. We are not responsible for security risks if
           you disable this option.
+        :param policy:
+         A custom type policy for deserialization security check.
+         If not None, it will be used to check whether a type can be 
deserialized
+         instead of the default type policy.
         :param compatible:
          Whether to enable compatible mode for cross-language serialization.
          When enabled, type forward/backward compatibility for struct fields 
will be enabled.
@@ -873,6 +880,7 @@ cdef class Fory:
             self.strict = True
         else:
             self.strict = False
+        self.policy = policy or DEFAULT_POLICY
         self.compatible = compatible
         self.ref_tracking = ref
         self.ref_resolver = MapRefResolver(ref)
@@ -2368,97 +2376,6 @@ cdef class MapSerializer(Serializer):
         return self.read(buffer)
 
 
[email protected]
-cdef class SubMapSerializer(Serializer):
-    cdef TypeResolver type_resolver
-    cdef MapRefResolver ref_resolver
-    cdef Serializer key_serializer
-    cdef Serializer value_serializer
-
-    def __init__(self, fory, type_, key_serializer=None, 
value_serializer=None):
-        super().__init__(fory, type_)
-        self.type_resolver = fory.type_resolver
-        self.ref_resolver = fory.ref_resolver
-        self.key_serializer = key_serializer
-        self.value_serializer = value_serializer
-
-    cpdef inline write(self, Buffer buffer, value):
-        buffer.write_varuint32(len(value))
-        cdef TypeInfo key_typeinfo
-        cdef TypeInfo value_typeinfo
-        for k, v in value.items():
-            key_cls = type(k)
-            if key_cls is str:
-                buffer.write_int16(NOT_NULL_STRING_FLAG)
-                buffer.write_string(k)
-            else:
-                if not self.ref_resolver.write_ref_or_null(buffer, k):
-                    key_typeinfo = self.type_resolver.get_typeinfo(key_cls)
-                    self.type_resolver.write_typeinfo(buffer, key_typeinfo)
-                    key_typeinfo.serializer.write(buffer, k)
-            value_cls = type(v)
-            if value_cls is str:
-                buffer.write_int16(NOT_NULL_STRING_FLAG)
-                buffer.write_string(v)
-            elif value_cls is int:
-                buffer.write_int16(NOT_NULL_INT64_FLAG)
-                buffer.write_varint64(v)
-            elif value_cls is bool:
-                buffer.write_int16(NOT_NULL_BOOL_FLAG)
-                buffer.write_bool(v)
-            elif value_cls is float:
-                buffer.write_int16(NOT_NULL_FLOAT64_FLAG)
-                buffer.write_double(v)
-            else:
-                if not self.ref_resolver.write_ref_or_null(buffer, v):
-                    value_typeinfo = self.type_resolver. \
-                        get_typeinfo(value_cls)
-                    self.type_resolver.write_typeinfo(buffer, value_typeinfo)
-                    value_typeinfo.serializer.write(buffer, v)
-
-    cpdef inline read(self, Buffer buffer):
-        cdef MapRefResolver ref_resolver = self.fory.ref_resolver
-        cdef TypeResolver type_resolver = self.fory.type_resolver
-        map_ = self.type_()
-        ref_resolver.reference(map_)
-        cdef int32_t len_ = buffer.read_varuint32()
-        cdef int32_t ref_id
-        cdef TypeInfo key_typeinfo
-        cdef TypeInfo value_typeinfo
-        self.fory.inc_depth()
-        for i in range(len_):
-            ref_id = ref_resolver.try_preserve_ref_id(buffer)
-            if ref_id < NOT_NULL_VALUE_FLAG:
-                key = ref_resolver.get_read_object()
-            else:
-                key_typeinfo = type_resolver.read_typeinfo(buffer)
-                if key_typeinfo.cls is str:
-                    key = buffer.read_string()
-                else:
-                    key = key_typeinfo.serializer.read(buffer)
-                    ref_resolver.set_read_object(ref_id, key)
-            ref_id = ref_resolver.try_preserve_ref_id(buffer)
-            if ref_id < NOT_NULL_VALUE_FLAG:
-                value = ref_resolver.get_read_object()
-            else:
-                value_typeinfo = type_resolver.read_typeinfo(buffer)
-                cls = value_typeinfo.cls
-                if cls is str:
-                    value = buffer.read_string()
-                elif cls is int:
-                    value = buffer.read_varint64()
-                elif cls is bool:
-                    value = buffer.read_bool()
-                elif cls is float:
-                    value = buffer.read_double()
-                else:
-                    value = value_typeinfo.serializer.read(buffer)
-                    ref_resolver.set_read_object(ref_id, value)
-            map_[key] = value
-        self.fory.dec_depth()
-        return map_
-
-
 @cython.final
 cdef class EnumSerializer(Serializer):
     @classmethod
diff --git a/python/pyfory/policy.py b/python/pyfory/policy.py
new file mode 100644
index 000000000..a47f5d8ae
--- /dev/null
+++ b/python/pyfory/policy.py
@@ -0,0 +1,580 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+class DeserializationPolicy:
+    """Deserialization Security Policy for PyFory.
+
+    DeserializationPolicy provides a comprehensive security layer for 
controlling deserialization
+    behavior, similar to how pickle.Unpickler can be customized but with 
finer-grained
+    control over the deserialization process.
+
+    Comparison with pickle.Unpickler
+    --------------------------------
+    Python's pickle.Unpickler provides basic security through the find_class() 
method,
+    which can be overridden to control class imports:
+
+        >>> class SafeUnpickler(pickle.Unpickler):
+        ...     def find_class(self, module, name):
+        ...         # Only allow safe modules
+        ...         if module in ('builtins', 'datetime'):
+        ...             return super().find_class(module, name)
+        ...         raise ValueError(f"Unsafe module: {module}")
+
+    Fory's DeserializationPolicy provides MORE granular control:
+
+    
+---------------------------+----------------------+----------------------------+
+    | Security Feature          | pickle.Unpickler     | Fory 
DeserializationPolicy           |
+    
+---------------------------+----------------------+----------------------------+
+    | Class import control      | ✓ find_class()       | ✓ validate_class()    
     |
+    | Function import control   | ✗ (via find_class)   | ✓ validate_function() 
     |
+    | Method validation         | ✗                    | ✓ validate_method()   
     |
+    | Module import control     | ✗                    | ✓ validate_module()   
     |
+    | Instantiation control     | ✗                    | ✓ 
authorize_instantiation()|
+    | __reduce__ interception   | ✗                    | ✓ 
intercept_reduce_call()  |
+    | Post-reduce inspection    | ✗                    | ✓ 
inspect_reduced_object() |
+    | __setstate__ interception | ✗                    | ✓ 
intercept_setstate()     |
+    | Object replacement        | ✗                    | ✓ (return from 
validators) |
+    | State sanitization        | ✗                    | ✓ (modify in-place)   
     |
+    | Local class/function      | ✗                    | ✓ (is_local flag)     
     |
+    
+---------------------------+----------------------+----------------------------+
+
+    Example: Blocking subprocess.Popen with pickle vs Fory:
+
+        # pickle.Unpickler - only catches class imports
+        class SafeUnpickler(pickle.Unpickler):
+            def find_class(self, module, name):
+                if module == 'subprocess' and name == 'Popen':
+                    raise ValueError("Blocked")
+                return super().find_class(module, name)
+
+        # Problem: Can't catch Popen invoked via __reduce__!
+        # A malicious payload can serialize: (subprocess.Popen, (["rm", "-rf", 
"/"],))
+
+        # Fory DeserializationPolicy - catches both imports AND reduce 
invocations
+        class SafeChecker(DeserializationPolicy):
+            def validate_class(self, cls, is_local, **kwargs):
+                if cls.__module__ == 'subprocess' and cls.__name__ == 'Popen':
+                    raise ValueError("Blocked")
+                return None
+
+            def intercept_reduce_call(self, callable_obj, args, **kwargs):
+                if callable_obj.__name__ == 'Popen':
+                    raise ValueError("Blocked at invocation!")
+                return None
+
+    Security Context
+    ----------------
+    Deserialization of untrusted data is inherently dangerous. Malicious 
payloads can:
+    - Import and instantiate arbitrary classes (e.g., subprocess.Popen)
+    - Execute arbitrary code through __reduce__ or __setstate__
+    - Access sensitive modules or perform unauthorized operations
+    - Cause denial of service through resource exhaustion
+
+    This DeserializationPolicy interface allows users to implement custom 
security policies by
+    subclassing and overriding specific hook methods. Each hook is called at a 
critical
+    point during deserialization, allowing inspection, replacement, or 
rejection of
+    dangerous constructs.
+
+    Hook Categories
+    ---------------
+    1. **Instantiation Authorization Hooks** (Guards)
+       - Control which classes can be instantiated
+       - Raise exception to block, return None to allow
+
+    2. **Reference Validation Hooks** (Validators)
+       - Validate deserialized type/function/module references
+       - Return None to accept original, return object to replace, raise 
exception to block,
+
+    3. **Protocol Interception Hooks** (Interceptors)
+       - Intercept pickle protocol operations (__reduce__, __setstate__)
+       - Return None to continue, return object to replace, or modify 
in-place, raise exception to block,
+
+    Usage Example
+    -------------
+    >>> class SafeDeserializationPolicy(DeserializationPolicy):
+    ...     ALLOWED_MODULES = {'builtins', 'datetime', 'decimal'}
+    ...
+    ...     def validate_module(self, module_name, **kwargs):
+    ...         # Reject imports from disallowed modules
+    ...         if module_name.split('.')[0] not in self.ALLOWED_MODULES:
+    ...             raise ValueError(f"Module {module_name} is not allowed")
+    ...         return None  # Accept
+    ...
+    ...     def validate_class(self, cls, is_local, **kwargs):
+    ...         # Reject dangerous built-in classes
+    ...         if cls.__name__ in ('eval', 'exec', 'compile'):
+    ...             raise ValueError(f"Class {cls} is forbidden")
+    ...         return None  # Accept
+    ...
+    ...     def intercept_reduce_call(self, callable_obj, args, **kwargs):
+    ...         # Log all __reduce__ callables for audit
+    ...         print(f"Reducing with {callable_obj.__name__}({args})")
+    ...         return None  # Proceed normally
+    ...
+    >>> fory = Fory(checker=SafeDeserializationPolicy())
+
+    Thread Safety
+    -------------
+    DeserializationPolicy instances should be thread-safe if shared across 
multiple Fory instances.
+    The default implementation is stateless and thread-safe.
+
+    Performance Considerations
+    --------------------------
+    - Hooks are called frequently during deserialization
+    - Keep validation logic fast to avoid performance degradation
+    - Cache validation results when possible (e.g., maintain allowed/blocked 
sets)
+    - Avoid I/O operations in hooks unless necessary
+
+    See Also
+    --------
+    - Python's pickle module security warnings: 
https://docs.python.org/3/library/pickle.html
+    - Fory documentation on secure deserialization: docs/guide/security.md
+    """
+
+    # 
============================================================================
+    # Instantiation Authorization Hooks (Guards)
+    # 
============================================================================
+
+    def authorize_instantiation(self, cls, **kwargs):
+        """Authorize instantiation of a class during deserialization.
+
+        This hook is called before creating an instance of any class during 
deserialization.
+        It acts as a security gate to prevent instantiation of dangerous 
classes.
+
+        When Called
+        -----------
+        - Before creating instances via cls.__new__(cls) in deserializers
+        - For both dataclass and regular object deserialization
+
+        Security Use Cases
+        ------------------
+        - Whitelist/blacklist specific classes by name or module
+        - Reject classes that could execute code in __init__ or __new__
+        - Prevent resource-exhausting classes (e.g., large buffers, threads)
+        - Log instantiation attempts for security auditing
+
+        Args:
+            cls (type): The class about to be instantiated.
+            **kwargs: Reserved for future extensions.
+
+        Raises:
+            Exception: Raise any exception to block instantiation. The 
exception
+                      will propagate to the caller of Fory.deserialize().
+
+        Returns:
+            None: Always return None to authorize. This method is a guard, not 
a transformer.
+
+        Example:
+            >>> class WhitelistChecker(DeserializationPolicy):
+            ...     ALLOWED = {'MyClass', 'SafeDataClass'}
+            ...
+            ...     def authorize_instantiation(self, cls, **kwargs):
+            ...         if cls.__name__ not in self.ALLOWED:
+            ...             raise ValueError(f"Class {cls.__name__} not 
whitelisted")
+
+        Note:
+            This method was previously named check_read_allowed and 
check_create_object.
+            Those names are kept as aliases for backward compatibility.
+        """
+        pass
+
+    # 
============================================================================
+    # Reference Validation Hooks (Validators)
+    # 
============================================================================
+
+    def validate_class(self, cls, *, is_local: bool, **kwargs):
+        """Validate a deserialized class reference.
+
+        This hook is called after a class reference has been deserialized 
(either by
+        importing from a module or reconstructing a local class), but before 
it is used.
+        It allows inspection, replacement, or rejection of class references.
+
+        When Called
+        -----------
+        - After importing global classes via importlib
+        - After reconstructing local classes from serialized code
+        - Before the class is stored or used in further deserialization
+
+        Security Use Cases
+        ------------------
+        - Block dangerous classes (subprocess.Popen, os.system, etc.)
+        - Replace untrusted classes with safe alternatives
+        - Validate that local classes match expected signatures
+        - Implement class versioning/migration logic
+
+        Args:
+            cls (type): The deserialized class object.
+            is_local (bool): True if the class is a local class (defined in 
__main__
+                           or within a function/method scope), False if it's a 
global
+                           class from an importable module.
+            **kwargs: Reserved for future extensions.
+
+        Returns:
+            None: Return None to accept the class as-is.
+            type: Return a different class to replace the original. The 
replacement
+                 class will be used instead for deserialization.
+
+        Raises:
+            Exception: Raise any exception to reject the class and abort 
deserialization.
+
+        Example:
+            >>> class MigrationChecker(DeserializationPolicy):
+            ...     def validate_class(self, cls, is_local, **kwargs):
+            ...         # Migrate old class to new class
+            ...         if cls.__name__ == 'OldUserClass':
+            ...             return NewUserClass
+            ...         # Block dangerous classes
+            ...         if cls.__module__ == 'subprocess':
+            ...             raise ValueError("subprocess classes not allowed")
+            ...         return None  # Accept
+
+        Note:
+            This method was previously named check_class. That name is kept as 
an
+            alias for backward compatibility.
+        """
+        pass
+
+    def validate_function(self, func, is_local: bool, **kwargs):
+        """Validate a deserialized function reference.
+
+        This hook is called after a function has been deserialized (either by 
importing
+        from a module or reconstructing from serialized code), but before it 
is used.
+
+        When Called
+        -----------
+        - After importing global functions via importlib
+        - After reconstructing local functions/lambdas from marshalled code
+        - Before the function is stored or called
+
+        Security Use Cases
+        ------------------
+        - Block dangerous built-in functions (eval, exec, compile, __import__)
+        - Validate that reconstructed functions have expected signatures
+        - Replace untrusted functions with safe stubs
+        - Audit function imports for security logging
+
+        Args:
+            func (function): The deserialized function object.
+            is_local (bool): True if the function is local (defined in 
__main__ or
+                           within a function scope), False if it's a global 
function.
+            **kwargs: Reserved for future extensions.
+
+        Returns:
+            None: Return None to accept the function as-is.
+            function: Return a different function to replace the original.
+
+        Raises:
+            Exception: Raise any exception to reject the function.
+
+        Example:
+            >>> class SafeFunctionChecker(DeserializationPolicy):
+            ...     BLOCKED = {'eval', 'exec', 'compile', '__import__'}
+            ...
+            ...     def validate_function(self, func, is_local, **kwargs):
+            ...         if func.__name__ in self.BLOCKED:
+            ...             raise ValueError(f"Function {func.__name__} is 
forbidden")
+            ...         return None
+
+        Note:
+            This method was previously named check_function. That name is kept 
as an
+            alias for backward compatibility.
+        """
+        pass
+
+    def validate_method(self, method, is_local: bool, **kwargs):
+        """Validate a deserialized method reference.
+
+        This hook is called after a method has been deserialized (either by 
importing
+        or reconstructing), but before it is used.
+
+        When Called
+        -----------
+        - After deserializing bound methods
+        - After reconstructing local methods from serialized code
+        - Before the method is stored or called
+
+        Security Use Cases
+        ------------------
+        - Validate that methods belong to expected classes
+        - Block methods that could perform dangerous operations
+        - Replace methods with safer alternatives
+
+        Args:
+            method (method): The deserialized bound method object.
+            is_local (bool): True if the method's class is local, False if 
global.
+            **kwargs: Reserved for future extensions.
+
+        Returns:
+            None: Return None to accept the method as-is.
+            method: Return a different method to replace the original.
+
+        Raises:
+            Exception: Raise any exception to reject the method.
+
+        Example:
+            >>> class MethodChecker(DeserializationPolicy):
+            ...     def validate_method(self, method, is_local, **kwargs):
+            ...         # Block methods from dangerous classes
+            ...         if method.__self__.__class__.__name__ == 'FileRemover':
+            ...             raise ValueError("FileRemover methods not allowed")
+            ...         return None
+
+        Note:
+            This method was previously named check_method. That name is kept 
as an
+            alias for backward compatibility.
+        """
+        pass
+
+    def validate_module(self, module_name: str, **kwargs):
+        """Validate a deserialized module reference.
+
+        This hook is called after a module has been imported during 
deserialization,
+        but before it is used.
+
+        When Called
+        -----------
+        - After importing modules via importlib.import_module()
+        - Before the module is stored or its contents accessed
+
+        Security Use Cases
+        ------------------
+        - Whitelist/blacklist modules by name or prefix
+        - Prevent imports of system modules (os, subprocess, sys, etc.)
+        - Replace modules with safe alternatives or mocks
+        - Audit module imports for security logging
+
+        Args:
+            module_name (str): The name of the imported module (e.g., 
'os.path').
+            **kwargs: Reserved for future extensions.
+
+        Returns:
+            None: Return None to accept the module as-is.
+            module: Return a different module object to replace the original.
+
+        Raises:
+            Exception: Raise any exception to reject the module import.
+
+        Example:
+            >>> class ModuleWhitelistChecker(DeserializationPolicy):
+            ...     ALLOWED = {'builtins', 'datetime', 'decimal', 
'collections'}
+            ...
+            ...     def validate_module(self, module_name, **kwargs):
+            ...         root = module_name.split('.')[0]
+            ...         if root not in self.ALLOWED:
+            ...             raise ValueError(f"Module {module_name} not 
whitelisted")
+            ...         return None
+
+        Note:
+            This method was previously named check_module. That name is kept 
as an
+            alias for backward compatibility.
+        """
+        pass
+
+    # 
============================================================================
+    # Protocol Interception Hooks (Interceptors)
+    # 
============================================================================
+
+    def intercept_reduce_call(self, callable_obj, args, **kwargs):
+        """Intercept and validate __reduce__ protocol callable invocation.
+
+        This hook is called when deserializing an object that was serialized 
using the
+        __reduce__ or __reduce_ex__ protocol, right before the callable is 
invoked
+        to reconstruct the object.
+
+        When Called
+        -----------
+        - During deserialization of objects using __reduce__/__reduce_ex__
+        - Before callable_obj(*args) is executed
+        - After the callable and args have been deserialized
+
+        Security Use Cases
+        ------------------
+        - Block dangerous callables (eval, exec, os.system, subprocess.Popen)
+        - Validate that callables match expected signatures
+        - Inspect arguments for malicious payloads
+        - Return pre-constructed safe objects to skip callable invocation
+        - Log reduce operations for auditing
+
+        Args:
+            callable_obj (callable): The callable that will be invoked to 
reconstruct
+                                    the object (typically a class or factory 
function).
+            args (tuple): The arguments that will be passed to the callable.
+            **kwargs: Reserved for future extensions.
+
+        Returns:
+            None: Return None to proceed with normal callable invocation 
(callable_obj(*args)).
+            object: Return an object to use directly, skipping the callable 
invocation.
+                   This allows you to construct safe replacement objects.
+
+        Raises:
+            Exception: Raise any exception to reject the callable and abort 
deserialization.
+
+        Example:
+            >>> class ReduceChecker(DeserializationPolicy):
+            ...     def intercept_reduce_call(self, callable_obj, args, 
**kwargs):
+            ...         # Block subprocess.Popen
+            ...         if callable_obj.__name__ == 'Popen':
+            ...             raise ValueError("Popen not allowed")
+            ...
+            ...         # Audit all reduce operations
+            ...         import logging
+            ...         logging.info(f"Reducing with {callable_obj}({args})")
+            ...
+            ...         return None  # Proceed normally
+
+        Note:
+            This is one of the most critical security hooks, as __reduce__ is 
the primary
+            vector for arbitrary code execution in pickle-based attacks.
+
+            This method was previously named check_reduce_callable. That name 
is kept
+            as an alias for backward compatibility.
+        """
+        pass
+
+    # Backward compatibility aliases
+    def check_reduce_callable(self, callable_obj, args, **kwargs):
+        """Deprecated: Use intercept_reduce_call instead.
+
+        This method is kept for backward compatibility. New code should use
+        intercept_reduce_call for clarity.
+        """
+        return self.intercept_reduce_call(callable_obj, args, **kwargs)
+
+    def inspect_reduced_object(self, obj, **kwargs):
+        """Inspect and validate an object after __reduce__ protocol 
reconstruction.
+
+        This hook is called after an object has been reconstructed using the 
__reduce__
+        protocol, allowing final inspection, modification, or replacement.
+
+        When Called
+        -----------
+        - After callable_obj(*args) has been executed
+        - After state has been restored (if applicable)
+        - After list/dict items have been added (if applicable)
+        - Before the object is returned to the deserializer
+
+        Security Use Cases
+        ------------------
+        - Validate reconstructed object's state
+        - Replace objects that pass callable checks but are still unsafe
+        - Sanitize object attributes
+        - Audit reconstructed objects for security logging
+
+        Args:
+            obj (object): The reconstructed object.
+            **kwargs: Reserved for future extensions.
+
+        Returns:
+            None: Return None to accept the object as-is.
+            object: Return a different object to replace the original.
+
+        Raises:
+            Exception: Raise any exception to reject the object.
+
+        Example:
+            >>> class PostReduceChecker(DeserializationPolicy):
+            ...     def inspect_reduced_object(self, obj, **kwargs):
+            ...         # Validate that file handles are read-only
+            ...         if isinstance(obj, io.IOBase) and obj.writable():
+            ...             raise ValueError("Writable file handles not 
allowed")
+            ...         return None
+
+        Note:
+            This hook provides a last line of defense after reduce 
reconstruction.
+
+            This method was previously named check_restored_reduced_object. 
That name
+            is kept as an alias for backward compatibility.
+        """
+        pass
+
+    # Backward compatibility aliases
+    def check_restored_reduced_object(self, obj, **kwargs):
+        """Deprecated: Use inspect_reduced_object instead.
+
+        This method is kept for backward compatibility. New code should use
+        inspect_reduced_object for clarity.
+        """
+        return self.inspect_reduced_object(obj, **kwargs)
+
+    def intercept_setstate(self, obj, state, **kwargs):
+        """Intercept and validate __setstate__ protocol before state 
restoration.
+
+        This hook is called when deserializing an object that implements 
__setstate__,
+        right before the state is restored to the object. It allows inspection 
and
+        modification of the state dictionary.
+
+        When Called
+        -----------
+        - Before obj.__setstate__(state) is called
+        - After the object has been instantiated (via __new__)
+        - After the state dict has been deserialized
+
+        Security Use Cases
+        ------------------
+        - Inspect state for malicious values
+        - Sanitize or filter dangerous state attributes
+        - Validate state against expected schema
+        - Modify state to enforce security policies
+        - Audit state restoration for logging
+
+        Args:
+            obj (object): The object whose state is about to be restored.
+            state (dict or other): The state to be restored (typically a dict, 
but can
+                                  be any object depending on __setstate__ 
implementation).
+            **kwargs: Reserved for future extensions.
+
+        Returns:
+            None: Always return None. Modify the state dict in-place if needed.
+
+        Raises:
+            Exception: Raise any exception to reject the state and abort 
deserialization.
+
+        Example:
+            >>> class SetStateChecker(DeserializationPolicy):
+            ...     def intercept_setstate(self, obj, state, **kwargs):
+            ...         # Block if state contains dangerous attributes
+            ...         if isinstance(state, dict):
+            ...             dangerous_attrs = {'__code__', '__globals__', 
'_eval'}
+            ...             if any(attr in state for attr in dangerous_attrs):
+            ...                 raise ValueError("State contains dangerous 
attributes")
+            ...
+            ...             # Sanitize: remove private attributes
+            ...             state.clear()
+            ...             state.update({k: v for k, v in state.items()
+            ...                          if not k.startswith('_')})
+
+        Note:
+            This hook can modify the state dict in-place. Changes will be 
reflected
+            when __setstate__ is called.
+
+            This method was previously named check_setstate. That name is kept 
as an
+            alias for backward compatibility.
+        """
+        pass
+
+    # Backward compatibility alias
+    def check_setstate(self, obj, state, **kwargs):
+        """Deprecated: Use intercept_setstate instead.
+
+        This method is kept for backward compatibility. New code should use
+        intercept_setstate for clarity.
+        """
+        return self.intercept_setstate(obj, state, **kwargs)
+
+
+DEFAULT_POLICY = DeserializationPolicy()
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index 893aa12c1..14f4175bc 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -28,7 +28,6 @@ import pickle
 import types
 import typing
 from typing import List, Dict
-import warnings
 
 from pyfory.buffer import Buffer
 from pyfory.codegen import (
@@ -76,7 +75,6 @@ if ENABLE_FORY_CYTHON_SERIALIZATION:
         StringArraySerializer,
         SetSerializer,
         MapSerializer,
-        SubMapSerializer,
         EnumSerializer,
         SliceSerializer,
     )
@@ -100,7 +98,6 @@ else:
         StringArraySerializer,
         SetSerializer,
         MapSerializer,
-        SubMapSerializer,
         EnumSerializer,
         SliceSerializer,
     )
@@ -176,6 +173,9 @@ class TypeSerializer(Serializer):
             cls = importlib.import_module(module_name)
             for name in qualname.split("."):
                 cls = getattr(cls, name)
+            result = self.fory.policy.validate_class(cls, is_local=False)
+            if result is not None:
+                cls = result
             return cls
 
     def _serialize_local_class(self, buffer, cls):
@@ -252,6 +252,9 @@ class TypeSerializer(Serializer):
         # Set module and qualname
         cls.__module__ = module
         cls.__qualname__ = qualname
+        result = fory.policy.validate_class(cls, is_local=True)
+        if result is not None:
+            cls = result
         return cls
 
 
@@ -266,7 +269,11 @@ class ModuleSerializer(Serializer):
 
     def read(self, buffer):
         mod = buffer.read_string()
-        return importlib.import_module(mod)
+        mod = importlib.import_module(mod)
+        result = self.fory.policy.validate_module(mod.__name__)
+        if result is not None:
+            mod = result
+        return mod
 
 
 class MappingProxySerializer(Serializer):
@@ -665,6 +672,9 @@ class DataClassSerializer(Serializer):
         stmts = [
             f'"""read method for {self.type_}"""',
         ]
+        if not self.fory.strict:
+            context["checker"] = self.fory.policy
+            stmts.append(f"checker.authorize_instantiation({obj_class})")
 
         # Read hash only in non-compatible mode; in compatible mode, read 
field count
         if not self.fory.compatible:
@@ -815,6 +825,9 @@ class DataClassSerializer(Serializer):
         stmts = [
             f'"""xread method for {self.type_}"""',
         ]
+        if not self.fory.strict:
+            context["checker"] = self.fory.policy
+            stmts.append(f"checker.authorize_instantiation({obj_class})")
         if not self.fory.compatible:
             stmts.extend(
                 [
@@ -1333,9 +1346,10 @@ class 
StatefulSerializer(CrossLanguageCompatibleSerializer):
         self.fory.serialize_ref(buffer, state)
 
     def read(self, buffer):
-        args = self.fory.read_ref(buffer)
-        kwargs = self.fory.read_ref(buffer)
-        state = self.fory.read_ref(buffer)
+        fory = self.fory
+        args = fory.read_ref(buffer)
+        kwargs = fory.read_ref(buffer)
+        state = fory.read_ref(buffer)
 
         if args or kwargs:
             # Case 1: __getnewargs__ was used. Re-create by calling __init__.
@@ -1345,6 +1359,7 @@ class 
StatefulSerializer(CrossLanguageCompatibleSerializer):
             obj = self.cls.__new__(self.cls)
 
         if state:
+            fory.policy.intercept_setstate(obj, state)
             obj.__setstate__(state)
         return obj
 
@@ -1449,8 +1464,10 @@ class 
ReduceSerializer(CrossLanguageCompatibleSerializer):
             listitems = reduce_data[4]
             dictitems = reduce_data[5] if len(reduce_data) > 5 else None
 
-            # Create the object using the callable and args
-            obj = callable_obj(*args)
+            obj = fory.policy.intercept_reduce_call(callable_obj, args)
+            if obj is None:
+                # Create the object using the callable and args
+                obj = callable_obj(*args)
 
             # Restore state if present
             if state is not None:
@@ -1470,6 +1487,9 @@ class ReduceSerializer(CrossLanguageCompatibleSerializer):
                 for key, value in dictitems:
                     obj[key] = value
 
+            result = fory.policy.inspect_reduced_object(obj)
+            if result is not None:
+                obj = result
             return obj
         else:
             raise ValueError(f"Invalid reduce data format: {reduce_data[0]}")
@@ -1615,7 +1635,11 @@ class 
FunctionSerializer(CrossLanguageCompatibleSerializer):
             # Handle bound methods
             self_obj = self.fory.read_ref(buffer)
             method_name = buffer.read_string()
-            return getattr(self_obj, method_name)
+            func = getattr(self_obj, method_name)
+            result = self.fory.policy.validate_function(func, is_local=False)
+            if result is not None:
+                func = result
+            return func
 
         if func_type_id == 1:
             module = buffer.read_string()
@@ -1623,6 +1647,9 @@ class 
FunctionSerializer(CrossLanguageCompatibleSerializer):
             mod = importlib.import_module(module)
             for name in qualname.split("."):
                 mod = getattr(mod, name)
+            result = self.fory.policy.validate_function(mod, is_local=False)
+            if result is not None:
+                mod = result
             return mod
 
         # Regular function or lambda
@@ -1697,6 +1724,9 @@ class 
FunctionSerializer(CrossLanguageCompatibleSerializer):
         for attr_name, attr_value in attrs.items():
             setattr(func, attr_name, attr_value)
 
+        result = self.fory.policy.validate_function(func, is_local=True)
+        if result is not None:
+            func = result
         return func
 
     def xwrite(self, buffer, value):
@@ -1732,10 +1762,14 @@ class NativeFuncMethodSerializer(Serializer):
         if buffer.read_bool():
             module = buffer.read_string()
             mod = importlib.import_module(module)
-            return getattr(mod, name)
+            func = getattr(mod, name)
         else:
             obj = self.fory.read_ref(buffer)
-            return getattr(obj, name)
+            func = getattr(obj, name)
+        result = self.fory.policy.validate_function(func, is_local=False)
+        if result is not None:
+            func = result
+        return func
 
 
 class MethodSerializer(Serializer):
@@ -1757,7 +1791,13 @@ class MethodSerializer(Serializer):
         instance = self.fory.read_ref(buffer)
         method_name = buffer.read_string()
 
-        return getattr(instance, method_name)
+        method = getattr(instance, method_name)
+        cls = method.__self__.__class__
+        is_local = cls.__module__ == "__main__" or "<locals>" in 
cls.__qualname__
+        result = self.fory.policy.validate_method(method, is_local=is_local)
+        if result is not None:
+            method = result
+        return method
 
     def xwrite(self, buffer, value):
         return self.write(buffer, value)
@@ -1796,12 +1836,14 @@ class ObjectSerializer(Serializer):
             self.fory.serialize_ref(buffer, field_value)
 
     def read(self, buffer):
+        fory = self.fory
+        fory.policy.authorize_instantiation(self.type_)
         obj = self.type_.__new__(self.type_)
-        self.fory.ref_resolver.reference(obj)
+        fory.ref_resolver.reference(obj)
         num_fields = buffer.read_varuint32()
         for _ in range(num_fields):
             field_name = buffer.read_string()
-            field_value = self.fory.read_ref(buffer)
+            field_value = fory.read_ref(buffer)
             setattr(obj, field_name, field_value)
         return obj
 
@@ -1814,17 +1856,6 @@ class ObjectSerializer(Serializer):
         return self.read(buffer)
 
 
-class ComplexObjectSerializer(DataClassSerializer):
-    def __new__(cls, fory, clz):
-        warnings.warn(
-            "`ComplexObjectSerializer` is deprecated and will be removed in a 
future version. "
-            "Use `DataClassSerializer(fory, clz, xlang=True)` instead.",
-            DeprecationWarning,
-            stacklevel=2,
-        )
-        return DataClassSerializer(fory, clz, xlang=True)
-
-
 @dataclasses.dataclass
 class NonExistEnum:
     value: int = -1
diff --git a/python/pyfory/tests/test_policy.py 
b/python/pyfory/tests/test_policy.py
new file mode 100644
index 000000000..3a727534e
--- /dev/null
+++ b/python/pyfory/tests/test_policy.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from pyfory import Fory, DeserializationPolicy
+
+
+class BlockClassPolicy(DeserializationPolicy):
+    """Policy that blocks specific class names from deserialization."""
+
+    def __init__(self, blocked_class_names):
+        self.blocked_class_names = blocked_class_names
+
+    def validate_class(self, cls, is_local, **kwargs):
+        if cls.__name__ in self.blocked_class_names:
+            raise ValueError(f"Class {cls.__name__} is blocked")
+        return None
+
+
+class ReplaceObjectPolicy(DeserializationPolicy):
+    """Policy that replaces deserialized objects from reduce."""
+
+    def __init__(self, replacement_value):
+        self.replacement_value = replacement_value
+
+    def inspect_reduced_object(self, obj, **kwargs):
+        if hasattr(obj, "value"):
+            return self.replacement_value
+        return None
+
+
+class BlockReduceCallPolicy(DeserializationPolicy):
+    """Policy that blocks specific callable invocations during reduce."""
+
+    def __init__(self, blocked_names):
+        self.blocked_names = blocked_names
+
+    def intercept_reduce_call(self, callable_obj, args, **kwargs):
+        if hasattr(callable_obj, "__name__") and callable_obj.__name__ in 
self.blocked_names:
+            raise ValueError(f"Callable {callable_obj.__name__} is blocked")
+        return None
+
+
+class SanitizeStatePolicy(DeserializationPolicy):
+    """Policy that sanitizes object state during setstate."""
+
+    def intercept_setstate(self, obj, state, **kwargs):
+        if isinstance(state, dict) and "password" in state:
+            state["password"] = "***REDACTED***"
+        return None
+
+
+def test_block_class_type_deserialization():
+    """Test blocking class type (not instance) deserialization."""
+
+    class SafeClass:
+        pass
+
+    class UnsafeClass:
+        pass
+
+    policy = BlockClassPolicy(blocked_class_names=["UnsafeClass"])
+    fory = Fory(ref=True, strict=False, policy=policy)
+
+    # Serialize and deserialize the class type itself (not an instance)
+    safe_data = fory.serialize(SafeClass)
+    result = fory.deserialize(safe_data)
+    assert result.__name__ == "SafeClass"
+
+    # Now test blocking
+    unsafe_data = fory.serialize(UnsafeClass)
+    with pytest.raises(ValueError, match="UnsafeClass is blocked"):
+        fory.deserialize(unsafe_data)
+
+
+def test_block_reduce_call():
+    """Test blocking callable invocations during reduce."""
+
+    class ReducibleClass:
+        def __init__(self, value):
+            self.value = value
+
+        def __reduce__(self):
+            return (ReducibleClass, (self.value,))
+
+    policy = BlockReduceCallPolicy(blocked_names=["ReducibleClass"])
+    fory = Fory(ref=True, strict=False, policy=policy)
+    data = fory.serialize(ReducibleClass(42))
+
+    with pytest.raises(ValueError, match="ReducibleClass is blocked"):
+        fory.deserialize(data)
+
+
+def test_replace_reduced_object():
+    """Test replacing objects created via __reduce__."""
+
+    class ReducibleClass:
+        def __init__(self, value):
+            self.value = value
+
+        def __reduce__(self):
+            return (ReducibleClass, (self.value,))
+
+    policy = ReplaceObjectPolicy(replacement_value="REPLACED")
+    fory = Fory(ref=True, strict=False, policy=policy)
+    data = fory.serialize(ReducibleClass(42))
+
+    result = fory.deserialize(data)
+    assert result == "REPLACED"
+
+
+def test_sanitize_state():
+    """Test sanitizing object state during setstate."""
+
+    class SecretHolder:
+        def __init__(self, username, password):
+            self.username = username
+            self.password = password
+
+        def __getstate__(self):
+            return {"username": self.username, "password": self.password}
+
+        def __setstate__(self, state):
+            self.__dict__.update(state)
+
+    policy = SanitizeStatePolicy()
+    fory = Fory(ref=False, strict=False, policy=policy)
+    data = fory.serialize(SecretHolder("admin", "secret123"))
+
+    result = fory.deserialize(data)
+    assert result.username == "admin"
+    assert result.password == "***REDACTED***"
+
+
+def test_policy_with_local_class():
+    """Test policy intercepts local class deserialization."""
+
+    def make_local_class():
+        class LocalClass:
+            pass
+
+        return LocalClass
+
+    LocalCls = make_local_class()
+
+    policy = BlockClassPolicy(blocked_class_names=["LocalClass"])
+    fory = Fory(ref=True, strict=False, policy=policy)
+
+    # Serialize the local class type
+    data = fory.serialize(LocalCls)
+
+    with pytest.raises(ValueError, match="LocalClass is blocked"):
+        fory.deserialize(data)
+
+
+def test_policy_with_ref_tracking():
+    """Test policy works with reference tracking."""
+
+    class ReducibleClass:
+        def __init__(self, value):
+            self.value = value
+
+        def __reduce__(self):
+            return (ReducibleClass, (self.value,))
+
+    policy = BlockReduceCallPolicy(blocked_names=["ReducibleClass"])
+    fory = Fory(ref=True, strict=False, policy=policy)
+
+    data = fory.serialize(ReducibleClass(42))
+
+    with pytest.raises(ValueError, match="ReducibleClass is blocked"):
+        fory.deserialize(data)
+
+
+def test_policy_allows_safe_operations():
+    """Test that policy doesn't interfere with safe built-in types."""
+    policy = BlockClassPolicy(blocked_class_names=[])
+    fory = Fory(ref=False, strict=False, policy=policy)
+
+    assert fory.deserialize(fory.serialize(42)) == 42
+    assert fory.deserialize(fory.serialize("test")) == "test"
+    assert fory.deserialize(fory.serialize([1, 2, 3])) == [1, 2, 3]
+
+
+def test_multiple_policy_hooks():
+    """Test policy with multiple hooks working together."""
+
+    class MultiHookPolicy(DeserializationPolicy):
+        def __init__(self):
+            self.hooks_called = []
+
+        def validate_class(self, cls, is_local, **kwargs):
+            self.hooks_called.append(("validate_class", cls.__name__))
+            return None
+
+        def intercept_reduce_call(self, callable_obj, args, **kwargs):
+            if hasattr(callable_obj, "__name__"):
+                self.hooks_called.append(("intercept_reduce_call", 
callable_obj.__name__))
+            return None
+
+        def inspect_reduced_object(self, obj, **kwargs):
+            self.hooks_called.append(("inspect_reduced_object", 
type(obj).__name__))
+            return None
+
+    class TestClass:
+        def __init__(self, value):
+            self.value = value
+
+        def __reduce__(self):
+            return (TestClass, (self.value,))
+
+    policy = MultiHookPolicy()
+    fory = Fory(ref=True, strict=False, policy=policy)
+
+    data = fory.serialize(TestClass(42))
+    result = fory.deserialize(data)
+
+    # All hooks should have been called
+    assert ("intercept_reduce_call", "TestClass") in policy.hooks_called
+    assert ("inspect_reduced_object", "TestClass") in policy.hooks_called
+    assert result.value == 42
+
+
+def test_policy_with_nested_reduce():
+    """Test policy handles nested objects with __reduce__."""
+
+    class Inner:
+        def __init__(self, value):
+            self.value = value
+
+        def __reduce__(self):
+            return (Inner, (self.value,))
+
+    class Outer:
+        def __init__(self, inner):
+            self.inner = inner
+
+        def __reduce__(self):
+            return (Outer, (self.inner,))
+
+    policy = BlockReduceCallPolicy(blocked_names=["Inner"])
+    fory = Fory(ref=True, strict=False, policy=policy)
+
+    data = fory.serialize(Outer(Inner(42)))
+
+    with pytest.raises(ValueError, match="Inner is blocked"):
+        fory.deserialize(data)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to