Author: Brian Kearns <bdkea...@gmail.com>
Branch: 
Changeset: r69521:aab294de242a
Date: 2014-02-27 11:29 -0500
http://bitbucket.org/pypy/pypy/changeset/aab294de242a/

Log:    kill duplicate code in AxisIterator, add AllButAxisIterator

diff --git a/pypy/module/micronumpy/concrete.py 
b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -283,9 +283,10 @@
                                             self.get_backstrides(),
                                             self.get_shape(), shape,
                                             backward_broadcast)
-            return iter.ArrayIterator(self, shape, r[0], r[1])
-        return iter.ArrayIterator(self, self.shape, self.strides,
-                                                    self.backstrides)
+            return iter.ArrayIterator(self, support.product(shape), shape,
+                                      r[0], r[1])
+        return iter.ArrayIterator(self, self.get_size(), self.shape,
+                                  self.strides, self.backstrides)
 
     def create_axis_iter(self, shape, dim, cum):
         return iter.AxisIterator(self, shape, dim, cum)
@@ -293,7 +294,8 @@
     def create_dot_iter(self, shape, skip):
         r = calculate_dot_strides(self.get_strides(), self.get_backstrides(),
                                   shape, skip)
-        return iter.ArrayIterator(self, shape, r[0], r[1])
+        return iter.ArrayIterator(self, support.product(shape), shape,
+                                  r[0], r[1])
 
     def swapaxes(self, space, orig_arr, axis1, axis2):
         shape = self.get_shape()[:]
diff --git a/pypy/module/micronumpy/iter.py b/pypy/module/micronumpy/iter.py
--- a/pypy/module/micronumpy/iter.py
+++ b/pypy/module/micronumpy/iter.py
@@ -82,10 +82,11 @@
     _immutable_fields_ = ['array', 'start', 'size', 'ndim_m1', 'shape_m1',
                           'strides', 'backstrides']
 
-    def __init__(self, array, shape, strides, backstrides):
+    def __init__(self, array, size, shape, strides, backstrides):
+        assert len(shape) == len(strides) == len(backstrides)
         self.array = array
         self.start = array.start
-        self.size = support.product(shape)
+        self.size = size
         self.ndim_m1 = len(shape) - 1
         self.shape_m1 = [s - 1 for s in shape]
         self.strides = strides
@@ -141,44 +142,25 @@
         self.array.setitem(self.offset, elem)
 
 
-class AxisIterator(ArrayIterator):
-    def __init__(self, array, shape, dim, cumulative):
-        self.shape = shape
-        strides = array.get_strides()
-        backstrides = array.get_backstrides()
-        if cumulative:
-            self.strides = strides
-            self.backstrides = backstrides
-        elif len(shape) == len(strides):
+def AxisIterator(array, shape, axis, cumulative):
+    strides = array.get_strides()
+    backstrides = array.get_backstrides()
+    if not cumulative:
+        if len(shape) == len(strides):
             # keepdims = True
-            self.strides = strides[:dim] + [0] + strides[dim + 1:]
-            self.backstrides = backstrides[:dim] + [0] + backstrides[dim + 1:]
+            strides = strides[:axis] + [0] + strides[axis + 1:]
+            backstrides = backstrides[:axis] + [0] + backstrides[axis + 1:]
         else:
-            self.strides = strides[:dim] + [0] + strides[dim:]
-            self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
-        self.first_line = True
-        self.indices = [0] * len(shape)
-        self._done = array.get_size() == 0
-        self.offset = array.start
-        self.dim = dim
-        self.array = array
+            strides = strides[:axis] + [0] + strides[axis:]
+            backstrides = backstrides[:axis] + [0] + backstrides[axis:]
+    return ArrayIterator(array, support.product(shape), shape, strides, 
backstrides)
 
-    @jit.unroll_safe
-    def next(self):
-        for i in range(len(self.shape) - 1, -1, -1):
-            if self.indices[i] < self.shape[i] - 1:
-                if i == self.dim:
-                    self.first_line = False
-                self.indices[i] += 1
-                self.offset += self.strides[i]
-                break
-            else:
-                if i == self.dim:
-                    self.first_line = True
-                self.indices[i] = 0
-                self.offset -= self.backstrides[i]
-        else:
-            self._done = True
 
-    def done(self):
-        return self._done
+def AllButAxisIterator(array, axis):
+    size = array.get_size()
+    shape = array.get_shape()[:]
+    backstrides = array.backstrides[:]
+    if size:
+        size /= shape[axis]
+    shape[axis] = backstrides[axis] = 0
+    return ArrayIterator(array, size, shape, array.strides, backstrides)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -215,16 +215,14 @@
     while not out_iter.done():
         axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
                                             dtype=dtype)
-        if arr_iter.done():
-            w_val = identity
+        assert not arr_iter.done()
+        w_val = arr_iter.getitem().convert_to(space, dtype)
+        if out_iter.indices[axis] == 0:
+            if identity is not None:
+                w_val = func(dtype, identity, w_val)
         else:
-            w_val = arr_iter.getitem().convert_to(space, dtype)
-            if out_iter.first_line:
-                if identity is not None:
-                    w_val = func(dtype, identity, w_val)
-            else:
-                cur = temp_iter.getitem()
-                w_val = func(dtype, cur, w_val)
+            cur = temp_iter.getitem()
+            w_val = func(dtype, cur, w_val)
         out_iter.setitem(w_val)
         if cumulative:
             temp_iter.setitem(w_val)
diff --git a/pypy/module/micronumpy/sort.py b/pypy/module/micronumpy/sort.py
--- a/pypy/module/micronumpy/sort.py
+++ b/pypy/module/micronumpy/sort.py
@@ -11,7 +11,7 @@
 from rpython.rtyper.lltypesystem import rffi, lltype
 from pypy.module.micronumpy import descriptor, types, constants as NPY
 from pypy.module.micronumpy.base import W_NDimArray
-from pypy.module.micronumpy.iter import AxisIterator
+from pypy.module.micronumpy.iter import AllButAxisIterator
 
 INT_SIZE = rffi.sizeof(lltype.Signed)
 
@@ -146,21 +146,20 @@
             if axis < 0 or axis >= len(shape):
                 raise OperationError(space.w_IndexError, space.wrap(
                                                     "Wrong axis %d" % axis))
-            iterable_shape = shape[:axis] + [0] + shape[axis + 1:]
-            iter = AxisIterator(arr, iterable_shape, axis, False)
+            arr_iter = AllButAxisIterator(arr, axis)
             index_impl = index_arr.implementation
-            index_iter = AxisIterator(index_impl, iterable_shape, axis, False)
+            index_iter = AllButAxisIterator(index_impl, axis)
             stride_size = arr.strides[axis]
             index_stride_size = index_impl.strides[axis]
             axis_size = arr.shape[axis]
-            while not iter.done():
+            while not arr_iter.done():
                 for i in range(axis_size):
                     raw_storage_setitem(storage, i * index_stride_size +
                                         index_iter.offset, i)
                 r = Repr(index_stride_size, stride_size, axis_size,
-                         arr.get_storage(), storage, index_iter.offset, 
iter.offset)
+                         arr.get_storage(), storage, index_iter.offset, 
arr_iter.offset)
                 ArgSort(r).sort()
-                iter.next()
+                arr_iter.next()
                 index_iter.next()
         return index_arr
 
@@ -292,14 +291,13 @@
             if axis < 0 or axis >= len(shape):
                 raise OperationError(space.w_IndexError, space.wrap(
                                                     "Wrong axis %d" % axis))
-            iterable_shape = shape[:axis] + [0] + shape[axis + 1:]
-            iter = AxisIterator(arr, iterable_shape, axis, False)
+            arr_iter = AllButAxisIterator(arr, axis)
             stride_size = arr.strides[axis]
             axis_size = arr.shape[axis]
-            while not iter.done():
-                r = Repr(stride_size, axis_size, arr.get_storage(), 
iter.offset)
+            while not arr_iter.done():
+                r = Repr(stride_size, axis_size, arr.get_storage(), 
arr_iter.offset)
                 ArgSort(r).sort()
-                iter.next()
+                arr_iter.next()
 
     return sort
 
diff --git a/pypy/module/micronumpy/test/test_iter.py 
b/pypy/module/micronumpy/test/test_iter.py
--- a/pypy/module/micronumpy/test/test_iter.py
+++ b/pypy/module/micronumpy/test/test_iter.py
@@ -1,8 +1,8 @@
+from pypy.module.micronumpy import support
 from pypy.module.micronumpy.iter import ArrayIterator
 
 
 class MockArray(object):
-    size = 1
     start = 0
 
 
@@ -14,7 +14,8 @@
         strides = [5, 1]
         backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
         assert backstrides == [10, 4]
-        i = ArrayIterator(MockArray, shape, strides, backstrides)
+        i = ArrayIterator(MockArray, support.product(shape), shape,
+                          strides, backstrides)
         i.next()
         i.next()
         i.next()
@@ -32,7 +33,8 @@
         strides = [1, 3]
         backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
         assert backstrides == [2, 12]
-        i = ArrayIterator(MockArray, shape, strides, backstrides)
+        i = ArrayIterator(MockArray, support.product(shape), shape,
+                          strides, backstrides)
         i.next()
         i.next()
         i.next()
@@ -52,7 +54,8 @@
         strides = [5, 1]
         backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
         assert backstrides == [10, 4]
-        i = ArrayIterator(MockArray, shape, strides, backstrides)
+        i = ArrayIterator(MockArray, support.product(shape), shape,
+                          strides, backstrides)
         i.next_skip_x(2)
         i.next_skip_x(2)
         i.next_skip_x(2)
@@ -75,7 +78,8 @@
         strides = [1, 3]
         backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
         assert backstrides == [2, 12]
-        i = ArrayIterator(MockArray, shape, strides, backstrides)
+        i = ArrayIterator(MockArray, support.product(shape), shape,
+                          strides, backstrides)
         i.next_skip_x(2)
         i.next_skip_x(2)
         i.next_skip_x(2)
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -237,6 +237,10 @@
                 dtype = out.get_dtype()
             else:
                 out = W_NDimArray.from_shape(space, shape, dtype, 
w_instance=obj)
+            if obj.get_size() == 0:
+                if self.identity is not None:
+                    out.fill(space, self.identity.convert_to(space, dtype))
+                return out
             return loop.do_axis_reduce(space, shape, self.func, obj, dtype, 
axis, out,
                                        self.identity, cumulative, temp)
         if cumulative:
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to