Author: Brian Kearns <[email protected]>
Branch: 
Changeset: r73853:c9577be74d54
Date: 2014-10-08 20:21 -0400
http://bitbucket.org/pypy/pypy/changeset/c9577be74d54/

Log:    optimize iterator next if array is contiguous

diff --git a/pypy/module/micronumpy/flatiter.py 
b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -45,6 +45,7 @@
         return space.wrap(self.state.index)
 
     def descr_coords(self, space):
+        self.state = self.iter.update(self.state)
         return space.newtuple([space.wrap(c) for c in self.state.indices])
 
     def descr_getitem(self, space, w_idx):
diff --git a/pypy/module/micronumpy/iterators.py 
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -93,7 +93,8 @@
     def __init__(self, array, size, shape, strides, backstrides):
         assert len(shape) == len(strides) == len(backstrides)
         _update_contiguous_flags(array)
-        self.contiguous = array.flags & NPY.ARRAY_C_CONTIGUOUS
+        self.contiguous = (array.flags & NPY.ARRAY_C_CONTIGUOUS and
+                           array.shape == shape and array.strides == strides)
 
         self.array = array
         self.size = size
@@ -128,15 +129,18 @@
         index = state.index + 1
         indices = state.indices
         offset = state.offset
-        for i in xrange(self.ndim_m1, -1, -1):
-            idx = indices[i]
-            if idx < self.shape_m1[i]:
-                indices[i] = idx + 1
-                offset += self.strides[i]
-                break
-            else:
-                indices[i] = 0
-                offset -= self.backstrides[i]
+        if self.contiguous:
+            offset += self.array.dtype.elsize
+        else:
+            for i in xrange(self.ndim_m1, -1, -1):
+                idx = indices[i]
+                if idx < self.shape_m1[i]:
+                    indices[i] = idx + 1
+                    offset += self.strides[i]
+                    break
+                else:
+                    indices[i] = 0
+                    offset -= self.backstrides[i]
         return IterState(self, index, indices, offset)
 
     @jit.unroll_safe
@@ -151,6 +155,21 @@
                 current %= self.factors[i]
         return IterState(self, index, None, offset)
 
+    @jit.unroll_safe
+    def update(self, state):
+        assert state.iterator is self
+        if not self.contiguous:
+            return state
+        current = state.index
+        indices = state.indices
+        for i in xrange(len(self.shape_m1)):
+            if self.factors[i] != 0:
+                indices[i] = current / self.factors[i]
+                current %= self.factors[i]
+            else:
+                indices[i] = 0
+        return IterState(self, state.index, indices, state.offset)
+
     def done(self, state):
         assert state.iterator is self
         return state.index >= self.size
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -229,6 +229,7 @@
                                             dtype=dtype)
         assert not arr_iter.done(arr_state)
         w_val = arr_iter.getitem(arr_state).convert_to(space, dtype)
+        out_state = out_iter.update(out_state)
         if out_state.indices[axis] == 0:
             if identity is not None:
                 w_val = func(dtype, identity, w_val)
@@ -360,6 +361,7 @@
     while not arr_iter.done(arr_state):
         nonzero_driver.jit_merge_point(shapelen=shapelen, dims=dims, 
dtype=dtype)
         if arr_iter.getitem_bool(arr_state):
+            arr_state = arr_iter.update(arr_state)
             for d in dims:
                 res_iter.setitem(res_state, box(arr_state.indices[d]))
                 res_state = res_iter.next(res_state)
@@ -453,9 +455,10 @@
         else:
             val = val.convert_to(space, dtype)
         arr_iter.setitem(arr_state, val)
-        # need to repeat i_nput values until all assignments are done
         arr_state = arr_iter.goto(arr_state.index + step)
         val_state = val_iter.next(val_state)
+        if val_iter.done(val_state):
+            val_state = val_iter.reset(val_state)
         length -= 1
 
 fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -313,6 +313,7 @@
         # create an iterator for each operand
         for i in range(len(self.seq)):
             it = get_iter(space, self.order, self.seq[i], iter_shape, 
self.dtypes[i])
+            it.contiguous = False
             self.iters.append((it, it.reset()))
 
     def set_op_axes(self, space, w_op_axes):
diff --git a/pypy/module/micronumpy/test/test_iterators.py 
b/pypy/module/micronumpy/test/test_iterators.py
--- a/pypy/module/micronumpy/test/test_iterators.py
+++ b/pypy/module/micronumpy/test/test_iterators.py
@@ -3,7 +3,15 @@
 
 
 class MockArray(object):
-    start = 0
+    flags = 0
+
+    class dtype:
+        elsize = 1
+
+    def __init__(self, shape, strides, start=0):
+        self.shape = shape
+        self.strides = strides
+        self.start = start
 
 
 class TestIterDirect(object):
@@ -14,19 +22,24 @@
         strides = [5, 1]
         backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
         assert backstrides == [10, 4]
-        i = ArrayIter(MockArray, support.product(shape), shape,
+        i = ArrayIter(MockArray(shape, strides), support.product(shape), shape,
                       strides, backstrides)
+        assert i.contiguous
         s = i.reset()
         s = i.next(s)
         s = i.next(s)
         s = i.next(s)
         assert s.offset == 3
         assert not i.done(s)
+        assert s.indices == [0,0]
+        s = i.update(s)
         assert s.indices == [0,3]
         #cause a dimension overflow
         s = i.next(s)
         s = i.next(s)
         assert s.offset == 5
+        assert s.indices == [0,3]
+        s = i.update(s)
         assert s.indices == [1,0]
 
         #Now what happens if the array is transposed? strides[-1] != 1
@@ -34,8 +47,9 @@
         strides = [1, 3]
         backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
         assert backstrides == [2, 12]
-        i = ArrayIter(MockArray, support.product(shape), shape,
+        i = ArrayIter(MockArray(shape, strides), support.product(shape), shape,
                       strides, backstrides)
+        assert not i.contiguous
         s = i.reset()
         s = i.next(s)
         s = i.next(s)
@@ -54,10 +68,10 @@
         strides = [1, 3]
         backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
         assert backstrides == [2, 12]
-        a = MockArray()
-        a.start = 42
+        a = MockArray(shape, strides, 42)
         i = ArrayIter(a, support.product(shape), shape,
                       strides, backstrides)
+        assert not i.contiguous
         s = i.reset()
         assert s.index == 0
         assert s.indices == [0, 0]
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -101,17 +101,17 @@
         self.check_trace_count(1)
         self.check_simple_loop({
             'float_add': 1,
-            'getarrayitem_gc': 3,
+            'getarrayitem_gc': 1,
             'guard_false': 1,
             'guard_not_invalidated': 1,
-            'guard_true': 3,
-            'int_add': 9,
+            'guard_true': 1,
+            'int_add': 7,
             'int_ge': 1,
-            'int_lt': 3,
+            'int_lt': 1,
             'jump': 1,
             'raw_load': 2,
             'raw_store': 1,
-            'setarrayitem_gc': 3,
+            'setarrayitem_gc': 1,
         })
 
     def define_pow():
@@ -130,18 +130,18 @@
             'float_eq': 3,
             'float_mul': 2,
             'float_ne': 1,
-            'getarrayitem_gc': 3,
+            'getarrayitem_gc': 1,
             'guard_false': 4,
             'guard_not_invalidated': 1,
-            'guard_true': 5,
-            'int_add': 9,
+            'guard_true': 3,
+            'int_add': 7,
             'int_ge': 1,
             'int_is_true': 1,
-            'int_lt': 3,
+            'int_lt': 1,
             'jump': 1,
             'raw_load': 2,
             'raw_store': 1,
-            'setarrayitem_gc': 3,
+            'setarrayitem_gc': 1,
         })
 
     def define_pow_int():
@@ -159,17 +159,17 @@
         del get_stats().loops[0]   # we don't care about it
         self.check_simple_loop({
             'call': 1,
-            'getarrayitem_gc': 3,
+            'getarrayitem_gc': 1,
             'guard_false': 1,
             'guard_not_invalidated': 1,
-            'guard_true': 3,
-            'int_add': 9,
+            'guard_true': 1,
+            'int_add': 7,
             'int_ge': 1,
-            'int_lt': 3,
+            'int_lt': 1,
             'jump': 1,
             'raw_load': 2,
             'raw_store': 1,
-            'setarrayitem_gc': 3,
+            'setarrayitem_gc': 1,
         })
 
     def define_sum():
@@ -384,17 +384,17 @@
         self.check_trace_count(1)
         self.check_simple_loop({
             'float_add': 1,
-            'getarrayitem_gc': 3,
+            'getarrayitem_gc': 2,
             'guard_false': 1,
             'guard_not_invalidated': 1,
-            'guard_true': 3,
-            'int_add': 9,
+            'guard_true': 2,
+            'int_add': 8,
             'int_ge': 1,
-            'int_lt': 3,
+            'int_lt': 2,
             'jump': 1,
             'raw_load': 2,
             'raw_store': 1,
-            'setarrayitem_gc': 3,
+            'setarrayitem_gc': 2,
         })
 
     def define_take():
@@ -519,17 +519,13 @@
         self.check_trace_count(1)
         self.check_simple_loop({
             'float_add': 1,
-            'getarrayitem_gc': 3,
             'guard_false': 1,
             'guard_not_invalidated': 1,
-            'guard_true': 3,
-            'int_add': 9,
+            'int_add': 6,
             'int_ge': 1,
-            'int_lt': 3,
             'jump': 1,
             'raw_load': 2,
             'raw_store': 1,
-            'setarrayitem_gc': 3,
         })
 
     def define_flat_getitem():
@@ -544,17 +540,13 @@
         assert result == 10.0
         self.check_trace_count(1)
         self.check_simple_loop({
-            'getarrayitem_gc': 1,
             'guard_false': 1,
-            'guard_true': 1,
-            'int_add': 5,
+            'int_add': 4,
             'int_ge': 1,
-            'int_lt': 1,
             'int_mul': 1,
             'jump': 1,
             'raw_load': 1,
             'raw_store': 1,
-            'setarrayitem_gc': 1,
         })
 
     def define_flat_setitem():
@@ -570,18 +562,17 @@
         assert result == 1.0
         self.check_trace_count(1)
         self.check_simple_loop({
-            'getarrayitem_gc': 1,
+            'guard_false': 1,
             'guard_not_invalidated': 1,
-            'guard_true': 2,
-            'int_add': 5,
+            'guard_true': 1,
+            'int_add': 4,
+            'int_ge': 1,
             'int_gt': 1,
-            'int_lt': 1,
             'int_mul': 1,
             'int_sub': 1,
             'jump': 1,
             'raw_load': 1,
             'raw_store': 1,
-            'setarrayitem_gc': 1,
         })
 
     def define_dot():
@@ -609,24 +600,25 @@
         self.check_resops({
             'float_add': 2,
             'float_mul': 2,
-            'getarrayitem_gc': 7,
-            'getarrayitem_gc_pure': 15,
-            'getfield_gc_pure': 52,
+            'getarrayitem_gc': 4,
+            'getarrayitem_gc_pure': 9,
+            'getfield_gc_pure': 46,
             'guard_class': 4,
-            'guard_false': 14,
+            'guard_false': 12,
             'guard_not_invalidated': 2,
-            'guard_true': 13,
-            'int_add': 25,
+            'guard_true': 12,
+            'int_add': 18,
             'int_ge': 4,
-            'int_le': 8,
-            'int_lt': 11,
-            'int_sub': 4,
+            'int_is_true': 3,
+            'int_le': 5,
+            'int_lt': 8,
+            'int_sub': 3,
             'jump': 3,
             'new_with_vtable': 7,
             'raw_load': 6,
             'raw_store': 1,
             'same_as': 2,
-            'setarrayitem_gc': 10,
+            'setarrayitem_gc': 7,
             'setfield_gc': 22,
         })
 
@@ -656,15 +648,12 @@
         self.check_trace_count(1)
         self.check_simple_loop({
             'float_ne': 1,
-            'getarrayitem_gc': 4,
             'guard_false': 1,
             'guard_not_invalidated': 1,
-            'guard_true': 5,
-            'int_add': 12,
+            'guard_true': 1,
+            'int_add': 8,
             'int_ge': 1,
-            'int_lt': 4,
             'jump': 1,
             'raw_load': 2,
             'raw_store': 1,
-            'setarrayitem_gc': 4,
         })
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to