This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 880ecc47e feat(python): support class methods serialization (#2670)
880ecc47e is described below
commit 880ecc47e5045632ea6a568fd9ea1d9b78f0ae38
Author: Shawn Yang <[email protected]>
AuthorDate: Fri Sep 26 23:59:11 2025 +0800
feat(python): support class methods serialization (#2670)
## Why?
<!-- Describe the purpose of this PR. -->
## What does this PR do?
1. support class methods serialization for global/local classes
2. support static class methods serialization for global/local classes
3. support bound method serialization for global/local classes
## Related issues
<!--
Is there any related issue? If this PR closes them you say say
fix/closes:
- #xxxx0
- #xxxx1
- Fixes #xxxx2
-->
## Does this PR introduce any user-facing change?
<!--
If any user-facing interface changes, please [open an
issue](https://github.com/apache/fory/issues/new/choose) describing the
need to do so and update the document if necessary.
Delete section if not applicable.
-->
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
<!--
When the PR has an impact on performance (if you don't know whether the
PR will have an impact on performance, you can submit the PR first, and
if it will have impact on performance, the code reviewer will explain
it), be sure to attach a benchmark data here.
Delete section if not applicable.
-->
---
python/pyfory/serializer.py | 48 +++--
python/pyfory/tests/test_method.py | 398 +++++++++++++++++++++++++++++++++++++
2 files changed, 432 insertions(+), 14 deletions(-)
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index 26dc92d1b..984d73e47 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -183,46 +183,66 @@ class TypeSerializer(Serializer):
qualname = cls.__qualname__
buffer.write_string(module)
buffer.write_string(qualname)
+ fory = self.fory
# Serialize base classes
# Let Fory's normal serialization handle bases (including other local
classes)
bases = cls.__bases__
buffer.write_varuint32(len(bases))
for base in bases:
- self.fory.serialize_ref(buffer, base)
+ fory.serialize_ref(buffer, base)
# Serialize class dictionary (excluding special attributes)
# FunctionSerializer will automatically handle methods with closures
class_dict = {}
+ attr_names, class_methods = [], []
for attr_name, attr_value in cls.__dict__.items():
# Skip special attributes that are handled by type() constructor
if attr_name in __skip_class_attr_names__:
continue
- class_dict[attr_name] = attr_value
+ if isinstance(attr_value, classmethod):
+ attr_names.append(attr_name)
+ class_methods.append(attr_value)
+ else:
+ class_dict[attr_name] = attr_value
+ # serialize method specially to avoid circular deps in method
deserialization
+ buffer.write_varuint32(len(class_methods))
+ for i in range(len(class_methods)):
+ buffer.write_string(attr_names[i])
+ class_method = class_methods[i]
+ fory.serialize_ref(buffer, class_method.__func__)
# Let Fory's normal serialization handle the class dict
# This will use FunctionSerializer for methods, which handles closures
properly
- self.fory.serialize_ref(buffer, class_dict)
+ fory.serialize_ref(buffer, class_dict)
def _deserialize_local_class(self, buffer):
"""Deserialize a local class by recreating it with the captured
context."""
- assert self.fory.ref_tracking, "Reference tracking must be enabled for
local classes deserialization"
+ fory = self.fory
+ assert fory.ref_tracking, "Reference tracking must be enabled for
local classes deserialization"
# Read basic class information
module = buffer.read_string()
qualname = buffer.read_string()
name = qualname.rsplit(".", 1)[-1]
- ref_id = self.fory.ref_resolver.last_preserved_ref_id()
+ ref_id = fory.ref_resolver.last_preserved_ref_id()
# Read base classes
num_bases = buffer.read_varuint32()
- bases = tuple([self.fory.deserialize_ref(buffer) for _ in
range(num_bases)])
+ bases = tuple([fory.deserialize_ref(buffer) for _ in range(num_bases)])
# Create the class using type() constructor
cls = type(name, bases, {})
# `class_dict` may reference to `cls`, which is a circular reference
- self.fory.ref_resolver.set_read_object(ref_id, cls)
+ fory.ref_resolver.set_read_object(ref_id, cls)
+
+ # classmethods
+ for i in range(buffer.read_varuint32()):
+ attr_name = buffer.read_string()
+ func = fory.deserialize_ref(buffer)
+ method = types.MethodType(func, cls)
+ setattr(cls, attr_name, method)
# Read class dictionary
# Fory's normal deserialization will handle methods via
FunctionSerializer
- class_dict = self.fory.deserialize_ref(buffer)
+ class_dict = fory.deserialize_ref(buffer)
for k, v in class_dict.items():
setattr(cls, k, v)
@@ -1115,19 +1135,17 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
# Regular function or lambda
code = func.__code__
- name = func.__name__
module = func.__module__
qualname = func.__qualname__
if "<locals>" not in qualname and module != "__main__":
buffer.write_int8(1) # Not a method
- buffer.write_string(name)
buffer.write_string(module)
+ buffer.write_string(qualname)
return
# Serialize function metadata
buffer.write_int8(2) # Not a method
- buffer.write_string(name)
buffer.write_string(module)
buffer.write_string(qualname)
@@ -1212,15 +1230,17 @@ class
FunctionSerializer(CrossLanguageCompatibleSerializer):
return getattr(self_obj, method_name)
if func_type_id == 1:
- name = buffer.read_string()
module = buffer.read_string()
+ qualname = buffer.read_string()
mod = importlib.import_module(module)
- return getattr(mod, name)
+ for name in qualname.split("."):
+ mod = getattr(mod, name)
+ return mod
# Regular function or lambda
- name = buffer.read_string()
module = buffer.read_string()
qualname = buffer.read_string()
+ name = qualname.rsplit(".")[-1]
# Use marshal to load the code object, which handles all Python
versions correctly
marshalled_code = buffer.read_bytes_and_size()
diff --git a/python/pyfory/tests/test_method.py
b/python/pyfory/tests/test_method.py
new file mode 100644
index 000000000..a1da596ed
--- /dev/null
+++ b/python/pyfory/tests/test_method.py
@@ -0,0 +1,398 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyfory
+
+
+# Global classes for testing global class method serialization
+class GlobalTestClass:
+ """Global test class for method serialization."""
+
+ class_variable = "global_class_value"
+
+ def __init__(self, value):
+ self.instance_value = value
+
+ def instance_method(self):
+ """Instance method for testing."""
+ return f"instance_{self.instance_value}"
+
+ @classmethod
+ def class_method(cls):
+ """Class method for testing."""
+ return f"class_{cls.class_variable}"
+
+ @classmethod
+ def class_method_with_args(cls, arg1, arg2):
+ """Class method with arguments for testing."""
+ return f"class_{cls.class_variable}_{arg1}_{arg2}"
+
+ @staticmethod
+ def static_method():
+ """Static method for testing."""
+ return "static_global_result"
+
+ @staticmethod
+ def static_method_with_args(arg1, arg2):
+ """Static method with arguments for testing."""
+ return f"static_{arg1}_{arg2}"
+
+
+class AnotherGlobalClass:
+ """Another global class to test cross-class method serialization."""
+
+ @classmethod
+ def another_class_method(cls):
+ return f"another_{cls.__name__}"
+
+
+class GlobalClassWithInheritance(GlobalTestClass):
+ """Global class with inheritance."""
+
+ class_variable = "inherited_value"
+
+ @classmethod
+ def inherited_class_method(cls):
+ return f"inherited_{cls.class_variable}"
+
+
+class TestMethodSerialization:
+ """Test class for method serialization scenarios."""
+
+ def test_instance_method_serialization(self):
+ """Test serialization of instance methods."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ class TestClass:
+ def __init__(self, value):
+ self.value = value
+
+ def instance_method(self):
+ return self.value * 2
+
+ obj = TestClass(5)
+ method = obj.instance_method
+
+ # Test serialization/deserialization
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert method() == deserialized()
+ assert method() == 10
+
+ def test_classmethod_serialization(self):
+ """Test serialization of class methods."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ class TestClass:
+ class_var = 42
+
+ @classmethod
+ def class_method(cls):
+ return cls.class_var
+
+ method = TestClass.class_method
+
+ # Test serialization/deserialization
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert method() == deserialized()
+ assert method() == 42
+
+ def test_staticmethod_serialization(self):
+ """Test serialization of static methods."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ class TestClass:
+ @staticmethod
+ def static_method():
+ return "static_result"
+
+ method = TestClass.static_method
+
+ # Test serialization/deserialization
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert method() == deserialized()
+ assert method() == "static_result"
+
+ def test_method_with_args_serialization(self):
+ """Test serialization of methods with arguments."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ class TestClass:
+ def __init__(self, base):
+ self.base = base
+
+ def add(self, x):
+ return self.base + x
+
+ @classmethod
+ def multiply(cls, a, b):
+ return a * b
+
+ @staticmethod
+ def subtract(a, b):
+ return a - b
+
+ obj = TestClass(10)
+
+ # Test instance method
+ instance_method = obj.add
+ serialized = fory.serialize(instance_method)
+ deserialized = fory.deserialize(serialized)
+ assert instance_method(5) == deserialized(5)
+ assert instance_method(5) == 15
+
+ # Test classmethod
+ class_method = TestClass.multiply
+ serialized = fory.serialize(class_method)
+ deserialized = fory.deserialize(serialized)
+ assert class_method(3, 4) == deserialized(3, 4)
+ assert class_method(3, 4) == 12
+
+ # Test staticmethod
+ static_method = TestClass.subtract
+ serialized = fory.serialize(static_method)
+ deserialized = fory.deserialize(serialized)
+ assert static_method(10, 3) == deserialized(10, 3)
+ assert static_method(10, 3) == 7
+
+ def test_nested_class_method_serialization(self):
+ """Test serialization of methods from nested classes."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ class OuterClass:
+ class InnerClass:
+ @classmethod
+ def inner_class_method(cls):
+ return "inner_result"
+
+ method = OuterClass.InnerClass.inner_class_method
+
+ # Test serialization/deserialization
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert method() == deserialized()
+ assert method() == "inner_result"
+
+
+def test_classmethod_serialization():
+ """Standalone test for classmethod serialization - reproduces the original
error."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ class A:
+ @classmethod
+ def f(cls):
+ pass
+
+ @staticmethod
+ def g():
+ return A
+
+ method = A.f
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert isinstance(deserialized, type(method))
+ # Check that the class names are the same (the classes might be different
instances due to deserialization)
+ assert deserialized.__self__.__name__ == method.__self__.__name__
+ assert deserialized.__func__.__name__ == method.__func__.__name__
+
+ # Most importantly, check that the deserialized method is callable and has
the same behavior
+ # Both should return None for this test case
+ original_result = method()
+ deserialized_result = deserialized()
+ assert original_result == deserialized_result
+
+
+def test_staticmethod_serialization():
+ """Standalone test for staticmethod serialization."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ class A:
+ @staticmethod
+ def g():
+ return "static_result"
+
+ method = A.g
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert method() == deserialized()
+ assert method() == "static_result"
+
+
+# Global class method tests
+def test_global_classmethod_serialization():
+ """Test serialization of global class methods."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ method = GlobalTestClass.class_method
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert isinstance(deserialized, type(method))
+ assert deserialized() == method()
+ assert deserialized() == "class_global_class_value"
+
+
+def test_global_classmethod_with_args():
+ """Test serialization of global class methods with arguments."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ method = GlobalTestClass.class_method_with_args
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ args = ("arg1", "arg2")
+ assert deserialized(*args) == method(*args)
+ assert deserialized(*args) == "class_global_class_value_arg1_arg2"
+
+
+def test_global_staticmethod_serialization():
+ """Test serialization of global static methods."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ method = GlobalTestClass.static_method
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert deserialized() == method()
+ assert deserialized() == "static_global_result"
+
+
+def test_global_staticmethod_with_args():
+ """Test serialization of global static methods with arguments."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ method = GlobalTestClass.static_method_with_args
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ args = ("test1", "test2")
+ assert deserialized(*args) == method(*args)
+ assert deserialized(*args) == "static_test1_test2"
+
+
+def test_global_instance_method_serialization():
+ """Test serialization of global instance methods."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ obj = GlobalTestClass("test_value")
+ method = obj.instance_method
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert deserialized() == method()
+ assert deserialized() == "instance_test_value"
+
+
+def test_multiple_global_classes():
+ """Test serialization of methods from multiple global classes."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ # Test methods from different global classes
+ method1 = GlobalTestClass.class_method
+ method2 = AnotherGlobalClass.another_class_method
+
+ serialized1 = fory.serialize(method1)
+ serialized2 = fory.serialize(method2)
+
+ deserialized1 = fory.deserialize(serialized1)
+ deserialized2 = fory.deserialize(serialized2)
+
+ assert deserialized1() == method1()
+ assert deserialized2() == method2()
+ assert deserialized1() == "class_global_class_value"
+ assert deserialized2() == "another_AnotherGlobalClass"
+
+
+def test_global_class_inheritance():
+ """Test serialization of methods from global classes with inheritance."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ # Test inherited class method
+ method = GlobalClassWithInheritance.inherited_class_method
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert deserialized() == method()
+ assert deserialized() == "inherited_inherited_value"
+
+ # Test parent class method on child class
+ parent_method = GlobalClassWithInheritance.class_method
+ serialized_parent = fory.serialize(parent_method)
+ deserialized_parent = fory.deserialize(serialized_parent)
+
+ assert deserialized_parent() == parent_method()
+ assert deserialized_parent() == "class_inherited_value" # Uses child's
class_variable
+
+
+def test_global_methods_without_ref_tracking():
+ """Test serialization of global class methods without reference
tracking."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=False)
+
+ # Global classes should work even without ref_tracking
+ method = GlobalTestClass.class_method
+ serialized = fory.serialize(method)
+ deserialized = fory.deserialize(serialized)
+
+ assert deserialized() == method()
+ assert deserialized() == "class_global_class_value"
+
+
+def test_global_method_collection():
+ """Test serialization of collections containing global methods."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ methods = [GlobalTestClass.class_method, GlobalTestClass.static_method,
AnotherGlobalClass.another_class_method]
+
+ serialized = fory.serialize(methods)
+ deserialized = fory.deserialize(serialized)
+
+ assert len(deserialized) == len(methods)
+ for original, restored in zip(methods, deserialized):
+ assert original() == restored()
+
+
+def test_global_method_in_dict():
+ """Test serialization of dictionaries containing global methods."""
+ fory = pyfory.Fory(require_type_registration=False, ref_tracking=True)
+
+ method_dict = {
+ "class_method": GlobalTestClass.class_method,
+ "static_method": GlobalTestClass.static_method,
+ "another_method": AnotherGlobalClass.another_class_method,
+ }
+
+ serialized = fory.serialize(method_dict)
+ deserialized = fory.deserialize(serialized)
+
+ assert len(deserialized) == len(method_dict)
+ for key in method_dict:
+ assert method_dict[key]() == deserialized[key]()
+
+
+if __name__ == "__main__":
+ # Run tests
+ import pytest
+
+ pytest.main([__file__, "-v"])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]