Author: stian
Branch: improve-rbigint
Changeset: r56352:0abcf5b8aaba
Date: 2012-07-06 20:01 +0200
http://bitbucket.org/pypy/pypy/changeset/0abcf5b8aaba/

Log:    Fix one test, fix so a few tests no longer fails (divrem fails for
        some reason, I don't understand why). Optimize mod() and fix issue
        with lshift and fix translation (for some reason the last commit
        failed today, but worked last night hehe)

diff --git a/pypy/rlib/rbigint.py b/pypy/rlib/rbigint.py
--- a/pypy/rlib/rbigint.py
+++ b/pypy/rlib/rbigint.py
@@ -151,23 +151,24 @@
         """Return the x'th digit, as an int."""
         return self._digits[x]
     digit._always_inline_ = True
-    
+    digit._annonforceargs_ = [None, r_uint] # These are necessary because x 
can't always be proven non negative, no matter how hard we try.
     def widedigit(self, x):
         """Return the x'th digit, as a long long int if needed
         to have enough room to contain two digits."""
         return _widen_digit(self._digits[x])
     widedigit._always_inline_ = True
-    
+    widedigit._annonforceargs_ = [None, r_uint]
     def udigit(self, x):
         """Return the x'th digit, as an unsigned int."""
         return _load_unsigned_digit(self._digits[x])
     udigit._always_inline_ = True
-    
+    udigit._annonforceargs_ = [None, r_uint]
     def setdigit(self, x, val):
         val = _mask_digit(val)
         assert val >= 0
         self._digits[x] = _store_digit(val)
     setdigit._annspecialcase_ = 'specialize:argtype(2)'
+    digit._annonforceargs_ = [None, r_uint, None]
     setdigit._always_inline_ = True
 
     def numdigits(self):
@@ -450,23 +451,21 @@
         if a.sign == 0 or b.sign == 0:
             return rbigint()
         
-        
         if asize == 1:
-            digit = a.widedigit(0)
-            if digit == 0:
+            if a._digits[0] == NULLDIGIT:
                 return rbigint()
-            elif digit == 1:
+            elif b._digits[0] == ONEDIGIT:
                 return rbigint(b._digits, a.sign * b.sign)
             elif bsize == 1:
                 result = rbigint([NULLDIGIT] * 2, a.sign * b.sign)
-                carry = b.widedigit(0) * digit
+                carry = b.widedigit(0) * a.widedigit(0)
                 result.setdigit(0, carry)
                 carry >>= SHIFT
                 if carry:
                     result.setdigit(1, carry)
                 return result
                 
-            result =  _x_mul(a, b, digit)
+            result =  _x_mul(a, b, a.digit(0))
         elif USE_TOOMCOCK and asize >= TOOMCOOK_CUTOFF:
             result = _tc_mul(a, b)
         elif USE_KARATSUBA:
@@ -512,7 +511,21 @@
 
     @jit.elidable
     def mod(self, other):
-        div, mod = _divrem(self, other)
+        if other.numdigits() == 1:
+            # Faster.
+            i = 0
+            mod = 0
+            b = other.digit(0) * other.sign
+            while i < self.numdigits():
+                digit = self.digit(i) * self.sign
+                if digit:
+                    mod <<= SHIFT
+                    mod = (mod + digit) % b
+                
+                i += 1
+            mod = rbigint.fromint(mod)
+        else:        
+            div, mod = _divrem(self, other)
         if mod.sign * other.sign == -1:
             mod = mod.add(other)
         return mod
@@ -577,7 +590,7 @@
 
             # if modulus == 1:
             #     return 0
-            if c.numdigits() == 1 and c.digit(0) == 1:
+            if c.numdigits() == 1 and c._digits[0] == ONEDIGIT:
                 return NULLRBIGINT
 
             # if base < 0:
@@ -588,13 +601,13 @@
                 
             
         elif size_b == 1:
-            digit = b.digit(0)
-            if digit == 0:
+            if b._digits[0] == NULLDIGIT:
                 return ONERBIGINT if a.sign == 1 else ONENEGATIVERBIGINT
-            elif digit == 1:
+            elif b._digits[0] == ONEDIGIT:
                 return a
             elif a.numdigits() == 1:
                 adigit = a.digit(0)
+                digit = b.digit(0)
                 if adigit == 1:
                     if a.sign == -1 and digit % 2:
                         return ONENEGATIVERBIGINT
@@ -612,7 +625,7 @@
         
         # python adaptation: moved macros REDUCE(X) and MULT(X, Y, result)
         # into helper function result = _help_mult(x, y, c)
-        if True: #not c or size_b <= FIVEARY_CUTOFF:
+        if not c or size_b <= FIVEARY_CUTOFF:
             # Left-to-right binary exponentiation (HAC Algorithm 14.79)
             # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
             size_b -= 1
@@ -627,7 +640,6 @@
                 size_b -= 1
                 
         else:
-            # XXX: Not working with int128! Yet
             # Left-to-right 5-ary exponentiation (HAC Algorithm 14.82)
             # This is only useful in the case where c != None.
             # z still holds 1L
@@ -662,7 +674,7 @@
                         break # Done
                         
                     size_b -= 1
-
+                    assert size_b >= 0
                     bi = b.udigit(size_b)
                     index = ((accum << (-j)) | (bi >> (j+SHIFT))) & 0x1f
                     accum = bi
@@ -706,11 +718,12 @@
         wordshift = int_other // SHIFT
         remshift  = int_other - wordshift * SHIFT
 
-        oldsize = self.numdigits()
         if not remshift:
             return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign)
-            
-        z = rbigint([NULLDIGIT] * (oldsize + wordshift + 1), self.sign)
+        
+        oldsize = self.numdigits()
+        newsize = oldsize + wordshift + 1
+        z = rbigint([NULLDIGIT] * newsize, self.sign)
         accum = _widen_digit(0)
         i = wordshift
         j = 0
@@ -720,8 +733,10 @@
             accum >>= SHIFT
             i += 1
             j += 1
-            
-        z.setdigit(oldsize, accum)
+        
+        newsize -= 1
+        assert newsize >= 0
+        z.setdigit(newsize, accum)
 
         z._positivenormalize()
         return z
@@ -830,31 +845,31 @@
             self._digits = [NULLDIGIT]
             return
         
-        while i > 1 and self.digit(i - 1) == 0:
+        while i > 1 and self._digits[i - 1] == NULLDIGIT:
             i -= 1
         assert i > 0
         if i != c:
             self._digits = self._digits[:i]
-        if self.numdigits() == 1 and self.digit(0) == 0:
+        if self.numdigits() == 1 and self._digits[0] == NULLDIGIT:
             self.sign = 0
             
-    _normalize._always_inline_ = True
+    #_normalize._always_inline_ = True
     
     def _positivenormalize(self):
         """ This function assumes numdigits > 0. Good for shifts and such """
         i = c = self.numdigits()
-        while i > 1 and self.digit(i - 1) == 0:
+        while i > 1 and self._digits[i - 1] == NULLDIGIT:
             i -= 1
         assert i > 0
         if i != c:
             self._digits = self._digits[:i]
-        if self.numdigits() == 1 and self.digit(0) == 0:
+        if self.numdigits() == 1 and self._digits[0] == NULLDIGIT:
             self.sign = 0
     _positivenormalize._always_inline_ = True
     
     def bit_length(self):
         i = self.numdigits()
-        if i == 1 and self.digit(0) == 0:
+        if i == 1 and self._digits[0] == NULLDIGIT:
             return 0
         msd = self.digit(i - 1)
         msd_bits = 0
@@ -1047,12 +1062,11 @@
         # via exploiting that each entry in the multiplication
         # pyramid appears twice (except for the size_a squares).
         z = rbigint([NULLDIGIT] * (size_a + size_b), 1)
-        i = _load_unsigned_digit(0)
+        i = 0
         while i < size_a:
             f = a.widedigit(i)
             pz = i << 1
             pa = i + 1
-            paend = size_a
 
             carry = z.widedigit(pz) + f * f
             z.setdigit(pz, carry)
@@ -1063,7 +1077,7 @@
             # Now f is added in twice in each column of the
             # pyramid it appears.  Same as adding f<<1 once.
             f <<= 1
-            while pa < paend:
+            while pa < size_a:
                 carry += z.widedigit(pz) + a.widedigit(pa) * f
                 pa += 1
                 z.setdigit(pz, carry)
@@ -1075,8 +1089,8 @@
                 z.setdigit(pz, carry)
                 pz += 1
                 carry >>= SHIFT
-            if carry:
-                z.setdigit(pz, z.widedigit(pz) + carry)
+                if carry:
+                    z.setdigit(pz, z.widedigit(pz) + carry)
             assert (carry >> SHIFT) == 0
             i += 1
         z._positivenormalize()
@@ -1087,7 +1101,7 @@
 
     z = rbigint([NULLDIGIT] * (size_a + size_b), 1)
     # gradeschool long mult
-    i = _load_unsigned_digit(0)
+    i = 0
     while i < size_a:
         carry = 0
         f = a.widedigit(i)
@@ -1101,6 +1115,7 @@
             carry >>= SHIFT
             assert carry <= MASK
         if carry:
+            assert pz >= 0
             z.setdigit(pz, z.widedigit(pz) + carry)
         assert (carry >> SHIFT) == 0
         i += 1
@@ -1550,7 +1565,7 @@
     w = _muladd1(w1, d)
     size_v = v1.numdigits()
     size_w = w1.numdigits()
-    assert size_v >= size_w and size_w >= 1 # (Assert checks by div()
+    assert size_v >= size_w and size_w > 1 # (Assert checks by div()
 
     """v = rbigint([NULLDIGIT] * (size_v + 1))
     w = rbigint([NULLDIGIT] * (size_w))
@@ -1565,12 +1580,13 @@
         
     size_a = size_v - size_w + 1
     a = rbigint([NULLDIGIT] * size_a, 1)
-
-    wm1 = w.widedigit(abs(size_w-1))
-    wm2 = w.widedigit(abs(size_w-2))
-    j = _load_unsigned_digit(size_v)
+    assert size_w >= 2
+    wm1 = w.widedigit(size_w-1)
+    wm2 = w.widedigit(size_w-2)
+    j = size_v
     k = size_a - 1
     while k >= 0:
+        assert j >= 2
         if j >= size_v:
             vj = 0
         else:
@@ -2099,7 +2115,7 @@
             ntostore = power
             rem = _inplace_divrem1(scratch, pin, powbase, size)
             pin = scratch  # no need to use a again
-            if pin.digit(size - 1) == 0:
+            if pin._digits[size - 1] == NULLDIGIT:
                 size -= 1
 
             # Break rem into digits.
diff --git a/pypy/rlib/test/test_rbigint.py b/pypy/rlib/test/test_rbigint.py
--- a/pypy/rlib/test/test_rbigint.py
+++ b/pypy/rlib/test/test_rbigint.py
@@ -360,7 +360,7 @@
                       for i in (10L, 5L, 0L)]
         py.test.raises(ValueError, f1.pow, f2, f3)
         #
-        MAX = 1E40
+        MAX = 1E20
         x = long(random() * MAX) + 1
         y = long(random() * MAX) + 1
         z = long(random() * MAX) + 1
@@ -521,9 +521,9 @@
     def test__x_divrem(self):
         x = 12345678901234567890L
         for i in range(100):
-            y = long(randint(0, 1 << 30))
-            y <<= 30
-            y += randint(0, 1 << 30)
+            y = long(randint(0, 1 << 60))
+            y <<= 60
+            y += randint(0, 1 << 60)
             f1 = rbigint.fromlong(x)
             f2 = rbigint.fromlong(y)
             div, rem = lobj._x_divrem(f1, f2)
@@ -532,9 +532,9 @@
     def test__divrem(self):
         x = 12345678901234567890L
         for i in range(100):
-            y = long(randint(0, 1 << 30))
-            y <<= 30
-            y += randint(0, 1 << 30)
+            y = long(randint(0, 1 << 60))
+            y <<= 60
+            y += randint(0, 1 << 60)
             for sx, sy in (1, 1), (1, -1), (-1, -1), (-1, 1):
                 sx *= x
                 sy *= y
diff --git a/pypy/rpython/lltypesystem/rlist.py 
b/pypy/rpython/lltypesystem/rlist.py
--- a/pypy/rpython/lltypesystem/rlist.py
+++ b/pypy/rpython/lltypesystem/rlist.py
@@ -303,12 +303,12 @@
     return l.items
 
 def ll_getitem_fast(l, index):
-    #ll_assert(index < l.length, "getitem out of bounds")
+    ll_assert(index < l.length, "getitem out of bounds")
     return l.ll_items()[index]
 ll_getitem_fast.oopspec = 'list.getitem(l, index)'
 
 def ll_setitem_fast(l, index, item):
-    #ll_assert(index < l.length, "setitem out of bounds")
+    ll_assert(index < l.length, "setitem out of bounds")
     l.ll_items()[index] = item
 ll_setitem_fast.oopspec = 'list.setitem(l, index, item)'
 
@@ -316,7 +316,7 @@
 
 @typeMethod
 def ll_fixed_newlist(LIST, length):
-    #ll_assert(length >= 0, "negative fixed list length")
+    ll_assert(length >= 0, "negative fixed list length")
     l = malloc(LIST, length)
     return l
 ll_fixed_newlist.oopspec = 'newlist(length)'
@@ -333,12 +333,12 @@
     return l
 
 def ll_fixed_getitem_fast(l, index):
-    #ll_assert(index < len(l), "fixed getitem out of bounds")
+    ll_assert(index < len(l), "fixed getitem out of bounds")
     return l[index]
 ll_fixed_getitem_fast.oopspec = 'list.getitem(l, index)'
 
 def ll_fixed_setitem_fast(l, index, item):
-    #ll_assert(index < len(l), "fixed setitem out of bounds")
+    ll_assert(index < len(l), "fixed setitem out of bounds")
     l[index] = item
 ll_fixed_setitem_fast.oopspec = 'list.setitem(l, index, item)'
 
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to