https://github.com/python/cpython/commit/cb35c11d82efd2959bda0397abcc1719bf6bb0cb
commit: cb35c11d82efd2959bda0397abcc1719bf6bb0cb
branch: main
author: Eric Snow <ericsnowcurren...@gmail.com>
committer: ericsnowcurrently <ericsnowcurren...@gmail.com>
date: 2025-04-30T17:34:05-06:00
summary:

gh-132775: Add _PyPickle_GetXIData() (gh-133107)

There's some extra complexity due to making sure we we get things right when 
handling functions and classes defined in the __main__ module.  This is also 
reflected in the tests, including the addition of extra functions in 
test.support.import_helper.

files:
M Include/internal/pycore_crossinterp.h
M Lib/test/support/import_helper.py
M Lib/test/test_crossinterp.py
M Modules/_testinternalcapi.c
M Python/crossinterp.c

diff --git a/Include/internal/pycore_crossinterp.h 
b/Include/internal/pycore_crossinterp.h
index 4b7446a1f40ccf..4b4617fdbcb2ad 100644
--- a/Include/internal/pycore_crossinterp.h
+++ b/Include/internal/pycore_crossinterp.h
@@ -171,6 +171,13 @@ PyAPI_FUNC(_PyBytes_data_t *) _PyBytes_GetXIDataWrapped(
         xid_newobjfunc,
         _PyXIData_t *);
 
+// _PyObject_GetXIData() for pickle
+PyAPI_DATA(PyObject *) _PyPickle_LoadFromXIData(_PyXIData_t *);
+PyAPI_FUNC(int) _PyPickle_GetXIData(
+        PyThreadState *,
+        PyObject *,
+        _PyXIData_t *);
+
 // _PyObject_GetXIData() for marshal
 PyAPI_FUNC(PyObject *) _PyMarshal_ReadObjectFromXIData(_PyXIData_t *);
 PyAPI_FUNC(int) _PyMarshal_GetXIData(
diff --git a/Lib/test/support/import_helper.py 
b/Lib/test/support/import_helper.py
index 42cfe9cfa8cb72..edb734d294f287 100644
--- a/Lib/test/support/import_helper.py
+++ b/Lib/test/support/import_helper.py
@@ -1,6 +1,7 @@
 import contextlib
 import _imp
 import importlib
+import importlib.machinery
 import importlib.util
 import os
 import shutil
@@ -332,3 +333,110 @@ def ensure_lazy_imports(imported_module, 
modules_to_block):
     )
     from .script_helper import assert_python_ok
     assert_python_ok("-S", "-c", script)
+
+
+@contextlib.contextmanager
+def module_restored(name):
+    """A context manager that restores a module to the original state."""
+    missing = object()
+    orig = sys.modules.get(name, missing)
+    if orig is None:
+        mod = importlib.import_module(name)
+    else:
+        mod = type(sys)(name)
+        mod.__dict__.update(orig.__dict__)
+        sys.modules[name] = mod
+    try:
+        yield mod
+    finally:
+        if orig is missing:
+            sys.modules.pop(name, None)
+        else:
+            sys.modules[name] = orig
+
+
+def create_module(name, loader=None, *, ispkg=False):
+    """Return a new, empty module."""
+    spec = importlib.machinery.ModuleSpec(
+        name,
+        loader,
+        origin='<import_helper>',
+        is_package=ispkg,
+    )
+    return importlib.util.module_from_spec(spec)
+
+
+def _ensure_module(name, ispkg, addparent, clearnone):
+    try:
+        mod = orig = sys.modules[name]
+    except KeyError:
+        mod = orig = None
+        missing = True
+    else:
+        missing = False
+        if mod is not None:
+            # It was already imported.
+            return mod, orig, missing
+        # Otherwise, None means it was explicitly disabled.
+
+    assert name != '__main__'
+    if not missing:
+        assert orig is None, (name, sys.modules[name])
+        if not clearnone:
+            raise ModuleNotFoundError(name)
+        del sys.modules[name]
+    # Try normal import, then fall back to adding the module.
+    try:
+        mod = importlib.import_module(name)
+    except ModuleNotFoundError:
+        if addparent and not clearnone:
+            addparent = None
+        mod = _add_module(name, ispkg, addparent)
+    return mod, orig, missing
+
+
+def _add_module(spec, ispkg, addparent):
+    if isinstance(spec, str):
+        name = spec
+        mod = create_module(name, ispkg=ispkg)
+        spec = mod.__spec__
+    else:
+        name = spec.name
+        mod = importlib.util.module_from_spec(spec)
+    sys.modules[name] = mod
+    if addparent is not False and spec.parent:
+        _ensure_module(spec.parent, True, addparent, bool(addparent))
+    return mod
+
+
+def add_module(spec, *, parents=True):
+    """Return the module after creating it and adding it to sys.modules.
+
+    If parents is True then also create any missing parents.
+    """
+    return _add_module(spec, False, parents)
+
+
+def add_package(spec, *, parents=True):
+    """Return the module after creating it and adding it to sys.modules.
+
+    If parents is True then also create any missing parents.
+    """
+    return _add_module(spec, True, parents)
+
+
+def ensure_module_imported(name, *, clearnone=True):
+    """Return the corresponding module.
+
+    If it was already imported then return that.  Otherwise, try
+    importing it (optionally clear it first if None).  If that fails
+    then create a new empty module.
+
+    It can be helpful to combine this with ready_to_import() and/or
+    isolated_modules().
+    """
+    if sys.modules.get(name) is not None:
+        mod = sys.modules[name]
+    else:
+        mod, _, _ = _force_import(name, False, True, clearnone)
+    return mod
diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py
index 5ebb78b0ea9e3b..32d6fd4e94bf1b 100644
--- a/Lib/test/test_crossinterp.py
+++ b/Lib/test/test_crossinterp.py
@@ -1,3 +1,6 @@
+import contextlib
+import importlib
+import importlib.util
 import itertools
 import sys
 import types
@@ -9,7 +12,7 @@
 _interpreters = import_helper.import_module('_interpreters')
 from _interpreters import NotShareableError
 
-
+from test import _code_definitions as code_defs
 from test import _crossinterp_definitions as defs
 
 
@@ -21,6 +24,88 @@
                if (isinstance(o, type) and
                   n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
 
+DEFS = defs
+with open(code_defs.__file__) as infile:
+    _code_defs_text = infile.read()
+with open(DEFS.__file__) as infile:
+    _defs_text = infile.read()
+    _defs_text = _defs_text.replace('from ', '# from ')
+DEFS_TEXT = f"""
+#######################################
+# from {code_defs.__file__}
+
+{_code_defs_text}
+
+#######################################
+# from {defs.__file__}
+
+{_defs_text}
+"""
+del infile, _code_defs_text, _defs_text
+
+
+def load_defs(module=None):
+    """Return a new copy of the test._crossinterp_definitions module.
+
+    The module's __name__ matches the "module" arg, which is either
+    a str or a module.
+
+    If the "module" arg is a module then the just-loaded defs are also
+    copied into that module.
+
+    Note that the new module is not added to sys.modules.
+    """
+    if module is None:
+        modname = DEFS.__name__
+    elif isinstance(module, str):
+        modname = module
+        module = None
+    else:
+        modname = module.__name__
+    # Create the new module and populate it.
+    defs = import_helper.create_module(modname)
+    defs.__file__ = DEFS.__file__
+    exec(DEFS_TEXT, defs.__dict__)
+    # Copy the defs into the module arg, if any.
+    if module is not None:
+        for name, value in defs.__dict__.items():
+            if name.startswith('_'):
+                continue
+            assert not hasattr(module, name), (name, getattr(module, name))
+            setattr(module, name, value)
+    return defs
+
+
+@contextlib.contextmanager
+def using___main__():
+    """Make sure __main__ module exists (and clean up after)."""
+    modname = '__main__'
+    if modname not in sys.modules:
+        with import_helper.isolated_modules():
+            yield import_helper.add_module(modname)
+    else:
+        with import_helper.module_restored(modname) as mod:
+            yield mod
+
+
+@contextlib.contextmanager
+def temp_module(modname):
+    """Create the module and add to sys.modules, then remove it after."""
+    assert modname not in sys.modules, (modname,)
+    with import_helper.isolated_modules():
+        yield import_helper.add_module(modname)
+
+
+@contextlib.contextmanager
+def missing_defs_module(modname, *, prep=False):
+    assert modname not in sys.modules, (modname,)
+    if prep:
+        with import_helper.ready_to_import(modname, DEFS_TEXT):
+            yield modname
+    else:
+        with import_helper.isolated_modules():
+            yield modname
+
 
 class _GetXIDataTests(unittest.TestCase):
 
@@ -32,52 +117,49 @@ def get_xidata(self, obj, *, mode=None):
 
     def get_roundtrip(self, obj, *, mode=None):
         mode = self._resolve_mode(mode)
-        xid =_testinternalcapi.get_crossinterp_data(obj, mode)
+        return self._get_roundtrip(obj, mode)
+
+    def _get_roundtrip(self, obj, mode):
+        xid = _testinternalcapi.get_crossinterp_data(obj, mode)
         return _testinternalcapi.restore_crossinterp_data(xid)
 
-    def iter_roundtrip_values(self, values, *, mode=None):
+    def assert_roundtrip_identical(self, values, *, mode=None):
         mode = self._resolve_mode(mode)
         for obj in values:
             with self.subTest(obj):
-                xid = _testinternalcapi.get_crossinterp_data(obj, mode)
-                got = _testinternalcapi.restore_crossinterp_data(xid)
-                yield obj, got
-
-    def assert_roundtrip_identical(self, values, *, mode=None):
-        for obj, got in self.iter_roundtrip_values(values, mode=mode):
-            # XXX What about between interpreters?
-            self.assertIs(got, obj)
+                got = self._get_roundtrip(obj, mode)
+                self.assertIs(got, obj)
 
     def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None):
-        for obj, got in self.iter_roundtrip_values(values, mode=mode):
-            self.assertEqual(got, obj)
-            self.assertIs(type(got),
-                          type(obj) if expecttype is None else expecttype)
-
-#    def assert_roundtrip_equal_not_identical(self, values, *,
-#                                            mode=None, expecttype=None):
-#        mode = self._resolve_mode(mode)
-#        for obj in values:
-#            cls = type(obj)
-#            with self.subTest(obj):
-#                got = self._get_roundtrip(obj, mode)
-#                self.assertIsNot(got, obj)
-#                self.assertIs(type(got), type(obj))
-#                self.assertEqual(got, obj)
-#                self.assertIs(type(got),
-#                              cls if expecttype is None else expecttype)
-#
-#    def assert_roundtrip_not_equal(self, values, *, mode=None, 
expecttype=None):
-#        mode = self._resolve_mode(mode)
-#        for obj in values:
-#            cls = type(obj)
-#            with self.subTest(obj):
-#                got = self._get_roundtrip(obj, mode)
-#                self.assertIsNot(got, obj)
-#                self.assertIs(type(got), type(obj))
-#                self.assertNotEqual(got, obj)
-#                self.assertIs(type(got),
-#                              cls if expecttype is None else expecttype)
+        mode = self._resolve_mode(mode)
+        for obj in values:
+            with self.subTest(obj):
+                got = self._get_roundtrip(obj, mode)
+                self.assertEqual(got, obj)
+                self.assertIs(type(got),
+                              type(obj) if expecttype is None else expecttype)
+
+    def assert_roundtrip_equal_not_identical(self, values, *,
+                                             mode=None, expecttype=None):
+        mode = self._resolve_mode(mode)
+        for obj in values:
+            with self.subTest(obj):
+                got = self._get_roundtrip(obj, mode)
+                self.assertIsNot(got, obj)
+                self.assertIs(type(got),
+                              type(obj) if expecttype is None else expecttype)
+                self.assertEqual(got, obj)
+
+    def assert_roundtrip_not_equal(self, values, *,
+                                   mode=None, expecttype=None):
+        mode = self._resolve_mode(mode)
+        for obj in values:
+            with self.subTest(obj):
+                got = self._get_roundtrip(obj, mode)
+                self.assertIsNot(got, obj)
+                self.assertIs(type(got),
+                              type(obj) if expecttype is None else expecttype)
+                self.assertNotEqual(got, obj)
 
     def assert_not_shareable(self, values, exctype=None, *, mode=None):
         mode = self._resolve_mode(mode)
@@ -95,6 +177,363 @@ def _resolve_mode(self, mode):
         return mode
 
 
+class PickleTests(_GetXIDataTests):
+
+    MODE = 'pickle'
+
+    def test_shareable(self):
+        self.assert_roundtrip_equal([
+            # singletons
+            None,
+            True,
+            False,
+            # bytes
+            *(i.to_bytes(2, 'little', signed=True)
+              for i in range(-1, 258)),
+            # str
+            'hello world',
+            '你好世界',
+            '',
+            # int
+            sys.maxsize,
+            -sys.maxsize - 1,
+            *range(-1, 258),
+            # float
+            0.0,
+            1.1,
+            -1.0,
+            0.12345678,
+            -0.12345678,
+            # tuple
+            (),
+            (1,),
+            ("hello", "world", ),
+            (1, True, "hello"),
+            ((1,),),
+            ((1, 2), (3, 4)),
+            ((1, 2), (3, 4), (5, 6)),
+        ])
+        # not shareable using xidata
+        self.assert_roundtrip_equal([
+            # int
+            sys.maxsize + 1,
+            -sys.maxsize - 2,
+            2**1000,
+            # tuple
+            (0, 1.0, []),
+            (0, 1.0, {}),
+            (0, 1.0, ([],)),
+            (0, 1.0, ({},)),
+        ])
+
+    def test_list(self):
+        self.assert_roundtrip_equal_not_identical([
+            [],
+            [1, 2, 3],
+            [[1], (2,), {3: 4}],
+        ])
+
+    def test_dict(self):
+        self.assert_roundtrip_equal_not_identical([
+            {},
+            {1: 7, 2: 8, 3: 9},
+            {1: [1], 2: (2,), 3: {3: 4}},
+        ])
+
+    def test_set(self):
+        self.assert_roundtrip_equal_not_identical([
+            set(),
+            {1, 2, 3},
+            {frozenset({1}), (2,)},
+        ])
+
+    # classes
+
+    def assert_class_defs_same(self, defs):
+        # Unpickle relative to the unchanged original module.
+        self.assert_roundtrip_identical(defs.TOP_CLASSES)
+
+        instances = []
+        for cls, args in defs.TOP_CLASSES.items():
+            if cls in defs.CLASSES_WITHOUT_EQUALITY:
+                continue
+            instances.append(cls(*args))
+        self.assert_roundtrip_equal_not_identical(instances)
+
+        # these don't compare equal
+        instances = []
+        for cls, args in defs.TOP_CLASSES.items():
+            if cls not in defs.CLASSES_WITHOUT_EQUALITY:
+                continue
+            instances.append(cls(*args))
+        self.assert_roundtrip_not_equal(instances)
+
+    def assert_class_defs_other_pickle(self, defs, mod):
+        # Pickle relative to a different module than the original.
+        for cls in defs.TOP_CLASSES:
+            assert not hasattr(mod, cls.__name__), (cls, getattr(mod, 
cls.__name__))
+        self.assert_not_shareable(defs.TOP_CLASSES)
+
+        instances = []
+        for cls, args in defs.TOP_CLASSES.items():
+            instances.append(cls(*args))
+        self.assert_not_shareable(instances)
+
+    def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False):
+        # Unpickle relative to a different module than the original.
+        for cls in defs.TOP_CLASSES:
+            assert not hasattr(mod, cls.__name__), (cls, getattr(mod, 
cls.__name__))
+
+        instances = []
+        for cls, args in defs.TOP_CLASSES.items():
+            with self.subTest(cls):
+                setattr(mod, cls.__name__, cls)
+                xid = self.get_xidata(cls)
+                inst = cls(*args)
+                instxid = self.get_xidata(inst)
+                instances.append(
+                        (cls, xid, inst, instxid))
+
+        for cls, xid, inst, instxid in instances:
+            with self.subTest(cls):
+                delattr(mod, cls.__name__)
+                if fail:
+                    with self.assertRaises(NotShareableError):
+                        _testinternalcapi.restore_crossinterp_data(xid)
+                    continue
+                got = _testinternalcapi.restore_crossinterp_data(xid)
+                self.assertIsNot(got, cls)
+                self.assertNotEqual(got, cls)
+
+                gotcls = got
+                got = _testinternalcapi.restore_crossinterp_data(instxid)
+                self.assertIsNot(got, inst)
+                self.assertIs(type(got), gotcls)
+                if cls in defs.CLASSES_WITHOUT_EQUALITY:
+                    self.assertNotEqual(got, inst)
+                elif cls in defs.BUILTIN_SUBCLASSES:
+                    self.assertEqual(got, inst)
+                else:
+                    self.assertNotEqual(got, inst)
+
+    def assert_class_defs_not_shareable(self, defs):
+        self.assert_not_shareable(defs.TOP_CLASSES)
+
+        instances = []
+        for cls, args in defs.TOP_CLASSES.items():
+            instances.append(cls(*args))
+        self.assert_not_shareable(instances)
+
+    def test_user_class_normal(self):
+        self.assert_class_defs_same(defs)
+
+    def test_user_class_in___main__(self):
+        with using___main__() as mod:
+            defs = load_defs(mod)
+            self.assert_class_defs_same(defs)
+
+    def test_user_class_not_in___main___with_filename(self):
+        with using___main__() as mod:
+            defs = load_defs('__main__')
+            assert defs.__file__
+            mod.__file__ = defs.__file__
+            self.assert_class_defs_not_shareable(defs)
+
+    def test_user_class_not_in___main___without_filename(self):
+        with using___main__() as mod:
+            defs = load_defs('__main__')
+            defs.__file__ = None
+            mod.__file__ = None
+            self.assert_class_defs_not_shareable(defs)
+
+    def test_user_class_not_in___main___unpickle_with_filename(self):
+        with using___main__() as mod:
+            defs = load_defs('__main__')
+            assert defs.__file__
+            mod.__file__ = defs.__file__
+            self.assert_class_defs_other_unpickle(defs, mod)
+
+    def test_user_class_not_in___main___unpickle_without_filename(self):
+        with using___main__() as mod:
+            defs = load_defs('__main__')
+            defs.__file__ = None
+            mod.__file__ = None
+            self.assert_class_defs_other_unpickle(defs, mod, fail=True)
+
+    def test_user_class_in_module(self):
+        with temp_module('__spam__') as mod:
+            defs = load_defs(mod)
+            self.assert_class_defs_same(defs)
+
+    def test_user_class_not_in_module_with_filename(self):
+        with temp_module('__spam__') as mod:
+            defs = load_defs(mod.__name__)
+            assert defs.__file__
+            # For now, we only address this case for __main__.
+            self.assert_class_defs_not_shareable(defs)
+
+    def test_user_class_not_in_module_without_filename(self):
+        with temp_module('__spam__') as mod:
+            defs = load_defs(mod.__name__)
+            defs.__file__ = None
+            self.assert_class_defs_not_shareable(defs)
+
+    def test_user_class_module_missing_then_imported(self):
+        with missing_defs_module('__spam__', prep=True) as modname:
+            defs = load_defs(modname)
+            # For now, we only address this case for __main__.
+            self.assert_class_defs_not_shareable(defs)
+
+    def test_user_class_module_missing_not_available(self):
+        with missing_defs_module('__spam__') as modname:
+            defs = load_defs(modname)
+            self.assert_class_defs_not_shareable(defs)
+
+    def test_nested_class(self):
+        eggs = defs.EggsNested()
+        with self.assertRaises(NotShareableError):
+            self.get_roundtrip(eggs)
+
+    # functions
+
+    def assert_func_defs_same(self, defs):
+        # Unpickle relative to the unchanged original module.
+        self.assert_roundtrip_identical(defs.TOP_FUNCTIONS)
+
+    def assert_func_defs_other_pickle(self, defs, mod):
+        # Pickle relative to a different module than the original.
+        for func in defs.TOP_FUNCTIONS:
+            assert not hasattr(mod, func.__name__), (cls, getattr(mod, 
func.__name__))
+        self.assert_not_shareable(defs.TOP_FUNCTIONS)
+
+    def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False):
+        # Unpickle relative to a different module than the original.
+        for func in defs.TOP_FUNCTIONS:
+            assert not hasattr(mod, func.__name__), (cls, getattr(mod, 
func.__name__))
+
+        captured = []
+        for func in defs.TOP_FUNCTIONS:
+            with self.subTest(func):
+                setattr(mod, func.__name__, func)
+                xid = self.get_xidata(func)
+                captured.append(
+                        (func, xid))
+
+        for func, xid in captured:
+            with self.subTest(func):
+                delattr(mod, func.__name__)
+                if fail:
+                    with self.assertRaises(NotShareableError):
+                        _testinternalcapi.restore_crossinterp_data(xid)
+                    continue
+                got = _testinternalcapi.restore_crossinterp_data(xid)
+                self.assertIsNot(got, func)
+                self.assertNotEqual(got, func)
+
+    def assert_func_defs_not_shareable(self, defs):
+        self.assert_not_shareable(defs.TOP_FUNCTIONS)
+
+    def test_user_function_normal(self):
+#        self.assert_roundtrip_equal(defs.TOP_FUNCTIONS)
+        self.assert_func_defs_same(defs)
+
+    def test_user_func_in___main__(self):
+        with using___main__() as mod:
+            defs = load_defs(mod)
+            self.assert_func_defs_same(defs)
+
+    def test_user_func_not_in___main___with_filename(self):
+        with using___main__() as mod:
+            defs = load_defs('__main__')
+            assert defs.__file__
+            mod.__file__ = defs.__file__
+            self.assert_func_defs_not_shareable(defs)
+
+    def test_user_func_not_in___main___without_filename(self):
+        with using___main__() as mod:
+            defs = load_defs('__main__')
+            defs.__file__ = None
+            mod.__file__ = None
+            self.assert_func_defs_not_shareable(defs)
+
+    def test_user_func_not_in___main___unpickle_with_filename(self):
+        with using___main__() as mod:
+            defs = load_defs('__main__')
+            assert defs.__file__
+            mod.__file__ = defs.__file__
+            self.assert_func_defs_other_unpickle(defs, mod)
+
+    def test_user_func_not_in___main___unpickle_without_filename(self):
+        with using___main__() as mod:
+            defs = load_defs('__main__')
+            defs.__file__ = None
+            mod.__file__ = None
+            self.assert_func_defs_other_unpickle(defs, mod, fail=True)
+
+    def test_user_func_in_module(self):
+        with temp_module('__spam__') as mod:
+            defs = load_defs(mod)
+            self.assert_func_defs_same(defs)
+
+    def test_user_func_not_in_module_with_filename(self):
+        with temp_module('__spam__') as mod:
+            defs = load_defs(mod.__name__)
+            assert defs.__file__
+            # For now, we only address this case for __main__.
+            self.assert_func_defs_not_shareable(defs)
+
+    def test_user_func_not_in_module_without_filename(self):
+        with temp_module('__spam__') as mod:
+            defs = load_defs(mod.__name__)
+            defs.__file__ = None
+            self.assert_func_defs_not_shareable(defs)
+
+    def test_user_func_module_missing_then_imported(self):
+        with missing_defs_module('__spam__', prep=True) as modname:
+            defs = load_defs(modname)
+            # For now, we only address this case for __main__.
+            self.assert_func_defs_not_shareable(defs)
+
+    def test_user_func_module_missing_not_available(self):
+        with missing_defs_module('__spam__') as modname:
+            defs = load_defs(modname)
+            self.assert_func_defs_not_shareable(defs)
+
+    def test_nested_function(self):
+        self.assert_not_shareable(defs.NESTED_FUNCTIONS)
+
+    # exceptions
+
+    def test_user_exception_normal(self):
+        self.assert_roundtrip_not_equal([
+            defs.MimimalError('error!'),
+        ])
+        self.assert_roundtrip_equal_not_identical([
+            defs.RichError('error!', 42),
+        ])
+
+    def test_builtin_exception(self):
+        msg = 'error!'
+        try:
+            raise Exception
+        except Exception as exc:
+            caught = exc
+        special = {
+            BaseExceptionGroup: (msg, [caught]),
+            ExceptionGroup: (msg, [caught]),
+#            UnicodeError: (None, msg, None, None, None),
+            UnicodeEncodeError: ('utf-8', '', 1, 3, msg),
+            UnicodeDecodeError: ('utf-8', b'', 1, 3, msg),
+            UnicodeTranslateError: ('', 1, 3, msg),
+        }
+        exceptions = []
+        for cls in EXCEPTION_TYPES:
+            args = special.get(cls) or (msg,)
+            exceptions.append(cls(*args))
+
+        self.assert_roundtrip_not_equal(exceptions)
+
+
 class MarshalTests(_GetXIDataTests):
 
     MODE = 'marshal'
@@ -444,22 +883,12 @@ def test_module(self):
         ])
 
     def test_class(self):
-        self.assert_not_shareable([
-            defs.Spam,
-            defs.SpamOkay,
-            defs.SpamFull,
-            defs.SubSpamFull,
-            defs.SubTuple,
-            defs.EggsNested,
-        ])
-        self.assert_not_shareable([
-            defs.Spam(),
-            defs.SpamOkay(),
-            defs.SpamFull(1, 2, 3),
-            defs.SubSpamFull(1, 2, 3),
-            defs.SubTuple([1, 2, 3]),
-            defs.EggsNested(),
-        ])
+        self.assert_not_shareable(defs.CLASSES)
+
+        instances = []
+        for cls, args in defs.CLASSES.items():
+            instances.append(cls(*args))
+        self.assert_not_shareable(instances)
 
     def test_builtin_type(self):
         self.assert_not_shareable([
diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c
index 4bfe88f2cf920c..812737e294fcb7 100644
--- a/Modules/_testinternalcapi.c
+++ b/Modules/_testinternalcapi.c
@@ -1939,6 +1939,11 @@ get_crossinterp_data(PyObject *self, PyObject *args, 
PyObject *kwargs)
             goto error;
         }
     }
+    else if (strcmp(mode, "pickle") == 0) {
+        if (_PyPickle_GetXIData(tstate, obj, xidata) != 0) {
+            goto error;
+        }
+    }
     else if (strcmp(mode, "marshal") == 0) {
         if (_PyMarshal_GetXIData(tstate, obj, xidata) != 0) {
             goto error;
diff --git a/Python/crossinterp.c b/Python/crossinterp.c
index 753d784a503467..a9f9b78562917e 100644
--- a/Python/crossinterp.c
+++ b/Python/crossinterp.c
@@ -3,6 +3,7 @@
 
 #include "Python.h"
 #include "marshal.h"              // PyMarshal_WriteObjectToString()
+#include "osdefs.h"               // MAXPATHLEN
 #include "pycore_ceval.h"         // _Py_simple_func
 #include "pycore_crossinterp.h"   // _PyXIData_t
 #include "pycore_initconfig.h"    // _PyStatus_OK()
@@ -10,6 +11,155 @@
 #include "pycore_typeobject.h"    // _PyStaticType_InitBuiltin()
 
 
+static Py_ssize_t
+_Py_GetMainfile(char *buffer, size_t maxlen)
+{
+    // We don't expect subinterpreters to have the __main__ module's
+    // __name__ set, but proceed just in case.
+    PyThreadState *tstate = _PyThreadState_GET();
+    PyObject *module = _Py_GetMainModule(tstate);
+    if (_Py_CheckMainModule(module) < 0) {
+        return -1;
+    }
+    Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen);
+    Py_DECREF(module);
+    return size;
+}
+
+
+static PyObject *
+import_get_module(PyThreadState *tstate, const char *modname)
+{
+    PyObject *module = NULL;
+    if (strcmp(modname, "__main__") == 0) {
+        module = _Py_GetMainModule(tstate);
+        if (_Py_CheckMainModule(module) < 0) {
+            assert(_PyErr_Occurred(tstate));
+            return NULL;
+        }
+    }
+    else {
+        module = PyImport_ImportModule(modname);
+        if (module == NULL) {
+            return NULL;
+        }
+    }
+    return module;
+}
+
+
+static PyObject *
+runpy_run_path(const char *filename, const char *modname)
+{
+    PyObject *run_path = PyImport_ImportModuleAttrString("runpy", "run_path");
+    if (run_path == NULL) {
+        return NULL;
+    }
+    PyObject *args = Py_BuildValue("(sOs)", filename, Py_None, modname);
+    if (args == NULL) {
+        Py_DECREF(run_path);
+        return NULL;
+    }
+    PyObject *ns = PyObject_Call(run_path, args, NULL);
+    Py_DECREF(run_path);
+    Py_DECREF(args);
+    return ns;
+}
+
+
+static PyObject *
+pyerr_get_message(PyObject *exc)
+{
+    assert(!PyErr_Occurred());
+    PyObject *args = PyException_GetArgs(exc);
+    if (args == NULL || args == Py_None || PyObject_Size(args) < 1) {
+        return NULL;
+    }
+    if (PyUnicode_Check(args)) {
+        return args;
+    }
+    PyObject *msg = PySequence_GetItem(args, 0);
+    Py_DECREF(args);
+    if (msg == NULL) {
+        PyErr_Clear();
+        return NULL;
+    }
+    if (!PyUnicode_Check(msg)) {
+        Py_DECREF(msg);
+        return NULL;
+    }
+    return msg;
+}
+
+#define MAX_MODNAME (255)
+#define MAX_ATTRNAME (255)
+
+struct attributeerror_info {
+    char modname[MAX_MODNAME+1];
+    char attrname[MAX_ATTRNAME+1];
+};
+
+static int
+_parse_attributeerror(PyObject *exc, struct attributeerror_info *info)
+{
+    assert(exc != NULL);
+    assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
+    int res = -1;
+
+    PyObject *msgobj = pyerr_get_message(exc);
+    if (msgobj == NULL) {
+        return -1;
+    }
+    const char *err = PyUnicode_AsUTF8(msgobj);
+
+    if (strncmp(err, "module '", 8) != 0) {
+        goto finally;
+    }
+    err += 8;
+
+    const char *matched = strchr(err, '\'');
+    if (matched == NULL) {
+        goto finally;
+    }
+    Py_ssize_t len = matched - err;
+    if (len > MAX_MODNAME) {
+        goto finally;
+    }
+    (void)strncpy(info->modname, err, len);
+    info->modname[len] = '\0';
+    err = matched;
+
+    if (strncmp(err, "' has no attribute '", 20) != 0) {
+        goto finally;
+    }
+    err += 20;
+
+    matched = strchr(err, '\'');
+    if (matched == NULL) {
+        goto finally;
+    }
+    len = matched - err;
+    if (len > MAX_ATTRNAME) {
+        goto finally;
+    }
+    (void)strncpy(info->attrname, err, len);
+    info->attrname[len] = '\0';
+    err = matched + 1;
+
+    if (strlen(err) > 0) {
+        goto finally;
+    }
+    res = 0;
+
+finally:
+    Py_DECREF(msgobj);
+    return res;
+}
+
+#undef MAX_MODNAME
+#undef MAX_ATTRNAME
+
+
 /**************/
 /* exceptions */
 /**************/
@@ -287,6 +437,308 @@ _PyObject_GetXIData(PyThreadState *tstate,
 }
 
 
+/* pickle C-API */
+
+struct _pickle_context {
+    PyThreadState *tstate;
+};
+
+static PyObject *
+_PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj)
+{
+    PyObject *dumps = PyImport_ImportModuleAttrString("pickle", "dumps");
+    if (dumps == NULL) {
+        return NULL;
+    }
+    PyObject *bytes = PyObject_CallOneArg(dumps, obj);
+    Py_DECREF(dumps);
+    return bytes;
+}
+
+
+struct sync_module_result {
+    PyObject *module;
+    PyObject *loaded;
+    PyObject *failed;
+};
+
+struct sync_module {
+    const char *filename;
+    char _filename[MAXPATHLEN+1];
+    struct sync_module_result cached;
+};
+
+static void
+sync_module_clear(struct sync_module *data)
+{
+    data->filename = NULL;
+    Py_CLEAR(data->cached.module);
+    Py_CLEAR(data->cached.loaded);
+    Py_CLEAR(data->cached.failed);
+}
+
+
+struct _unpickle_context {
+    PyThreadState *tstate;
+    // We only special-case the __main__ module,
+    // since other modules behave consistently.
+    struct sync_module main;
+};
+
+static void
+_unpickle_context_clear(struct _unpickle_context *ctx)
+{
+    sync_module_clear(&ctx->main);
+}
+
+static struct sync_module_result
+_unpickle_context_get_module(struct _unpickle_context *ctx,
+                             const char *modname)
+{
+    if (strcmp(modname, "__main__") == 0) {
+        return ctx->main.cached;
+    }
+    else {
+        return (struct sync_module_result){
+            .failed = PyExc_NotImplementedError,
+        };
+    }
+}
+
+static struct sync_module_result
+_unpickle_context_set_module(struct _unpickle_context *ctx,
+                             const char *modname)
+{
+    struct sync_module_result res = {0};
+    struct sync_module_result *cached = NULL;
+    const char *filename = NULL;
+    if (strcmp(modname, "__main__") == 0) {
+        cached = &ctx->main.cached;
+        filename = ctx->main.filename;
+    }
+    else {
+        res.failed = PyExc_NotImplementedError;
+        goto finally;
+    }
+
+    res.module = import_get_module(ctx->tstate, modname);
+    if (res.module == NULL) {
+        res.failed = _PyErr_GetRaisedException(ctx->tstate);
+        assert(res.failed != NULL);
+        goto finally;
+    }
+
+    if (filename == NULL) {
+        Py_CLEAR(res.module);
+        res.failed = PyExc_NotImplementedError;
+        goto finally;
+    }
+    res.loaded = runpy_run_path(filename, modname);
+    if (res.loaded == NULL) {
+        Py_CLEAR(res.module);
+        res.failed = _PyErr_GetRaisedException(ctx->tstate);
+        assert(res.failed != NULL);
+        goto finally;
+    }
+
+finally:
+    if (cached != NULL) {
+        assert(cached->module == NULL);
+        assert(cached->loaded == NULL);
+        assert(cached->failed == NULL);
+        *cached = res;
+    }
+    return res;
+}
+
+
+static int
+_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc)
+{
+    // The caller must check if an exception is set or not when -1 is returned.
+    assert(!_PyErr_Occurred(ctx->tstate));
+    assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
+    struct attributeerror_info info;
+    if (_parse_attributeerror(exc, &info) < 0) {
+        return -1;
+    }
+
+    // Get the module.
+    struct sync_module_result mod = _unpickle_context_get_module(ctx, 
info.modname);
+    if (mod.failed != NULL) {
+        // It must have failed previously.
+        return -1;
+    }
+    if (mod.module == NULL) {
+        mod = _unpickle_context_set_module(ctx, info.modname);
+        if (mod.failed != NULL) {
+            return -1;
+        }
+        assert(mod.module != NULL);
+    }
+
+    // Bail out if it is unexpectedly set already.
+    if (PyObject_HasAttrString(mod.module, info.attrname)) {
+        return -1;
+    }
+
+    // Try setting the attribute.
+    PyObject *value = NULL;
+    if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) {
+        return -1;
+    }
+    assert(value != NULL);
+    int res = PyObject_SetAttrString(mod.module, info.attrname, value);
+    Py_DECREF(value);
+    if (res < 0) {
+        return -1;
+    }
+
+    return 0;
+}
+
+static PyObject *
+_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled)
+{
+    PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads");
+    if (loads == NULL) {
+        return NULL;
+    }
+    PyObject *obj = PyObject_CallOneArg(loads, pickled);
+    if (ctx != NULL) {
+        while (obj == NULL) {
+            assert(_PyErr_Occurred(ctx->tstate));
+            if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+                // We leave other failures unhandled.
+                break;
+            }
+            // Try setting the attr if not set.
+            PyObject *exc = _PyErr_GetRaisedException(ctx->tstate);
+            if (_handle_unpickle_missing_attr(ctx, exc) < 0) {
+                // Any resulting exceptions are ignored
+                // in favor of the original.
+                _PyErr_SetRaisedException(ctx->tstate, exc);
+                break;
+            }
+            Py_CLEAR(exc);
+            // Retry with the attribute set.
+            obj = PyObject_CallOneArg(loads, pickled);
+        }
+    }
+    Py_DECREF(loads);
+    return obj;
+}
+
+
+/* pickle wrapper */
+
+struct _pickle_xid_context {
+    // __main__.__file__
+    struct {
+        const char *utf8;
+        size_t len;
+        char _utf8[MAXPATHLEN+1];
+    } mainfile;
+};
+
+static int
+_set_pickle_xid_context(PyThreadState *tstate, struct _pickle_xid_context *ctx)
+{
+    // Set mainfile if possible.
+    Py_ssize_t len = _Py_GetMainfile(ctx->mainfile._utf8, MAXPATHLEN);
+    if (len < 0) {
+        // For now we ignore any exceptions.
+        PyErr_Clear();
+    }
+    else if (len > 0) {
+        ctx->mainfile.utf8 = ctx->mainfile._utf8;
+        ctx->mainfile.len = (size_t)len;
+    }
+
+    return 0;
+}
+
+
+struct _shared_pickle_data {
+    _PyBytes_data_t pickled;  // Must be first if we use _PyBytes_FromXIData().
+    struct _pickle_xid_context ctx;
+};
+
+PyObject *
+_PyPickle_LoadFromXIData(_PyXIData_t *xidata)
+{
+    PyThreadState *tstate = _PyThreadState_GET();
+    struct _shared_pickle_data *shared =
+                            (struct _shared_pickle_data *)xidata->data;
+    // We avoid copying the pickled data by wrapping it in a memoryview.
+    // The alternative is to get a bytes object using _PyBytes_FromXIData().
+    PyObject *pickled = PyMemoryView_FromMemory(
+            (char *)shared->pickled.bytes, shared->pickled.len, PyBUF_READ);
+    if (pickled == NULL) {
+        return NULL;
+    }
+
+    // Unpickle the object.
+    struct _unpickle_context ctx = {
+        .tstate = tstate,
+        .main = {
+            .filename = shared->ctx.mainfile.utf8,
+        },
+    };
+    PyObject *obj = _PyPickle_Loads(&ctx, pickled);
+    Py_DECREF(pickled);
+    _unpickle_context_clear(&ctx);
+    if (obj == NULL) {
+        PyObject *cause = _PyErr_GetRaisedException(tstate);
+        assert(cause != NULL);
+        _set_xid_lookup_failure(
+                    tstate, NULL, "object could not be unpickled", cause);
+        Py_DECREF(cause);
+    }
+    return obj;
+}
+
+
+int
+_PyPickle_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata)
+{
+    // Pickle the object.
+    struct _pickle_context ctx = {
+        .tstate = tstate,
+    };
+    PyObject *bytes = _PyPickle_Dumps(&ctx, obj);
+    if (bytes == NULL) {
+        PyObject *cause = _PyErr_GetRaisedException(tstate);
+        assert(cause != NULL);
+        _set_xid_lookup_failure(
+                    tstate, NULL, "object could not be pickled", cause);
+        Py_DECREF(cause);
+        return -1;
+    }
+
+    // If we had an "unwrapper" mechnanism, we could call
+    // _PyObject_GetXIData() on the bytes object directly and add
+    // a simple unwrapper to call pickle.loads() on the bytes.
+    size_t size = sizeof(struct _shared_pickle_data);
+    struct _shared_pickle_data *shared =
+            (struct _shared_pickle_data *)_PyBytes_GetXIDataWrapped(
+                    tstate, bytes, size, _PyPickle_LoadFromXIData, xidata);
+    Py_DECREF(bytes);
+    if (shared == NULL) {
+        return -1;
+    }
+
+    // If it mattered, we could skip getting __main__.__file__
+    // when "__main__" doesn't show up in the pickle bytes.
+    if (_set_pickle_xid_context(tstate, &shared->ctx) < 0) {
+        _xidata_clear(xidata);
+        return -1;
+    }
+
+    return 0;
+}
+
+
 /* marshal wrapper */
 
 PyObject *

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: arch...@mail-archive.com

Reply via email to