https://github.com/python/cpython/commit/cb35c11d82efd2959bda0397abcc1719bf6bb0cb commit: cb35c11d82efd2959bda0397abcc1719bf6bb0cb branch: main author: Eric Snow <ericsnowcurren...@gmail.com> committer: ericsnowcurrently <ericsnowcurren...@gmail.com> date: 2025-04-30T17:34:05-06:00 summary:
gh-132775: Add _PyPickle_GetXIData() (gh-133107) There's some extra complexity due to making sure we we get things right when handling functions and classes defined in the __main__ module. This is also reflected in the tests, including the addition of extra functions in test.support.import_helper. files: M Include/internal/pycore_crossinterp.h M Lib/test/support/import_helper.py M Lib/test/test_crossinterp.py M Modules/_testinternalcapi.c M Python/crossinterp.c diff --git a/Include/internal/pycore_crossinterp.h b/Include/internal/pycore_crossinterp.h index 4b7446a1f40ccf..4b4617fdbcb2ad 100644 --- a/Include/internal/pycore_crossinterp.h +++ b/Include/internal/pycore_crossinterp.h @@ -171,6 +171,13 @@ PyAPI_FUNC(_PyBytes_data_t *) _PyBytes_GetXIDataWrapped( xid_newobjfunc, _PyXIData_t *); +// _PyObject_GetXIData() for pickle +PyAPI_DATA(PyObject *) _PyPickle_LoadFromXIData(_PyXIData_t *); +PyAPI_FUNC(int) _PyPickle_GetXIData( + PyThreadState *, + PyObject *, + _PyXIData_t *); + // _PyObject_GetXIData() for marshal PyAPI_FUNC(PyObject *) _PyMarshal_ReadObjectFromXIData(_PyXIData_t *); PyAPI_FUNC(int) _PyMarshal_GetXIData( diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 42cfe9cfa8cb72..edb734d294f287 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -1,6 +1,7 @@ import contextlib import _imp import importlib +import importlib.machinery import importlib.util import os import shutil @@ -332,3 +333,110 @@ def ensure_lazy_imports(imported_module, modules_to_block): ) from .script_helper import assert_python_ok assert_python_ok("-S", "-c", script) + + +@contextlib.contextmanager +def module_restored(name): + """A context manager that restores a module to the original state.""" + missing = object() + orig = sys.modules.get(name, missing) + if orig is None: + mod = importlib.import_module(name) + else: + mod = type(sys)(name) + mod.__dict__.update(orig.__dict__) + sys.modules[name] = mod + try: + yield mod + finally: + if orig is missing: + sys.modules.pop(name, None) + else: + sys.modules[name] = orig + + +def create_module(name, loader=None, *, ispkg=False): + """Return a new, empty module.""" + spec = importlib.machinery.ModuleSpec( + name, + loader, + origin='<import_helper>', + is_package=ispkg, + ) + return importlib.util.module_from_spec(spec) + + +def _ensure_module(name, ispkg, addparent, clearnone): + try: + mod = orig = sys.modules[name] + except KeyError: + mod = orig = None + missing = True + else: + missing = False + if mod is not None: + # It was already imported. + return mod, orig, missing + # Otherwise, None means it was explicitly disabled. + + assert name != '__main__' + if not missing: + assert orig is None, (name, sys.modules[name]) + if not clearnone: + raise ModuleNotFoundError(name) + del sys.modules[name] + # Try normal import, then fall back to adding the module. + try: + mod = importlib.import_module(name) + except ModuleNotFoundError: + if addparent and not clearnone: + addparent = None + mod = _add_module(name, ispkg, addparent) + return mod, orig, missing + + +def _add_module(spec, ispkg, addparent): + if isinstance(spec, str): + name = spec + mod = create_module(name, ispkg=ispkg) + spec = mod.__spec__ + else: + name = spec.name + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + if addparent is not False and spec.parent: + _ensure_module(spec.parent, True, addparent, bool(addparent)) + return mod + + +def add_module(spec, *, parents=True): + """Return the module after creating it and adding it to sys.modules. + + If parents is True then also create any missing parents. + """ + return _add_module(spec, False, parents) + + +def add_package(spec, *, parents=True): + """Return the module after creating it and adding it to sys.modules. + + If parents is True then also create any missing parents. + """ + return _add_module(spec, True, parents) + + +def ensure_module_imported(name, *, clearnone=True): + """Return the corresponding module. + + If it was already imported then return that. Otherwise, try + importing it (optionally clear it first if None). If that fails + then create a new empty module. + + It can be helpful to combine this with ready_to_import() and/or + isolated_modules(). + """ + if sys.modules.get(name) is not None: + mod = sys.modules[name] + else: + mod, _, _ = _force_import(name, False, True, clearnone) + return mod diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py index 5ebb78b0ea9e3b..32d6fd4e94bf1b 100644 --- a/Lib/test/test_crossinterp.py +++ b/Lib/test/test_crossinterp.py @@ -1,3 +1,6 @@ +import contextlib +import importlib +import importlib.util import itertools import sys import types @@ -9,7 +12,7 @@ _interpreters = import_helper.import_module('_interpreters') from _interpreters import NotShareableError - +from test import _code_definitions as code_defs from test import _crossinterp_definitions as defs @@ -21,6 +24,88 @@ if (isinstance(o, type) and n not in ('DynamicClassAttribute', '_GeneratorWrapper'))] +DEFS = defs +with open(code_defs.__file__) as infile: + _code_defs_text = infile.read() +with open(DEFS.__file__) as infile: + _defs_text = infile.read() + _defs_text = _defs_text.replace('from ', '# from ') +DEFS_TEXT = f""" +####################################### +# from {code_defs.__file__} + +{_code_defs_text} + +####################################### +# from {defs.__file__} + +{_defs_text} +""" +del infile, _code_defs_text, _defs_text + + +def load_defs(module=None): + """Return a new copy of the test._crossinterp_definitions module. + + The module's __name__ matches the "module" arg, which is either + a str or a module. + + If the "module" arg is a module then the just-loaded defs are also + copied into that module. + + Note that the new module is not added to sys.modules. + """ + if module is None: + modname = DEFS.__name__ + elif isinstance(module, str): + modname = module + module = None + else: + modname = module.__name__ + # Create the new module and populate it. + defs = import_helper.create_module(modname) + defs.__file__ = DEFS.__file__ + exec(DEFS_TEXT, defs.__dict__) + # Copy the defs into the module arg, if any. + if module is not None: + for name, value in defs.__dict__.items(): + if name.startswith('_'): + continue + assert not hasattr(module, name), (name, getattr(module, name)) + setattr(module, name, value) + return defs + + +@contextlib.contextmanager +def using___main__(): + """Make sure __main__ module exists (and clean up after).""" + modname = '__main__' + if modname not in sys.modules: + with import_helper.isolated_modules(): + yield import_helper.add_module(modname) + else: + with import_helper.module_restored(modname) as mod: + yield mod + + +@contextlib.contextmanager +def temp_module(modname): + """Create the module and add to sys.modules, then remove it after.""" + assert modname not in sys.modules, (modname,) + with import_helper.isolated_modules(): + yield import_helper.add_module(modname) + + +@contextlib.contextmanager +def missing_defs_module(modname, *, prep=False): + assert modname not in sys.modules, (modname,) + if prep: + with import_helper.ready_to_import(modname, DEFS_TEXT): + yield modname + else: + with import_helper.isolated_modules(): + yield modname + class _GetXIDataTests(unittest.TestCase): @@ -32,52 +117,49 @@ def get_xidata(self, obj, *, mode=None): def get_roundtrip(self, obj, *, mode=None): mode = self._resolve_mode(mode) - xid =_testinternalcapi.get_crossinterp_data(obj, mode) + return self._get_roundtrip(obj, mode) + + def _get_roundtrip(self, obj, mode): + xid = _testinternalcapi.get_crossinterp_data(obj, mode) return _testinternalcapi.restore_crossinterp_data(xid) - def iter_roundtrip_values(self, values, *, mode=None): + def assert_roundtrip_identical(self, values, *, mode=None): mode = self._resolve_mode(mode) for obj in values: with self.subTest(obj): - xid = _testinternalcapi.get_crossinterp_data(obj, mode) - got = _testinternalcapi.restore_crossinterp_data(xid) - yield obj, got - - def assert_roundtrip_identical(self, values, *, mode=None): - for obj, got in self.iter_roundtrip_values(values, mode=mode): - # XXX What about between interpreters? - self.assertIs(got, obj) + got = self._get_roundtrip(obj, mode) + self.assertIs(got, obj) def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None): - for obj, got in self.iter_roundtrip_values(values, mode=mode): - self.assertEqual(got, obj) - self.assertIs(type(got), - type(obj) if expecttype is None else expecttype) - -# def assert_roundtrip_equal_not_identical(self, values, *, -# mode=None, expecttype=None): -# mode = self._resolve_mode(mode) -# for obj in values: -# cls = type(obj) -# with self.subTest(obj): -# got = self._get_roundtrip(obj, mode) -# self.assertIsNot(got, obj) -# self.assertIs(type(got), type(obj)) -# self.assertEqual(got, obj) -# self.assertIs(type(got), -# cls if expecttype is None else expecttype) -# -# def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None): -# mode = self._resolve_mode(mode) -# for obj in values: -# cls = type(obj) -# with self.subTest(obj): -# got = self._get_roundtrip(obj, mode) -# self.assertIsNot(got, obj) -# self.assertIs(type(got), type(obj)) -# self.assertNotEqual(got, obj) -# self.assertIs(type(got), -# cls if expecttype is None else expecttype) + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertEqual(got, obj) + self.assertIs(type(got), + type(obj) if expecttype is None else expecttype) + + def assert_roundtrip_equal_not_identical(self, values, *, + mode=None, expecttype=None): + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertIsNot(got, obj) + self.assertIs(type(got), + type(obj) if expecttype is None else expecttype) + self.assertEqual(got, obj) + + def assert_roundtrip_not_equal(self, values, *, + mode=None, expecttype=None): + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertIsNot(got, obj) + self.assertIs(type(got), + type(obj) if expecttype is None else expecttype) + self.assertNotEqual(got, obj) def assert_not_shareable(self, values, exctype=None, *, mode=None): mode = self._resolve_mode(mode) @@ -95,6 +177,363 @@ def _resolve_mode(self, mode): return mode +class PickleTests(_GetXIDataTests): + + MODE = 'pickle' + + def test_shareable(self): + self.assert_roundtrip_equal([ + # singletons + None, + True, + False, + # bytes + *(i.to_bytes(2, 'little', signed=True) + for i in range(-1, 258)), + # str + 'hello world', + '你好世界', + '', + # int + sys.maxsize, + -sys.maxsize - 1, + *range(-1, 258), + # float + 0.0, + 1.1, + -1.0, + 0.12345678, + -0.12345678, + # tuple + (), + (1,), + ("hello", "world", ), + (1, True, "hello"), + ((1,),), + ((1, 2), (3, 4)), + ((1, 2), (3, 4), (5, 6)), + ]) + # not shareable using xidata + self.assert_roundtrip_equal([ + # int + sys.maxsize + 1, + -sys.maxsize - 2, + 2**1000, + # tuple + (0, 1.0, []), + (0, 1.0, {}), + (0, 1.0, ([],)), + (0, 1.0, ({},)), + ]) + + def test_list(self): + self.assert_roundtrip_equal_not_identical([ + [], + [1, 2, 3], + [[1], (2,), {3: 4}], + ]) + + def test_dict(self): + self.assert_roundtrip_equal_not_identical([ + {}, + {1: 7, 2: 8, 3: 9}, + {1: [1], 2: (2,), 3: {3: 4}}, + ]) + + def test_set(self): + self.assert_roundtrip_equal_not_identical([ + set(), + {1, 2, 3}, + {frozenset({1}), (2,)}, + ]) + + # classes + + def assert_class_defs_same(self, defs): + # Unpickle relative to the unchanged original module. + self.assert_roundtrip_identical(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + if cls in defs.CLASSES_WITHOUT_EQUALITY: + continue + instances.append(cls(*args)) + self.assert_roundtrip_equal_not_identical(instances) + + # these don't compare equal + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + if cls not in defs.CLASSES_WITHOUT_EQUALITY: + continue + instances.append(cls(*args)) + self.assert_roundtrip_not_equal(instances) + + def assert_class_defs_other_pickle(self, defs, mod): + # Pickle relative to a different module than the original. + for cls in defs.TOP_CLASSES: + assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) + self.assert_not_shareable(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) + + def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False): + # Unpickle relative to a different module than the original. + for cls in defs.TOP_CLASSES: + assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + with self.subTest(cls): + setattr(mod, cls.__name__, cls) + xid = self.get_xidata(cls) + inst = cls(*args) + instxid = self.get_xidata(inst) + instances.append( + (cls, xid, inst, instxid)) + + for cls, xid, inst, instxid in instances: + with self.subTest(cls): + delattr(mod, cls.__name__) + if fail: + with self.assertRaises(NotShareableError): + _testinternalcapi.restore_crossinterp_data(xid) + continue + got = _testinternalcapi.restore_crossinterp_data(xid) + self.assertIsNot(got, cls) + self.assertNotEqual(got, cls) + + gotcls = got + got = _testinternalcapi.restore_crossinterp_data(instxid) + self.assertIsNot(got, inst) + self.assertIs(type(got), gotcls) + if cls in defs.CLASSES_WITHOUT_EQUALITY: + self.assertNotEqual(got, inst) + elif cls in defs.BUILTIN_SUBCLASSES: + self.assertEqual(got, inst) + else: + self.assertNotEqual(got, inst) + + def assert_class_defs_not_shareable(self, defs): + self.assert_not_shareable(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) + + def test_user_class_normal(self): + self.assert_class_defs_same(defs) + + def test_user_class_in___main__(self): + with using___main__() as mod: + defs = load_defs(mod) + self.assert_class_defs_same(defs) + + def test_user_class_not_in___main___with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in___main___without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in___main___unpickle_with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_class_defs_other_unpickle(defs, mod) + + def test_user_class_not_in___main___unpickle_without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_class_defs_other_unpickle(defs, mod, fail=True) + + def test_user_class_in_module(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod) + self.assert_class_defs_same(defs) + + def test_user_class_not_in_module_with_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + assert defs.__file__ + # For now, we only address this case for __main__. + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in_module_without_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + defs.__file__ = None + self.assert_class_defs_not_shareable(defs) + + def test_user_class_module_missing_then_imported(self): + with missing_defs_module('__spam__', prep=True) as modname: + defs = load_defs(modname) + # For now, we only address this case for __main__. + self.assert_class_defs_not_shareable(defs) + + def test_user_class_module_missing_not_available(self): + with missing_defs_module('__spam__') as modname: + defs = load_defs(modname) + self.assert_class_defs_not_shareable(defs) + + def test_nested_class(self): + eggs = defs.EggsNested() + with self.assertRaises(NotShareableError): + self.get_roundtrip(eggs) + + # functions + + def assert_func_defs_same(self, defs): + # Unpickle relative to the unchanged original module. + self.assert_roundtrip_identical(defs.TOP_FUNCTIONS) + + def assert_func_defs_other_pickle(self, defs, mod): + # Pickle relative to a different module than the original. + for func in defs.TOP_FUNCTIONS: + assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) + self.assert_not_shareable(defs.TOP_FUNCTIONS) + + def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False): + # Unpickle relative to a different module than the original. + for func in defs.TOP_FUNCTIONS: + assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) + + captured = [] + for func in defs.TOP_FUNCTIONS: + with self.subTest(func): + setattr(mod, func.__name__, func) + xid = self.get_xidata(func) + captured.append( + (func, xid)) + + for func, xid in captured: + with self.subTest(func): + delattr(mod, func.__name__) + if fail: + with self.assertRaises(NotShareableError): + _testinternalcapi.restore_crossinterp_data(xid) + continue + got = _testinternalcapi.restore_crossinterp_data(xid) + self.assertIsNot(got, func) + self.assertNotEqual(got, func) + + def assert_func_defs_not_shareable(self, defs): + self.assert_not_shareable(defs.TOP_FUNCTIONS) + + def test_user_function_normal(self): +# self.assert_roundtrip_equal(defs.TOP_FUNCTIONS) + self.assert_func_defs_same(defs) + + def test_user_func_in___main__(self): + with using___main__() as mod: + defs = load_defs(mod) + self.assert_func_defs_same(defs) + + def test_user_func_not_in___main___with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in___main___without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in___main___unpickle_with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_func_defs_other_unpickle(defs, mod) + + def test_user_func_not_in___main___unpickle_without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_func_defs_other_unpickle(defs, mod, fail=True) + + def test_user_func_in_module(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod) + self.assert_func_defs_same(defs) + + def test_user_func_not_in_module_with_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + assert defs.__file__ + # For now, we only address this case for __main__. + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in_module_without_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + defs.__file__ = None + self.assert_func_defs_not_shareable(defs) + + def test_user_func_module_missing_then_imported(self): + with missing_defs_module('__spam__', prep=True) as modname: + defs = load_defs(modname) + # For now, we only address this case for __main__. + self.assert_func_defs_not_shareable(defs) + + def test_user_func_module_missing_not_available(self): + with missing_defs_module('__spam__') as modname: + defs = load_defs(modname) + self.assert_func_defs_not_shareable(defs) + + def test_nested_function(self): + self.assert_not_shareable(defs.NESTED_FUNCTIONS) + + # exceptions + + def test_user_exception_normal(self): + self.assert_roundtrip_not_equal([ + defs.MimimalError('error!'), + ]) + self.assert_roundtrip_equal_not_identical([ + defs.RichError('error!', 42), + ]) + + def test_builtin_exception(self): + msg = 'error!' + try: + raise Exception + except Exception as exc: + caught = exc + special = { + BaseExceptionGroup: (msg, [caught]), + ExceptionGroup: (msg, [caught]), +# UnicodeError: (None, msg, None, None, None), + UnicodeEncodeError: ('utf-8', '', 1, 3, msg), + UnicodeDecodeError: ('utf-8', b'', 1, 3, msg), + UnicodeTranslateError: ('', 1, 3, msg), + } + exceptions = [] + for cls in EXCEPTION_TYPES: + args = special.get(cls) or (msg,) + exceptions.append(cls(*args)) + + self.assert_roundtrip_not_equal(exceptions) + + class MarshalTests(_GetXIDataTests): MODE = 'marshal' @@ -444,22 +883,12 @@ def test_module(self): ]) def test_class(self): - self.assert_not_shareable([ - defs.Spam, - defs.SpamOkay, - defs.SpamFull, - defs.SubSpamFull, - defs.SubTuple, - defs.EggsNested, - ]) - self.assert_not_shareable([ - defs.Spam(), - defs.SpamOkay(), - defs.SpamFull(1, 2, 3), - defs.SubSpamFull(1, 2, 3), - defs.SubTuple([1, 2, 3]), - defs.EggsNested(), - ]) + self.assert_not_shareable(defs.CLASSES) + + instances = [] + for cls, args in defs.CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) def test_builtin_type(self): self.assert_not_shareable([ diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c index 4bfe88f2cf920c..812737e294fcb7 100644 --- a/Modules/_testinternalcapi.c +++ b/Modules/_testinternalcapi.c @@ -1939,6 +1939,11 @@ get_crossinterp_data(PyObject *self, PyObject *args, PyObject *kwargs) goto error; } } + else if (strcmp(mode, "pickle") == 0) { + if (_PyPickle_GetXIData(tstate, obj, xidata) != 0) { + goto error; + } + } else if (strcmp(mode, "marshal") == 0) { if (_PyMarshal_GetXIData(tstate, obj, xidata) != 0) { goto error; diff --git a/Python/crossinterp.c b/Python/crossinterp.c index 753d784a503467..a9f9b78562917e 100644 --- a/Python/crossinterp.c +++ b/Python/crossinterp.c @@ -3,6 +3,7 @@ #include "Python.h" #include "marshal.h" // PyMarshal_WriteObjectToString() +#include "osdefs.h" // MAXPATHLEN #include "pycore_ceval.h" // _Py_simple_func #include "pycore_crossinterp.h" // _PyXIData_t #include "pycore_initconfig.h" // _PyStatus_OK() @@ -10,6 +11,155 @@ #include "pycore_typeobject.h" // _PyStaticType_InitBuiltin() +static Py_ssize_t +_Py_GetMainfile(char *buffer, size_t maxlen) +{ + // We don't expect subinterpreters to have the __main__ module's + // __name__ set, but proceed just in case. + PyThreadState *tstate = _PyThreadState_GET(); + PyObject *module = _Py_GetMainModule(tstate); + if (_Py_CheckMainModule(module) < 0) { + return -1; + } + Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen); + Py_DECREF(module); + return size; +} + + +static PyObject * +import_get_module(PyThreadState *tstate, const char *modname) +{ + PyObject *module = NULL; + if (strcmp(modname, "__main__") == 0) { + module = _Py_GetMainModule(tstate); + if (_Py_CheckMainModule(module) < 0) { + assert(_PyErr_Occurred(tstate)); + return NULL; + } + } + else { + module = PyImport_ImportModule(modname); + if (module == NULL) { + return NULL; + } + } + return module; +} + + +static PyObject * +runpy_run_path(const char *filename, const char *modname) +{ + PyObject *run_path = PyImport_ImportModuleAttrString("runpy", "run_path"); + if (run_path == NULL) { + return NULL; + } + PyObject *args = Py_BuildValue("(sOs)", filename, Py_None, modname); + if (args == NULL) { + Py_DECREF(run_path); + return NULL; + } + PyObject *ns = PyObject_Call(run_path, args, NULL); + Py_DECREF(run_path); + Py_DECREF(args); + return ns; +} + + +static PyObject * +pyerr_get_message(PyObject *exc) +{ + assert(!PyErr_Occurred()); + PyObject *args = PyException_GetArgs(exc); + if (args == NULL || args == Py_None || PyObject_Size(args) < 1) { + return NULL; + } + if (PyUnicode_Check(args)) { + return args; + } + PyObject *msg = PySequence_GetItem(args, 0); + Py_DECREF(args); + if (msg == NULL) { + PyErr_Clear(); + return NULL; + } + if (!PyUnicode_Check(msg)) { + Py_DECREF(msg); + return NULL; + } + return msg; +} + +#define MAX_MODNAME (255) +#define MAX_ATTRNAME (255) + +struct attributeerror_info { + char modname[MAX_MODNAME+1]; + char attrname[MAX_ATTRNAME+1]; +}; + +static int +_parse_attributeerror(PyObject *exc, struct attributeerror_info *info) +{ + assert(exc != NULL); + assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); + int res = -1; + + PyObject *msgobj = pyerr_get_message(exc); + if (msgobj == NULL) { + return -1; + } + const char *err = PyUnicode_AsUTF8(msgobj); + + if (strncmp(err, "module '", 8) != 0) { + goto finally; + } + err += 8; + + const char *matched = strchr(err, '\''); + if (matched == NULL) { + goto finally; + } + Py_ssize_t len = matched - err; + if (len > MAX_MODNAME) { + goto finally; + } + (void)strncpy(info->modname, err, len); + info->modname[len] = '\0'; + err = matched; + + if (strncmp(err, "' has no attribute '", 20) != 0) { + goto finally; + } + err += 20; + + matched = strchr(err, '\''); + if (matched == NULL) { + goto finally; + } + len = matched - err; + if (len > MAX_ATTRNAME) { + goto finally; + } + (void)strncpy(info->attrname, err, len); + info->attrname[len] = '\0'; + err = matched + 1; + + if (strlen(err) > 0) { + goto finally; + } + res = 0; + +finally: + Py_DECREF(msgobj); + return res; +} + +#undef MAX_MODNAME +#undef MAX_ATTRNAME + + /**************/ /* exceptions */ /**************/ @@ -287,6 +437,308 @@ _PyObject_GetXIData(PyThreadState *tstate, } +/* pickle C-API */ + +struct _pickle_context { + PyThreadState *tstate; +}; + +static PyObject * +_PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj) +{ + PyObject *dumps = PyImport_ImportModuleAttrString("pickle", "dumps"); + if (dumps == NULL) { + return NULL; + } + PyObject *bytes = PyObject_CallOneArg(dumps, obj); + Py_DECREF(dumps); + return bytes; +} + + +struct sync_module_result { + PyObject *module; + PyObject *loaded; + PyObject *failed; +}; + +struct sync_module { + const char *filename; + char _filename[MAXPATHLEN+1]; + struct sync_module_result cached; +}; + +static void +sync_module_clear(struct sync_module *data) +{ + data->filename = NULL; + Py_CLEAR(data->cached.module); + Py_CLEAR(data->cached.loaded); + Py_CLEAR(data->cached.failed); +} + + +struct _unpickle_context { + PyThreadState *tstate; + // We only special-case the __main__ module, + // since other modules behave consistently. + struct sync_module main; +}; + +static void +_unpickle_context_clear(struct _unpickle_context *ctx) +{ + sync_module_clear(&ctx->main); +} + +static struct sync_module_result +_unpickle_context_get_module(struct _unpickle_context *ctx, + const char *modname) +{ + if (strcmp(modname, "__main__") == 0) { + return ctx->main.cached; + } + else { + return (struct sync_module_result){ + .failed = PyExc_NotImplementedError, + }; + } +} + +static struct sync_module_result +_unpickle_context_set_module(struct _unpickle_context *ctx, + const char *modname) +{ + struct sync_module_result res = {0}; + struct sync_module_result *cached = NULL; + const char *filename = NULL; + if (strcmp(modname, "__main__") == 0) { + cached = &ctx->main.cached; + filename = ctx->main.filename; + } + else { + res.failed = PyExc_NotImplementedError; + goto finally; + } + + res.module = import_get_module(ctx->tstate, modname); + if (res.module == NULL) { + res.failed = _PyErr_GetRaisedException(ctx->tstate); + assert(res.failed != NULL); + goto finally; + } + + if (filename == NULL) { + Py_CLEAR(res.module); + res.failed = PyExc_NotImplementedError; + goto finally; + } + res.loaded = runpy_run_path(filename, modname); + if (res.loaded == NULL) { + Py_CLEAR(res.module); + res.failed = _PyErr_GetRaisedException(ctx->tstate); + assert(res.failed != NULL); + goto finally; + } + +finally: + if (cached != NULL) { + assert(cached->module == NULL); + assert(cached->loaded == NULL); + assert(cached->failed == NULL); + *cached = res; + } + return res; +} + + +static int +_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc) +{ + // The caller must check if an exception is set or not when -1 is returned. + assert(!_PyErr_Occurred(ctx->tstate)); + assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); + struct attributeerror_info info; + if (_parse_attributeerror(exc, &info) < 0) { + return -1; + } + + // Get the module. + struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname); + if (mod.failed != NULL) { + // It must have failed previously. + return -1; + } + if (mod.module == NULL) { + mod = _unpickle_context_set_module(ctx, info.modname); + if (mod.failed != NULL) { + return -1; + } + assert(mod.module != NULL); + } + + // Bail out if it is unexpectedly set already. + if (PyObject_HasAttrString(mod.module, info.attrname)) { + return -1; + } + + // Try setting the attribute. + PyObject *value = NULL; + if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) { + return -1; + } + assert(value != NULL); + int res = PyObject_SetAttrString(mod.module, info.attrname, value); + Py_DECREF(value); + if (res < 0) { + return -1; + } + + return 0; +} + +static PyObject * +_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled) +{ + PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads"); + if (loads == NULL) { + return NULL; + } + PyObject *obj = PyObject_CallOneArg(loads, pickled); + if (ctx != NULL) { + while (obj == NULL) { + assert(_PyErr_Occurred(ctx->tstate)); + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + // We leave other failures unhandled. + break; + } + // Try setting the attr if not set. + PyObject *exc = _PyErr_GetRaisedException(ctx->tstate); + if (_handle_unpickle_missing_attr(ctx, exc) < 0) { + // Any resulting exceptions are ignored + // in favor of the original. + _PyErr_SetRaisedException(ctx->tstate, exc); + break; + } + Py_CLEAR(exc); + // Retry with the attribute set. + obj = PyObject_CallOneArg(loads, pickled); + } + } + Py_DECREF(loads); + return obj; +} + + +/* pickle wrapper */ + +struct _pickle_xid_context { + // __main__.__file__ + struct { + const char *utf8; + size_t len; + char _utf8[MAXPATHLEN+1]; + } mainfile; +}; + +static int +_set_pickle_xid_context(PyThreadState *tstate, struct _pickle_xid_context *ctx) +{ + // Set mainfile if possible. + Py_ssize_t len = _Py_GetMainfile(ctx->mainfile._utf8, MAXPATHLEN); + if (len < 0) { + // For now we ignore any exceptions. + PyErr_Clear(); + } + else if (len > 0) { + ctx->mainfile.utf8 = ctx->mainfile._utf8; + ctx->mainfile.len = (size_t)len; + } + + return 0; +} + + +struct _shared_pickle_data { + _PyBytes_data_t pickled; // Must be first if we use _PyBytes_FromXIData(). + struct _pickle_xid_context ctx; +}; + +PyObject * +_PyPickle_LoadFromXIData(_PyXIData_t *xidata) +{ + PyThreadState *tstate = _PyThreadState_GET(); + struct _shared_pickle_data *shared = + (struct _shared_pickle_data *)xidata->data; + // We avoid copying the pickled data by wrapping it in a memoryview. + // The alternative is to get a bytes object using _PyBytes_FromXIData(). + PyObject *pickled = PyMemoryView_FromMemory( + (char *)shared->pickled.bytes, shared->pickled.len, PyBUF_READ); + if (pickled == NULL) { + return NULL; + } + + // Unpickle the object. + struct _unpickle_context ctx = { + .tstate = tstate, + .main = { + .filename = shared->ctx.mainfile.utf8, + }, + }; + PyObject *obj = _PyPickle_Loads(&ctx, pickled); + Py_DECREF(pickled); + _unpickle_context_clear(&ctx); + if (obj == NULL) { + PyObject *cause = _PyErr_GetRaisedException(tstate); + assert(cause != NULL); + _set_xid_lookup_failure( + tstate, NULL, "object could not be unpickled", cause); + Py_DECREF(cause); + } + return obj; +} + + +int +_PyPickle_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata) +{ + // Pickle the object. + struct _pickle_context ctx = { + .tstate = tstate, + }; + PyObject *bytes = _PyPickle_Dumps(&ctx, obj); + if (bytes == NULL) { + PyObject *cause = _PyErr_GetRaisedException(tstate); + assert(cause != NULL); + _set_xid_lookup_failure( + tstate, NULL, "object could not be pickled", cause); + Py_DECREF(cause); + return -1; + } + + // If we had an "unwrapper" mechnanism, we could call + // _PyObject_GetXIData() on the bytes object directly and add + // a simple unwrapper to call pickle.loads() on the bytes. + size_t size = sizeof(struct _shared_pickle_data); + struct _shared_pickle_data *shared = + (struct _shared_pickle_data *)_PyBytes_GetXIDataWrapped( + tstate, bytes, size, _PyPickle_LoadFromXIData, xidata); + Py_DECREF(bytes); + if (shared == NULL) { + return -1; + } + + // If it mattered, we could skip getting __main__.__file__ + // when "__main__" doesn't show up in the pickle bytes. + if (_set_pickle_xid_context(tstate, &shared->ctx) < 0) { + _xidata_clear(xidata); + return -1; + } + + return 0; +} + + /* marshal wrapper */ PyObject * _______________________________________________ Python-checkins mailing list -- python-checkins@python.org To unsubscribe send an email to python-checkins-le...@python.org https://mail.python.org/mailman3/lists/python-checkins.python.org/ Member address: arch...@mail-archive.com