Author: Carl Friedrich Bolz-Tereick <cfb...@gmx.de> Branch: Changeset: r95895:7a4d0769c63d Date: 2019-02-08 11:01 +0100 http://bitbucket.org/pypy/pypy/changeset/7a4d0769c63d/
Log: merge math-improvements diff too long, truncating to 2000 out of 2276 lines diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst --- a/pypy/doc/whatsnew-head.rst +++ b/pypy/doc/whatsnew-head.rst @@ -13,3 +13,9 @@ The zlib module's compressobj and decompressobj now expose copy methods as they do on CPython. + + +.. math-improvements + +Improve performance of long operations where one of the operands fits into +an int. \ No newline at end of file diff --git a/pypy/objspace/std/intobject.py b/pypy/objspace/std/intobject.py --- a/pypy/objspace/std/intobject.py +++ b/pypy/objspace/std/intobject.py @@ -299,7 +299,7 @@ return ix -def _pow_ovf2long(space, iv, iw, w_modulus): +def _pow_ovf2long(space, iv, w_iv, iw, w_iw, w_modulus): if space.is_none(w_modulus) and _recover_with_smalllong(space): from pypy.objspace.std.smalllongobject import _pow as _pow_small try: @@ -308,9 +308,12 @@ return _pow_small(space, r_longlong(iv), iw, r_longlong(0)) except (OverflowError, ValueError): pass - from pypy.objspace.std.longobject import W_LongObject - w_iv = W_LongObject.fromint(space, iv) - w_iw = W_LongObject.fromint(space, iw) + from pypy.objspace.std.longobject import W_LongObject, W_AbstractLongObject + if w_iv is None or not isinstance(w_iv, W_AbstractLongObject): + w_iv = W_LongObject.fromint(space, iv) + if w_iw is None or not isinstance(w_iw, W_AbstractLongObject): + w_iw = W_LongObject.fromint(space, iw) + return w_iv.descr_pow(space, w_iw, w_modulus) @@ -318,7 +321,7 @@ op = getattr(operator, opname, None) assert op or ovf2small - def ovf2long(space, x, y): + def ovf2long(space, x, w_x, y, w_y): """Handle overflowing to smalllong or long""" if _recover_with_smalllong(space): if ovf2small: @@ -330,9 +333,12 @@ b = r_longlong(y) return W_SmallLongObject(op(a, b)) - from pypy.objspace.std.longobject import W_LongObject - w_x = W_LongObject.fromint(space, x) - w_y = W_LongObject.fromint(space, y) + from pypy.objspace.std.longobject import W_LongObject, W_AbstractLongObject + if w_x is None or not isinstance(w_x, W_AbstractLongObject): + w_x = W_LongObject.fromint(space, x) + if w_y is None or not isinstance(w_y, W_AbstractLongObject): + w_y = W_LongObject.fromint(space, y) + return getattr(w_x, 'descr_' + opname)(space, w_y) return ovf2long @@ -496,12 +502,18 @@ # can't return NotImplemented (space.pow doesn't do full # ternary, i.e. w_modulus.__zpow__(self, w_exponent)), so # handle it ourselves - return _pow_ovf2long(space, x, y, w_modulus) + return _pow_ovf2long(space, x, self, y, w_exponent, w_modulus) try: result = _pow(space, x, y, z) - except (OverflowError, ValueError): - return _pow_ovf2long(space, x, y, w_modulus) + except OverflowError: + return _pow_ovf2long(space, x, self, y, w_exponent, w_modulus) + except ValueError: + # float result, so let avoid a roundtrip in rbigint. + self = self.descr_float(space) + w_exponent = w_exponent.descr_float(space) + return space.pow(self, w_exponent, space.w_None) + return space.newint(result) @unwrap_spec(w_modulus=WrappedDefault(None)) @@ -546,7 +558,7 @@ try: z = ovfcheck(op(x, y)) except OverflowError: - return ovf2long(space, x, y) + return ovf2long(space, x, self, y, w_other) else: z = op(x, y) return wrapint(space, z) @@ -568,7 +580,7 @@ try: z = ovfcheck(op(y, x)) except OverflowError: - return ovf2long(space, y, x) + return ovf2long(space, y, w_other, x, self) # XXX write a test else: z = op(y, x) return wrapint(space, z) @@ -599,7 +611,7 @@ try: return func(space, x, y) except OverflowError: - return ovf2long(space, x, y) + return ovf2long(space, x, self, y, w_other) else: return func(space, x, y) @@ -614,7 +626,7 @@ try: return func(space, y, x) except OverflowError: - return ovf2long(space, y, x) + return ovf2long(space, y, w_other, x, self) else: return func(space, y, x) diff --git a/pypy/objspace/std/longobject.py b/pypy/objspace/std/longobject.py --- a/pypy/objspace/std/longobject.py +++ b/pypy/objspace/std/longobject.py @@ -308,28 +308,47 @@ @unwrap_spec(w_modulus=WrappedDefault(None)) def descr_pow(self, space, w_exponent, w_modulus=None): + exp_int = 0 + exp_bigint = None + sign = 0 + if isinstance(w_exponent, W_AbstractIntObject): - w_exponent = w_exponent.descr_long(space) + exp_int = w_exponent.int_w(space) + if exp_int > 0: + sign = 1 + elif exp_int < 0: + sign = -1 elif not isinstance(w_exponent, W_AbstractLongObject): return space.w_NotImplemented + else: + exp_bigint = w_exponent.asbigint() + sign = exp_bigint.sign if space.is_none(w_modulus): - if w_exponent.asbigint().sign < 0: + if sign < 0: self = self.descr_float(space) w_exponent = w_exponent.descr_float(space) return space.pow(self, w_exponent, space.w_None) - return W_LongObject(self.num.pow(w_exponent.asbigint())) + if not exp_bigint: + return W_LongObject(self.num.int_pow(exp_int)) + else: + return W_LongObject(self.num.pow(exp_bigint)) + elif isinstance(w_modulus, W_AbstractIntObject): w_modulus = w_modulus.descr_long(space) + elif not isinstance(w_modulus, W_AbstractLongObject): return space.w_NotImplemented - if w_exponent.asbigint().sign < 0: + if sign < 0: raise oefmt(space.w_TypeError, "pow() 2nd argument cannot be negative when 3rd " "argument specified") try: - result = self.num.pow(w_exponent.asbigint(), w_modulus.asbigint()) + if not exp_bigint: + result = self.num.int_pow(exp_int, w_modulus.asbigint()) + else: + result = self.num.pow(exp_bigint, w_modulus.asbigint()) except ValueError: raise oefmt(space.w_ValueError, "pow 3rd argument cannot be 0") return W_LongObject(result) @@ -372,22 +391,16 @@ descr_gt = _make_descr_cmp('gt') descr_ge = _make_descr_cmp('ge') - def _make_generic_descr_binop_noncommutative(opname): - methname = opname + '_' if opname in ('and', 'or') else opname - descr_rname = 'descr_r' + opname - op = getattr(rbigint, methname) + def descr_sub(self, space, w_other): + if isinstance(w_other, W_AbstractIntObject): + return W_LongObject(self.num.int_sub(w_other.int_w(space))) + elif not isinstance(w_other, W_AbstractLongObject): + return space.w_NotImplemented + return W_LongObject(self.num.sub(w_other.asbigint())) - @func_renamer('descr_' + opname) - @delegate_other - def descr_binop(self, space, w_other): - return W_LongObject(op(self.num, w_other.asbigint())) - - @func_renamer(descr_rname) - @delegate_other - def descr_rbinop(self, space, w_other): - return W_LongObject(op(w_other.asbigint(), self.num)) - - return descr_binop, descr_rbinop + @delegate_other + def descr_rsub(self, space, w_other): + return W_LongObject(w_other.asbigint().sub(self.num)) def _make_generic_descr_binop(opname): if opname not in COMMUTATIVE_OPS: @@ -419,28 +432,23 @@ return descr_binop, descr_rbinop descr_add, descr_radd = _make_generic_descr_binop('add') - descr_sub, descr_rsub = _make_generic_descr_binop_noncommutative('sub') + descr_mul, descr_rmul = _make_generic_descr_binop('mul') descr_and, descr_rand = _make_generic_descr_binop('and') descr_or, descr_ror = _make_generic_descr_binop('or') descr_xor, descr_rxor = _make_generic_descr_binop('xor') - def _make_descr_binop(func, int_func=None): + def _make_descr_binop(func, int_func): opname = func.__name__[1:] - if int_func: - @func_renamer('descr_' + opname) - def descr_binop(self, space, w_other): - if isinstance(w_other, W_AbstractIntObject): - return int_func(self, space, w_other.int_w(space)) - elif not isinstance(w_other, W_AbstractLongObject): - return space.w_NotImplemented - return func(self, space, w_other) - else: - @delegate_other - @func_renamer('descr_' + opname) - def descr_binop(self, space, w_other): - return func(self, space, w_other) + @func_renamer('descr_' + opname) + def descr_binop(self, space, w_other): + if isinstance(w_other, W_AbstractIntObject): + return int_func(self, space, w_other.int_w(space)) + elif not isinstance(w_other, W_AbstractLongObject): + return space.w_NotImplemented + return func(self, space, w_other) + @delegate_other @func_renamer('descr_r' + opname) def descr_rbinop(self, space, w_other): @@ -460,10 +468,10 @@ raise oefmt(space.w_OverflowError, "shift count too large") return W_LongObject(self.num.lshift(shift)) - def _int_lshift(self, space, w_other): - if w_other < 0: + def _int_lshift(self, space, other): + if other < 0: raise oefmt(space.w_ValueError, "negative shift count") - return W_LongObject(self.num.lshift(w_other)) + return W_LongObject(self.num.lshift(other)) descr_lshift, descr_rlshift = _make_descr_binop(_lshift, _int_lshift) @@ -476,11 +484,11 @@ raise oefmt(space.w_OverflowError, "shift count too large") return newlong(space, self.num.rshift(shift)) - def _int_rshift(self, space, w_other): - if w_other < 0: + def _int_rshift(self, space, other): + if other < 0: raise oefmt(space.w_ValueError, "negative shift count") - return newlong(space, self.num.rshift(w_other)) + return newlong(space, self.num.rshift(other)) descr_rshift, descr_rrshift = _make_descr_binop(_rshift, _int_rshift) def _floordiv(self, space, w_other): @@ -491,17 +499,18 @@ "long division or modulo by zero") return newlong(space, z) - def _floordiv(self, space, w_other): + def _int_floordiv(self, space, other): try: - z = self.num.floordiv(w_other.asbigint()) + z = self.num.int_floordiv(other) except ZeroDivisionError: raise oefmt(space.w_ZeroDivisionError, "long division or modulo by zero") return newlong(space, z) - descr_floordiv, descr_rfloordiv = _make_descr_binop(_floordiv) + descr_floordiv, descr_rfloordiv = _make_descr_binop(_floordiv, _int_floordiv) _div = func_with_new_name(_floordiv, '_div') - descr_div, descr_rdiv = _make_descr_binop(_div) + _int_div = func_with_new_name(_int_floordiv, '_int_div') + descr_div, descr_rdiv = _make_descr_binop(_div, _int_div) def _mod(self, space, w_other): try: @@ -511,9 +520,9 @@ "long division or modulo by zero") return newlong(space, z) - def _int_mod(self, space, w_other): + def _int_mod(self, space, other): try: - z = self.num.int_mod(w_other) + z = self.num.int_mod(other) except ZeroDivisionError: raise oefmt(space.w_ZeroDivisionError, "long division or modulo by zero") @@ -527,7 +536,16 @@ raise oefmt(space.w_ZeroDivisionError, "long division or modulo by zero") return space.newtuple([newlong(space, div), newlong(space, mod)]) - descr_divmod, descr_rdivmod = _make_descr_binop(_divmod) + + def _int_divmod(self, space, other): + try: + div, mod = self.num.int_divmod(other) + except ZeroDivisionError: + raise oefmt(space.w_ZeroDivisionError, + "long division or modulo by zero") + return space.newtuple([newlong(space, div), newlong(space, mod)]) + + descr_divmod, descr_rdivmod = _make_descr_binop(_divmod, _int_divmod) def newlong(space, bigint): diff --git a/pypy/objspace/std/test/test_intobject.py b/pypy/objspace/std/test/test_intobject.py --- a/pypy/objspace/std/test/test_intobject.py +++ b/pypy/objspace/std/test/test_intobject.py @@ -679,6 +679,11 @@ x = int(321) assert x.__rlshift__(333) == 1422567365923326114875084456308921708325401211889530744784729710809598337369906606315292749899759616L + def test_some_rops(self): + import sys + x = int(-sys.maxint) + assert x.__rsub__(2) == (2 + sys.maxint) + class AppTestIntShortcut(AppTestInt): spaceconfig = {"objspace.std.intshortcut": True} diff --git a/rpython/rlib/rarithmetic.py b/rpython/rlib/rarithmetic.py --- a/rpython/rlib/rarithmetic.py +++ b/rpython/rlib/rarithmetic.py @@ -612,6 +612,7 @@ r_ulonglong = build_int('r_ulonglong', False, 64) r_longlonglong = build_int('r_longlonglong', True, 128) +r_ulonglonglong = build_int('r_ulonglonglong', False, 128) longlongmax = r_longlong(LONGLONG_TEST - 1) if r_longlong is not r_int: diff --git a/rpython/rlib/rbigint.py b/rpython/rlib/rbigint.py --- a/rpython/rlib/rbigint.py +++ b/rpython/rlib/rbigint.py @@ -27,6 +27,7 @@ else: UDIGIT_MASK = longlongmask LONG_TYPE = rffi.__INT128_T + ULONG_TYPE = rffi.__UINT128_T if LONG_BIT > SHIFT: STORE_TYPE = lltype.Signed UNSIGNED_TYPE = lltype.Unsigned @@ -40,6 +41,7 @@ STORE_TYPE = lltype.Signed UNSIGNED_TYPE = lltype.Unsigned LONG_TYPE = rffi.LONGLONG + ULONG_TYPE = rffi.ULONGLONG MASK = int((1 << SHIFT) - 1) FLOAT_MULTIPLIER = float(1 << SHIFT) @@ -97,6 +99,9 @@ def _widen_digit(x): return rffi.cast(LONG_TYPE, x) +def _unsigned_widen_digit(x): + return rffi.cast(ULONG_TYPE, x) + @specialize.argtype(0) def _store_digit(x): return rffi.cast(STORE_TYPE, x) @@ -108,6 +113,7 @@ NULLDIGIT = _store_digit(0) ONEDIGIT = _store_digit(1) +NULLDIGITS = [NULLDIGIT] def _check_digits(l): for x in l: @@ -133,22 +139,26 @@ def specialize_call(self, hop): hop.exception_cannot_occur() +def intsign(i): + return -1 if i < 0 else 1 class rbigint(object): """This is a reimplementation of longs using a list of digits.""" _immutable_ = True - _immutable_fields_ = ["_digits"] - - def __init__(self, digits=[NULLDIGIT], sign=0, size=0): + _immutable_fields_ = ["_digits[*]", "size", "sign"] + + def __init__(self, digits=NULLDIGITS, sign=0, size=0): if not we_are_translated(): _check_digits(digits) make_sure_not_resized(digits) self._digits = digits + assert size >= 0 self.size = size or len(digits) + self.sign = sign - # __eq__ and __ne__ method exist for testingl only, they are not RPython! + # __eq__ and __ne__ method exist for testing only, they are not RPython! @not_rpython def __eq__(self, other): if not isinstance(other, rbigint): @@ -159,6 +169,7 @@ def __ne__(self, other): return not (self == other) + @specialize.argtype(1) def digit(self, x): """Return the x'th digit, as an int.""" return self._digits[x] @@ -170,6 +181,12 @@ return _widen_digit(self._digits[x]) widedigit._always_inline_ = True + def uwidedigit(self, x): + """Return the x'th digit, as a long long int if needed + to have enough room to contain two digits.""" + return _unsigned_widen_digit(self._digits[x]) + uwidedigit._always_inline_ = True + def udigit(self, x): """Return the x'th digit, as an unsigned int.""" return _load_unsigned_digit(self._digits[x]) @@ -183,7 +200,9 @@ setdigit._always_inline_ = True def numdigits(self): - return self.size + w = self.size + assert w > 0 + return w numdigits._always_inline_ = True @staticmethod @@ -196,13 +215,15 @@ if intval < 0: sign = -1 ival = -r_uint(intval) + carry = ival >> SHIFT elif intval > 0: sign = 1 ival = r_uint(intval) + carry = 0 else: return NULLRBIGINT - carry = ival >> SHIFT + if carry: return rbigint([_store_digit(ival & MASK), _store_digit(carry)], sign, 2) @@ -509,23 +530,22 @@ return True @jit.elidable - def int_eq(self, other): + def int_eq(self, iother): """ eq with int """ - - if not int_in_valid_range(other): - # Fallback to Long. - return self.eq(rbigint.fromint(other)) + if not int_in_valid_range(iother): + # Fallback to Long. + return self.eq(rbigint.fromint(iother)) if self.numdigits() > 1: return False - return (self.sign * self.digit(0)) == other + return (self.sign * self.digit(0)) == iother def ne(self, other): return not self.eq(other) - def int_ne(self, other): - return not self.int_eq(other) + def int_ne(self, iother): + return not self.int_eq(iother) @jit.elidable def lt(self, other): @@ -563,59 +583,38 @@ return False @jit.elidable - def int_lt(self, other): + def int_lt(self, iother): """ lt where other is an int """ - if not int_in_valid_range(other): + if not int_in_valid_range(iother): # Fallback to Long. - return self.lt(rbigint.fromint(other)) - - osign = 1 - if other == 0: - osign = 0 - elif other < 0: - osign = -1 - - if self.sign > osign: - return False - elif self.sign < osign: - return True - - digits = self.numdigits() - - if digits > 1: - if osign == 1: - return False - else: - return True - - d1 = self.sign * self.digit(0) - if d1 < other: - return True - return False + return self.lt(rbigint.fromint(iother)) + + return _x_int_lt(self, iother, False) def le(self, other): return not other.lt(self) - def int_le(self, other): - # Alternative that might be faster, reimplant this. as a check with other + 1. But we got to check for overflow - # or reduce valid range. - - if self.int_eq(other): - return True - return self.int_lt(other) + def int_le(self, iother): + """ le where iother is an int """ + + if not int_in_valid_range(iother): + # Fallback to Long. + return self.le(rbigint.fromint(iother)) + + return _x_int_lt(self, iother, True) def gt(self, other): return other.lt(self) - def int_gt(self, other): - return not self.int_le(other) + def int_gt(self, iother): + return not self.int_le(iother) def ge(self, other): return not self.lt(other) - def int_ge(self, other): - return not self.int_lt(other) + def int_ge(self, iother): + return not self.int_lt(iother) @jit.elidable def hash(self): @@ -635,20 +634,20 @@ return result @jit.elidable - def int_add(self, other): - if not int_in_valid_range(other): + def int_add(self, iother): + if not int_in_valid_range(iother): # Fallback to long. - return self.add(rbigint.fromint(other)) + return self.add(rbigint.fromint(iother)) elif self.sign == 0: - return rbigint.fromint(other) - elif other == 0: + return rbigint.fromint(iother) + elif iother == 0: return self - sign = -1 if other < 0 else 1 + sign = intsign(iother) if self.sign == sign: - result = _x_int_add(self, other) + result = _x_int_add(self, iother) else: - result = _x_int_sub(self, other) + result = _x_int_sub(self, iother) result.sign *= -1 result.sign *= sign return result @@ -658,7 +657,7 @@ if other.sign == 0: return self elif self.sign == 0: - return rbigint(other._digits[:other.size], -other.sign, other.size) + return rbigint(other._digits[:other.numdigits()], -other.sign, other.numdigits()) elif self.sign == other.sign: result = _x_sub(self, other) else: @@ -667,93 +666,94 @@ return result @jit.elidable - def int_sub(self, other): - if not int_in_valid_range(other): + def int_sub(self, iother): + if not int_in_valid_range(iother): # Fallback to long. - return self.sub(rbigint.fromint(other)) - elif other == 0: + return self.sub(rbigint.fromint(iother)) + elif iother == 0: return self elif self.sign == 0: - return rbigint.fromint(-other) - elif self.sign == (-1 if other < 0 else 1): - result = _x_int_sub(self, other) + return rbigint.fromint(-iother) + elif self.sign == intsign(iother): + result = _x_int_sub(self, iother) else: - result = _x_int_add(self, other) + result = _x_int_add(self, iother) result.sign *= self.sign return result @jit.elidable - def mul(self, b): - asize = self.numdigits() - bsize = b.numdigits() - - a = self - - if asize > bsize: - a, b, asize, bsize = b, a, bsize, asize - - if a.sign == 0 or b.sign == 0: + def mul(self, other): + selfsize = self.numdigits() + othersize = other.numdigits() + + if selfsize > othersize: + self, other, selfsize, othersize = other, self, othersize, selfsize + + if self.sign == 0 or other.sign == 0: return NULLRBIGINT - if asize == 1: - if a._digits[0] == ONEDIGIT: - return rbigint(b._digits[:b.size], a.sign * b.sign, b.size) - elif bsize == 1: - res = b.widedigit(0) * a.widedigit(0) + if selfsize == 1: + if self._digits[0] == ONEDIGIT: + return rbigint(other._digits[:othersize], self.sign * other.sign, othersize) + elif othersize == 1: + res = other.uwidedigit(0) * self.udigit(0) carry = res >> SHIFT if carry: - return rbigint([_store_digit(res & MASK), _store_digit(carry)], a.sign * b.sign, 2) + return rbigint([_store_digit(res & MASK), _store_digit(carry)], self.sign * other.sign, 2) else: - return rbigint([_store_digit(res & MASK)], a.sign * b.sign, 1) - - result = _x_mul(a, b, a.digit(0)) + return rbigint([_store_digit(res & MASK)], self.sign * other.sign, 1) + + result = _x_mul(self, other, self.digit(0)) elif USE_KARATSUBA: - if a is b: + if self is other: i = KARATSUBA_SQUARE_CUTOFF else: i = KARATSUBA_CUTOFF - if asize <= i: - result = _x_mul(a, b) - """elif 2 * asize <= bsize: - result = _k_lopsided_mul(a, b)""" + if selfsize <= i: + result = _x_mul(self, other) + """elif 2 * selfsize <= othersize: + result = _k_lopsided_mul(self, other)""" else: - result = _k_mul(a, b) + result = _k_mul(self, other) else: - result = _x_mul(a, b) - - result.sign = a.sign * b.sign + result = _x_mul(self, other) + + result.sign = self.sign * other.sign return result @jit.elidable - def int_mul(self, b): - if not int_in_valid_range(b): + def int_mul(self, iother): + if not int_in_valid_range(iother): # Fallback to long. - return self.mul(rbigint.fromint(b)) - - if self.sign == 0 or b == 0: + return self.mul(rbigint.fromint(iother)) + + if self.sign == 0 or iother == 0: return NULLRBIGINT asize = self.numdigits() - digit = abs(b) - bsign = -1 if b < 0 else 1 + digit = abs(iother) + + othersign = intsign(iother) if digit == 1: - return rbigint(self._digits[:self.size], self.sign * bsign, asize) + if othersign == 1: + return self + return rbigint(self._digits[:asize], self.sign * othersign, asize) elif asize == 1: - res = self.widedigit(0) * digit + udigit = r_uint(digit) + res = self.uwidedigit(0) * udigit carry = res >> SHIFT if carry: - return rbigint([_store_digit(res & MASK), _store_digit(carry)], self.sign * bsign, 2) + return rbigint([_store_digit(res & MASK), _store_digit(carry)], self.sign * othersign, 2) else: - return rbigint([_store_digit(res & MASK)], self.sign * bsign, 1) - + return rbigint([_store_digit(res & MASK)], self.sign * othersign, 1) elif digit & (digit - 1) == 0: result = self.lqshift(ptwotable[digit]) else: result = _muladd1(self, digit) - result.sign = self.sign * bsign + result.sign = self.sign * othersign return result @jit.elidable @@ -763,12 +763,10 @@ @jit.elidable def floordiv(self, other): - if self.sign == 1 and other.numdigits() == 1 and other.sign == 1: - digit = other.digit(0) - if digit == 1: - return rbigint(self._digits[:self.size], 1, self.size) - elif digit and digit & (digit - 1) == 0: - return self.rshift(ptwotable[digit]) + if other.numdigits() == 1: + otherint = other.digit(0) * other.sign + assert int_in_valid_range(otherint) + return self.int_floordiv(otherint) div, mod = _divrem(self, other) if mod.sign * other.sign == -1: @@ -782,6 +780,37 @@ return self.floordiv(other) @jit.elidable + def int_floordiv(self, iother): + if not int_in_valid_range(iother): + # Fallback to long. + return self.floordiv(rbigint.fromint(iother)) + + if iother == 0: + raise ZeroDivisionError("long division by zero") + + digit = abs(iother) + assert digit > 0 + + if self.sign == 1 and iother > 0: + if digit == 1: + return self + elif digit & (digit - 1) == 0: + return self.rqshift(ptwotable[digit]) + + div, mod = _divrem1(self, digit) + + if mod != 0 and self.sign * intsign(iother) == -1: + if div.sign == 0: + return ONENEGATIVERBIGINT + div = div.int_add(1) + div.sign = self.sign * intsign(iother) + div._normalize() + return div + + def int_div(self, iother): + return self.int_floordiv(iother) + + @jit.elidable def mod(self, other): if other.sign == 0: raise ZeroDivisionError("long division or modulo by zero") @@ -799,50 +828,50 @@ return mod @jit.elidable - def int_mod(self, other): - if other == 0: + def int_mod(self, iother): + if iother == 0: raise ZeroDivisionError("long division or modulo by zero") if self.sign == 0: return NULLRBIGINT - elif not int_in_valid_range(other): + elif not int_in_valid_range(iother): # Fallback to long. - return self.mod(rbigint.fromint(other)) + return self.mod(rbigint.fromint(iother)) if 1: # preserve indentation to preserve history - digit = abs(other) + digit = abs(iother) if digit == 1: return NULLRBIGINT elif digit == 2: modm = self.digit(0) & 1 if modm: - return ONENEGATIVERBIGINT if other < 0 else ONERBIGINT + return ONENEGATIVERBIGINT if iother < 0 else ONERBIGINT return NULLRBIGINT elif digit & (digit - 1) == 0: mod = self.int_and_(digit - 1) else: # Perform - size = self.numdigits() - 1 + size = UDIGIT_TYPE(self.numdigits() - 1) if size > 0: - rem = self.widedigit(size) - size -= 1 - while size >= 0: - rem = ((rem << SHIFT) + self.widedigit(size)) % digit + wrem = self.widedigit(size) + while size > 0: size -= 1 + wrem = ((wrem << SHIFT) | self.digit(size)) % digit + rem = _store_digit(wrem) else: - rem = self.digit(0) % digit + rem = _store_digit(self.digit(0) % digit) if rem == 0: return NULLRBIGINT - mod = rbigint([_store_digit(rem)], -1 if self.sign < 0 else 1, 1) - - if mod.sign * (-1 if other < 0 else 1) == -1: - mod = mod.int_add(other) + mod = rbigint([rem], -1 if self.sign < 0 else 1, 1) + + if mod.sign * intsign(iother) == -1: + mod = mod.int_add(iother) return mod @jit.elidable - def divmod(v, w): + def divmod(self, other): """ The / and % operators are now defined in terms of divmod(). The expression a mod b has the value a - b*floor(a/b). @@ -859,46 +888,78 @@ have different signs. We then subtract one from the 'div' part of the outcome to keep the invariant intact. """ - div, mod = _divrem(v, w) - if mod.sign * w.sign == -1: - mod = mod.add(w) + div, mod = _divrem(self, other) + if mod.sign * other.sign == -1: + mod = mod.add(other) if div.sign == 0: return ONENEGATIVERBIGINT, mod div = div.int_sub(1) return div, mod @jit.elidable - def pow(a, b, c=None): + def int_divmod(self, iother): + """ Divmod with int """ + + if iother == 0: + raise ZeroDivisionError("long division or modulo by zero") + + wsign = intsign(iother) + if not int_in_valid_range(iother) or (wsign == -1 and self.sign != wsign): + # Just fallback. + return self.divmod(rbigint.fromint(iother)) + + digit = abs(iother) + assert digit > 0 + + div, mod = _divrem1(self, digit) + # _divrem1 doesn't fix the sign + if div.size == 1 and div._digits[0] == NULLDIGIT: + div.sign = 0 + else: + div.sign = self.sign * wsign + if self.sign < 0: + mod = -mod + if mod and self.sign * wsign == -1: + mod += iother + if div.sign == 0: + div = ONENEGATIVERBIGINT + else: + div = div.int_sub(1) + mod = rbigint.fromint(mod) + return div, mod + + @jit.elidable + def pow(self, other, modulus=None): negativeOutput = False # if x<0 return negative output # 5-ary values. If the exponent is large enough, table is - # precomputed so that table[i] == a**i % c for i in range(32). + # precomputed so that table[i] == self**i % modulus for i in range(32). # python translation: the table is computed when needed. - if b.sign < 0: # if exponent is negative - if c is not None: + if other.sign < 0: # if exponent is negative + if modulus is not None: raise TypeError( "pow() 2nd argument " "cannot be negative when 3rd argument specified") # XXX failed to implement raise ValueError("bigint pow() too negative") - size_b = b.numdigits() - - if c is not None: - if c.sign == 0: + size_b = UDIGIT_TYPE(other.numdigits()) + + if modulus is not None: + if modulus.sign == 0: raise ValueError("pow() 3rd argument cannot be 0") # if modulus < 0: # negativeOutput = True # modulus = -modulus - if c.sign < 0: + if modulus.sign < 0: negativeOutput = True - c = c.neg() + modulus = modulus.neg() # if modulus == 1: # return 0 - if c.numdigits() == 1 and c._digits[0] == ONEDIGIT: + if modulus.numdigits() == 1 and modulus._digits[0] == ONEDIGIT: return NULLRBIGINT # Reduce base by modulus in some cases: @@ -910,63 +971,61 @@ # base % modulus instead. # We could _always_ do this reduction, but mod() isn't cheap, # so we only do it when it buys something. - if a.sign < 0 or a.numdigits() > c.numdigits(): - a = a.mod(c) - - elif b.sign == 0: + if self.sign < 0 or self.numdigits() > modulus.numdigits(): + self = self.mod(modulus) + elif other.sign == 0: return ONERBIGINT - elif a.sign == 0: + elif self.sign == 0: return NULLRBIGINT elif size_b == 1: - if b._digits[0] == NULLDIGIT: - return ONERBIGINT if a.sign == 1 else ONENEGATIVERBIGINT - elif b._digits[0] == ONEDIGIT: - return a - elif a.numdigits() == 1: - adigit = a.digit(0) - digit = b.digit(0) + if other._digits[0] == ONEDIGIT: + return self + elif self.numdigits() == 1 and modulus is None: + adigit = self.digit(0) + digit = other.digit(0) if adigit == 1: - if a.sign == -1 and digit % 2: + if self.sign == -1 and digit % 2: return ONENEGATIVERBIGINT return ONERBIGINT elif adigit & (adigit - 1) == 0: - ret = a.lshift(((digit-1)*(ptwotable[adigit]-1)) + digit-1) - if a.sign == -1 and not digit % 2: + ret = self.lshift(((digit-1)*(ptwotable[adigit]-1)) + digit-1) + if self.sign == -1 and not digit % 2: ret.sign = 1 return ret - # At this point a, b, and c are guaranteed non-negative UNLESS - # c is NULL, in which case a may be negative. */ - - z = rbigint([ONEDIGIT], 1, 1) + # At this point self, other, and modulus are guaranteed non-negative UNLESS + # modulus is NULL, in which case self may be negative. */ + + z = ONERBIGINT # python adaptation: moved macros REDUCE(X) and MULT(X, Y, result) # into helper function result = _help_mult(x, y, c) if size_b <= FIVEARY_CUTOFF: # Left-to-right binary exponentiation (HAC Algorithm 14.79) # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf - size_b -= 1 - while size_b >= 0: - bi = b.digit(size_b) + + while size_b > 0: + size_b -= 1 + bi = other.digit(size_b) j = 1 << (SHIFT-1) while j != 0: - z = _help_mult(z, z, c) + z = _help_mult(z, z, modulus) if bi & j: - z = _help_mult(z, a, c) + z = _help_mult(z, self, modulus) j >>= 1 - size_b -= 1 + else: # Left-to-right 5-ary exponentiation (HAC Algorithm 14.82) - # This is only useful in the case where c != None. + # This is only useful in the case where modulus != None. # z still holds 1L table = [z] * 32 table[0] = z for i in range(1, 32): - table[i] = _help_mult(table[i-1], a, c) + table[i] = _help_mult(table[i-1], self, modulus) # Note that here SHIFT is not a multiple of 5. The difficulty - # is to extract 5 bits at a time from 'b', starting from the + # is to extract 5 bits at a time from 'other', starting from the # most significant digits, so that at the end of the algorithm # it falls exactly to zero. # m = max number of bits = i * SHIFT @@ -985,37 +1044,120 @@ index = (accum >> j) & 0x1f else: # 'accum' does not have enough digit. - # must get the next digit from 'b' in order to complete + # must get the next digit from 'other' in order to complete if size_b == 0: break # Done size_b -= 1 assert size_b >= 0 - bi = b.udigit(size_b) + bi = other.udigit(size_b) index = ((accum << (-j)) | (bi >> (j+SHIFT))) & 0x1f accum = bi j += SHIFT # for k in range(5): - z = _help_mult(z, z, c) + z = _help_mult(z, z, modulus) if index: - z = _help_mult(z, table[index], c) + z = _help_mult(z, table[index], modulus) # assert j == -5 if negativeOutput and z.sign != 0: - z = z.sub(c) + z = z.sub(modulus) + return z + + @jit.elidable + def int_pow(self, iother, modulus=None): + negativeOutput = False # if x<0 return negative output + + # 5-ary values. If the exponent is large enough, table is + # precomputed so that table[i] == self**i % modulus for i in range(32). + # python translation: the table is computed when needed. + + if iother < 0: # if exponent is negative + if modulus is not None: + raise TypeError( + "pow() 2nd argument " + "cannot be negative when 3rd argument specified") + # XXX failed to implement + raise ValueError("bigint pow() too negative") + + assert iother >= 0 + if modulus is not None: + if modulus.sign == 0: + raise ValueError("pow() 3rd argument cannot be 0") + + # if modulus < 0: + # negativeOutput = True + # modulus = -modulus + if modulus.sign < 0: + negativeOutput = True + modulus = modulus.neg() + + # if modulus == 1: + # return 0 + if modulus.numdigits() == 1 and modulus._digits[0] == ONEDIGIT: + return NULLRBIGINT + + # Reduce base by modulus in some cases: + # 1. If base < 0. Forcing the base non-neg makes things easier. + # 2. If base is obviously larger than the modulus. The "small + # exponent" case later can multiply directly by base repeatedly, + # while the "large exponent" case multiplies directly by base 31 + # times. It can be unboundedly faster to multiply by + # base % modulus instead. + # We could _always_ do this reduction, but mod() isn't cheap, + # so we only do it when it buys something. + if self.sign < 0 or self.numdigits() > modulus.numdigits(): + self = self.mod(modulus) + elif iother == 0: + return ONERBIGINT + elif self.sign == 0: + return NULLRBIGINT + elif iother == 1: + return self + elif self.numdigits() == 1: + adigit = self.digit(0) + if adigit == 1: + if self.sign == -1 and iother % 2: + return ONENEGATIVERBIGINT + return ONERBIGINT + elif adigit & (adigit - 1) == 0: + ret = self.lshift(((iother-1)*(ptwotable[adigit]-1)) + iother-1) + if self.sign == -1 and not iother % 2: + ret.sign = 1 + return ret + + # At this point self, iother, and modulus are guaranteed non-negative UNLESS + # modulus is NULL, in which case self may be negative. */ + + z = ONERBIGINT + + # python adaptation: moved macros REDUCE(X) and MULT(X, Y, result) + # into helper function result = _help_mult(x, y, modulus) + # Left-to-right binary exponentiation (HAC Algorithm 14.79) + # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf + j = 1 << (SHIFT-1) + + while j != 0: + z = _help_mult(z, z, modulus) + if iother & j: + z = _help_mult(z, self, modulus) + j >>= 1 + + if negativeOutput and z.sign != 0: + z = z.sub(modulus) return z @jit.elidable def neg(self): - return rbigint(self._digits, -self.sign, self.size) + return rbigint(self._digits, -self.sign, self.numdigits()) @jit.elidable def abs(self): if self.sign != -1: return self - return rbigint(self._digits, 1, self.size) + return rbigint(self._digits, 1, self.numdigits()) @jit.elidable def invert(self): #Implement ~x as -(x + 1) @@ -1041,15 +1183,15 @@ # So we can avoid problems with eq, AND avoid the need for normalize. if self.sign == 0: return self - return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign, self.size + wordshift) + return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign, self.numdigits() + wordshift) oldsize = self.numdigits() newsize = oldsize + wordshift + 1 z = rbigint([NULLDIGIT] * newsize, self.sign, newsize) - accum = _widen_digit(0) + accum = _unsigned_widen_digit(0) j = 0 while j < oldsize: - accum += self.widedigit(j) << remshift + accum += self.uwidedigit(j) << remshift z.setdigit(wordshift, accum) accum >>= SHIFT wordshift += 1 @@ -1061,7 +1203,7 @@ z._normalize() return z - lshift._always_inline_ = True # It's so fast that it's always benefitial. + lshift._always_inline_ = True # It's so fast that it's always beneficial. @jit.elidable def lqshift(self, int_other): @@ -1071,17 +1213,17 @@ oldsize = self.numdigits() z = rbigint([NULLDIGIT] * (oldsize + 1), self.sign, (oldsize + 1)) - accum = _widen_digit(0) + accum = _unsigned_widen_digit(0) i = 0 while i < oldsize: - accum += self.widedigit(i) << int_other + accum += self.uwidedigit(i) << int_other z.setdigit(i, accum) accum >>= SHIFT i += 1 z.setdigit(oldsize, accum) z._normalize() return z - lqshift._always_inline_ = True # It's so fast that it's always benefitial. + lqshift._always_inline_ = True # It's so fast that it's always beneficial. @jit.elidable def rshift(self, int_other, dont_invert=False): @@ -1112,6 +1254,31 @@ z._normalize() return z rshift._always_inline_ = 'try' # It's so fast that it's always benefitial. + + @jit.elidable + def rqshift(self, int_other): + wordshift = int_other / SHIFT + loshift = int_other % SHIFT + newsize = self.numdigits() - wordshift + + if newsize <= 0: + return NULLRBIGINT + + hishift = SHIFT - loshift + z = rbigint([NULLDIGIT] * newsize, self.sign, newsize) + i = 0 + + while i < newsize: + digit = self.udigit(wordshift) + newdigit = (digit >> loshift) + if i+1 < newsize: + newdigit |= (self.udigit(wordshift+1) << hishift) + z.setdigit(i, newdigit) + i += 1 + wordshift += 1 + z._normalize() + return z + rshift._always_inline_ = 'try' # It's so fast that it's always beneficial. @jit.elidable def abs_rshift_and_mask(self, bigshiftcount, mask): @@ -1167,24 +1334,24 @@ return _bitwise(self, '&', other) @jit.elidable - def int_and_(self, other): - return _int_bitwise(self, '&', other) + def int_and_(self, iother): + return _int_bitwise(self, '&', iother) @jit.elidable def xor(self, other): return _bitwise(self, '^', other) @jit.elidable - def int_xor(self, other): - return _int_bitwise(self, '^', other) + def int_xor(self, iother): + return _int_bitwise(self, '^', iother) @jit.elidable def or_(self, other): return _bitwise(self, '|', other) @jit.elidable - def int_or_(self, other): - return _int_bitwise(self, '|', other) + def int_or_(self, iother): + return _int_bitwise(self, '|', iother) @jit.elidable def oct(self): @@ -1218,7 +1385,10 @@ for d in digits: l = l << SHIFT l += intmask(d) - return l * self.sign + result = l * self.sign + if result == 0: + assert self.sign == 0 + return result def _normalize(self): i = self.numdigits() @@ -1227,11 +1397,10 @@ i -= 1 assert i > 0 - if i != self.numdigits(): - self.size = i - if self.numdigits() == 1 and self._digits[0] == NULLDIGIT: + self.size = i + if i == 1 and self._digits[0] == NULLDIGIT: self.sign = 0 - self._digits = [NULLDIGIT] + self._digits = NULLDIGITS _normalize._always_inline_ = True @@ -1256,8 +1425,8 @@ def __repr__(self): return "<rbigint digits=%s, sign=%s, size=%d, len=%d, %s>" % (self._digits, - self.sign, self.size, len(self._digits), - self.str()) + self.sign, self.numdigits(), len(self._digits), + self.tolong()) ONERBIGINT = rbigint([ONEDIGIT], 1, 1) ONENEGATIVERBIGINT = rbigint([ONEDIGIT], -1, 1) @@ -1322,7 +1491,7 @@ if x > 0: return digits_from_nonneg_long(x), 1 elif x == 0: - return [NULLDIGIT], 0 + return NULLDIGITS, 0 elif x != most_neg_value_of_same_type(x): # normal case return digits_from_nonneg_long(-x), -1 @@ -1340,7 +1509,7 @@ def args_from_long(x): if x >= 0: if x == 0: - return [NULLDIGIT], 0 + return NULLDIGITS, 0 else: return digits_from_nonneg_long(x), 1 else: @@ -1450,7 +1619,7 @@ if adigit == bdigit: return NULLRBIGINT - + return rbigint.fromint(adigit - bdigit) z = rbigint([NULLDIGIT] * size_a, 1, size_a) @@ -1497,11 +1666,11 @@ z = rbigint([NULLDIGIT] * (size_a + size_b), 1) i = UDIGIT_TYPE(0) while i < size_a: - f = a.widedigit(i) + f = a.uwidedigit(i) pz = i << 1 pa = i + 1 - carry = z.widedigit(pz) + f * f + carry = z.uwidedigit(pz) + f * f z.setdigit(pz, carry) pz += 1 carry >>= SHIFT @@ -1511,18 +1680,18 @@ # pyramid it appears. Same as adding f<<1 once. f <<= 1 while pa < size_a: - carry += z.widedigit(pz) + a.widedigit(pa) * f + carry += z.uwidedigit(pz) + a.uwidedigit(pa) * f pa += 1 z.setdigit(pz, carry) pz += 1 carry >>= SHIFT if carry: - carry += z.widedigit(pz) + carry += z.udigit(pz) z.setdigit(pz, carry) pz += 1 carry >>= SHIFT if carry: - z.setdigit(pz, z.widedigit(pz) + carry) + z.setdigit(pz, z.udigit(pz) + carry) assert (carry >> SHIFT) == 0 i += 1 z._normalize() @@ -1543,29 +1712,29 @@ size_a1 = UDIGIT_TYPE(size_a - 1) size_b1 = UDIGIT_TYPE(size_b - 1) while i < size_a1: - f0 = a.widedigit(i) - f1 = a.widedigit(i + 1) + f0 = a.uwidedigit(i) + f1 = a.uwidedigit(i + 1) pz = i - carry = z.widedigit(pz) + b.widedigit(0) * f0 + carry = z.uwidedigit(pz) + b.uwidedigit(0) * f0 z.setdigit(pz, carry) pz += 1 carry >>= SHIFT j = UDIGIT_TYPE(0) while j < size_b1: - # this operation does not overflow using + # this operation does not overflow using # SHIFT = (LONG_BIT // 2) - 1 = B - 1; in fact before it # carry and z.widedigit(pz) are less than 2**(B - 1); # b.widedigit(j + 1) * f0 < (2**(B-1) - 1)**2; so # carry + z.widedigit(pz) + b.widedigit(j + 1) * f0 + # b.widedigit(j) * f1 < 2**(2*B - 1) - 2**B < 2**LONG)BIT - 1 - carry += z.widedigit(pz) + b.widedigit(j + 1) * f0 + \ - b.widedigit(j) * f1 + carry += z.uwidedigit(pz) + b.uwidedigit(j + 1) * f0 + \ + b.uwidedigit(j) * f1 z.setdigit(pz, carry) pz += 1 carry >>= SHIFT j += 1 # carry < 2**(B + 1) - 2 - carry += z.widedigit(pz) + b.widedigit(size_b1) * f1 + carry += z.uwidedigit(pz) + b.uwidedigit(size_b1) * f1 z.setdigit(pz, carry) pz += 1 carry >>= SHIFT @@ -1576,17 +1745,17 @@ i += 2 if size_a & 1: pz = size_a1 - f = a.widedigit(pz) + f = a.uwidedigit(pz) pb = 0 - carry = _widen_digit(0) + carry = _unsigned_widen_digit(0) while pb < size_b: - carry += z.widedigit(pz) + b.widedigit(pb) * f + carry += z.uwidedigit(pz) + b.uwidedigit(pb) * f pb += 1 z.setdigit(pz, carry) pz += 1 carry >>= SHIFT if carry: - z.setdigit(pz, z.widedigit(pz) + carry) + z.setdigit(pz, z.udigit(pz) + carry) z._normalize() return z @@ -1602,8 +1771,8 @@ size_lo = min(size_n, size) # We use "or" her to avoid having a check where list can be empty in _normalize. - lo = rbigint(n._digits[:size_lo] or [NULLDIGIT], 1) - hi = rbigint(n._digits[size_lo:n.size] or [NULLDIGIT], 1) + lo = rbigint(n._digits[:size_lo] or NULLDIGITS, 1) + hi = rbigint(n._digits[size_lo:size_n] or NULLDIGITS, 1) lo._normalize() hi._normalize() return hi, lo @@ -1708,113 +1877,16 @@ ret._normalize() return ret -""" (*) Why adding t3 can't "run out of room" above. - -Let f(x) mean the floor of x and c(x) mean the ceiling of x. Some facts -to start with: - -1. For any integer i, i = c(i/2) + f(i/2). In particular, - bsize = c(bsize/2) + f(bsize/2). -2. shift = f(bsize/2) -3. asize <= bsize -4. Since we call k_lopsided_mul if asize*2 <= bsize, asize*2 > bsize in this - routine, so asize > bsize/2 >= f(bsize/2) in this routine. - -We allocated asize + bsize result digits, and add t3 into them at an offset -of shift. This leaves asize+bsize-shift allocated digit positions for t3 -to fit into, = (by #1 and #2) asize + f(bsize/2) + c(bsize/2) - f(bsize/2) = -asize + c(bsize/2) available digit positions. - -bh has c(bsize/2) digits, and bl at most f(size/2) digits. So bh+hl has -at most c(bsize/2) digits + 1 bit. - -If asize == bsize, ah has c(bsize/2) digits, else ah has at most f(bsize/2) -digits, and al has at most f(bsize/2) digits in any case. So ah+al has at -most (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 1 bit. - -The product (ah+al)*(bh+bl) therefore has at most - - c(bsize/2) + (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits - -and we have asize + c(bsize/2) available digit positions. We need to show -this is always enough. An instance of c(bsize/2) cancels out in both, so -the question reduces to whether asize digits is enough to hold -(asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits. If asize < bsize, -then we're asking whether asize digits >= f(bsize/2) digits + 2 bits. By #4, -asize is at least f(bsize/2)+1 digits, so this in turn reduces to whether 1 -digit is enough to hold 2 bits. This is so since SHIFT=15 >= 2. If -asize == bsize, then we're asking whether bsize digits is enough to hold -c(bsize/2) digits + 2 bits, or equivalently (by #1) whether f(bsize/2) digits -is enough to hold 2 bits. This is so if bsize >= 2, which holds because -bsize >= KARATSUBA_CUTOFF >= 2. - -Note that since there's always enough room for (ah+al)*(bh+bl), and that's -clearly >= each of ah*bh and al*bl, there's always enough room to subtract -ah*bh and al*bl too. -""" - -def _k_lopsided_mul(a, b): - # Not in use anymore, only account for like 1% performance. Perhaps if we - # Got rid of the extra list allocation this would be more effective. - """ - b has at least twice the digits of a, and a is big enough that Karatsuba - would pay off *if* the inputs had balanced sizes. View b as a sequence - of slices, each with a->ob_size digits, and multiply the slices by a, - one at a time. This gives k_mul balanced inputs to work with, and is - also cache-friendly (we compute one double-width slice of the result - at a time, then move on, never bactracking except for the helpful - single-width slice overlap between successive partial sums). - """ - asize = a.numdigits() - bsize = b.numdigits() - # nbdone is # of b digits already multiplied - - assert asize > KARATSUBA_CUTOFF - assert 2 * asize <= bsize - - # Allocate result space, and zero it out. - ret = rbigint([NULLDIGIT] * (asize + bsize), 1) - - # Successive slices of b are copied into bslice. - #bslice = rbigint([0] * asize, 1) - # XXX we cannot pre-allocate, see comments below! - # XXX prevent one list from being created. - bslice = rbigint(sign=1) - - nbdone = 0 - while bsize > 0: - nbtouse = min(bsize, asize) - - # Multiply the next slice of b by a. - - #bslice.digits[:nbtouse] = b.digits[nbdone : nbdone + nbtouse] - # XXX: this would be more efficient if we adopted CPython's - # way to store the size, instead of resizing the list! - # XXX change the implementation, encoding length via the sign. - bslice._digits = b._digits[nbdone : nbdone + nbtouse] - bslice.size = nbtouse - product = _k_mul(a, bslice) - - # Add into result. - _v_iadd(ret, nbdone, ret.numdigits() - nbdone, - product, product.numdigits()) - - bsize -= nbtouse - nbdone += nbtouse - - ret._normalize() - return ret - def _inplace_divrem1(pout, pin, n): """ Divide bigint pin by non-zero digit n, storing quotient in pout, and returning the remainder. It's OK for pin == pout on entry. """ - rem = _widen_digit(0) + rem = _unsigned_widen_digit(0) assert n > 0 and n <= MASK size = pin.numdigits() - 1 while size >= 0: - rem = (rem << SHIFT) | pin.widedigit(size) + rem = (rem << SHIFT) | pin.udigit(size) hi = rem // n pout.setdigit(size, hi) rem -= hi * n @@ -1891,14 +1963,15 @@ def _muladd1(a, n, extra=0): """Multiply by a single digit and add a single digit, ignoring the sign. """ + assert n > 0 size_a = a.numdigits() z = rbigint([NULLDIGIT] * (size_a+1), 1) assert extra & MASK == extra - carry = _widen_digit(extra) + carry = _unsigned_widen_digit(extra) i = 0 while i < size_a: - carry += a.widedigit(i) * n + carry += a.uwidedigit(i) * n z.setdigit(i, carry) carry >>= SHIFT i += 1 @@ -1912,10 +1985,10 @@ """ carry = 0 - assert 0 <= d and d < SHIFT + #assert 0 <= d and d < SHIFT i = 0 while i < m: - acc = a.widedigit(i) << d | carry + acc = a.uwidedigit(i) << d | carry z.setdigit(i, acc) carry = acc >> SHIFT i += 1 @@ -1927,14 +2000,14 @@ * result in z[0:m], and return the d bits shifted out of the bottom. """ - carry = _widen_digit(0) - acc = _widen_digit(0) + carry = _unsigned_widen_digit(0) + acc = _unsigned_widen_digit(0) mask = (1 << d) - 1 - assert 0 <= d and d < SHIFT + #assert 0 <= d and d < SHIFT i = m-1 while i >= 0: - acc = (carry << SHIFT) | a.widedigit(i) + acc = (carry << SHIFT) | a.udigit(i) carry = acc & mask z.setdigit(i, acc >> d) i -= 1 @@ -1989,10 +2062,17 @@ else: vtop = v.widedigit(j) assert vtop <= wm1 + vv = (vtop << SHIFT) | v.widedigit(abs(j-1)) + + # Hints to make division just as fast as doing it unsigned. But avoids casting to get correct results. + assert vv >= 0 + assert wm1 >= 1 + q = vv / wm1 - r = vv - wm1 * q - while wm2 * q > ((r << SHIFT) | v.widedigit(abs(j-2))): + r = vv % wm1 # This seems to be slightly faster on widen digits than vv - wm1 * q. + vj2 = v.digit(abs(j-2)) + while wm2 * q > ((r << SHIFT) | vj2): q -= 1 r += wm1 @@ -2059,6 +2139,36 @@ rem.sign = - rem.sign return z, rem +def _x_int_lt(a, b, eq=False): + """ Compare bigint a with int b for less than or less than or equal """ + osign = 1 + if b == 0: + osign = 0 + elif b < 0: + osign = -1 + + if a.sign > osign: + return False + elif a.sign < osign: + return True + + digits = a.numdigits() + + if digits > 1: + if osign == 1: + return False + else: + return True + + d1 = a.sign * a.digit(0) + if eq: + if d1 <= b: + return True + else: + if d1 < b: + return True + return False + # ______________ conversions to double _______________ def _AsScaledDouble(v): @@ -2764,7 +2874,7 @@ elif s[p] == '+': p += 1 - a = rbigint() + a = NULLRBIGINT tens = 1 dig = 0 ord0 = ord('0') @@ -2785,7 +2895,7 @@ base = parser.base if (base & (base - 1)) == 0 and base >= 2: return parse_string_from_binary_base(parser) - a = rbigint() + a = NULLRBIGINT digitmax = BASE_MAX[base] tens, dig = 1, 0 while True: diff --git a/rpython/rlib/test/test_rbigint.py b/rpython/rlib/test/test_rbigint.py --- a/rpython/rlib/test/test_rbigint.py +++ b/rpython/rlib/test/test_rbigint.py @@ -95,6 +95,46 @@ r2 = op1 // op2 assert r1.tolong() == r2 + def test_int_floordiv(self): + x = 1000L + r = rbigint.fromlong(x) + r2 = r.int_floordiv(10) + assert r2.tolong() == 100L + + for op1 in gen_signs(long_vals): + for op2 in signed_int_vals: + if not op2: + continue + rl_op1 = rbigint.fromlong(op1) + r1 = rl_op1.int_floordiv(op2) + r2 = op1 // op2 + assert r1.tolong() == r2 + + assert pytest.raises(ZeroDivisionError, r.int_floordiv, 0) + + # Error pointed out by Armin Rigo + n = sys.maxint+1 + r = rbigint.fromlong(n) + assert r.int_floordiv(int(-n)).tolong() == -1L + + for x in int_vals: + if not x: + continue + r = rbigint.fromlong(x) + rn = rbigint.fromlong(-x) + res = r.int_floordiv(x) + res2 = r.int_floordiv(-x) + res3 = rn.int_floordiv(x) + assert res.tolong() == 1L + assert res2.tolong() == -1L + assert res3.tolong() == -1L + + def test_floordiv2(self): + n1 = rbigint.fromlong(sys.maxint + 1) + n2 = rbigint.fromlong(-(sys.maxint + 1)) + assert n1.floordiv(n2).tolong() == -1L + assert n2.floordiv(n1).tolong() == -1L + def test_truediv(self): for op1 in gen_signs(long_vals_not_too_big): rl_op1 = rbigint.fromlong(op1) @@ -185,9 +225,26 @@ r4 = pow(op1, op2, op3) assert r3.tolong() == r4 + def test_int_pow(self): + for op1 in gen_signs(long_vals_not_too_big): + rl_op1 = rbigint.fromlong(op1) + for op2 in [0, 1, 2, 8, 9, 10, 11, 127, 128, 129]: + r1 = rl_op1.int_pow(op2) + r2 = op1 ** op2 + assert r1.tolong() == r2 + + for op3 in gen_signs(long_vals_not_too_big): + if not op3: + continue + r3 = rl_op1.int_pow(op2, rbigint.fromlong(op3)) + r4 = pow(op1, op2, op3) + print op1, op2, op3 + assert r3.tolong() == r4 + def test_pow_raises(self): r1 = rbigint.fromint(2) r0 = rbigint.fromint(0) + py.test.raises(ValueError, r1.int_pow, 2, r0) py.test.raises(ValueError, r1.pow, r1, r0) def test_touint(self): @@ -601,6 +658,9 @@ # test special optimization case in rshift: assert rbigint.fromlong(-(1 << 100)).rshift(5).tolong() == -(1 << 100) >> 5 + # Chek value accuracy. + assert rbigint.fromlong(18446744073709551615L).rshift(1).tolong() == 18446744073709551615L >> 1 + def test_qshift(self): for x in range(10): for y in range(1, 161, 16): @@ -610,11 +670,18 @@ for z in range(1, 31): res1 = f1.lqshift(z).tolong() + res2 = f1.rqshift(z).tolong() res3 = nf1.lqshift(z).tolong() assert res1 == num << z + assert res2 == num >> z assert res3 == -num << z + # Large digit + for x in range((1 << SHIFT) - 10, (1 << SHIFT) + 10): + f1 = rbigint.fromlong(x) + assert f1.rqshift(SHIFT).tolong() == x >> SHIFT + assert f1.rqshift(SHIFT+1).tolong() == x >> (SHIFT+1) def test_from_list_n_bits(self): for x in ([3L ** 30L, 5L ** 20L, 7 ** 300] + @@ -864,6 +931,27 @@ assert rem.tolong() == _rem + def test_int_divmod(self): + for x in long_vals: + for y in int_vals + [-sys.maxint-1]: + if not y: + continue + for sx, sy in (1, 1), (1, -1), (-1, -1), (-1, 1): + sx *= x + sy *= y + if sy == sys.maxint + 1: + continue + f1 = rbigint.fromlong(sx) + div, rem = f1.int_divmod(sy) + div1, rem1 = f1.divmod(rbigint.fromlong(sy)) + _div, _rem = divmod(sx, sy) + print sx, sy, " | ", div.tolong(), rem.tolong() + assert div1.tolong() == _div + assert rem1.tolong() == _rem + assert div.tolong() == _div + assert rem.tolong() == _rem + py.test.raises(ZeroDivisionError, rbigint.fromlong(x).int_divmod, 0) + # testing Karatsuba stuff def test__v_iadd(self): f1 = bigint([lobj.MASK] * 10, 1) @@ -1067,8 +1155,14 @@ except Exception as e: pytest.raises(type(e), f1.pow, f2, f3) else: - v = f1.pow(f2, f3) - assert v.tolong() == res + v1 = f1.pow(f2, f3) + try: + v2 = f1.int_pow(f2.toint(), f3) + except OverflowError: + pass + else: + assert v2.tolong() == res + assert v1.tolong() == res @given(biglongs, biglongs) @example(510439143470502793407446782273075179618477362188870662225920, @@ -1088,6 +1182,18 @@ a, b = f1.divmod(f2) assert (a.tolong(), b.tolong()) == res + @given(biglongs, ints) + def test_int_divmod(self, x, iy): + f1 = rbigint.fromlong(x) + try: + res = divmod(x, iy) + except Exception as e: + pytest.raises(type(e), f1.int_divmod, iy) + else: + print x, iy + a, b = f1.int_divmod(iy) + assert (a.tolong(), b.tolong()) == res + @given(longs) def test_hash(self, x): # hash of large integers: should be equal to the hash of the @@ -1118,10 +1224,34 @@ assert ra.truediv(rb) == a / b @given(longs, longs) - def test_bitwise(self, x, y): + def test_bitwise_and_mul(self, x, y): lx = rbigint.fromlong(x) ly = rbigint.fromlong(y) - for mod in "xor and_ or_".split(): - res1 = getattr(lx, mod)(ly).tolong() + for mod in "xor and_ or_ mul".split(): + res1a = getattr(lx, mod)(ly).tolong() + res1b = getattr(ly, mod)(lx).tolong() + res2 = getattr(operator, mod)(x, y) + assert res1a == res2 + + @given(longs, ints) + def test_int_bitwise_and_mul(self, x, y): + lx = rbigint.fromlong(x) + for mod in "xor and_ or_ mul".split(): + res1 = getattr(lx, 'int_' + mod)(y).tolong() res2 = getattr(operator, mod)(x, y) assert res1 == res2 + + @given(longs, ints) + def test_int_comparison(self, x, y): + lx = rbigint.fromlong(x) + assert lx.int_lt(y) == (x < y) + assert lx.int_eq(y) == (x == y) + assert lx.int_le(y) == (x <= y) + + @given(longs, longs) + def test_int_comparison(self, x, y): + lx = rbigint.fromlong(x) + ly = rbigint.fromlong(y) + assert lx.lt(ly) == (x < y) + assert lx.eq(ly) == (x == y) + assert lx.le(ly) == (x <= y) diff --git a/rpython/rtyper/lltypesystem/ll2ctypes.py b/rpython/rtyper/lltypesystem/ll2ctypes.py --- a/rpython/rtyper/lltypesystem/ll2ctypes.py +++ b/rpython/rtyper/lltypesystem/ll2ctypes.py @@ -175,7 +175,16 @@ if res >= (1 << 127): res -= 1 << 128 return res + class c_uint128(ctypes.Array): # based on 2 ulongs + _type_ = ctypes.c_uint64 + _length_ = 2 + @property + def value(self): + res = self[0] | (self[1] << 64) + return res + _ctypes_cache[rffi.__INT128_T] = c_int128 + _ctypes_cache[rffi.__UINT128_T] = c_uint128 # for unicode strings, do not use ctypes.c_wchar because ctypes # automatically converts arrays into unicode strings. diff --git a/rpython/rtyper/lltypesystem/lloperation.py b/rpython/rtyper/lltypesystem/lloperation.py --- a/rpython/rtyper/lltypesystem/lloperation.py +++ b/rpython/rtyper/lltypesystem/lloperation.py @@ -324,6 +324,26 @@ 'lllong_rshift': LLOp(canfold=True), # args (r_longlonglong, int) 'lllong_xor': LLOp(canfold=True), + 'ulllong_is_true': LLOp(canfold=True), + 'ulllong_invert': LLOp(canfold=True), + + 'ulllong_add': LLOp(canfold=True), + 'ulllong_sub': LLOp(canfold=True), + 'ulllong_mul': LLOp(canfold=True), + 'ulllong_floordiv': LLOp(canfold=True), + 'ulllong_mod': LLOp(canfold=True), + 'ulllong_lt': LLOp(canfold=True), + 'ulllong_le': LLOp(canfold=True), + 'ulllong_eq': LLOp(canfold=True), + 'ulllong_ne': LLOp(canfold=True), + 'ulllong_gt': LLOp(canfold=True), + 'ulllong_ge': LLOp(canfold=True), + 'ulllong_and': LLOp(canfold=True), + 'ulllong_or': LLOp(canfold=True), + 'ulllong_lshift': LLOp(canfold=True), # args (r_ulonglonglong, int) + 'ulllong_rshift': LLOp(canfold=True), # args (r_ulonglonglong, int) + 'ulllong_xor': LLOp(canfold=True), + 'cast_primitive': LLOp(canfold=True), 'cast_bool_to_int': LLOp(canfold=True), 'cast_bool_to_uint': LLOp(canfold=True), diff --git a/rpython/rtyper/lltypesystem/lltype.py b/rpython/rtyper/lltypesystem/lltype.py --- a/rpython/rtyper/lltypesystem/lltype.py +++ b/rpython/rtyper/lltypesystem/lltype.py @@ -8,7 +8,7 @@ from rpython.rlib.rarithmetic import ( base_int, intmask, is_emulated_long, is_valid_int, longlonglongmask, longlongmask, maxint, normalizedinttype, r_int, r_longfloat, r_longlong, - r_longlonglong, r_singlefloat, r_uint, r_ulonglong) + r_longlonglong, r_singlefloat, r_uint, r_ulonglong, r_ulonglonglong) from rpython.rtyper.extregistry import ExtRegistryEntry from rpython.tool import leakfinder from rpython.tool.identity_dict import identity_dict @@ -676,6 +676,7 @@ _numbertypes[r_int] = _numbertypes[int] _numbertypes[r_longlonglong] = Number("SignedLongLongLong", r_longlonglong, longlonglongmask) + _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit