Author: Stian Andreassen
Branch: improve-rbigint
Changeset: r56757:95fea225f922
Date: 2012-08-19 22:21 +0200
http://bitbucket.org/pypy/pypy/changeset/95fea225f922/

Log:    Progress?

diff --git a/pypy/rlib/rbigint.py b/pypy/rlib/rbigint.py
--- a/pypy/rlib/rbigint.py
+++ b/pypy/rlib/rbigint.py
@@ -109,9 +109,10 @@
         hop.exception_cannot_occur()
 
 class rbigint(object):
+    """This is a reimplementation of longs using a list of digits."""
     _immutable_ = True
     _immutable_fields_ = ["_digits"]
-    """This is a reimplementation of longs using a list of digits."""
+    
 
     def __init__(self, digits=[NULLDIGIT], sign=0, size=0):
         if not we_are_translated():
@@ -743,12 +744,12 @@
 
         z = rbigint([NULLDIGIT] * (oldsize + 1), self.sign, (oldsize + 1))
         accum = _widen_digit(0)
-
-        for i in range(oldsize):
+        i = 0
+        while i < oldsize:
             accum += self.widedigit(i) << int_other
             z.setdigit(i, accum)
             accum >>= SHIFT
-            
+            i += 1
         z.setdigit(oldsize, accum)
         z._normalize()
         return z
@@ -1105,6 +1106,84 @@
     z._normalize()
     return z
 
+def _x_mul(a, b, digit=0):
+    """
+    Grade school multiplication, ignoring the signs.
+    Returns the absolute value of the product, or None if error.
+    """
+
+    size_a = a.numdigits()
+    size_b = b.numdigits()
+
+    if a is b:
+        # Efficient squaring per HAC, Algorithm 14.16:
+        # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
+        # Gives slightly less than a 2x speedup when a == b,
+        # via exploiting that each entry in the multiplication
+        # pyramid appears twice (except for the size_a squares).
+        z = rbigint([NULLDIGIT] * (size_a + size_b), 1)
+        i = UDIGIT_TYPE(0)
+        while i < size_a:
+            f = a.widedigit(i)
+            pz = i << 1
+            pa = i + 1
+
+            carry = z.widedigit(pz) + f * f
+            z.setdigit(pz, carry)
+            pz += 1
+            carry >>= SHIFT
+            assert carry <= MASK
+
+            # Now f is added in twice in each column of the
+            # pyramid it appears.  Same as adding f<<1 once.
+            f <<= 1
+            while pa < size_a:
+                carry += z.widedigit(pz) + a.widedigit(pa) * f
+                pa += 1
+                z.setdigit(pz, carry)
+                pz += 1
+                carry >>= SHIFT
+            if carry:
+                carry += z.widedigit(pz)
+                z.setdigit(pz, carry)
+                pz += 1
+                carry >>= SHIFT
+            if carry:
+                z.setdigit(pz, z.widedigit(pz) + carry)
+            assert (carry >> SHIFT) == 0
+            i += 1
+        z._normalize()
+        return z
+    
+    elif digit:
+        if digit & (digit - 1) == 0:
+            return b.lqshift(ptwotable[digit])
+        
+        # Even if it's not power of two it can still be useful.
+        return _muladd1(b, digit)
+        
+    z = rbigint([NULLDIGIT] * (size_a + size_b), 1)
+    # gradeschool long mult
+    i = UDIGIT_TYPE(0)
+    while i < size_a:
+        carry = 0
+        f = a.widedigit(i)
+        pz = i
+        pb = 0
+        while pb < size_b:
+            carry += z.widedigit(pz) + b.widedigit(pb) * f
+            pb += 1
+            z.setdigit(pz, carry)
+            pz += 1
+            carry >>= SHIFT
+            assert carry <= MASK
+        if carry:
+            assert pz >= 0
+            z.setdigit(pz, z.widedigit(pz) + carry)
+        assert (carry >> SHIFT) == 0
+        i += 1
+    z._normalize()
+    return z
 
 def _kmul_split(n, size):
     """
@@ -1429,10 +1508,12 @@
     
     carry = 0
     assert 0 <= d and d < SHIFT
-    for i in range(m):
+    i = 0
+    while i < m:
         acc = a.widedigit(i) << d | carry
         z.setdigit(i, acc)
         carry = acc >> SHIFT
+        i += 1
         
     return carry
 
@@ -1446,10 +1527,12 @@
     mask = (1 << d) - 1
     
     assert 0 <= d and d < SHIFT
-    for i in range(m-1, 0, -1):
+    i = m-1
+    while i >= 0:
         acc = (carry << SHIFT) | a.widedigit(i)
         carry = acc & mask
         z.setdigit(i, acc >> d)
+        i -= 1
         
     return carry
 
@@ -1462,10 +1545,10 @@
     v = rbigint([NULLDIGIT] * (size_v + 1), 1, size_v + 1)
     w = rbigint([NULLDIGIT] * size_w, 1, size_w)
     
-    """/normalize: shift w1 left so that its top digit is >= PyLong_BASE/2.
+    """ normalize: shift w1 left so that its top digit is >= PyLong_BASE/2.
         shift v1 left by the same amount. Results go into w and v. """
         
-    d = SHIFT - bits_in_digit(w1.digit(size_w-1))
+    d = SHIFT - bits_in_digit(w1.digit(abs(size_w-1)))
     carry = _v_lshift(w, w1, size_w, d)
     assert carry == 0
     carry = _v_lshift(v, v1, size_v, d)
@@ -1475,16 +1558,14 @@
         
     """ Now v->ob_digit[size_v-1] < w->ob_digit[size_w-1], so quotient has
         at most (and usually exactly) k = size_v - size_w digits. """
-        
-    size_a = size_v - size_w + 1
-    a = rbigint([NULLDIGIT] * size_a, 1, size_a)
+    k = size_v - size_w
+    assert k > 0
+    a = rbigint([NULLDIGIT] * k, 1, k)
     
-    wm1 = w.widedigit(abs(size_w-1))
+    wm1 = w.digit(abs(size_w-1))
     wm2 = w.widedigit(abs(size_w-2))
 
     j = size_v
-    k = size_a - 1
-    assert k > 0
     while k >= 0:
         assert j >= 0
         """ inner loop: divide vk[0:size_w+1] by w0[0:size_w], giving
@@ -1494,12 +1575,12 @@
         if j >= size_v:
             vtop = 0
         else:
-            vtop = v.widedigit(j)
+            vtop = v.digit(j)
         assert vtop <= wm1
-        vv = (vtop << SHIFT | v.widedigit(abs(j-1)))
-        q = vv / wm1
-        r = vv - (wm1 * q)
-        while wm2 * q > (r << SHIFT | v.widedigit(abs(j-2))):
+        vv = (vtop << SHIFT) | v.widedigit(abs(j-1))
+        q = UDIGIT_MASK(vv / wm1)
+        r = vv - wm1 * q
+        while wm2 * q > ((r << SHIFT) | v.widedigit(abs(j-2))):
             q -= 1
             r += wm1
             if r > MASK:
@@ -1517,24 +1598,21 @@
             i += 1
         
         # add w back if q was too large (this branch taken rarely)
-        assert vtop+zhi == -1 or vtop + zhi == 0
         if vtop + zhi < 0:
-            carry = _widen_digit(0)
+            carry = UDIGIT_TYPE(0)
             i = 0
             while i < size_w:
-                carry += v.widedigit(k+i) + w.widedigit(i)
+                carry += v.udigit(k+i) + w.udigit(i)
                 v.setdigit(k+i, carry)
                 carry >>= SHIFT
                 i += 1
             q -= 1
             
         # store quotient digit
-        a.setdigit(k, q)
         k -= 1
         j -= 1
+        a.setdigit(k, q)
         
-        
-    
     carry = _v_rshift(w, v, size_w, d)
     assert carry == 0
     
@@ -1882,7 +1960,8 @@
                 break
             basebits += 1
 
-        for i in range(size_a):
+        i = 0
+        while i < size_a:
             accum |= a.widedigit(i) << accumbits
             accumbits += SHIFT
             assert accumbits >= basebits
@@ -1899,6 +1978,8 @@
                 else:
                     if accum <= 0:
                         break
+                        
+            i += 1
     else:
         # Not 0, and base not a power of 2.  Divide repeatedly by
         # base, but for speed use the highest power of base that
@@ -2014,8 +2095,8 @@
         size_z = max(size_a, size_b)
 
     z = rbigint([NULLDIGIT] * size_z, 1, size_z)
-
-    for i in range(size_z):
+    i = 0
+    while i < size_z:
         if i < size_a:
             diga = a.digit(i) ^ maska
         else:
@@ -2031,7 +2112,8 @@
             z.setdigit(i, diga | digb)
         elif op == '^':
             z.setdigit(i, diga ^ digb)
-
+        i += 1
+        
     z._normalize()
     if negz == 0:
         return z
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to