This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 1dc733a64 feat(python): add deserialization policy for more
fine-grained control and audit deserialization behaviour (#2811)
1dc733a64 is described below
commit 1dc733a647e119e699004b9bf0727211d9f48e9e
Author: Shawn Yang <[email protected]>
AuthorDate: Wed Oct 22 23:14:55 2025 +0800
feat(python): add deserialization policy for more fine-grained control and
audit deserialization behaviour (#2811)
## Why?
<!-- Describe the purpose of this PR. -->
## What does this PR do?
add deserialization policy for more fine-grained control and audit
deserialization behaviour
## Related issues
<!--
Is there any related issue? If this PR closes them you say say
fix/closes:
- #xxxx0
- #xxxx1
- Fixes #xxxx2
-->
## Does this PR introduce any user-facing change?
<!--
If any user-facing interface changes, please [open an
issue](https://github.com/apache/fory/issues/new/choose) describing the
need to do so and update the document if necessary.
Delete section if not applicable.
-->
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
<!--
When the PR has an impact on performance (if you don't know whether the
PR will have an impact on performance, you can submit the PR first, and
if it will have impact on performance, the code reviewer will explain
it), be sure to attach a benchmark data here.
Delete section if not applicable.
-->
---
python/README.md | 63 +++-
python/pyfory/__init__.py | 1 +
python/pyfory/_fory.py | 8 +
python/pyfory/_serialization.pyx | 99 +------
python/pyfory/policy.py | 580 +++++++++++++++++++++++++++++++++++++
python/pyfory/serializer.py | 83 ++++--
python/pyfory/tests/test_policy.py | 261 +++++++++++++++++
7 files changed, 977 insertions(+), 118 deletions(-)
diff --git a/python/README.md b/python/README.md
index f0bd62a32..3426a4545 100644
--- a/python/README.md
+++ b/python/README.md
@@ -133,7 +133,7 @@ print(result) # Person(name='Bob', age=25, ...)
- **For circular references**: Set `ref=True` to enable reference tracking
- **For functions/classes**: Set `strict=False` to allow deserialization of
dynamic types
-**⚠️ Security Warning**: When `strict=False`, Fory will deserialize arbitrary
types, which can pose security risks if data comes from untrusted sources. Only
use `strict=False` in controlled environments where you trust the data source
completely.
+**⚠️ Security Warning**: When `strict=False`, Fory will deserialize arbitrary
types, which can pose security risks if data comes from untrusted sources. Only
use `strict=False` in controlled environments where you trust the data source
completely. If you do need to use `strict=False`, please configure a
`DeserializationPolicy` when creating fory using `policy=your_policy` to
controlling deserialization behavior.
#### Common Usage
@@ -1141,6 +1141,67 @@ else:
fory.register(model_class, type_id=100 + idx)
```
+### DeserializationPolicy
+
+When `strict=False` is necessary (e.g., deserializing functions/lambdas), use
`DeserializationPolicy` to implement fine-grained security controls during
deserialization. This provides protection similar to
`pickle.Unpickler.find_class()` but with more comprehensive hooks.
+
+**Why use DeserializationPolicy?**
+
+- Block dangerous classes/modules (e.g., `subprocess.Popen`)
+- Intercept and validate `__reduce__` callables before invocation
+- Sanitize sensitive data during `__setstate__`
+- Replace or reject deserialized objects based on custom rules
+
+**Example: Blocking Dangerous Classes**
+
+```python
+import pyfory
+from pyfory import DeserializationPolicy
+
+dangerous_modules = {'subprocess', 'os', '__builtin__'}
+
+class SafeDeserializationPolicy(DeserializationPolicy):
+ """Block potentially dangerous classes during deserialization."""
+
+ def validate_class(self, cls, is_local, **kwargs):
+ # Block dangerous modules
+ if cls.__module__ in dangerous_modules:
+ raise ValueError(f"Blocked dangerous class:
{cls.__module__}.{cls.__name__}")
+ return None
+
+ def intercept_reduce_call(self, callable_obj, args, **kwargs):
+ # Block specific callable invocations during __reduce__
+ if getattr(callable_obj, '__name__', "") == 'Popen':
+ raise ValueError("Blocked attempt to invoke subprocess.Popen")
+ return None
+
+ def intercept_setstate(self, obj, state, **kwargs):
+ # Sanitize sensitive data
+ if isinstance(state, dict) and 'password' in state:
+ state['password'] = '***REDACTED***'
+ return None
+
+# Create Fory with custom security policy
+policy = SafeDeserializationPolicy()
+fory = pyfory.Fory(xlang=False, ref=True, strict=False, policy=policy)
+
+# Now deserialization is protected by your custom policy
+data = fory.serialize(my_object)
+result = fory.deserialize(data) # Policy hooks will be invoked
+```
+
+**Available Policy Hooks:**
+
+- `validate_class(cls, is_local)` - Validate/block class types during
deserialization
+- `validate_module(module, is_local)` - Validate/block module imports
+- `validate_function(func, is_local)` - Validate/block function references
+- `intercept_reduce_call(callable_obj, args)` - Intercept `__reduce__`
invocations
+- `inspect_reduced_object(obj)` - Inspect/replace objects created via
`__reduce__`
+- `intercept_setstate(obj, state)` - Sanitize state before `__setstate__`
+- `authorize_instantiation(cls, args, kwargs)` - Control class instantiation
+
+**See also:** `pyfory/policy.py` contains detailed documentation and examples
for each hook.
+
## 🐛 Troubleshooting
### Common Issues
diff --git a/python/pyfory/__init__.py b/python/pyfory/__init__.py
index bf070bea4..0ff316d60 100644
--- a/python/pyfory/__init__.py
+++ b/python/pyfory/__init__.py
@@ -54,6 +54,7 @@ from pyfory.type import ( # noqa: F401 # pylint:
disable=unused-import
Float64ArrayType,
dataslots,
)
+from pyfory.policy import DeserializationPolicy # noqa: F401 # pylint:
disable=unused-import
from pyfory._util import Buffer # noqa: F401 # pylint: disable=unused-import
import warnings
diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py
index f283bc4ea..9398a3b78 100644
--- a/python/pyfory/_fory.py
+++ b/python/pyfory/_fory.py
@@ -30,6 +30,7 @@ from pyfory.resolver import (
)
from pyfory.util import is_little_endian, set_bit, get_bit, clear_bit
from pyfory.type import TypeId
+from pyfory.policy import DeserializationPolicy, DEFAULT_POLICY
try:
import numpy as np
@@ -122,6 +123,7 @@ class Fory:
"max_depth",
"depth",
"field_nullable",
+ "policy",
)
def __init__(
@@ -131,6 +133,7 @@ class Fory:
strict: bool = True,
compatible: bool = False,
max_depth: int = 50,
+ policy: DeserializationPolicy = None,
field_nullable: bool = False,
**kwargs,
):
@@ -157,6 +160,10 @@ class Fory:
Do not disable strict mode if you can't ensure your environment are
*indeed secure*. We are not responsible for security risks if
you disable this option.
+ :param policy:
+ A custom type policy for deserialization security check.
+ If not None, it will be used to check whether a type can be
deserialized
+ instead of the default type policy.
:param compatible:
Whether to enable compatible mode for cross-language serialization.
When enabled, type forward/backward compatibility for dataclass
fields will be enabled.
@@ -182,6 +189,7 @@ class Fory:
if kwargs.get("require_type_registration") is not None:
strict = kwargs.get("require_type_registration")
self.strict = _ENABLE_TYPE_REGISTRATION_FORCIBLY or strict
+ self.policy = policy or DEFAULT_POLICY
self.compatible = compatible
self.field_nullable = field_nullable if self.is_py else False
from pyfory._serialization import MetaStringResolver,
SerializationContext
diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx
index 188c48320..9f246ff08 100644
--- a/python/pyfory/_serialization.pyx
+++ b/python/pyfory/_serialization.pyx
@@ -34,6 +34,7 @@ from pyfory._fory import _ENABLE_TYPE_REGISTRATION_FORCIBLY
from pyfory.lib import mmh3
from pyfory.meta.metastring import Encoding
from pyfory.type import is_primitive_type
+from pyfory.policy import DeserializationPolicy, DEFAULT_POLICY
from pyfory.util import is_little_endian
from pyfory.includes.libserialization cimport \
(TypeId, IsNamespacedType, IsTypeShareMeta,
Fory_PyBooleanSequenceWriteToBuffer, Fory_PyFloatSequenceWriteToBuffer)
@@ -805,6 +806,7 @@ cdef class Fory:
cdef readonly c_bool is_py
cdef readonly c_bool compatible
cdef readonly c_bool field_nullable
+ cdef readonly object policy
cdef readonly MapRefResolver ref_resolver
cdef readonly TypeResolver type_resolver
cdef readonly MetaStringResolver metastring_resolver
@@ -823,6 +825,7 @@ cdef class Fory:
xlang: bool = False,
ref: bool = False,
strict: bool = True,
+ policy: DeserializationPolicy = None,
compatible: bool = False,
max_depth: int = 50,
field_nullable: bool = False,
@@ -851,6 +854,10 @@ cdef class Fory:
Do not disable strict mode if you can't ensure your environment are
*indeed secure*. We are not responsible for security risks if
you disable this option.
+ :param policy:
+ A custom type policy for deserialization security check.
+ If not None, it will be used to check whether a type can be
deserialized
+ instead of the default type policy.
:param compatible:
Whether to enable compatible mode for cross-language serialization.
When enabled, type forward/backward compatibility for struct fields
will be enabled.
@@ -873,6 +880,7 @@ cdef class Fory:
self.strict = True
else:
self.strict = False
+ self.policy = policy or DEFAULT_POLICY
self.compatible = compatible
self.ref_tracking = ref
self.ref_resolver = MapRefResolver(ref)
@@ -2368,97 +2376,6 @@ cdef class MapSerializer(Serializer):
return self.read(buffer)
[email protected]
-cdef class SubMapSerializer(Serializer):
- cdef TypeResolver type_resolver
- cdef MapRefResolver ref_resolver
- cdef Serializer key_serializer
- cdef Serializer value_serializer
-
- def __init__(self, fory, type_, key_serializer=None,
value_serializer=None):
- super().__init__(fory, type_)
- self.type_resolver = fory.type_resolver
- self.ref_resolver = fory.ref_resolver
- self.key_serializer = key_serializer
- self.value_serializer = value_serializer
-
- cpdef inline write(self, Buffer buffer, value):
- buffer.write_varuint32(len(value))
- cdef TypeInfo key_typeinfo
- cdef TypeInfo value_typeinfo
- for k, v in value.items():
- key_cls = type(k)
- if key_cls is str:
- buffer.write_int16(NOT_NULL_STRING_FLAG)
- buffer.write_string(k)
- else:
- if not self.ref_resolver.write_ref_or_null(buffer, k):
- key_typeinfo = self.type_resolver.get_typeinfo(key_cls)
- self.type_resolver.write_typeinfo(buffer, key_typeinfo)
- key_typeinfo.serializer.write(buffer, k)
- value_cls = type(v)
- if value_cls is str:
- buffer.write_int16(NOT_NULL_STRING_FLAG)
- buffer.write_string(v)
- elif value_cls is int:
- buffer.write_int16(NOT_NULL_INT64_FLAG)
- buffer.write_varint64(v)
- elif value_cls is bool:
- buffer.write_int16(NOT_NULL_BOOL_FLAG)
- buffer.write_bool(v)
- elif value_cls is float:
- buffer.write_int16(NOT_NULL_FLOAT64_FLAG)
- buffer.write_double(v)
- else:
- if not self.ref_resolver.write_ref_or_null(buffer, v):
- value_typeinfo = self.type_resolver. \
- get_typeinfo(value_cls)
- self.type_resolver.write_typeinfo(buffer, value_typeinfo)
- value_typeinfo.serializer.write(buffer, v)
-
- cpdef inline read(self, Buffer buffer):
- cdef MapRefResolver ref_resolver = self.fory.ref_resolver
- cdef TypeResolver type_resolver = self.fory.type_resolver
- map_ = self.type_()
- ref_resolver.reference(map_)
- cdef int32_t len_ = buffer.read_varuint32()
- cdef int32_t ref_id
- cdef TypeInfo key_typeinfo
- cdef TypeInfo value_typeinfo
- self.fory.inc_depth()
- for i in range(len_):
- ref_id = ref_resolver.try_preserve_ref_id(buffer)
- if ref_id < NOT_NULL_VALUE_FLAG:
- key = ref_resolver.get_read_object()
- else:
- key_typeinfo = type_resolver.read_typeinfo(buffer)
- if key_typeinfo.cls is str:
- key = buffer.read_string()
- else:
- key = key_typeinfo.serializer.read(buffer)
- ref_resolver.set_read_object(ref_id, key)
- ref_id = ref_resolver.try_preserve_ref_id(buffer)
- if ref_id < NOT_NULL_VALUE_FLAG:
- value = ref_resolver.get_read_object()
- else:
- value_typeinfo = type_resolver.read_typeinfo(buffer)
- cls = value_typeinfo.cls
- if cls is str:
- value = buffer.read_string()
- elif cls is int:
- value = buffer.read_varint64()
- elif cls is bool:
- value = buffer.read_bool()
- elif cls is float:
- value = buffer.read_double()
- else:
- value = value_typeinfo.serializer.read(buffer)
- ref_resolver.set_read_object(ref_id, value)
- map_[key] = value
- self.fory.dec_depth()
- return map_
-
-
@cython.final
cdef class EnumSerializer(Serializer):
@classmethod
diff --git a/python/pyfory/policy.py b/python/pyfory/policy.py
new file mode 100644
index 000000000..a47f5d8ae
--- /dev/null
+++ b/python/pyfory/policy.py
@@ -0,0 +1,580 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+class DeserializationPolicy:
+ """Deserialization Security Policy for PyFory.
+
+ DeserializationPolicy provides a comprehensive security layer for
controlling deserialization
+ behavior, similar to how pickle.Unpickler can be customized but with
finer-grained
+ control over the deserialization process.
+
+ Comparison with pickle.Unpickler
+ --------------------------------
+ Python's pickle.Unpickler provides basic security through the find_class()
method,
+ which can be overridden to control class imports:
+
+ >>> class SafeUnpickler(pickle.Unpickler):
+ ... def find_class(self, module, name):
+ ... # Only allow safe modules
+ ... if module in ('builtins', 'datetime'):
+ ... return super().find_class(module, name)
+ ... raise ValueError(f"Unsafe module: {module}")
+
+ Fory's DeserializationPolicy provides MORE granular control:
+
+
+---------------------------+----------------------+----------------------------+
+ | Security Feature | pickle.Unpickler | Fory
DeserializationPolicy |
+
+---------------------------+----------------------+----------------------------+
+ | Class import control | ✓ find_class() | ✓ validate_class()
|
+ | Function import control | ✗ (via find_class) | ✓ validate_function()
|
+ | Method validation | ✗ | ✓ validate_method()
|
+ | Module import control | ✗ | ✓ validate_module()
|
+ | Instantiation control | ✗ | ✓
authorize_instantiation()|
+ | __reduce__ interception | ✗ | ✓
intercept_reduce_call() |
+ | Post-reduce inspection | ✗ | ✓
inspect_reduced_object() |
+ | __setstate__ interception | ✗ | ✓
intercept_setstate() |
+ | Object replacement | ✗ | ✓ (return from
validators) |
+ | State sanitization | ✗ | ✓ (modify in-place)
|
+ | Local class/function | ✗ | ✓ (is_local flag)
|
+
+---------------------------+----------------------+----------------------------+
+
+ Example: Blocking subprocess.Popen with pickle vs Fory:
+
+ # pickle.Unpickler - only catches class imports
+ class SafeUnpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if module == 'subprocess' and name == 'Popen':
+ raise ValueError("Blocked")
+ return super().find_class(module, name)
+
+ # Problem: Can't catch Popen invoked via __reduce__!
+ # A malicious payload can serialize: (subprocess.Popen, (["rm", "-rf",
"/"],))
+
+ # Fory DeserializationPolicy - catches both imports AND reduce
invocations
+ class SafeChecker(DeserializationPolicy):
+ def validate_class(self, cls, is_local, **kwargs):
+ if cls.__module__ == 'subprocess' and cls.__name__ == 'Popen':
+ raise ValueError("Blocked")
+ return None
+
+ def intercept_reduce_call(self, callable_obj, args, **kwargs):
+ if callable_obj.__name__ == 'Popen':
+ raise ValueError("Blocked at invocation!")
+ return None
+
+ Security Context
+ ----------------
+ Deserialization of untrusted data is inherently dangerous. Malicious
payloads can:
+ - Import and instantiate arbitrary classes (e.g., subprocess.Popen)
+ - Execute arbitrary code through __reduce__ or __setstate__
+ - Access sensitive modules or perform unauthorized operations
+ - Cause denial of service through resource exhaustion
+
+ This DeserializationPolicy interface allows users to implement custom
security policies by
+ subclassing and overriding specific hook methods. Each hook is called at a
critical
+ point during deserialization, allowing inspection, replacement, or
rejection of
+ dangerous constructs.
+
+ Hook Categories
+ ---------------
+ 1. **Instantiation Authorization Hooks** (Guards)
+ - Control which classes can be instantiated
+ - Raise exception to block, return None to allow
+
+ 2. **Reference Validation Hooks** (Validators)
+ - Validate deserialized type/function/module references
+ - Return None to accept original, return object to replace, raise
exception to block,
+
+ 3. **Protocol Interception Hooks** (Interceptors)
+ - Intercept pickle protocol operations (__reduce__, __setstate__)
+ - Return None to continue, return object to replace, or modify
in-place, raise exception to block,
+
+ Usage Example
+ -------------
+ >>> class SafeDeserializationPolicy(DeserializationPolicy):
+ ... ALLOWED_MODULES = {'builtins', 'datetime', 'decimal'}
+ ...
+ ... def validate_module(self, module_name, **kwargs):
+ ... # Reject imports from disallowed modules
+ ... if module_name.split('.')[0] not in self.ALLOWED_MODULES:
+ ... raise ValueError(f"Module {module_name} is not allowed")
+ ... return None # Accept
+ ...
+ ... def validate_class(self, cls, is_local, **kwargs):
+ ... # Reject dangerous built-in classes
+ ... if cls.__name__ in ('eval', 'exec', 'compile'):
+ ... raise ValueError(f"Class {cls} is forbidden")
+ ... return None # Accept
+ ...
+ ... def intercept_reduce_call(self, callable_obj, args, **kwargs):
+ ... # Log all __reduce__ callables for audit
+ ... print(f"Reducing with {callable_obj.__name__}({args})")
+ ... return None # Proceed normally
+ ...
+ >>> fory = Fory(checker=SafeDeserializationPolicy())
+
+ Thread Safety
+ -------------
+ DeserializationPolicy instances should be thread-safe if shared across
multiple Fory instances.
+ The default implementation is stateless and thread-safe.
+
+ Performance Considerations
+ --------------------------
+ - Hooks are called frequently during deserialization
+ - Keep validation logic fast to avoid performance degradation
+ - Cache validation results when possible (e.g., maintain allowed/blocked
sets)
+ - Avoid I/O operations in hooks unless necessary
+
+ See Also
+ --------
+ - Python's pickle module security warnings:
https://docs.python.org/3/library/pickle.html
+ - Fory documentation on secure deserialization: docs/guide/security.md
+ """
+
+ #
============================================================================
+ # Instantiation Authorization Hooks (Guards)
+ #
============================================================================
+
+ def authorize_instantiation(self, cls, **kwargs):
+ """Authorize instantiation of a class during deserialization.
+
+ This hook is called before creating an instance of any class during
deserialization.
+ It acts as a security gate to prevent instantiation of dangerous
classes.
+
+ When Called
+ -----------
+ - Before creating instances via cls.__new__(cls) in deserializers
+ - For both dataclass and regular object deserialization
+
+ Security Use Cases
+ ------------------
+ - Whitelist/blacklist specific classes by name or module
+ - Reject classes that could execute code in __init__ or __new__
+ - Prevent resource-exhausting classes (e.g., large buffers, threads)
+ - Log instantiation attempts for security auditing
+
+ Args:
+ cls (type): The class about to be instantiated.
+ **kwargs: Reserved for future extensions.
+
+ Raises:
+ Exception: Raise any exception to block instantiation. The
exception
+ will propagate to the caller of Fory.deserialize().
+
+ Returns:
+ None: Always return None to authorize. This method is a guard, not
a transformer.
+
+ Example:
+ >>> class WhitelistChecker(DeserializationPolicy):
+ ... ALLOWED = {'MyClass', 'SafeDataClass'}
+ ...
+ ... def authorize_instantiation(self, cls, **kwargs):
+ ... if cls.__name__ not in self.ALLOWED:
+ ... raise ValueError(f"Class {cls.__name__} not
whitelisted")
+
+ Note:
+ This method was previously named check_read_allowed and
check_create_object.
+ Those names are kept as aliases for backward compatibility.
+ """
+ pass
+
+ #
============================================================================
+ # Reference Validation Hooks (Validators)
+ #
============================================================================
+
+ def validate_class(self, cls, *, is_local: bool, **kwargs):
+ """Validate a deserialized class reference.
+
+ This hook is called after a class reference has been deserialized
(either by
+ importing from a module or reconstructing a local class), but before
it is used.
+ It allows inspection, replacement, or rejection of class references.
+
+ When Called
+ -----------
+ - After importing global classes via importlib
+ - After reconstructing local classes from serialized code
+ - Before the class is stored or used in further deserialization
+
+ Security Use Cases
+ ------------------
+ - Block dangerous classes (subprocess.Popen, os.system, etc.)
+ - Replace untrusted classes with safe alternatives
+ - Validate that local classes match expected signatures
+ - Implement class versioning/migration logic
+
+ Args:
+ cls (type): The deserialized class object.
+ is_local (bool): True if the class is a local class (defined in
__main__
+ or within a function/method scope), False if it's a
global
+ class from an importable module.
+ **kwargs: Reserved for future extensions.
+
+ Returns:
+ None: Return None to accept the class as-is.
+ type: Return a different class to replace the original. The
replacement
+ class will be used instead for deserialization.
+
+ Raises:
+ Exception: Raise any exception to reject the class and abort
deserialization.
+
+ Example:
+ >>> class MigrationChecker(DeserializationPolicy):
+ ... def validate_class(self, cls, is_local, **kwargs):
+ ... # Migrate old class to new class
+ ... if cls.__name__ == 'OldUserClass':
+ ... return NewUserClass
+ ... # Block dangerous classes
+ ... if cls.__module__ == 'subprocess':
+ ... raise ValueError("subprocess classes not allowed")
+ ... return None # Accept
+
+ Note:
+ This method was previously named check_class. That name is kept as
an
+ alias for backward compatibility.
+ """
+ pass
+
+ def validate_function(self, func, is_local: bool, **kwargs):
+ """Validate a deserialized function reference.
+
+ This hook is called after a function has been deserialized (either by
importing
+ from a module or reconstructing from serialized code), but before it
is used.
+
+ When Called
+ -----------
+ - After importing global functions via importlib
+ - After reconstructing local functions/lambdas from marshalled code
+ - Before the function is stored or called
+
+ Security Use Cases
+ ------------------
+ - Block dangerous built-in functions (eval, exec, compile, __import__)
+ - Validate that reconstructed functions have expected signatures
+ - Replace untrusted functions with safe stubs
+ - Audit function imports for security logging
+
+ Args:
+ func (function): The deserialized function object.
+ is_local (bool): True if the function is local (defined in
__main__ or
+ within a function scope), False if it's a global
function.
+ **kwargs: Reserved for future extensions.
+
+ Returns:
+ None: Return None to accept the function as-is.
+ function: Return a different function to replace the original.
+
+ Raises:
+ Exception: Raise any exception to reject the function.
+
+ Example:
+ >>> class SafeFunctionChecker(DeserializationPolicy):
+ ... BLOCKED = {'eval', 'exec', 'compile', '__import__'}
+ ...
+ ... def validate_function(self, func, is_local, **kwargs):
+ ... if func.__name__ in self.BLOCKED:
+ ... raise ValueError(f"Function {func.__name__} is
forbidden")
+ ... return None
+
+ Note:
+ This method was previously named check_function. That name is kept
as an
+ alias for backward compatibility.
+ """
+ pass
+
+ def validate_method(self, method, is_local: bool, **kwargs):
+ """Validate a deserialized method reference.
+
+ This hook is called after a method has been deserialized (either by
importing
+ or reconstructing), but before it is used.
+
+ When Called
+ -----------
+ - After deserializing bound methods
+ - After reconstructing local methods from serialized code
+ - Before the method is stored or called
+
+ Security Use Cases
+ ------------------
+ - Validate that methods belong to expected classes
+ - Block methods that could perform dangerous operations
+ - Replace methods with safer alternatives
+
+ Args:
+ method (method): The deserialized bound method object.
+ is_local (bool): True if the method's class is local, False if
global.
+ **kwargs: Reserved for future extensions.
+
+ Returns:
+ None: Return None to accept the method as-is.
+ method: Return a different method to replace the original.
+
+ Raises:
+ Exception: Raise any exception to reject the method.
+
+ Example:
+ >>> class MethodChecker(DeserializationPolicy):
+ ... def validate_method(self, method, is_local, **kwargs):
+ ... # Block methods from dangerous classes
+ ... if method.__self__.__class__.__name__ == 'FileRemover':
+ ... raise ValueError("FileRemover methods not allowed")
+ ... return None
+
+ Note:
+ This method was previously named check_method. That name is kept
as an
+ alias for backward compatibility.
+ """
+ pass
+
+ def validate_module(self, module_name: str, **kwargs):
+ """Validate a deserialized module reference.
+
+ This hook is called after a module has been imported during
deserialization,
+ but before it is used.
+
+ When Called
+ -----------
+ - After importing modules via importlib.import_module()
+ - Before the module is stored or its contents accessed
+
+ Security Use Cases
+ ------------------
+ - Whitelist/blacklist modules by name or prefix
+ - Prevent imports of system modules (os, subprocess, sys, etc.)
+ - Replace modules with safe alternatives or mocks
+ - Audit module imports for security logging
+
+ Args:
+ module_name (str): The name of the imported module (e.g.,
'os.path').
+ **kwargs: Reserved for future extensions.
+
+ Returns:
+ None: Return None to accept the module as-is.
+ module: Return a different module object to replace the original.
+
+ Raises:
+ Exception: Raise any exception to reject the module import.
+
+ Example:
+ >>> class ModuleWhitelistChecker(DeserializationPolicy):
+ ... ALLOWED = {'builtins', 'datetime', 'decimal',
'collections'}
+ ...
+ ... def validate_module(self, module_name, **kwargs):
+ ... root = module_name.split('.')[0]
+ ... if root not in self.ALLOWED:
+ ... raise ValueError(f"Module {module_name} not
whitelisted")
+ ... return None
+
+ Note:
+ This method was previously named check_module. That name is kept
as an
+ alias for backward compatibility.
+ """
+ pass
+
+ #
============================================================================
+ # Protocol Interception Hooks (Interceptors)
+ #
============================================================================
+
+ def intercept_reduce_call(self, callable_obj, args, **kwargs):
+ """Intercept and validate __reduce__ protocol callable invocation.
+
+ This hook is called when deserializing an object that was serialized
using the
+ __reduce__ or __reduce_ex__ protocol, right before the callable is
invoked
+ to reconstruct the object.
+
+ When Called
+ -----------
+ - During deserialization of objects using __reduce__/__reduce_ex__
+ - Before callable_obj(*args) is executed
+ - After the callable and args have been deserialized
+
+ Security Use Cases
+ ------------------
+ - Block dangerous callables (eval, exec, os.system, subprocess.Popen)
+ - Validate that callables match expected signatures
+ - Inspect arguments for malicious payloads
+ - Return pre-constructed safe objects to skip callable invocation
+ - Log reduce operations for auditing
+
+ Args:
+ callable_obj (callable): The callable that will be invoked to
reconstruct
+ the object (typically a class or factory
function).
+ args (tuple): The arguments that will be passed to the callable.
+ **kwargs: Reserved for future extensions.
+
+ Returns:
+ None: Return None to proceed with normal callable invocation
(callable_obj(*args)).
+ object: Return an object to use directly, skipping the callable
invocation.
+ This allows you to construct safe replacement objects.
+
+ Raises:
+ Exception: Raise any exception to reject the callable and abort
deserialization.
+
+ Example:
+ >>> class ReduceChecker(DeserializationPolicy):
+ ... def intercept_reduce_call(self, callable_obj, args,
**kwargs):
+ ... # Block subprocess.Popen
+ ... if callable_obj.__name__ == 'Popen':
+ ... raise ValueError("Popen not allowed")
+ ...
+ ... # Audit all reduce operations
+ ... import logging
+ ... logging.info(f"Reducing with {callable_obj}({args})")
+ ...
+ ... return None # Proceed normally
+
+ Note:
+ This is one of the most critical security hooks, as __reduce__ is
the primary
+ vector for arbitrary code execution in pickle-based attacks.
+
+ This method was previously named check_reduce_callable. That name
is kept
+ as an alias for backward compatibility.
+ """
+ pass
+
+ # Backward compatibility aliases
+ def check_reduce_callable(self, callable_obj, args, **kwargs):
+ """Deprecated: Use intercept_reduce_call instead.
+
+ This method is kept for backward compatibility. New code should use
+ intercept_reduce_call for clarity.
+ """
+ return self.intercept_reduce_call(callable_obj, args, **kwargs)
+
+ def inspect_reduced_object(self, obj, **kwargs):
+ """Inspect and validate an object after __reduce__ protocol
reconstruction.
+
+ This hook is called after an object has been reconstructed using the
__reduce__
+ protocol, allowing final inspection, modification, or replacement.
+
+ When Called
+ -----------
+ - After callable_obj(*args) has been executed
+ - After state has been restored (if applicable)
+ - After list/dict items have been added (if applicable)
+ - Before the object is returned to the deserializer
+
+ Security Use Cases
+ ------------------
+ - Validate reconstructed object's state
+ - Replace objects that pass callable checks but are still unsafe
+ - Sanitize object attributes
+ - Audit reconstructed objects for security logging
+
+ Args:
+ obj (object): The reconstructed object.
+ **kwargs: Reserved for future extensions.
+
+ Returns:
+ None: Return None to accept the object as-is.
+ object: Return a different object to replace the original.
+
+ Raises:
+ Exception: Raise any exception to reject the object.
+
+ Example:
+ >>> class PostReduceChecker(DeserializationPolicy):
+ ... def inspect_reduced_object(self, obj, **kwargs):
+ ... # Validate that file handles are read-only
+ ... if isinstance(obj, io.IOBase) and obj.writable():
+ ... raise ValueError("Writable file handles not
allowed")
+ ... return None
+
+ Note:
+ This hook provides a last line of defense after reduce
reconstruction.
+
+ This method was previously named check_restored_reduced_object.
That name
+ is kept as an alias for backward compatibility.
+ """
+ pass
+
+ # Backward compatibility aliases
+ def check_restored_reduced_object(self, obj, **kwargs):
+ """Deprecated: Use inspect_reduced_object instead.
+
+ This method is kept for backward compatibility. New code should use
+ inspect_reduced_object for clarity.
+ """
+ return self.inspect_reduced_object(obj, **kwargs)
+
+ def intercept_setstate(self, obj, state, **kwargs):
+ """Intercept and validate __setstate__ protocol before state
restoration.
+
+ This hook is called when deserializing an object that implements
__setstate__,
+ right before the state is restored to the object. It allows inspection
and
+ modification of the state dictionary.
+
+ When Called
+ -----------
+ - Before obj.__setstate__(state) is called
+ - After the object has been instantiated (via __new__)
+ - After the state dict has been deserialized
+
+ Security Use Cases
+ ------------------
+ - Inspect state for malicious values
+ - Sanitize or filter dangerous state attributes
+ - Validate state against expected schema
+ - Modify state to enforce security policies
+ - Audit state restoration for logging
+
+ Args:
+ obj (object): The object whose state is about to be restored.
+ state (dict or other): The state to be restored (typically a dict,
but can
+ be any object depending on __setstate__
implementation).
+ **kwargs: Reserved for future extensions.
+
+ Returns:
+ None: Always return None. Modify the state dict in-place if needed.
+
+ Raises:
+ Exception: Raise any exception to reject the state and abort
deserialization.
+
+ Example:
+ >>> class SetStateChecker(DeserializationPolicy):
+ ... def intercept_setstate(self, obj, state, **kwargs):
+ ... # Block if state contains dangerous attributes
+ ... if isinstance(state, dict):
+ ... dangerous_attrs = {'__code__', '__globals__',
'_eval'}
+ ... if any(attr in state for attr in dangerous_attrs):
+ ... raise ValueError("State contains dangerous
attributes")
+ ...
+ ... # Sanitize: remove private attributes
+ ... state.clear()
+ ... state.update({k: v for k, v in state.items()
+ ... if not k.startswith('_')})
+
+ Note:
+ This hook can modify the state dict in-place. Changes will be
reflected
+ when __setstate__ is called.
+
+ This method was previously named check_setstate. That name is kept
as an
+ alias for backward compatibility.
+ """
+ pass
+
+ # Backward compatibility alias
+ def check_setstate(self, obj, state, **kwargs):
+ """Deprecated: Use intercept_setstate instead.
+
+ This method is kept for backward compatibility. New code should use
+ intercept_setstate for clarity.
+ """
+ return self.intercept_setstate(obj, state, **kwargs)
+
+
+DEFAULT_POLICY = DeserializationPolicy()
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index 893aa12c1..14f4175bc 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -28,7 +28,6 @@ import pickle
import types
import typing
from typing import List, Dict
-import warnings
from pyfory.buffer import Buffer
from pyfory.codegen import (
@@ -76,7 +75,6 @@ if ENABLE_FORY_CYTHON_SERIALIZATION:
StringArraySerializer,
SetSerializer,
MapSerializer,
- SubMapSerializer,
EnumSerializer,
SliceSerializer,
)
@@ -100,7 +98,6 @@ else:
StringArraySerializer,
SetSerializer,
MapSerializer,
- SubMapSerializer,
EnumSerializer,
SliceSerializer,
)
@@ -176,6 +173,9 @@ class TypeSerializer(Serializer):
cls = importlib.import_module(module_name)
for name in qualname.split("."):
cls = getattr(cls, name)
+ result = self.fory.policy.validate_class(cls, is_local=False)
+ if result is not None:
+ cls = result
return cls
def _serialize_local_class(self, buffer, cls):
@@ -252,6 +252,9 @@ class TypeSerializer(Serializer):
# Set module and qualname
cls.__module__ = module
cls.__qualname__ = qualname
+ result = fory.policy.validate_class(cls, is_local=True)
+ if result is not None:
+ cls = result
return cls
@@ -266,7 +269,11 @@ class ModuleSerializer(Serializer):
def read(self, buffer):
mod = buffer.read_string()
- return importlib.import_module(mod)
+ mod = importlib.import_module(mod)
+ result = self.fory.policy.validate_module(mod.__name__)
+ if result is not None:
+ mod = result
+ return mod
class MappingProxySerializer(Serializer):
@@ -665,6 +672,9 @@ class DataClassSerializer(Serializer):
stmts = [
f'"""read method for {self.type_}"""',
]
+ if not self.fory.strict:
+ context["checker"] = self.fory.policy
+ stmts.append(f"checker.authorize_instantiation({obj_class})")
# Read hash only in non-compatible mode; in compatible mode, read
field count
if not self.fory.compatible:
@@ -815,6 +825,9 @@ class DataClassSerializer(Serializer):
stmts = [
f'"""xread method for {self.type_}"""',
]
+ if not self.fory.strict:
+ context["checker"] = self.fory.policy
+ stmts.append(f"checker.authorize_instantiation({obj_class})")
if not self.fory.compatible:
stmts.extend(
[
@@ -1333,9 +1346,10 @@ class
StatefulSerializer(CrossLanguageCompatibleSerializer):
self.fory.serialize_ref(buffer, state)
def read(self, buffer):
- args = self.fory.read_ref(buffer)
- kwargs = self.fory.read_ref(buffer)
- state = self.fory.read_ref(buffer)
+ fory = self.fory
+ args = fory.read_ref(buffer)
+ kwargs = fory.read_ref(buffer)
+ state = fory.read_ref(buffer)
if args or kwargs:
# Case 1: __getnewargs__ was used. Re-create by calling __init__.
@@ -1345,6 +1359,7 @@ class
StatefulSerializer(CrossLanguageCompatibleSerializer):
obj = self.cls.__new__(self.cls)
if state:
+ fory.policy.intercept_setstate(obj, state)
obj.__setstate__(state)
return obj
@@ -1449,8 +1464,10 @@ class
ReduceSerializer(CrossLanguageCompatibleSerializer):
listitems = reduce_data[4]
dictitems = reduce_data[5] if len(reduce_data) > 5 else None
- # Create the object using the callable and args
- obj = callable_obj(*args)
+ obj = fory.policy.intercept_reduce_call(callable_obj, args)
+ if obj is None:
+ # Create the object using the callable and args
+ obj = callable_obj(*args)
# Restore state if present
if state is not None:
@@ -1470,6 +1487,9 @@ class ReduceSerializer(CrossLanguageCompatibleSerializer):
for key, value in dictitems:
obj[key] = value
+ result = fory.policy.inspect_reduced_object(obj)
+ if result is not None:
+ obj = result
return obj
else:
raise ValueError(f"Invalid reduce data format: {reduce_data[0]}")
@@ -1615,7 +1635,11 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
# Handle bound methods
self_obj = self.fory.read_ref(buffer)
method_name = buffer.read_string()
- return getattr(self_obj, method_name)
+ func = getattr(self_obj, method_name)
+ result = self.fory.policy.validate_function(func, is_local=False)
+ if result is not None:
+ func = result
+ return func
if func_type_id == 1:
module = buffer.read_string()
@@ -1623,6 +1647,9 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
mod = importlib.import_module(module)
for name in qualname.split("."):
mod = getattr(mod, name)
+ result = self.fory.policy.validate_function(mod, is_local=False)
+ if result is not None:
+ mod = result
return mod
# Regular function or lambda
@@ -1697,6 +1724,9 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
for attr_name, attr_value in attrs.items():
setattr(func, attr_name, attr_value)
+ result = self.fory.policy.validate_function(func, is_local=True)
+ if result is not None:
+ func = result
return func
def xwrite(self, buffer, value):
@@ -1732,10 +1762,14 @@ class NativeFuncMethodSerializer(Serializer):
if buffer.read_bool():
module = buffer.read_string()
mod = importlib.import_module(module)
- return getattr(mod, name)
+ func = getattr(mod, name)
else:
obj = self.fory.read_ref(buffer)
- return getattr(obj, name)
+ func = getattr(obj, name)
+ result = self.fory.policy.validate_function(func, is_local=False)
+ if result is not None:
+ func = result
+ return func
class MethodSerializer(Serializer):
@@ -1757,7 +1791,13 @@ class MethodSerializer(Serializer):
instance = self.fory.read_ref(buffer)
method_name = buffer.read_string()
- return getattr(instance, method_name)
+ method = getattr(instance, method_name)
+ cls = method.__self__.__class__
+ is_local = cls.__module__ == "__main__" or "<locals>" in
cls.__qualname__
+ result = self.fory.policy.validate_method(method, is_local=is_local)
+ if result is not None:
+ method = result
+ return method
def xwrite(self, buffer, value):
return self.write(buffer, value)
@@ -1796,12 +1836,14 @@ class ObjectSerializer(Serializer):
self.fory.serialize_ref(buffer, field_value)
def read(self, buffer):
+ fory = self.fory
+ fory.policy.authorize_instantiation(self.type_)
obj = self.type_.__new__(self.type_)
- self.fory.ref_resolver.reference(obj)
+ fory.ref_resolver.reference(obj)
num_fields = buffer.read_varuint32()
for _ in range(num_fields):
field_name = buffer.read_string()
- field_value = self.fory.read_ref(buffer)
+ field_value = fory.read_ref(buffer)
setattr(obj, field_name, field_value)
return obj
@@ -1814,17 +1856,6 @@ class ObjectSerializer(Serializer):
return self.read(buffer)
-class ComplexObjectSerializer(DataClassSerializer):
- def __new__(cls, fory, clz):
- warnings.warn(
- "`ComplexObjectSerializer` is deprecated and will be removed in a
future version. "
- "Use `DataClassSerializer(fory, clz, xlang=True)` instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- return DataClassSerializer(fory, clz, xlang=True)
-
-
@dataclasses.dataclass
class NonExistEnum:
value: int = -1
diff --git a/python/pyfory/tests/test_policy.py
b/python/pyfory/tests/test_policy.py
new file mode 100644
index 000000000..3a727534e
--- /dev/null
+++ b/python/pyfory/tests/test_policy.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from pyfory import Fory, DeserializationPolicy
+
+
+class BlockClassPolicy(DeserializationPolicy):
+ """Policy that blocks specific class names from deserialization."""
+
+ def __init__(self, blocked_class_names):
+ self.blocked_class_names = blocked_class_names
+
+ def validate_class(self, cls, is_local, **kwargs):
+ if cls.__name__ in self.blocked_class_names:
+ raise ValueError(f"Class {cls.__name__} is blocked")
+ return None
+
+
+class ReplaceObjectPolicy(DeserializationPolicy):
+ """Policy that replaces deserialized objects from reduce."""
+
+ def __init__(self, replacement_value):
+ self.replacement_value = replacement_value
+
+ def inspect_reduced_object(self, obj, **kwargs):
+ if hasattr(obj, "value"):
+ return self.replacement_value
+ return None
+
+
+class BlockReduceCallPolicy(DeserializationPolicy):
+ """Policy that blocks specific callable invocations during reduce."""
+
+ def __init__(self, blocked_names):
+ self.blocked_names = blocked_names
+
+ def intercept_reduce_call(self, callable_obj, args, **kwargs):
+ if hasattr(callable_obj, "__name__") and callable_obj.__name__ in
self.blocked_names:
+ raise ValueError(f"Callable {callable_obj.__name__} is blocked")
+ return None
+
+
+class SanitizeStatePolicy(DeserializationPolicy):
+ """Policy that sanitizes object state during setstate."""
+
+ def intercept_setstate(self, obj, state, **kwargs):
+ if isinstance(state, dict) and "password" in state:
+ state["password"] = "***REDACTED***"
+ return None
+
+
+def test_block_class_type_deserialization():
+ """Test blocking class type (not instance) deserialization."""
+
+ class SafeClass:
+ pass
+
+ class UnsafeClass:
+ pass
+
+ policy = BlockClassPolicy(blocked_class_names=["UnsafeClass"])
+ fory = Fory(ref=True, strict=False, policy=policy)
+
+ # Serialize and deserialize the class type itself (not an instance)
+ safe_data = fory.serialize(SafeClass)
+ result = fory.deserialize(safe_data)
+ assert result.__name__ == "SafeClass"
+
+ # Now test blocking
+ unsafe_data = fory.serialize(UnsafeClass)
+ with pytest.raises(ValueError, match="UnsafeClass is blocked"):
+ fory.deserialize(unsafe_data)
+
+
+def test_block_reduce_call():
+ """Test blocking callable invocations during reduce."""
+
+ class ReducibleClass:
+ def __init__(self, value):
+ self.value = value
+
+ def __reduce__(self):
+ return (ReducibleClass, (self.value,))
+
+ policy = BlockReduceCallPolicy(blocked_names=["ReducibleClass"])
+ fory = Fory(ref=True, strict=False, policy=policy)
+ data = fory.serialize(ReducibleClass(42))
+
+ with pytest.raises(ValueError, match="ReducibleClass is blocked"):
+ fory.deserialize(data)
+
+
+def test_replace_reduced_object():
+ """Test replacing objects created via __reduce__."""
+
+ class ReducibleClass:
+ def __init__(self, value):
+ self.value = value
+
+ def __reduce__(self):
+ return (ReducibleClass, (self.value,))
+
+ policy = ReplaceObjectPolicy(replacement_value="REPLACED")
+ fory = Fory(ref=True, strict=False, policy=policy)
+ data = fory.serialize(ReducibleClass(42))
+
+ result = fory.deserialize(data)
+ assert result == "REPLACED"
+
+
+def test_sanitize_state():
+ """Test sanitizing object state during setstate."""
+
+ class SecretHolder:
+ def __init__(self, username, password):
+ self.username = username
+ self.password = password
+
+ def __getstate__(self):
+ return {"username": self.username, "password": self.password}
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+
+ policy = SanitizeStatePolicy()
+ fory = Fory(ref=False, strict=False, policy=policy)
+ data = fory.serialize(SecretHolder("admin", "secret123"))
+
+ result = fory.deserialize(data)
+ assert result.username == "admin"
+ assert result.password == "***REDACTED***"
+
+
+def test_policy_with_local_class():
+ """Test policy intercepts local class deserialization."""
+
+ def make_local_class():
+ class LocalClass:
+ pass
+
+ return LocalClass
+
+ LocalCls = make_local_class()
+
+ policy = BlockClassPolicy(blocked_class_names=["LocalClass"])
+ fory = Fory(ref=True, strict=False, policy=policy)
+
+ # Serialize the local class type
+ data = fory.serialize(LocalCls)
+
+ with pytest.raises(ValueError, match="LocalClass is blocked"):
+ fory.deserialize(data)
+
+
+def test_policy_with_ref_tracking():
+ """Test policy works with reference tracking."""
+
+ class ReducibleClass:
+ def __init__(self, value):
+ self.value = value
+
+ def __reduce__(self):
+ return (ReducibleClass, (self.value,))
+
+ policy = BlockReduceCallPolicy(blocked_names=["ReducibleClass"])
+ fory = Fory(ref=True, strict=False, policy=policy)
+
+ data = fory.serialize(ReducibleClass(42))
+
+ with pytest.raises(ValueError, match="ReducibleClass is blocked"):
+ fory.deserialize(data)
+
+
+def test_policy_allows_safe_operations():
+ """Test that policy doesn't interfere with safe built-in types."""
+ policy = BlockClassPolicy(blocked_class_names=[])
+ fory = Fory(ref=False, strict=False, policy=policy)
+
+ assert fory.deserialize(fory.serialize(42)) == 42
+ assert fory.deserialize(fory.serialize("test")) == "test"
+ assert fory.deserialize(fory.serialize([1, 2, 3])) == [1, 2, 3]
+
+
+def test_multiple_policy_hooks():
+ """Test policy with multiple hooks working together."""
+
+ class MultiHookPolicy(DeserializationPolicy):
+ def __init__(self):
+ self.hooks_called = []
+
+ def validate_class(self, cls, is_local, **kwargs):
+ self.hooks_called.append(("validate_class", cls.__name__))
+ return None
+
+ def intercept_reduce_call(self, callable_obj, args, **kwargs):
+ if hasattr(callable_obj, "__name__"):
+ self.hooks_called.append(("intercept_reduce_call",
callable_obj.__name__))
+ return None
+
+ def inspect_reduced_object(self, obj, **kwargs):
+ self.hooks_called.append(("inspect_reduced_object",
type(obj).__name__))
+ return None
+
+ class TestClass:
+ def __init__(self, value):
+ self.value = value
+
+ def __reduce__(self):
+ return (TestClass, (self.value,))
+
+ policy = MultiHookPolicy()
+ fory = Fory(ref=True, strict=False, policy=policy)
+
+ data = fory.serialize(TestClass(42))
+ result = fory.deserialize(data)
+
+ # All hooks should have been called
+ assert ("intercept_reduce_call", "TestClass") in policy.hooks_called
+ assert ("inspect_reduced_object", "TestClass") in policy.hooks_called
+ assert result.value == 42
+
+
+def test_policy_with_nested_reduce():
+ """Test policy handles nested objects with __reduce__."""
+
+ class Inner:
+ def __init__(self, value):
+ self.value = value
+
+ def __reduce__(self):
+ return (Inner, (self.value,))
+
+ class Outer:
+ def __init__(self, inner):
+ self.inner = inner
+
+ def __reduce__(self):
+ return (Outer, (self.inner,))
+
+ policy = BlockReduceCallPolicy(blocked_names=["Inner"])
+ fory = Fory(ref=True, strict=False, policy=policy)
+
+ data = fory.serialize(Outer(Inner(42)))
+
+ with pytest.raises(ValueError, match="Inner is blocked"):
+ fory.deserialize(data)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]