Author: Brian Kearns <bdkea...@gmail.com>
Branch: 
Changeset: r74812:e9c67f6fba33
Date: 2014-12-04 12:23 -0500
http://bitbucket.org/pypy/pypy/changeset/e9c67f6fba33/

Log:    avoid tracking an iterator index in axis_reduce

diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -239,10 +239,9 @@
             state = x_state
     return out
 
-axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
-                                    greens=['shapelen',
-                                            'func', 'dtype'],
-                                    reds='auto')
+axis_reduce_driver = jit.JitDriver(name='numpy_axis_reduce',
+                                   greens=['shapelen', 'func', 'dtype'],
+                                   reds='auto')
 
 def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity, 
cumulative,
                    temp):
@@ -255,14 +254,16 @@
         temp_iter = out_iter  # hack
         temp_state = out_state
     arr_iter, arr_state = arr.create_iter()
+    arr_iter.track_index = False
     if identity is not None:
         identity = identity.convert_to(space, dtype)
     shapelen = len(shape)
     while not out_iter.done(out_state):
-        axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
-                                            dtype=dtype)
-        assert not arr_iter.done(arr_state)
+        axis_reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
+                                           dtype=dtype)
         w_val = arr_iter.getitem(arr_state).convert_to(space, dtype)
+        arr_state = arr_iter.next(arr_state)
+
         out_indices = out_iter.indices(out_state)
         if out_indices[axis] == 0:
             if identity is not None:
@@ -270,6 +271,7 @@
         else:
             cur = temp_iter.getitem(temp_state)
             w_val = func(dtype, cur, w_val)
+
         out_iter.setitem(out_state, w_val)
         out_state = out_iter.next(out_state)
         if cumulative:
@@ -277,7 +279,6 @@
             temp_state = temp_iter.next(temp_state)
         else:
             temp_state = out_state
-        arr_state = arr_iter.next(arr_state)
     return out
 
 
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -197,7 +197,7 @@
             'guard_false': 2,
             'guard_not_invalidated': 1,
             'guard_true': 1,
-            'int_add': 5,
+            'int_add': 4,
             'int_ge': 1,
             'int_is_zero': 1,
             'int_lt': 1,
@@ -212,13 +212,13 @@
             'getarrayitem_gc_pure': 7,
             'getfield_gc_pure': 56,
             'guard_class': 3,
-            'guard_false': 11,
-            'guard_nonnull': 8,
+            'guard_false': 12,
+            'guard_nonnull': 11,
             'guard_nonnull_class': 3,
             'guard_not_invalidated': 2,
-            'guard_true': 12,
-            'guard_value': 4,
-            'int_add': 17,
+            'guard_true': 10,
+            'guard_value': 5,
+            'int_add': 13,
             'int_ge': 4,
             'int_is_true': 4,
             'int_is_zero': 4,
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to