Author: stian Branch: improve-rbigint Changeset: r56320:398e2b212e8b Date: 2012-06-23 05:15 +0200 http://bitbucket.org/pypy/pypy/changeset/398e2b212e8b/
Log: Reorganize to make room for toom cock (WIP) This also gave the first pow(a,b,c) benchmark a boost. Perhaps since some tricks kicks in earlier diff --git a/pypy/rlib/rbigint.py b/pypy/rlib/rbigint.py --- a/pypy/rlib/rbigint.py +++ b/pypy/rlib/rbigint.py @@ -36,6 +36,8 @@ KARATSUBA_CUTOFF = 38 KARATSUBA_SQUARE_CUTOFF = 2 * KARATSUBA_CUTOFF +USE_TOOMCOCK = False +TOOMCOOK_CUTOFF = 102 # For exponentiation, use the binary left-to-right algorithm # unless the exponent contains more than FIVEARY_CUTOFF digits. @@ -365,12 +367,44 @@ result._normalize() return result - def mul(self, other): - if USE_KARATSUBA: - result = _k_mul(self, other) + def mul(self, b): + asize = self.numdigits() + bsize = b.numdigits() + + a = self + + if asize > bsize: + a, b, asize, bsize = b, a, bsize, asize + + if a.sign == 0 or b.sign == 0: + return rbigint() + + if asize == 1: + digit = a.digit(0) + if digit == 0: + return rbigint() + elif digit == 1: + return rbigint(b._digits[:], a.sign * b.sign) + + result = _x_mul(a, b, digit) + elif USE_TOOMCOCK and asize >= TOOMCOOK_CUTOFF: + result = _tc_mul(a, b) + elif USE_KARATSUBA: + if a is b: + i = KARATSUBA_SQUARE_CUTOFF + else: + i = KARATSUBA_CUTOFF + + if asize <= i: + result = _x_mul(a, b) + elif 2 * asize <= bsize: + result = _k_lopsided_mul(a, b) + else: + result = _k_mul(a, b) else: - result = _x_mul(self, other) - result.sign = self.sign * other.sign + result = _x_mul(a, b) + + result.sign = a.sign * b.sign return result def truediv(self, other): @@ -848,15 +882,6 @@ """ size_a = a.numdigits() - - if size_a == 1: - # Special case. - digit = a.digit(0) - if digit == 0: - return rbigint([NULLDIGIT], 1) - elif digit == 1: - return rbigint(b._digits[:], 1) # We assume b was normalized already. - size_b = b.numdigits() if a is b: @@ -927,6 +952,103 @@ return z +def _tcmul_split(n, size): + """ + A helper for Karatsuba multiplication (k_mul). + Takes a bigint "n" and an integer "size" representing the place to + split, and sets low and high such that abs(n) == (high << size) + low, + viewing the shift as being by digits. The sign bit is ignored, and + the return values are >= 0. + """ + + assert size > 0 + + size_n = n.numdigits() + shift = min(size_n, size) + + lo = rbigint(n._digits[:shift], 1) + if size_n >= (shift * 2): + mid = rbigint(n._digits[shift:shift >> 1], 1) + hi = rbigint(n._digits[shift >> 1:], 1) + else: + mid = rbigint(n._digits[shift:], 1) + hi = rbigint([NULLDIGIT] * ((shift * 3) - size_n), 1) + lo._normalize() + mid._normalize() + hi._normalize() + return hi, mid, lo + +# Declear a simple 2 as constants for our toom cook +POINT2 = rbigint.fromint(2) +def _tc_mul(a, b): + """ + Toom Cook + """ + asize = a.numdigits() + bsize = b.numdigits() + + # Split a & b into hi, mid and lo pieces. + shift = bsize >> 1 + ah, am, al = _tcmul_split(a, shift) + assert ah.sign == 1 # the split isn't degenerate + + if a is b: + bh = ah + bm = am + bl = al + else: + bh, bm, bl = _tcmul_split(b, shift) + + # 1. Allocate result space. + ret = rbigint([NULLDIGIT] * (asize + bsize), 1) + + # 2. w points + pO = al.add(ah) + p1 = pO.add(am) + pn1 = pO.sub(am) + pn2 = pn1.add(ah).mul(POINT2).sub(al) + + qO = bl.add(bh) + q1 = qO.add(bm) + qn1 = qO.sub(bm) + qn2 = qn1.add(bh).mul(POINT2).sub(bl) + + w0 = al.mul(bl) + winf = ah.mul(bh) + w1 = p1.mul(q1) + wn1 = pn1.mul(qn1) + wn2 = pn2.mul(qn2) + + # 3. The important stuff + # XXX: Need a faster / 3 and /2 like in GMP! + r0 = w0 + r4 = winf + r3 = _divrem1(wn2.sub(wn1), 3)[0] + r1 = _divrem1(w1.sub(wn1), 2)[0] + r2 = wn1.sub(w0) + r3 = _divrem1(r2.sub(r3), 2)[0].add(r4.mul(POINT2)) + r2 = r2.add(r1).sub(r4) + r1 = r1.sub(r3) + + # Now we fit r+ r2 + r4 into the new string. + # Now we got to add the r1 and r3 in the mid shift. This is TODO (aga, not fixed yet) + pointer = r0.numdigits() + ret._digits[:pointer] = r0._digits + + pointer2 = pointer + r2.numdigits() + ret._digits[pointer:pointer2] = r2._digits + + pointer3 = pointer2 + r4.numdigits() + ret._digits[pointer2:pointer3] = r4._digits + + # TODO!!!! + #_v_iadd(ret, shift, i, r1, r1.numdigits()) + #_v_iadd(ret, shift >> 1, i, r3, r3.numdigits()) + + ret._normalize() + return ret + + def _kmul_split(n, size): """ A helper for Karatsuba multiplication (k_mul). @@ -952,6 +1074,7 @@ """ asize = a.numdigits() bsize = b.numdigits() + # (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl # Let k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl # Then the original product is @@ -959,30 +1082,6 @@ # By picking X to be a power of 2, "*X" is just shifting, and it's # been reduced to 3 multiplies on numbers half the size. - # We want to split based on the larger number; fiddle so that b - # is largest. - if asize > bsize: - a, b, asize, bsize = b, a, bsize, asize - - # Use gradeschool math when either number is too small. - if a is b: - i = KARATSUBA_SQUARE_CUTOFF - else: - i = KARATSUBA_CUTOFF - if asize <= i: - if a.sign == 0: - return rbigint() # zero - else: - return _x_mul(a, b) - - # If a is small compared to b, splitting on b gives a degenerate - # case with ah==0, and Karatsuba may be (even much) less efficient - # than "grade school" then. However, we can still win, by viewing - # b as a string of "big digits", each of width a->ob_size. That - # leads to a sequence of balanced calls to k_mul. - if 2 * asize <= bsize: - return _k_lopsided_mul(a, b) - # Split a & b into hi & lo pieces. shift = bsize >> 1 ah, al = _kmul_split(a, shift) @@ -1013,7 +1112,7 @@ ret = rbigint([NULLDIGIT] * (asize + bsize), 1) # 2. t1 <- ah*bh, and copy into high digits of result. - t1 = _k_mul(ah, bh) + t1 = ah.mul(bh) assert t1.sign >= 0 assert 2*shift + t1.numdigits() <= ret.numdigits() ret._digits[2*shift : 2*shift + t1.numdigits()] = t1._digits @@ -1026,7 +1125,7 @@ ## i * sizeof(digit)); # 3. t2 <- al*bl, and copy into the low digits. - t2 = _k_mul(al, bl) + t2 = al.mul(bl) assert t2.sign >= 0 assert t2.numdigits() <= 2*shift # no overlap with high digits ret._digits[:t2.numdigits()] = t2._digits @@ -1051,7 +1150,7 @@ else: t2 = _x_add(bh, bl) - t3 = _k_mul(t1, t2) + t3 = t1.mul(t2) assert t3.sign >=0 # Add t3. It's not obvious why we can't run out of room here. diff --git a/pypy/translator/goal/targetbigintbenchmark.py b/pypy/translator/goal/targetbigintbenchmark.py --- a/pypy/translator/goal/targetbigintbenchmark.py +++ b/pypy/translator/goal/targetbigintbenchmark.py @@ -19,13 +19,13 @@ 6.647562 Pypy with improvements: - 6.048997 - 10.091559 - 14.680590 - 1.635417 - 12.023154 - 14.320596 - 6.439088 + 5.797121 + 10.068798 + 14.770187 + 1.620009 + 12.054951 + 14.292367 + 6.440351 """ _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit