Author: Brian Kearns <[email protected]>
Branch: 
Changeset: r73855:a61621a9a9fd
Date: 2014-10-08 22:20 -0400
http://bitbucket.org/pypy/pypy/changeset/a61621a9a9fd/

Log:    add track_index flag to iterator, use to save some operations

diff --git a/pypy/module/micronumpy/iterators.py 
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -88,7 +88,10 @@
 
 class ArrayIter(object):
     _immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 
'shape_m1[*]',
-                          'strides[*]', 'backstrides[*]', 'factors[*]']
+                          'strides[*]', 'backstrides[*]', 'factors[*]',
+                          'track_index']
+
+    track_index = True
 
     def __init__(self, array, size, shape, strides, backstrides):
         assert len(shape) == len(strides) == len(backstrides)
@@ -126,7 +129,9 @@
     @jit.unroll_safe
     def next(self, state):
         assert state.iterator is self
-        index = state.index + 1
+        index = state.index
+        if self.track_index:
+            index += 1
         indices = state.indices
         offset = state.offset
         if self.contiguous:
@@ -158,6 +163,7 @@
     @jit.unroll_safe
     def update(self, state):
         assert state.iterator is self
+        assert self.track_index
         if not self.contiguous:
             return state
         current = state.index
@@ -172,6 +178,7 @@
 
     def done(self, state):
         assert state.iterator is self
+        assert self.track_index
         return state.index >= self.size
 
     def getitem(self, state):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -48,6 +48,7 @@
     left_iter, left_state = w_lhs.create_iter(shape)
     right_iter, right_state = w_rhs.create_iter(shape)
     out_iter, out_state = out.create_iter(shape)
+    left_iter.track_index = right_iter.track_index = False
     shapelen = len(shape)
     while not out_iter.done(out_state):
         call2_driver.jit_merge_point(shapelen=shapelen, func=func,
@@ -182,6 +183,9 @@
             iter, state = y_iter, y_state
     else:
         iter, state = x_iter, x_state
+    out_iter.track_index = x_iter.track_index = False
+    arr_iter.track_index = y_iter.track_index = False
+    iter.track_index = True
     shapelen = len(shape)
     while not iter.done(state):
         where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
@@ -299,6 +303,7 @@
     assert left_shape[-1] == right_shape[right_critical_dim]
     assert result.get_dtype() == dtype
     outi, outs = result.create_iter()
+    outi.track_index = False
     lefti = AllButAxisIter(left_impl, len(left_shape) - 1)
     righti = AllButAxisIter(right_impl, right_critical_dim)
     lefts = lefti.reset()
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -105,7 +105,7 @@
             'guard_false': 1,
             'guard_not_invalidated': 1,
             'guard_true': 1,
-            'int_add': 7,
+            'int_add': 5,
             'int_ge': 1,
             'int_lt': 1,
             'jump': 1,
@@ -134,7 +134,7 @@
             'guard_false': 4,
             'guard_not_invalidated': 1,
             'guard_true': 3,
-            'int_add': 7,
+            'int_add': 5,
             'int_ge': 1,
             'int_is_true': 1,
             'int_lt': 1,
@@ -163,7 +163,7 @@
             'guard_false': 1,
             'guard_not_invalidated': 1,
             'guard_true': 1,
-            'int_add': 7,
+            'int_add': 5,
             'int_ge': 1,
             'int_lt': 1,
             'jump': 1,
@@ -388,7 +388,7 @@
             'guard_false': 1,
             'guard_not_invalidated': 1,
             'guard_true': 2,
-            'int_add': 8,
+            'int_add': 6,
             'int_ge': 1,
             'int_lt': 2,
             'jump': 1,
@@ -521,7 +521,7 @@
             'float_add': 1,
             'guard_false': 1,
             'guard_not_invalidated': 1,
-            'int_add': 6,
+            'int_add': 4,
             'int_ge': 1,
             'jump': 1,
             'raw_load': 2,
@@ -602,12 +602,12 @@
             'float_mul': 2,
             'getarrayitem_gc': 4,
             'getarrayitem_gc_pure': 9,
-            'getfield_gc_pure': 46,
+            'getfield_gc_pure': 49,
             'guard_class': 4,
-            'guard_false': 12,
+            'guard_false': 13,
             'guard_not_invalidated': 2,
-            'guard_true': 12,
-            'int_add': 18,
+            'guard_true': 14,
+            'int_add': 17,
             'int_ge': 4,
             'int_is_true': 3,
             'int_le': 5,
@@ -651,7 +651,7 @@
             'guard_false': 1,
             'guard_not_invalidated': 1,
             'guard_true': 1,
-            'int_add': 8,
+            'int_add': 5,
             'int_ge': 1,
             'jump': 1,
             'raw_load': 2,
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to