Author: Robert Zaremba <robert.zare...@scale-it.pl> Branch: py3.5-fix-decimal-module-name Changeset: r90424:793b49cacddd Date: 2017-02-28 15:32 +0100 http://bitbucket.org/pypy/pypy/changeset/793b49cacddd/
Log: (stevie, robert-zaremba) FIX: test_pickle (test.test_decimal.PyPythonAPItests) Fixes: http://buildbot.pypy.org/summary/longrepr?testname=unmodified &builder=pypy-c-jit-linux-x86-64&build=4406&mod=lib- python.3.test.test_decimal We removed the __module__ hack from the classes in lib_pypy/_decimal.py and added module __name__ pointing to the right module name. diff --git a/extra_tests/support.py b/extra_tests/support.py new file mode 100644 --- /dev/null +++ b/extra_tests/support.py @@ -0,0 +1,98 @@ +import contextlib +import importlib +import sys +import warnings + + +@contextlib.contextmanager +def _ignore_deprecated_imports(ignore=True): + """Context manager to suppress package and module deprecation + warnings when importing them. + + If ignore is False, this context manager has no effect. + """ + if ignore: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".+ (module|package)", + DeprecationWarning) + yield + else: + yield + + +def _save_and_remove_module(name, orig_modules): + """Helper function to save and remove a module from sys.modules + + Raise ImportError if the module can't be imported. + """ + # try to import the module and raise an error if it can't be imported + if name not in sys.modules: + __import__(name) + del sys.modules[name] + for modname in list(sys.modules): + if modname == name or modname.startswith(name + '.'): + orig_modules[modname] = sys.modules[modname] + del sys.modules[modname] + +def _save_and_block_module(name, orig_modules): + """Helper function to save and block a module in sys.modules + + Return True if the module was in sys.modules, False otherwise. + """ + saved = True + try: + orig_modules[name] = sys.modules[name] + except KeyError: + saved = False + sys.modules[name] = None + return saved + + +def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): + """Import and return a module, deliberately bypassing sys.modules. + + This function imports and returns a fresh copy of the named Python module + by removing the named module from sys.modules before doing the import. + Note that unlike reload, the original module is not affected by + this operation. + + *fresh* is an iterable of additional module names that are also removed + from the sys.modules cache before doing the import. + + *blocked* is an iterable of module names that are replaced with None + in the module cache during the import to ensure that attempts to import + them raise ImportError. + + The named module and any modules named in the *fresh* and *blocked* + parameters are saved before starting the import and then reinserted into + sys.modules when the fresh import is complete. + + Module and package deprecation messages are suppressed during this import + if *deprecated* is True. + + This function will raise ImportError if the named module cannot be + imported. + """ + # NOTE: test_heapq, test_json and test_warnings include extra sanity checks + # to make sure that this utility function is working as expected + with _ignore_deprecated_imports(deprecated): + # Keep track of modules saved for later restoration as well + # as those which just need a blocking entry removed + orig_modules = {} + names_to_remove = [] + _save_and_remove_module(name, orig_modules) + try: + for fresh_name in fresh: + _save_and_remove_module(fresh_name, orig_modules) + for blocked_name in blocked: + if not _save_and_block_module(blocked_name, orig_modules): + names_to_remove.append(blocked_name) + fresh_module = importlib.import_module(name) + except ImportError: + fresh_module = None + finally: + for orig_name, module in orig_modules.items(): + sys.modules[orig_name] = module + for name_to_remove in names_to_remove: + del sys.modules[name_to_remove] + return fresh_module diff --git a/extra_tests/test_decimal.py b/extra_tests/test_decimal.py new file mode 100644 --- /dev/null +++ b/extra_tests/test_decimal.py @@ -0,0 +1,59 @@ +import pickle +import sys + +from support import import_fresh_module + +C = import_fresh_module('decimal', fresh=['_decimal']) +P = import_fresh_module('decimal', blocked=['_decimal']) +# import _decimal as C +# import _pydecimal as P + + +class TestPythonAPI: + + def check_equal(self, val, proto): + d = C.Decimal(val) + p = pickle.dumps(d, proto) + assert d == pickle.loads(p) + + def test_C(self): + sys.modules["decimal"] = C + import decimal + d = decimal.Decimal('1') + assert isinstance(d, C.Decimal) + assert isinstance(d, decimal.Decimal) + assert isinstance(d.as_tuple(), C.DecimalTuple) + + assert d == C.Decimal('1') + + def test_pickle(self): + v = '-3.123e81723' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + sys.modules["decimal"] = C + self.check_equal('-3.141590000', proto) + self.check_equal(v, proto) + + cd = C.Decimal(v) + pd = P.Decimal(v) + cdt = cd.as_tuple() + pdt = pd.as_tuple() + assert cdt.__module__ == pdt.__module__ + + p = pickle.dumps(cdt, proto) + r = pickle.loads(p) + assert isinstance(r, C.DecimalTuple) + assert cdt == r + + sys.modules["decimal"] = C + p = pickle.dumps(cd, proto) + sys.modules["decimal"] = P + r = pickle.loads(p) + assert isinstance(r, P.Decimal) + assert r == pd + + sys.modules["decimal"] = C + p = pickle.dumps(cdt, proto) + sys.modules["decimal"] = P + r = pickle.loads(p) + assert isinstance(r, P.DecimalTuple) + assert r == pdt diff --git a/lib_pypy/_decimal.py b/lib_pypy/_decimal.py --- a/lib_pypy/_decimal.py +++ b/lib_pypy/_decimal.py @@ -1,5 +1,9 @@ # Implementation of the "decimal" module, based on libmpdec library. +__xname__ = __name__ # sys.modules lookup (--without-threads) +__name__ = 'decimal' # For pickling + + import collections as _collections import math as _math import numbers as _numbers @@ -23,15 +27,13 @@ # Errors class DecimalException(ArithmeticError): - __module__ = 'decimal' def handle(self, context, *args): pass class Clamped(DecimalException): - __module__ = 'decimal' + pass class InvalidOperation(DecimalException): - __module__ = 'decimal' def handle(self, context, *args): if args: ans = _dec_from_triple(args[0]._sign, args[0]._int, 'n', True) @@ -39,41 +41,35 @@ return _NaN class ConversionSyntax(InvalidOperation): - __module__ = 'decimal' def handle(self, context, *args): return _NaN class DivisionByZero(DecimalException, ZeroDivisionError): - __module__ = 'decimal' def handle(self, context, sign, *args): return _SignedInfinity[sign] class DivisionImpossible(InvalidOperation): - __module__ = 'decimal' def handle(self, context, *args): return _NaN class DivisionUndefined(InvalidOperation, ZeroDivisionError): - __module__ = 'decimal' def handle(self, context, *args): return _NaN class Inexact(DecimalException): - __module__ = 'decimal' + pass class InvalidContext(InvalidOperation): - __module__ = 'decimal' def handle(self, context, *args): return _NaN class Rounded(DecimalException): - __module__ = 'decimal' + pass class Subnormal(DecimalException): - __module__ = 'decimal' + pass class Overflow(Inexact, Rounded): - __module__ = 'decimal' def handle(self, context, sign, *args): if context.rounding in (ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_HALF_DOWN, ROUND_UP): @@ -90,10 +86,10 @@ context.Emax-context.prec+1) class Underflow(Inexact, Rounded, Subnormal): - __module__ = 'decimal' + pass class FloatOperation(DecimalException, TypeError): - __module__ = 'decimal' + pass __version__ = "1.70" @@ -107,7 +103,7 @@ def getcontext(): """Returns this thread's context. - + If this thread does not yet have a context, returns a new context and sets this thread's context. New contexts are copies of DefaultContext. @@ -173,8 +169,6 @@ _DEC_MINALLOC = 4 class Decimal(object): - __module__ = 'decimal' - __slots__ = ('_mpd', '_data') def __new__(cls, value="0", context=None): @@ -326,7 +320,7 @@ builder.append(b'E') builder.append(str(exponent).encode()) - return cls._from_bytes(b''.join(builder), context, exact=exact) + return cls._from_bytes(b''.join(builder), context, exact=exact) @classmethod def from_float(cls, value): @@ -481,7 +475,7 @@ numerator = Decimal._from_int(other.numerator, context) if not _mpdec.mpd_isspecial(self._mpd): # multiplied = self * other.denominator - # + # # Prevent Overflow in the following multiplication. # The result of the multiplication is # only used in mpd_qcmp, which can handle values that @@ -542,7 +536,7 @@ _mpdec.mpd_qset_ssize(p._mpd, self._PyHASH_MODULUS, maxctx, status_ptr) ten = self._new_empty() - _mpdec.mpd_qset_ssize(ten._mpd, 10, + _mpdec.mpd_qset_ssize(ten._mpd, 10, maxctx, status_ptr) inv10_p = self._new_empty() _mpdec.mpd_qset_ssize(inv10_p._mpd, self._PyHASH_10INV, @@ -755,7 +749,7 @@ number_class = _make_unary_operation('number_class') to_eng_string = _make_unary_operation('to_eng_string') - + def fma(self, other, third, context=None): context = _getcontext(context) return context.fma(self, other, third) @@ -790,7 +784,7 @@ result = int.from_bytes(s, 'little', signed=False) if _mpdec.mpd_isnegative(x._mpd) and not _mpdec.mpd_iszero(x._mpd): result = -result - return result + return result def __int__(self): return self._to_int(_mpdec.MPD_ROUND_DOWN) @@ -798,10 +792,10 @@ __trunc__ = __int__ def __floor__(self): - return self._to_int(_mpdec.MPD_ROUND_FLOOR) + return self._to_int(_mpdec.MPD_ROUND_FLOOR) def __ceil__(self): - return self._to_int(_mpdec.MPD_ROUND_CEILING) + return self._to_int(_mpdec.MPD_ROUND_CEILING) def to_integral(self, rounding=None, context=None): context = _getcontext(context) @@ -817,7 +811,7 @@ return result to_integral_value = to_integral - + def to_integral_exact(self, rounding=None, context=None): context = _getcontext(context) workctx = context.copy() @@ -886,7 +880,7 @@ if _mpdec.mpd_isspecial(self._mpd): return 0 return _mpdec.mpd_adjexp(self._mpd) - + @property def real(self): return self @@ -916,7 +910,7 @@ fmt = specifier.encode('utf-8') context = getcontext() - replace_fillchar = False + replace_fillchar = False if fmt and fmt[0] == 0: # NUL fill character: must be replaced with a valid UTF-8 char # before calling mpd_parse_fmt_str(). @@ -975,7 +969,7 @@ result = result.replace(b'\xff', b'\0') return result.decode('utf-8') - + # Register Decimal as a kind of Number (an abstract base class). # However, do not register it as Real (because Decimals are not # interoperable with floats). @@ -988,7 +982,7 @@ # Rounding _ROUNDINGS = { - 'ROUND_DOWN': _mpdec.MPD_ROUND_DOWN, + 'ROUND_DOWN': _mpdec.MPD_ROUND_DOWN, 'ROUND_HALF_UP': _mpdec.MPD_ROUND_HALF_UP, 'ROUND_HALF_EVEN': _mpdec.MPD_ROUND_HALF_EVEN, 'ROUND_CEILING': _mpdec.MPD_ROUND_CEILING, @@ -1047,8 +1041,6 @@ clamp - If 1, change exponents if too high (Default 0) """ - __module__ = 'decimal' - __slots__ = ('_ctx', '_capitals') def __new__(cls, prec=None, rounding=None, Emin=None, Emax=None, @@ -1068,7 +1060,7 @@ ctx.round = _mpdec.MPD_ROUND_HALF_EVEN ctx.clamp = 0 ctx.allcr = 1 - + self._capitals = 1 return self @@ -1291,7 +1283,7 @@ if b is NotImplemented: return b, b return a, b - + def _make_unary_method(name, mpd_func_name): mpd_func = getattr(_mpdec, mpd_func_name) @@ -1570,7 +1562,7 @@ def copy(self): return self._as_dict() - + def __len__(self): return len(_SIGNALS) @@ -1629,7 +1621,7 @@ def __enter__(self): self.status_ptr = _ffi.new("uint32_t*") return self.context._ctx, self.status_ptr - + def __exit__(self, *args): status = self.status_ptr[0] # May raise a DecimalException _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit