This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git


The following commit(s) were added to refs/heads/main by this push:
     new 0048399d2 fix(python): make policy validators validation-only (#3708)
0048399d2 is described below

commit 0048399d2b8658eb479488b7c2d1638785b03cab
Author: Shawn Yang <[email protected]>
AuthorDate: Wed May 27 08:34:05 2026 +0800

    fix(python): make policy validators validation-only (#3708)
    
    ## Why?
    
    
    
    ## What does this PR do?
    
    
    
    ## Related issues
    
    
    
    ## AI Contribution Checklist
    
    
    
    - [ ] Substantial AI assistance was used in this PR: `yes` / `no`
    - [ ] If `yes`, I included a completed [AI Contribution
    
Checklist](https://github.com/apache/fory/blob/main/AI_POLICY.md#9-contributor-checklist-for-ai-assisted-prs)
    in this PR description and the required `AI Usage Disclosure`.
    - [ ] If `yes`, my PR description includes the required `ai_review`
    summary and screenshot evidence of the final clean AI review results
    from both fresh reviewers on the current PR diff or current HEAD after
    the latest code changes.
    
    
    
    ## Does this PR introduce any user-facing change?
    
    
    
    - [ ] Does this PR introduce any public API change?
    - [ ] Does this PR introduce any binary protocol compatibility change?
    
    ## Benchmark
---
 docs/guide/python/configuration.md    |   7 +-
 python/README.md                      |   5 +-
 python/pyfory/meta/typedef_decoder.py |   4 +-
 python/pyfory/policy.py               |  71 +++++++--------------
 python/pyfory/serializer.py           |  70 +++++++++-----------
 python/pyfory/tests/test_policy.py    | 116 +++++++++++++++++++++++++++++-----
 python/pyfory/type_util.py            |  29 ++-------
 7 files changed, 166 insertions(+), 136 deletions(-)

diff --git a/docs/guide/python/configuration.md 
b/docs/guide/python/configuration.md
index a90219238..7a16bbee0 100644
--- a/docs/guide/python/configuration.md
+++ b/docs/guide/python/configuration.md
@@ -211,7 +211,6 @@ class SafeDeserializationPolicy(DeserializationPolicy):
     def validate_class(self, cls, is_local, **kwargs):
         if cls.__module__ in dangerous_modules:
             raise ValueError(f"Blocked dangerous class: 
{cls.__module__}.{cls.__name__}")
-        return None
 
     def intercept_reduce_call(self, callable_obj, args, **kwargs):
         if getattr(callable_obj, "__name__", "") == "Popen":
@@ -229,11 +228,15 @@ fory = pyfory.Fory(xlang=False, ref=True, strict=False, 
policy=policy)
 
 Available policy hooks include:
 
+Reference validation hooks reject by raising exceptions and otherwise leave 
deserialized references
+unchanged.
+
 | Hook                                         | Description                   
                      |
 | -------------------------------------------- | 
--------------------------------------------------- |
 | `validate_class(cls, is_local)`              | Validate or block class types 
                      |
-| `validate_module(module, is_local)`          | Validate or block module 
imports                    |
+| `validate_module(module_name, is_local)`     | Validate or block module 
imports                    |
 | `validate_function(func, is_local)`          | Validate or block function 
references               |
+| `validate_method(method, is_local)`          | Validate or block method 
references                 |
 | `intercept_reduce_call(callable_obj, args)`  | Intercept `__reduce__` 
invocations                  |
 | `inspect_reduced_object(obj)`                | Inspect or replace objects 
created via `__reduce__` |
 | `intercept_setstate(obj, state)`             | Sanitize state before 
`__setstate__`                |
diff --git a/python/README.md b/python/README.md
index d9ba9c707..011e78b0c 100644
--- a/python/README.md
+++ b/python/README.md
@@ -1119,7 +1119,6 @@ class SafeDeserializationPolicy(DeserializationPolicy):
         # Block dangerous modules
         if cls.__module__ in dangerous_modules:
             raise ValueError(f"Blocked dangerous class: 
{cls.__module__}.{cls.__name__}")
-        return None
 
     def intercept_reduce_call(self, callable_obj, args, **kwargs):
         # Block specific callable invocations during __reduce__
@@ -1144,9 +1143,11 @@ result = fory.deserialize(data)  # Policy hooks will be 
invoked
 
 **Available Policy Hooks:**
 
+- Reference validation hooks reject by raising exceptions and otherwise leave 
deserialized references unchanged.
 - `validate_class(cls, is_local)` - Validate/block class types during 
deserialization
-- `validate_module(module, is_local)` - Validate/block module imports
+- `validate_module(module_name, is_local)` - Validate/block module imports
 - `validate_function(func, is_local)` - Validate/block function references
+- `validate_method(method, is_local)` - Validate/block method references
 - `intercept_reduce_call(callable_obj, args)` - Intercept `__reduce__` 
invocations
 - `inspect_reduced_object(obj)` - Inspect/replace objects created via 
`__reduce__`
 - `intercept_setstate(obj, state)` - Sanitize state before `__setstate__`
diff --git a/python/pyfory/meta/typedef_decoder.py 
b/python/pyfory/meta/typedef_decoder.py
index 3f2b224c1..2c6a89d60 100644
--- a/python/pyfory/meta/typedef_decoder.py
+++ b/python/pyfory/meta/typedef_decoder.py
@@ -184,9 +184,7 @@ def decode_typedef(buffer: Buffer, resolver, header=None) 
-> TypeDef:
         type_cls = make_dataclass(class_name, field_definitions)
         policy = getattr(resolver, "policy", None)
         if policy is not None:
-            result = policy.validate_class(type_cls, is_local=True)
-            if result is not None:
-                type_cls = result
+            policy.validate_class(type_cls, is_local=True)
     elif type_cls is None:
         raise ValueError(f"TypeDef {name} is not registered")
 
diff --git a/python/pyfory/policy.py b/python/pyfory/policy.py
index 5070821c0..b8f8ea2e4 100644
--- a/python/pyfory/policy.py
+++ b/python/pyfory/policy.py
@@ -48,7 +48,7 @@ class DeserializationPolicy:
     | __reduce__ interception   | no                   | 
intercept_reduce_call()    |
     | Post-reduce inspection    | no                   | 
inspect_reduced_object()   |
     | __setstate__ interception | no                   | intercept_setstate()  
     |
-    | Object replacement        | no                   | return from 
validators     |
+    | Object replacement        | no                   | 
inspect_reduced_object()   |
     | State sanitization        | no                   | modify in-place       
     |
     | Local class/function      | no                   | is_local flag         
     |
     
+---------------------------+----------------------+----------------------------+
@@ -87,8 +87,8 @@ class DeserializationPolicy:
 
     This DeserializationPolicy interface allows users to implement custom 
security policies by
     subclassing and overriding specific hook methods. Each hook is called at a 
critical
-    point during deserialization, allowing inspection, replacement, or 
rejection of
-    dangerous constructs.
+    point during deserialization, allowing validation hooks to inspect or 
reject
+    dangerous constructs and interceptor hooks to control protocol operations.
 
     Hook Categories
     ---------------
@@ -98,7 +98,7 @@ class DeserializationPolicy:
 
     2. **Reference Validation Hooks** (Validators)
        - Validate deserialized type/function/module references
-       - Return None to accept original, return object to replace, raise 
exception to block,
+       - Raise exception to block, otherwise return normally
 
     3. **Protocol Interception Hooks** (Interceptors)
        - Intercept pickle protocol operations (__reduce__, __setstate__)
@@ -109,17 +109,15 @@ class DeserializationPolicy:
     >>> class SafeDeserializationPolicy(DeserializationPolicy):
     ...     ALLOWED_MODULES = {'builtins', 'datetime', 'decimal'}
     ...
-    ...     def validate_module(self, module_name, **kwargs):
+    ...     def validate_module(self, module_name, is_local, **kwargs):
     ...         # Reject imports from disallowed modules
     ...         if module_name.split('.')[0] not in self.ALLOWED_MODULES:
     ...             raise ValueError(f"Module {module_name} is not allowed")
-    ...         return None  # Accept
     ...
     ...     def validate_class(self, cls, is_local, **kwargs):
     ...         # Reject dangerous built-in classes
     ...         if cls.__name__ in ('eval', 'exec', 'compile'):
     ...             raise ValueError(f"Class {cls} is forbidden")
-    ...         return None  # Accept
     ...
     ...     def intercept_reduce_call(self, callable_obj, args, **kwargs):
     ...         # Log all __reduce__ callables for audit
@@ -201,7 +199,7 @@ class DeserializationPolicy:
 
         This hook is called after a class reference has been deserialized 
(either by
         importing from a module or reconstructing a local class), but before 
it is used.
-        It allows inspection, replacement, or rejection of class references.
+        It allows inspection or rejection of class references.
 
         When Called
         -----------
@@ -212,9 +210,8 @@ class DeserializationPolicy:
         Security Use Cases
         ------------------
         - Block dangerous classes (subprocess.Popen, os.system, etc.)
-        - Replace untrusted classes with safe alternatives
         - Validate that local classes match expected signatures
-        - Implement class versioning or adaptation logic
+        - Audit class imports for security logging
 
         Args:
             cls (type): The deserialized class object.
@@ -223,29 +220,20 @@ class DeserializationPolicy:
                            class from an importable module.
             **kwargs: Reserved for future extensions.
 
-        Returns:
-            None: Return None to accept the class as-is.
-            type: Return a different class to replace the original. The 
replacement
-                 class will be used instead for deserialization.
-
         Raises:
             Exception: Raise any exception to reject the class and abort 
deserialization.
 
         Example:
-            >>> class ClassAdapter(DeserializationPolicy):
+            >>> class ClassChecker(DeserializationPolicy):
             ...     def validate_class(self, cls, is_local, **kwargs):
-            ...         # Map a serialized class name to the current class.
-            ...         if cls.__name__ == 'ArchivedUserClass':
-            ...             return NewUserClass
             ...         # Block dangerous classes
             ...         if cls.__module__ == 'subprocess':
             ...             raise ValueError("subprocess classes not allowed")
-            ...         return None  # Accept
 
         Note:
             `check_class` is an alias for this hook.
         """
-        pass
+        return None
 
     def validate_function(self, func, is_local: bool, **kwargs):
         """Validate a deserialized function reference.
@@ -263,7 +251,6 @@ class DeserializationPolicy:
         ------------------
         - Block dangerous built-in functions (eval, exec, compile, __import__)
         - Validate that reconstructed functions have expected signatures
-        - Replace untrusted functions with safe alternatives
         - Audit function imports for security logging
 
         Args:
@@ -272,10 +259,6 @@ class DeserializationPolicy:
                            within a function scope), False if it's a global 
function.
             **kwargs: Reserved for future extensions.
 
-        Returns:
-            None: Return None to accept the function as-is.
-            function: Return a different function to replace the original.
-
         Raises:
             Exception: Raise any exception to reject the function.
 
@@ -286,12 +269,11 @@ class DeserializationPolicy:
             ...     def validate_function(self, func, is_local, **kwargs):
             ...         if func.__name__ in self.BLOCKED:
             ...             raise ValueError(f"Function {func.__name__} is 
forbidden")
-            ...         return None
 
         Note:
             `check_function` is an alias for this hook.
         """
-        pass
+        return None
 
     def validate_method(self, method, is_local: bool, **kwargs):
         """Validate a deserialized method reference.
@@ -309,17 +291,13 @@ class DeserializationPolicy:
         ------------------
         - Validate that methods belong to expected classes
         - Block methods that could perform dangerous operations
-        - Replace methods with safer alternatives
+        - Audit method references for security logging
 
         Args:
             method (method): The deserialized bound method object.
             is_local (bool): True if the method's class is local, False if 
global.
             **kwargs: Reserved for future extensions.
 
-        Returns:
-            None: Return None to accept the method as-is.
-            method: Return a different method to replace the original.
-
         Raises:
             Exception: Raise any exception to reject the method.
 
@@ -329,39 +307,35 @@ class DeserializationPolicy:
             ...         # Block methods from dangerous classes
             ...         if method.__self__.__class__.__name__ == 'FileRemover':
             ...             raise ValueError("FileRemover methods not allowed")
-            ...         return None
 
         Note:
             `check_method` is an alias for this hook.
         """
-        pass
+        return None
 
-    def validate_module(self, module_name: str, **kwargs):
+    def validate_module(self, module_name: str, *, is_local: bool, **kwargs):
         """Validate a deserialized module reference.
 
-        This hook is called after a module has been imported during 
deserialization,
-        but before it is used.
+        This hook is called before a module is imported during deserialization.
 
         When Called
         -----------
-        - After importing modules via importlib.import_module()
-        - Before the module is stored or its contents accessed
+        - Before importing modules via importlib.import_module()
+        - Before the module is stored or its contents are accessed
 
         Security Use Cases
         ------------------
         - Whitelist/blacklist modules by name or prefix
         - Prevent imports of system modules (os, subprocess, sys, etc.)
-        - Replace modules with safe alternatives or mocks
         - Audit module imports for security logging
 
         Args:
-            module_name (str): The name of the imported module (e.g., 
'os.path').
+            module_name (str): The name of the module to import (e.g., 
'os.path').
+            is_local (bool): True if the reference being resolved is local 
(defined
+                           in __main__ or within a function/method scope), 
False
+                           otherwise.
             **kwargs: Reserved for future extensions.
 
-        Returns:
-            None: Return None to accept the module as-is.
-            module: Return a different module object to replace the original.
-
         Raises:
             Exception: Raise any exception to reject the module import.
 
@@ -369,16 +343,15 @@ class DeserializationPolicy:
             >>> class ModuleWhitelistChecker(DeserializationPolicy):
             ...     ALLOWED = {'builtins', 'datetime', 'decimal', 
'collections'}
             ...
-            ...     def validate_module(self, module_name, **kwargs):
+            ...     def validate_module(self, module_name, is_local, **kwargs):
             ...         root = module_name.split('.')[0]
             ...         if root not in self.ALLOWED:
             ...             raise ValueError(f"Module {module_name} not 
whitelisted")
-            ...         return None
 
         Note:
             `check_module` is an alias for this hook.
         """
-        pass
+        return None
 
     # 
============================================================================
     # Protocol Interception Hooks (Interceptors)
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index 2736599b5..e2cdf7994 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -46,23 +46,18 @@ _WINDOWS = os.name == "nt"
 from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION
 
 
-def _import_validated_module(policy, module_name):
-    result = policy.validate_module(module_name)
-    if result is not None:
-        if isinstance(result, types.ModuleType):
-            return result
-        assert isinstance(result, str), f"validate_module must return module, 
str, or None, got {type(result)}"
-        module_name = result
+def _import_validated_module(policy, module_name, is_local=False):
+    policy.validate_module(module_name, is_local=is_local)
     return importlib.import_module(module_name)
 
 
-def _resolve_validated_module_attr(policy, module_name, attr_name):
-    module = _import_validated_module(policy, module_name)
+def _resolve_validated_module_attr(policy, module_name, attr_name, 
is_local=False):
+    module = _import_validated_module(policy, module_name, is_local=is_local)
     return getattr(module, attr_name)
 
 
 def _resolve_validated_module_qualname(policy, module_name, qualname):
-    obj = _import_validated_module(policy, module_name)
+    obj = _import_validated_module(policy, module_name, 
is_local=_is_local_qualname(module_name, qualname))
     for name in qualname.split("."):
         obj = getattr(obj, name)
     return obj
@@ -111,21 +106,15 @@ def _is_bound_method_value(obj):
 
 def _validate_function_value(policy, func, is_local):
     if isinstance(func, type):
-        result = policy.validate_class(func, is_local=is_local)
-        if result is not None:
-            func = result
+        policy.validate_class(func, is_local=is_local)
         if isinstance(func, type):
             raise TypeError(f"Function serializer resolved class 
{func.__module__}.{func.__qualname__}")
     if _is_bound_method_value(func):
-        result = policy.validate_method(func, is_local=is_local)
-        if result is not None:
-            func = result
+        policy.validate_method(func, is_local=is_local)
         return func
     if not callable(func):
         raise TypeError(f"Function serializer resolved non-callable object 
{func!r}")
-    result = policy.validate_function(func, is_local=is_local)
-    if result is not None:
-        func = result
+    policy.validate_function(func, is_local=is_local)
     return func
 
 
@@ -159,9 +148,7 @@ def _resolve_validated_bound_method(policy, obj, 
method_name, is_local):
     if policy is DEFAULT_POLICY:
         return getattr(obj, method_name)
     method = _bind_static_method(obj, method_name)
-    result = policy.validate_method(method, is_local=is_local)
-    if result is not None:
-        method = result
+    policy.validate_method(method, is_local=is_local)
     return method
 
 
@@ -1214,15 +1201,12 @@ class ReduceSerializer(Serializer):
         self._getnewargs = getattr(cls, "__getnewargs__", None)
 
     def _validate_global_object(self, policy, obj):
-        result = None
         if isinstance(obj, type):
-            result = policy.validate_class(obj, is_local=_is_local_class(obj))
+            policy.validate_class(obj, is_local=_is_local_class(obj))
         elif _is_bound_method_value(obj):
-            result = policy.validate_method(obj, 
is_local=_is_local_callable(obj))
+            policy.validate_method(obj, is_local=_is_local_callable(obj))
         elif isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
-            result = policy.validate_function(obj, 
is_local=_is_local_callable(obj))
-        if result is not None:
-            obj = result
+            policy.validate_function(obj, is_local=_is_local_callable(obj))
         return obj
 
     def _resolve_global_name(self, read_context, global_name):
@@ -1232,7 +1216,12 @@ class ReduceSerializer(Serializer):
         else:
             module_name, obj_name = "builtins", global_name
         try:
-            obj = _resolve_validated_module_attr(policy, module_name, obj_name)
+            obj = _resolve_validated_module_attr(
+                policy,
+                module_name,
+                obj_name,
+                is_local=_is_local_qualname(module_name, obj_name),
+            )
         except AttributeError:
             raise ValueError(f"Cannot resolve global name: {global_name}")
         return self._validate_global_object(policy, obj)
@@ -1370,9 +1359,7 @@ class TypeSerializer(Serializer):
         module_name = read_context.read_string()
         qualname = read_context.read_string()
         cls = _resolve_validated_module_qualname(read_context.policy, 
module_name, qualname)
-        result = read_context.policy.validate_class(cls, 
is_local=_is_local_class(cls))
-        if result is not None:
-            cls = result
+        read_context.policy.validate_class(cls, is_local=_is_local_class(cls))
         return cls
 
     def _serialize_local_class(self, write_context, cls):
@@ -1422,9 +1409,7 @@ class TypeSerializer(Serializer):
         read_context.policy.authorize_instantiation(type, module=module, 
qualname=qualname, bases=bases)
         cls = type(name, bases, {})
         read_context.set_read_ref(ref_id, cls)
-        result = read_context.policy.validate_class(cls, is_local=True)
-        if result is not None:
-            cls = result
+        read_context.policy.validate_class(cls, is_local=True)
 
         num_class_methods = read_context.read_var_uint32()
         _check_collection_size(read_context, num_class_methods, "local class 
method")
@@ -1440,9 +1425,7 @@ class TypeSerializer(Serializer):
         # Set module and qualname
         cls.__module__ = module
         cls.__qualname__ = qualname
-        result = read_context.policy.validate_class(cls, is_local=True)
-        if result is not None:
-            cls = result
+        read_context.policy.validate_class(cls, is_local=True)
         return cls
 
 
@@ -1457,7 +1440,7 @@ class ModuleSerializer(Serializer):
 
     def read(self, read_context):
         mod_name = read_context.read_string()
-        return _import_validated_module(read_context.policy, mod_name)
+        return _import_validated_module(read_context.policy, mod_name, 
is_local=_is_local_qualname(mod_name, ""))
 
 
 class MappingProxySerializer(Serializer):
@@ -1617,7 +1600,7 @@ class FunctionSerializer(Serializer):
 
         module = read_context.read_string()
         qualname = read_context.read_string()
-        mod = _import_validated_module(read_context.policy, module)
+        mod = _import_validated_module(read_context.policy, module, 
is_local=_is_local_qualname(module, qualname))
         name = qualname.rsplit(".")[-1]
 
         marshalled_code = read_context.read_bytes_and_size()
@@ -1699,7 +1682,12 @@ class NativeFuncMethodSerializer(Serializer):
         name = read_context.read_string()
         if read_context.read_bool():
             module = read_context.read_string()
-            func = _resolve_validated_module_attr(read_context.policy, module, 
name)
+            func = _resolve_validated_module_attr(
+                read_context.policy,
+                module,
+                name,
+                is_local=_is_local_qualname(module, name),
+            )
             func = _validate_function_value(read_context.policy, func, 
is_local=_is_local_callable(func))
         else:
             obj = read_context.read_ref()
diff --git a/python/pyfory/tests/test_policy.py 
b/python/pyfory/tests/test_policy.py
index 9eb7d55f0..dc0737369 100644
--- a/python/pyfory/tests/test_policy.py
+++ b/python/pyfory/tests/test_policy.py
@@ -30,6 +30,10 @@ def policy_global_function():
     return "safe"
 
 
+def policy_replacement_function():
+    return "replacement"
+
+
 class PolicyMethodHolder:
     def run(self):
         return "safe"
@@ -491,26 +495,26 @@ def test_validate_module():
     import json
     import collections
 
-    # Test 1: Return module object directly
     class ReturnModulePolicy(DeserializationPolicy):
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
+            assert not is_local
             return collections
 
     fory1 = Fory(xlang=False, ref=True, strict=False, 
policy=ReturnModulePolicy())
     data = fory1.serialize(json)
-    assert fory1.deserialize(data) is collections
+    assert fory1.deserialize(data) is json
 
-    # Test 2: Return string to redirect import
     class RedirectPolicy(DeserializationPolicy):
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
+            assert not is_local
             return "collections" if module_name == "json" else None
 
     fory2 = Fory(xlang=False, ref=True, strict=False, policy=RedirectPolicy())
-    assert fory2.deserialize(fory2.serialize(json)).__name__ == "collections"
+    assert fory2.deserialize(fory2.serialize(json)).__name__ == "json"
 
-    # Test 3: Raise to block module
     class BlockPolicy(DeserializationPolicy):
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
+            assert not is_local
             raise ValueError(f"Module {module_name} blocked")
 
     fory3 = Fory(xlang=False, ref=True, strict=False, policy=BlockPolicy())
@@ -518,6 +522,63 @@ def test_validate_module():
         fory3.deserialize(fory3.serialize(json))
 
 
+def test_validator_returns_ignored():
+    import json
+    import collections
+
+    class ReplacementClass:
+        pass
+
+    class ReturnPolicy(DeserializationPolicy):
+        def validate_module(self, module_name, is_local, **kwargs):
+            assert not is_local
+            return collections
+
+        def validate_class(self, cls, is_local, **kwargs):
+            return ReplacementClass
+
+        def validate_function(self, func, is_local, **kwargs):
+            return policy_replacement_function
+
+        def validate_method(self, method, is_local, **kwargs):
+            return policy_replacement_function
+
+    policy = ReturnPolicy()
+    fory = Fory(xlang=False, ref=True, strict=False, policy=policy)
+    assert fory.deserialize(fory.serialize(json)) is json
+    assert fory.deserialize(fory.serialize(PolicyGlobalClass)) is 
PolicyGlobalClass
+    assert fory.deserialize(fory.serialize(policy_global_function)) is 
policy_global_function
+
+    serializer = FunctionSerializer(fory.type_resolver, 
type(policy_global_function))
+    read_context = FakeReadContext(policy, [1, __name__, 
"policy_global_bound_method"])
+    assert serializer._deserialize_function(read_context) is 
policy_global_bound_method
+
+
+def test_local_class_return_ignored():
+    class SafeClass:
+        @classmethod
+        def run(cls):
+            return "safe"
+
+    def make_payload_class():
+        class PayloadClass:
+            @classmethod
+            def run(cls):
+                return "payload"
+
+        return PayloadClass
+
+    class ReturnClassPolicy(DeserializationPolicy):
+        def validate_class(self, cls, is_local, **kwargs):
+            return SafeClass if is_local else None
+
+    fory = Fory(xlang=False, ref=True, strict=False, 
policy=ReturnClassPolicy())
+    decoded = fory.deserialize(fory.serialize(make_payload_class()))
+    assert decoded is not SafeClass
+    assert decoded.run() == "payload"
+    assert SafeClass.run() == "safe"
+
+
 def test_type_deserialization_validates_module():
     """Test validate_module policy hook for global class deserialization."""
     import subprocess
@@ -525,9 +586,11 @@ def test_type_deserialization_validates_module():
     class BlockModulePolicy(DeserializationPolicy):
         def __init__(self):
             self.validate_module_calls = 0
+            self.is_local_values = []
 
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
             self.validate_module_calls += 1
+            self.is_local_values.append(is_local)
             if module_name == "subprocess":
                 raise ValueError("subprocess blocked")
             return None
@@ -537,6 +600,7 @@ def test_type_deserialization_validates_module():
     with pytest.raises(ValueError, match="subprocess blocked"):
         fory.deserialize(fory.serialize(subprocess.Popen))
     assert policy.validate_module_calls == 1
+    assert policy.is_local_values == [False]
 
 
 def test_native_bound_method_uses_validate_method():
@@ -818,9 +882,11 @@ def 
test_global_function_deserialization_validates_module():
     class BlockModulePolicy(DeserializationPolicy):
         def __init__(self):
             self.validate_module_calls = 0
+            self.is_local_values = []
 
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
             self.validate_module_calls += 1
+            self.is_local_values.append(is_local)
             if module_name == policy_global_function.__module__:
                 raise ValueError("function module blocked")
             return None
@@ -830,6 +896,7 @@ def test_global_function_deserialization_validates_module():
     with pytest.raises(ValueError, match="function module blocked"):
         fory.deserialize(fory.serialize(policy_global_function))
     assert policy.validate_module_calls == 1
+    assert policy.is_local_values == [False]
 
 
 def test_local_function_deserialization_validates_module():
@@ -841,9 +908,11 @@ def test_local_function_deserialization_validates_module():
     class BlockModulePolicy(DeserializationPolicy):
         def __init__(self):
             self.validate_module_calls = 0
+            self.is_local_values = []
 
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
             self.validate_module_calls += 1
+            self.is_local_values.append(is_local)
             if module_name == local_function.__module__:
                 raise ValueError("local function module blocked")
             return None
@@ -853,6 +922,7 @@ def test_local_function_deserialization_validates_module():
     with pytest.raises(ValueError, match="local function module blocked"):
         fory.deserialize(fory.serialize(local_function))
     assert policy.validate_module_calls == 1
+    assert policy.is_local_values == [True]
 
 
 def test_native_function_deserialization_validates_module():
@@ -862,9 +932,11 @@ def 
test_native_function_deserialization_validates_module():
     class BlockModulePolicy(DeserializationPolicy):
         def __init__(self):
             self.validate_module_calls = 0
+            self.is_local_values = []
 
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
             self.validate_module_calls += 1
+            self.is_local_values.append(is_local)
             if module_name == "time":
                 raise ValueError("time blocked")
             return None
@@ -874,6 +946,7 @@ def test_native_function_deserialization_validates_module():
     with pytest.raises(ValueError, match="time blocked"):
         fory.deserialize(fory.serialize(time.time))
     assert policy.validate_module_calls == 1
+    assert policy.is_local_values == [False]
 
 
 def test_type_metadata_load_validates_module():
@@ -882,9 +955,11 @@ def test_type_metadata_load_validates_module():
     class BlockModulePolicy(DeserializationPolicy):
         def __init__(self):
             self.validate_module_calls = 0
+            self.is_local_values = []
 
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
             self.validate_module_calls += 1
+            self.is_local_values.append(is_local)
             if module_name == "subprocess":
                 raise ValueError("subprocess blocked")
             return None
@@ -902,6 +977,7 @@ def test_type_metadata_load_validates_module():
     with pytest.raises(ValueError, match="subprocess blocked"):
         resolver._load_metabytes_to_type_info(ns_metabytes, type_metabytes)
     assert policy.validate_module_calls == 1
+    assert policy.is_local_values == [False]
 
 
 def test_type_metadata_load_validates_class():
@@ -942,9 +1018,11 @@ def test_reduce_global_name_validates_module():
     class BlockModulePolicy(DeserializationPolicy):
         def __init__(self):
             self.validate_module_calls = 0
+            self.is_local_values = []
 
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
             self.validate_module_calls += 1
+            self.is_local_values.append(is_local)
             if module_name == "subprocess":
                 raise ValueError(f"Module {module_name} blocked")
             return None
@@ -954,6 +1032,7 @@ def test_reduce_global_name_validates_module():
     with pytest.raises(ValueError, match="subprocess blocked"):
         fory.deserialize(fory.serialize(GlobalNamePayload()))
     assert policy.validate_module_calls == 1
+    assert policy.is_local_values == [False]
 
 
 def test_reduce_global_name_validates_class():
@@ -968,8 +1047,9 @@ def test_reduce_global_name_validates_class():
             self.validate_module_calls = 0
             self.validate_class_calls = 0
 
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
             self.validate_module_calls += 1
+            assert not is_local
             return None
 
         def validate_class(self, cls, is_local, **kwargs):
@@ -998,8 +1078,9 @@ def test_reduce_global_name_validates_function():
             self.validate_module_calls = 0
             self.validate_function_calls = 0
 
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
             self.validate_module_calls += 1
+            assert not is_local
             return None
 
         def validate_function(self, func, is_local, **kwargs):
@@ -1029,8 +1110,9 @@ def 
test_reduce_global_method_resolution_uses_validate_method():
             self.validate_method_calls = 0
             self.validate_function_calls = 0
 
-        def validate_module(self, module_name, **kwargs):
+        def validate_module(self, module_name, is_local, **kwargs):
             self.validate_module_calls += 1
+            assert not is_local
             return None
 
         def validate_method(self, method, is_local, **kwargs):
diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py
index d0e8b02a9..ca1a750dc 100644
--- a/python/pyfory/type_util.py
+++ b/python/pyfory/type_util.py
@@ -18,7 +18,6 @@
 import dataclasses
 import importlib
 import inspect
-import types
 
 import typing
 from typing import TypeVar
@@ -367,34 +366,20 @@ def qualified_class_name(cls):
 
 def load_class(classname: str, policy=None):
     mod_name, cls_name = classname.rsplit("#", 1)
+    is_local = mod_name == "__main__" or "<locals>" in cls_name
     if policy is not None:
-        result = policy.validate_module(mod_name)
-        if result is not None:
-            if isinstance(result, str):
-                mod_name = result
-                mod = None
-            else:
-                assert isinstance(result, types.ModuleType), f"validate_module 
must return module, str, or None, got {type(result)}"
-                mod = result
-        else:
-            mod = None
-    else:
-        mod = None
-    if mod is None:
-        try:
-            mod = importlib.import_module(mod_name)
-        except ImportError as ex:
-            raise Exception(f"Can't import module {mod_name}") from ex
+        policy.validate_module(mod_name, is_local=is_local)
+    try:
+        mod = importlib.import_module(mod_name)
+    except ImportError as ex:
+        raise Exception(f"Can't import module {mod_name}") from ex
     try:
         classes = cls_name.split(".")
         cls = getattr(mod, classes.pop(0))
         while classes:
             cls = getattr(cls, classes.pop(0))
         if policy is not None:
-            is_local = cls.__module__ == "__main__" or "<locals>" in 
cls.__qualname__
-            result = policy.validate_class(cls, is_local=is_local)
-            if result is not None:
-                cls = result
+            policy.validate_class(cls, is_local=is_local)
         return cls
     except AttributeError as ex:
         raise Exception(f"Can't import class {cls_name} from module 
{mod_name}") from ex


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to