This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 0048399d2 fix(python): make policy validators validation-only (#3708)
0048399d2 is described below
commit 0048399d2b8658eb479488b7c2d1638785b03cab
Author: Shawn Yang <[email protected]>
AuthorDate: Wed May 27 08:34:05 2026 +0800
fix(python): make policy validators validation-only (#3708)
## Why?
## What does this PR do?
## Related issues
## AI Contribution Checklist
- [ ] Substantial AI assistance was used in this PR: `yes` / `no`
- [ ] If `yes`, I included a completed [AI Contribution
Checklist](https://github.com/apache/fory/blob/main/AI_POLICY.md#9-contributor-checklist-for-ai-assisted-prs)
in this PR description and the required `AI Usage Disclosure`.
- [ ] If `yes`, my PR description includes the required `ai_review`
summary and screenshot evidence of the final clean AI review results
from both fresh reviewers on the current PR diff or current HEAD after
the latest code changes.
## Does this PR introduce any user-facing change?
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
---
docs/guide/python/configuration.md | 7 +-
python/README.md | 5 +-
python/pyfory/meta/typedef_decoder.py | 4 +-
python/pyfory/policy.py | 71 +++++++--------------
python/pyfory/serializer.py | 70 +++++++++-----------
python/pyfory/tests/test_policy.py | 116 +++++++++++++++++++++++++++++-----
python/pyfory/type_util.py | 29 ++-------
7 files changed, 166 insertions(+), 136 deletions(-)
diff --git a/docs/guide/python/configuration.md
b/docs/guide/python/configuration.md
index a90219238..7a16bbee0 100644
--- a/docs/guide/python/configuration.md
+++ b/docs/guide/python/configuration.md
@@ -211,7 +211,6 @@ class SafeDeserializationPolicy(DeserializationPolicy):
def validate_class(self, cls, is_local, **kwargs):
if cls.__module__ in dangerous_modules:
raise ValueError(f"Blocked dangerous class:
{cls.__module__}.{cls.__name__}")
- return None
def intercept_reduce_call(self, callable_obj, args, **kwargs):
if getattr(callable_obj, "__name__", "") == "Popen":
@@ -229,11 +228,15 @@ fory = pyfory.Fory(xlang=False, ref=True, strict=False,
policy=policy)
Available policy hooks include:
+Reference validation hooks reject by raising exceptions and otherwise leave
deserialized references
+unchanged.
+
| Hook | Description
|
| -------------------------------------------- |
--------------------------------------------------- |
| `validate_class(cls, is_local)` | Validate or block class types
|
-| `validate_module(module, is_local)` | Validate or block module
imports |
+| `validate_module(module_name, is_local)` | Validate or block module
imports |
| `validate_function(func, is_local)` | Validate or block function
references |
+| `validate_method(method, is_local)` | Validate or block method
references |
| `intercept_reduce_call(callable_obj, args)` | Intercept `__reduce__`
invocations |
| `inspect_reduced_object(obj)` | Inspect or replace objects
created via `__reduce__` |
| `intercept_setstate(obj, state)` | Sanitize state before
`__setstate__` |
diff --git a/python/README.md b/python/README.md
index d9ba9c707..011e78b0c 100644
--- a/python/README.md
+++ b/python/README.md
@@ -1119,7 +1119,6 @@ class SafeDeserializationPolicy(DeserializationPolicy):
# Block dangerous modules
if cls.__module__ in dangerous_modules:
raise ValueError(f"Blocked dangerous class:
{cls.__module__}.{cls.__name__}")
- return None
def intercept_reduce_call(self, callable_obj, args, **kwargs):
# Block specific callable invocations during __reduce__
@@ -1144,9 +1143,11 @@ result = fory.deserialize(data) # Policy hooks will be
invoked
**Available Policy Hooks:**
+- Reference validation hooks reject by raising exceptions and otherwise leave
deserialized references unchanged.
- `validate_class(cls, is_local)` - Validate/block class types during
deserialization
-- `validate_module(module, is_local)` - Validate/block module imports
+- `validate_module(module_name, is_local)` - Validate/block module imports
- `validate_function(func, is_local)` - Validate/block function references
+- `validate_method(method, is_local)` - Validate/block method references
- `intercept_reduce_call(callable_obj, args)` - Intercept `__reduce__`
invocations
- `inspect_reduced_object(obj)` - Inspect/replace objects created via
`__reduce__`
- `intercept_setstate(obj, state)` - Sanitize state before `__setstate__`
diff --git a/python/pyfory/meta/typedef_decoder.py
b/python/pyfory/meta/typedef_decoder.py
index 3f2b224c1..2c6a89d60 100644
--- a/python/pyfory/meta/typedef_decoder.py
+++ b/python/pyfory/meta/typedef_decoder.py
@@ -184,9 +184,7 @@ def decode_typedef(buffer: Buffer, resolver, header=None)
-> TypeDef:
type_cls = make_dataclass(class_name, field_definitions)
policy = getattr(resolver, "policy", None)
if policy is not None:
- result = policy.validate_class(type_cls, is_local=True)
- if result is not None:
- type_cls = result
+ policy.validate_class(type_cls, is_local=True)
elif type_cls is None:
raise ValueError(f"TypeDef {name} is not registered")
diff --git a/python/pyfory/policy.py b/python/pyfory/policy.py
index 5070821c0..b8f8ea2e4 100644
--- a/python/pyfory/policy.py
+++ b/python/pyfory/policy.py
@@ -48,7 +48,7 @@ class DeserializationPolicy:
| __reduce__ interception | no |
intercept_reduce_call() |
| Post-reduce inspection | no |
inspect_reduced_object() |
| __setstate__ interception | no | intercept_setstate()
|
- | Object replacement | no | return from
validators |
+ | Object replacement | no |
inspect_reduced_object() |
| State sanitization | no | modify in-place
|
| Local class/function | no | is_local flag
|
+---------------------------+----------------------+----------------------------+
@@ -87,8 +87,8 @@ class DeserializationPolicy:
This DeserializationPolicy interface allows users to implement custom
security policies by
subclassing and overriding specific hook methods. Each hook is called at a
critical
- point during deserialization, allowing inspection, replacement, or
rejection of
- dangerous constructs.
+ point during deserialization, allowing validation hooks to inspect or
reject
+ dangerous constructs and interceptor hooks to control protocol operations.
Hook Categories
---------------
@@ -98,7 +98,7 @@ class DeserializationPolicy:
2. **Reference Validation Hooks** (Validators)
- Validate deserialized type/function/module references
- - Return None to accept original, return object to replace, raise
exception to block,
+ - Raise exception to block, otherwise return normally
3. **Protocol Interception Hooks** (Interceptors)
- Intercept pickle protocol operations (__reduce__, __setstate__)
@@ -109,17 +109,15 @@ class DeserializationPolicy:
>>> class SafeDeserializationPolicy(DeserializationPolicy):
... ALLOWED_MODULES = {'builtins', 'datetime', 'decimal'}
...
- ... def validate_module(self, module_name, **kwargs):
+ ... def validate_module(self, module_name, is_local, **kwargs):
... # Reject imports from disallowed modules
... if module_name.split('.')[0] not in self.ALLOWED_MODULES:
... raise ValueError(f"Module {module_name} is not allowed")
- ... return None # Accept
...
... def validate_class(self, cls, is_local, **kwargs):
... # Reject dangerous built-in classes
... if cls.__name__ in ('eval', 'exec', 'compile'):
... raise ValueError(f"Class {cls} is forbidden")
- ... return None # Accept
...
... def intercept_reduce_call(self, callable_obj, args, **kwargs):
... # Log all __reduce__ callables for audit
@@ -201,7 +199,7 @@ class DeserializationPolicy:
This hook is called after a class reference has been deserialized
(either by
importing from a module or reconstructing a local class), but before
it is used.
- It allows inspection, replacement, or rejection of class references.
+ It allows inspection or rejection of class references.
When Called
-----------
@@ -212,9 +210,8 @@ class DeserializationPolicy:
Security Use Cases
------------------
- Block dangerous classes (subprocess.Popen, os.system, etc.)
- - Replace untrusted classes with safe alternatives
- Validate that local classes match expected signatures
- - Implement class versioning or adaptation logic
+ - Audit class imports for security logging
Args:
cls (type): The deserialized class object.
@@ -223,29 +220,20 @@ class DeserializationPolicy:
class from an importable module.
**kwargs: Reserved for future extensions.
- Returns:
- None: Return None to accept the class as-is.
- type: Return a different class to replace the original. The
replacement
- class will be used instead for deserialization.
-
Raises:
Exception: Raise any exception to reject the class and abort
deserialization.
Example:
- >>> class ClassAdapter(DeserializationPolicy):
+ >>> class ClassChecker(DeserializationPolicy):
... def validate_class(self, cls, is_local, **kwargs):
- ... # Map a serialized class name to the current class.
- ... if cls.__name__ == 'ArchivedUserClass':
- ... return NewUserClass
... # Block dangerous classes
... if cls.__module__ == 'subprocess':
... raise ValueError("subprocess classes not allowed")
- ... return None # Accept
Note:
`check_class` is an alias for this hook.
"""
- pass
+ return None
def validate_function(self, func, is_local: bool, **kwargs):
"""Validate a deserialized function reference.
@@ -263,7 +251,6 @@ class DeserializationPolicy:
------------------
- Block dangerous built-in functions (eval, exec, compile, __import__)
- Validate that reconstructed functions have expected signatures
- - Replace untrusted functions with safe alternatives
- Audit function imports for security logging
Args:
@@ -272,10 +259,6 @@ class DeserializationPolicy:
within a function scope), False if it's a global
function.
**kwargs: Reserved for future extensions.
- Returns:
- None: Return None to accept the function as-is.
- function: Return a different function to replace the original.
-
Raises:
Exception: Raise any exception to reject the function.
@@ -286,12 +269,11 @@ class DeserializationPolicy:
... def validate_function(self, func, is_local, **kwargs):
... if func.__name__ in self.BLOCKED:
... raise ValueError(f"Function {func.__name__} is
forbidden")
- ... return None
Note:
`check_function` is an alias for this hook.
"""
- pass
+ return None
def validate_method(self, method, is_local: bool, **kwargs):
"""Validate a deserialized method reference.
@@ -309,17 +291,13 @@ class DeserializationPolicy:
------------------
- Validate that methods belong to expected classes
- Block methods that could perform dangerous operations
- - Replace methods with safer alternatives
+ - Audit method references for security logging
Args:
method (method): The deserialized bound method object.
is_local (bool): True if the method's class is local, False if
global.
**kwargs: Reserved for future extensions.
- Returns:
- None: Return None to accept the method as-is.
- method: Return a different method to replace the original.
-
Raises:
Exception: Raise any exception to reject the method.
@@ -329,39 +307,35 @@ class DeserializationPolicy:
... # Block methods from dangerous classes
... if method.__self__.__class__.__name__ == 'FileRemover':
... raise ValueError("FileRemover methods not allowed")
- ... return None
Note:
`check_method` is an alias for this hook.
"""
- pass
+ return None
- def validate_module(self, module_name: str, **kwargs):
+ def validate_module(self, module_name: str, *, is_local: bool, **kwargs):
"""Validate a deserialized module reference.
- This hook is called after a module has been imported during
deserialization,
- but before it is used.
+ This hook is called before a module is imported during deserialization.
When Called
-----------
- - After importing modules via importlib.import_module()
- - Before the module is stored or its contents accessed
+ - Before importing modules via importlib.import_module()
+ - Before the module is stored or its contents are accessed
Security Use Cases
------------------
- Whitelist/blacklist modules by name or prefix
- Prevent imports of system modules (os, subprocess, sys, etc.)
- - Replace modules with safe alternatives or mocks
- Audit module imports for security logging
Args:
- module_name (str): The name of the imported module (e.g.,
'os.path').
+ module_name (str): The name of the module to import (e.g.,
'os.path').
+ is_local (bool): True if the reference being resolved is local
(defined
+ in __main__ or within a function/method scope),
False
+ otherwise.
**kwargs: Reserved for future extensions.
- Returns:
- None: Return None to accept the module as-is.
- module: Return a different module object to replace the original.
-
Raises:
Exception: Raise any exception to reject the module import.
@@ -369,16 +343,15 @@ class DeserializationPolicy:
>>> class ModuleWhitelistChecker(DeserializationPolicy):
... ALLOWED = {'builtins', 'datetime', 'decimal',
'collections'}
...
- ... def validate_module(self, module_name, **kwargs):
+ ... def validate_module(self, module_name, is_local, **kwargs):
... root = module_name.split('.')[0]
... if root not in self.ALLOWED:
... raise ValueError(f"Module {module_name} not
whitelisted")
- ... return None
Note:
`check_module` is an alias for this hook.
"""
- pass
+ return None
#
============================================================================
# Protocol Interception Hooks (Interceptors)
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index 2736599b5..e2cdf7994 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -46,23 +46,18 @@ _WINDOWS = os.name == "nt"
from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION
-def _import_validated_module(policy, module_name):
- result = policy.validate_module(module_name)
- if result is not None:
- if isinstance(result, types.ModuleType):
- return result
- assert isinstance(result, str), f"validate_module must return module,
str, or None, got {type(result)}"
- module_name = result
+def _import_validated_module(policy, module_name, is_local=False):
+ policy.validate_module(module_name, is_local=is_local)
return importlib.import_module(module_name)
-def _resolve_validated_module_attr(policy, module_name, attr_name):
- module = _import_validated_module(policy, module_name)
+def _resolve_validated_module_attr(policy, module_name, attr_name,
is_local=False):
+ module = _import_validated_module(policy, module_name, is_local=is_local)
return getattr(module, attr_name)
def _resolve_validated_module_qualname(policy, module_name, qualname):
- obj = _import_validated_module(policy, module_name)
+ obj = _import_validated_module(policy, module_name,
is_local=_is_local_qualname(module_name, qualname))
for name in qualname.split("."):
obj = getattr(obj, name)
return obj
@@ -111,21 +106,15 @@ def _is_bound_method_value(obj):
def _validate_function_value(policy, func, is_local):
if isinstance(func, type):
- result = policy.validate_class(func, is_local=is_local)
- if result is not None:
- func = result
+ policy.validate_class(func, is_local=is_local)
if isinstance(func, type):
raise TypeError(f"Function serializer resolved class
{func.__module__}.{func.__qualname__}")
if _is_bound_method_value(func):
- result = policy.validate_method(func, is_local=is_local)
- if result is not None:
- func = result
+ policy.validate_method(func, is_local=is_local)
return func
if not callable(func):
raise TypeError(f"Function serializer resolved non-callable object
{func!r}")
- result = policy.validate_function(func, is_local=is_local)
- if result is not None:
- func = result
+ policy.validate_function(func, is_local=is_local)
return func
@@ -159,9 +148,7 @@ def _resolve_validated_bound_method(policy, obj,
method_name, is_local):
if policy is DEFAULT_POLICY:
return getattr(obj, method_name)
method = _bind_static_method(obj, method_name)
- result = policy.validate_method(method, is_local=is_local)
- if result is not None:
- method = result
+ policy.validate_method(method, is_local=is_local)
return method
@@ -1214,15 +1201,12 @@ class ReduceSerializer(Serializer):
self._getnewargs = getattr(cls, "__getnewargs__", None)
def _validate_global_object(self, policy, obj):
- result = None
if isinstance(obj, type):
- result = policy.validate_class(obj, is_local=_is_local_class(obj))
+ policy.validate_class(obj, is_local=_is_local_class(obj))
elif _is_bound_method_value(obj):
- result = policy.validate_method(obj,
is_local=_is_local_callable(obj))
+ policy.validate_method(obj, is_local=_is_local_callable(obj))
elif isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
- result = policy.validate_function(obj,
is_local=_is_local_callable(obj))
- if result is not None:
- obj = result
+ policy.validate_function(obj, is_local=_is_local_callable(obj))
return obj
def _resolve_global_name(self, read_context, global_name):
@@ -1232,7 +1216,12 @@ class ReduceSerializer(Serializer):
else:
module_name, obj_name = "builtins", global_name
try:
- obj = _resolve_validated_module_attr(policy, module_name, obj_name)
+ obj = _resolve_validated_module_attr(
+ policy,
+ module_name,
+ obj_name,
+ is_local=_is_local_qualname(module_name, obj_name),
+ )
except AttributeError:
raise ValueError(f"Cannot resolve global name: {global_name}")
return self._validate_global_object(policy, obj)
@@ -1370,9 +1359,7 @@ class TypeSerializer(Serializer):
module_name = read_context.read_string()
qualname = read_context.read_string()
cls = _resolve_validated_module_qualname(read_context.policy,
module_name, qualname)
- result = read_context.policy.validate_class(cls,
is_local=_is_local_class(cls))
- if result is not None:
- cls = result
+ read_context.policy.validate_class(cls, is_local=_is_local_class(cls))
return cls
def _serialize_local_class(self, write_context, cls):
@@ -1422,9 +1409,7 @@ class TypeSerializer(Serializer):
read_context.policy.authorize_instantiation(type, module=module,
qualname=qualname, bases=bases)
cls = type(name, bases, {})
read_context.set_read_ref(ref_id, cls)
- result = read_context.policy.validate_class(cls, is_local=True)
- if result is not None:
- cls = result
+ read_context.policy.validate_class(cls, is_local=True)
num_class_methods = read_context.read_var_uint32()
_check_collection_size(read_context, num_class_methods, "local class
method")
@@ -1440,9 +1425,7 @@ class TypeSerializer(Serializer):
# Set module and qualname
cls.__module__ = module
cls.__qualname__ = qualname
- result = read_context.policy.validate_class(cls, is_local=True)
- if result is not None:
- cls = result
+ read_context.policy.validate_class(cls, is_local=True)
return cls
@@ -1457,7 +1440,7 @@ class ModuleSerializer(Serializer):
def read(self, read_context):
mod_name = read_context.read_string()
- return _import_validated_module(read_context.policy, mod_name)
+ return _import_validated_module(read_context.policy, mod_name,
is_local=_is_local_qualname(mod_name, ""))
class MappingProxySerializer(Serializer):
@@ -1617,7 +1600,7 @@ class FunctionSerializer(Serializer):
module = read_context.read_string()
qualname = read_context.read_string()
- mod = _import_validated_module(read_context.policy, module)
+ mod = _import_validated_module(read_context.policy, module,
is_local=_is_local_qualname(module, qualname))
name = qualname.rsplit(".")[-1]
marshalled_code = read_context.read_bytes_and_size()
@@ -1699,7 +1682,12 @@ class NativeFuncMethodSerializer(Serializer):
name = read_context.read_string()
if read_context.read_bool():
module = read_context.read_string()
- func = _resolve_validated_module_attr(read_context.policy, module,
name)
+ func = _resolve_validated_module_attr(
+ read_context.policy,
+ module,
+ name,
+ is_local=_is_local_qualname(module, name),
+ )
func = _validate_function_value(read_context.policy, func,
is_local=_is_local_callable(func))
else:
obj = read_context.read_ref()
diff --git a/python/pyfory/tests/test_policy.py
b/python/pyfory/tests/test_policy.py
index 9eb7d55f0..dc0737369 100644
--- a/python/pyfory/tests/test_policy.py
+++ b/python/pyfory/tests/test_policy.py
@@ -30,6 +30,10 @@ def policy_global_function():
return "safe"
+def policy_replacement_function():
+ return "replacement"
+
+
class PolicyMethodHolder:
def run(self):
return "safe"
@@ -491,26 +495,26 @@ def test_validate_module():
import json
import collections
- # Test 1: Return module object directly
class ReturnModulePolicy(DeserializationPolicy):
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
+ assert not is_local
return collections
fory1 = Fory(xlang=False, ref=True, strict=False,
policy=ReturnModulePolicy())
data = fory1.serialize(json)
- assert fory1.deserialize(data) is collections
+ assert fory1.deserialize(data) is json
- # Test 2: Return string to redirect import
class RedirectPolicy(DeserializationPolicy):
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
+ assert not is_local
return "collections" if module_name == "json" else None
fory2 = Fory(xlang=False, ref=True, strict=False, policy=RedirectPolicy())
- assert fory2.deserialize(fory2.serialize(json)).__name__ == "collections"
+ assert fory2.deserialize(fory2.serialize(json)).__name__ == "json"
- # Test 3: Raise to block module
class BlockPolicy(DeserializationPolicy):
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
+ assert not is_local
raise ValueError(f"Module {module_name} blocked")
fory3 = Fory(xlang=False, ref=True, strict=False, policy=BlockPolicy())
@@ -518,6 +522,63 @@ def test_validate_module():
fory3.deserialize(fory3.serialize(json))
+def test_validator_returns_ignored():
+ import json
+ import collections
+
+ class ReplacementClass:
+ pass
+
+ class ReturnPolicy(DeserializationPolicy):
+ def validate_module(self, module_name, is_local, **kwargs):
+ assert not is_local
+ return collections
+
+ def validate_class(self, cls, is_local, **kwargs):
+ return ReplacementClass
+
+ def validate_function(self, func, is_local, **kwargs):
+ return policy_replacement_function
+
+ def validate_method(self, method, is_local, **kwargs):
+ return policy_replacement_function
+
+ policy = ReturnPolicy()
+ fory = Fory(xlang=False, ref=True, strict=False, policy=policy)
+ assert fory.deserialize(fory.serialize(json)) is json
+ assert fory.deserialize(fory.serialize(PolicyGlobalClass)) is
PolicyGlobalClass
+ assert fory.deserialize(fory.serialize(policy_global_function)) is
policy_global_function
+
+ serializer = FunctionSerializer(fory.type_resolver,
type(policy_global_function))
+ read_context = FakeReadContext(policy, [1, __name__,
"policy_global_bound_method"])
+ assert serializer._deserialize_function(read_context) is
policy_global_bound_method
+
+
+def test_local_class_return_ignored():
+ class SafeClass:
+ @classmethod
+ def run(cls):
+ return "safe"
+
+ def make_payload_class():
+ class PayloadClass:
+ @classmethod
+ def run(cls):
+ return "payload"
+
+ return PayloadClass
+
+ class ReturnClassPolicy(DeserializationPolicy):
+ def validate_class(self, cls, is_local, **kwargs):
+ return SafeClass if is_local else None
+
+ fory = Fory(xlang=False, ref=True, strict=False,
policy=ReturnClassPolicy())
+ decoded = fory.deserialize(fory.serialize(make_payload_class()))
+ assert decoded is not SafeClass
+ assert decoded.run() == "payload"
+ assert SafeClass.run() == "safe"
+
+
def test_type_deserialization_validates_module():
"""Test validate_module policy hook for global class deserialization."""
import subprocess
@@ -525,9 +586,11 @@ def test_type_deserialization_validates_module():
class BlockModulePolicy(DeserializationPolicy):
def __init__(self):
self.validate_module_calls = 0
+ self.is_local_values = []
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
self.validate_module_calls += 1
+ self.is_local_values.append(is_local)
if module_name == "subprocess":
raise ValueError("subprocess blocked")
return None
@@ -537,6 +600,7 @@ def test_type_deserialization_validates_module():
with pytest.raises(ValueError, match="subprocess blocked"):
fory.deserialize(fory.serialize(subprocess.Popen))
assert policy.validate_module_calls == 1
+ assert policy.is_local_values == [False]
def test_native_bound_method_uses_validate_method():
@@ -818,9 +882,11 @@ def
test_global_function_deserialization_validates_module():
class BlockModulePolicy(DeserializationPolicy):
def __init__(self):
self.validate_module_calls = 0
+ self.is_local_values = []
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
self.validate_module_calls += 1
+ self.is_local_values.append(is_local)
if module_name == policy_global_function.__module__:
raise ValueError("function module blocked")
return None
@@ -830,6 +896,7 @@ def test_global_function_deserialization_validates_module():
with pytest.raises(ValueError, match="function module blocked"):
fory.deserialize(fory.serialize(policy_global_function))
assert policy.validate_module_calls == 1
+ assert policy.is_local_values == [False]
def test_local_function_deserialization_validates_module():
@@ -841,9 +908,11 @@ def test_local_function_deserialization_validates_module():
class BlockModulePolicy(DeserializationPolicy):
def __init__(self):
self.validate_module_calls = 0
+ self.is_local_values = []
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
self.validate_module_calls += 1
+ self.is_local_values.append(is_local)
if module_name == local_function.__module__:
raise ValueError("local function module blocked")
return None
@@ -853,6 +922,7 @@ def test_local_function_deserialization_validates_module():
with pytest.raises(ValueError, match="local function module blocked"):
fory.deserialize(fory.serialize(local_function))
assert policy.validate_module_calls == 1
+ assert policy.is_local_values == [True]
def test_native_function_deserialization_validates_module():
@@ -862,9 +932,11 @@ def
test_native_function_deserialization_validates_module():
class BlockModulePolicy(DeserializationPolicy):
def __init__(self):
self.validate_module_calls = 0
+ self.is_local_values = []
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
self.validate_module_calls += 1
+ self.is_local_values.append(is_local)
if module_name == "time":
raise ValueError("time blocked")
return None
@@ -874,6 +946,7 @@ def test_native_function_deserialization_validates_module():
with pytest.raises(ValueError, match="time blocked"):
fory.deserialize(fory.serialize(time.time))
assert policy.validate_module_calls == 1
+ assert policy.is_local_values == [False]
def test_type_metadata_load_validates_module():
@@ -882,9 +955,11 @@ def test_type_metadata_load_validates_module():
class BlockModulePolicy(DeserializationPolicy):
def __init__(self):
self.validate_module_calls = 0
+ self.is_local_values = []
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
self.validate_module_calls += 1
+ self.is_local_values.append(is_local)
if module_name == "subprocess":
raise ValueError("subprocess blocked")
return None
@@ -902,6 +977,7 @@ def test_type_metadata_load_validates_module():
with pytest.raises(ValueError, match="subprocess blocked"):
resolver._load_metabytes_to_type_info(ns_metabytes, type_metabytes)
assert policy.validate_module_calls == 1
+ assert policy.is_local_values == [False]
def test_type_metadata_load_validates_class():
@@ -942,9 +1018,11 @@ def test_reduce_global_name_validates_module():
class BlockModulePolicy(DeserializationPolicy):
def __init__(self):
self.validate_module_calls = 0
+ self.is_local_values = []
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
self.validate_module_calls += 1
+ self.is_local_values.append(is_local)
if module_name == "subprocess":
raise ValueError(f"Module {module_name} blocked")
return None
@@ -954,6 +1032,7 @@ def test_reduce_global_name_validates_module():
with pytest.raises(ValueError, match="subprocess blocked"):
fory.deserialize(fory.serialize(GlobalNamePayload()))
assert policy.validate_module_calls == 1
+ assert policy.is_local_values == [False]
def test_reduce_global_name_validates_class():
@@ -968,8 +1047,9 @@ def test_reduce_global_name_validates_class():
self.validate_module_calls = 0
self.validate_class_calls = 0
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
self.validate_module_calls += 1
+ assert not is_local
return None
def validate_class(self, cls, is_local, **kwargs):
@@ -998,8 +1078,9 @@ def test_reduce_global_name_validates_function():
self.validate_module_calls = 0
self.validate_function_calls = 0
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
self.validate_module_calls += 1
+ assert not is_local
return None
def validate_function(self, func, is_local, **kwargs):
@@ -1029,8 +1110,9 @@ def
test_reduce_global_method_resolution_uses_validate_method():
self.validate_method_calls = 0
self.validate_function_calls = 0
- def validate_module(self, module_name, **kwargs):
+ def validate_module(self, module_name, is_local, **kwargs):
self.validate_module_calls += 1
+ assert not is_local
return None
def validate_method(self, method, is_local, **kwargs):
diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py
index d0e8b02a9..ca1a750dc 100644
--- a/python/pyfory/type_util.py
+++ b/python/pyfory/type_util.py
@@ -18,7 +18,6 @@
import dataclasses
import importlib
import inspect
-import types
import typing
from typing import TypeVar
@@ -367,34 +366,20 @@ def qualified_class_name(cls):
def load_class(classname: str, policy=None):
mod_name, cls_name = classname.rsplit("#", 1)
+ is_local = mod_name == "__main__" or "<locals>" in cls_name
if policy is not None:
- result = policy.validate_module(mod_name)
- if result is not None:
- if isinstance(result, str):
- mod_name = result
- mod = None
- else:
- assert isinstance(result, types.ModuleType), f"validate_module
must return module, str, or None, got {type(result)}"
- mod = result
- else:
- mod = None
- else:
- mod = None
- if mod is None:
- try:
- mod = importlib.import_module(mod_name)
- except ImportError as ex:
- raise Exception(f"Can't import module {mod_name}") from ex
+ policy.validate_module(mod_name, is_local=is_local)
+ try:
+ mod = importlib.import_module(mod_name)
+ except ImportError as ex:
+ raise Exception(f"Can't import module {mod_name}") from ex
try:
classes = cls_name.split(".")
cls = getattr(mod, classes.pop(0))
while classes:
cls = getattr(cls, classes.pop(0))
if policy is not None:
- is_local = cls.__module__ == "__main__" or "<locals>" in
cls.__qualname__
- result = policy.validate_class(cls, is_local=is_local)
- if result is not None:
- cls = result
+ policy.validate_class(cls, is_local=is_local)
return cls
except AttributeError as ex:
raise Exception(f"Can't import class {cls_name} from module
{mod_name}") from ex
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]