Author: Brian Kearns <bdkea...@gmail.com>
Branch: numpy-refactor
Changeset: r69510:b2e159740743
Date: 2014-02-27 03:09 -0500
http://bitbucket.org/pypy/pypy/changeset/b2e159740743/

Log:    use the new iterator where possible

diff --git a/pypy/module/micronumpy/concrete.py 
b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -274,6 +274,17 @@
                              backstrides)
         return loop.setslice(space, self.get_shape(), impl, self)
 
+    def create_iter(self, shape=None, backward_broadcast=False):
+        if shape is not None and \
+                support.product(shape) > support.product(self.get_shape()):
+            r = calculate_broadcast_strides(self.get_strides(),
+                                            self.get_backstrides(),
+                                            self.get_shape(), shape,
+                                            backward_broadcast)
+            return iter.MultiDimViewIterator(self, self.start,
+                                             r[0], r[1], shape)
+        return iter.ArrayIterator(self)
+
     def create_axis_iter(self, shape, dim, cum):
         return iter.AxisIterator(self, shape, dim, cum)
 
@@ -333,26 +344,6 @@
         self.backstrides = backstrides
         self.storage = storage
 
-    def create_iter(self, shape=None, backward_broadcast=False, 
require_index=False):
-        if shape is not None and \
-                support.product(shape) > support.product(self.get_shape()):
-            r = calculate_broadcast_strides(self.get_strides(),
-                                            self.get_backstrides(),
-                                            self.get_shape(), shape,
-                                            backward_broadcast)
-            return iter.MultiDimViewIterator(self, self.start,
-                                             r[0], r[1], shape)
-        if not require_index:
-            return iter.ConcreteArrayIterator(self)
-        if len(self.get_shape()) <= 1:
-            return iter.OneDimViewIterator(self, self.start,
-                                           self.get_strides(),
-                                           self.get_shape())
-        return iter.MultiDimViewIterator(self, self.start,
-                                         self.get_strides(),
-                                         self.get_backstrides(),
-                                         self.get_shape())
-
     def fill(self, space, box):
         self.dtype.itemtype.fill(self.storage, self.dtype.elsize,
                                  box, 0, self.size, 0)
@@ -438,24 +429,6 @@
     def fill(self, space, box):
         loop.fill(self, box.convert_to(space, self.dtype))
 
-    def create_iter(self, shape=None, backward_broadcast=False, 
require_index=False):
-        if shape is not None and \
-                support.product(shape) > support.product(self.get_shape()):
-            r = calculate_broadcast_strides(self.get_strides(),
-                                            self.get_backstrides(),
-                                            self.get_shape(), shape,
-                                            backward_broadcast)
-            return iter.MultiDimViewIterator(self, self.start,
-                                             r[0], r[1], shape)
-        if len(self.get_shape()) <= 1:
-            return iter.OneDimViewIterator(self, self.start,
-                                           self.get_strides(),
-                                           self.get_shape())
-        return iter.MultiDimViewIterator(self, self.start,
-                                         self.get_strides(),
-                                         self.get_backstrides(),
-                                         self.get_shape())
-
     def set_shape(self, space, orig_array, new_shape):
         if len(self.get_shape()) < 2 or self.size == 0:
             # TODO: this code could be refactored into calc_strides
diff --git a/pypy/module/micronumpy/flatiter.py 
b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -19,7 +19,7 @@
     def get_shape(self):
         return self.shape
 
-    def create_iter(self, shape=None, backward_broadcast=False, 
require_index=False):
+    def create_iter(self, shape=None, backward_broadcast=False):
         assert isinstance(self.base(), W_NDimArray)
         return self.base().create_iter()
 
@@ -33,7 +33,6 @@
 
     def reset(self):
         self.iter = self.base.create_iter()
-        self.index = 0
 
     def descr_len(self, space):
         return space.wrap(self.base.get_size())
@@ -43,14 +42,13 @@
             raise OperationError(space.w_StopIteration, space.w_None)
         w_res = self.iter.getitem()
         self.iter.next()
-        self.index += 1
         return w_res
 
     def descr_index(self, space):
-        return space.wrap(self.index)
+        return space.wrap(self.iter.index)
 
     def descr_coords(self, space):
-        coords = self.base.to_coords(space, space.wrap(self.index))
+        coords = self.base.to_coords(space, space.wrap(self.iter.index))
         return space.newtuple([space.wrap(c) for c in coords])
 
     def descr_getitem(self, space, w_idx):
diff --git a/pypy/module/micronumpy/iter.py b/pypy/module/micronumpy/iter.py
--- a/pypy/module/micronumpy/iter.py
+++ b/pypy/module/micronumpy/iter.py
@@ -37,10 +37,9 @@
 All the calculations happen in next()
 
 next_skip_x(steps) tries to do the iteration for a number of steps at once,
-but then we cannot gaurentee that we only overflow one single shape
+but then we cannot guarantee that we only overflow one single shape
 dimension, perhaps we could overflow times in one big step.
 """
-
 from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy import support
 from rpython.rlib import jit
@@ -107,6 +106,11 @@
                 self.indices[i] = 0
                 self.offset -= self.backstrides[i]
 
+    def next_skip_x(self, step):
+        # XXX implement
+        for _ in range(step):
+            self.next()
+
     def done(self):
         return self.index >= self.size
 
@@ -120,70 +124,7 @@
         self.array.setitem(self.offset, elem)
 
 
-class ConcreteArrayIterator(ArrayIterator):
-    _immutable_fields_ = ['array', 'skip', 'size']
-
-    def __init__(self, array):
-        self.array = array
-        self.offset = 0
-        self.skip = array.dtype.elsize
-        self.size = array.size
-
-    def setitem(self, elem):
-        self.array.setitem(self.offset, elem)
-
-    def getitem(self):
-        return self.array.getitem(self.offset)
-
-    def getitem_bool(self):
-        return self.array.getitem_bool(self.offset)
-
-    def next(self):
-        self.offset += self.skip
-
-    def next_skip_x(self, x):
-        self.offset += self.skip * x
-
-    def done(self):
-        return self.offset >= self.size
-
-    def reset(self):
-        self.offset %= self.size
-
-
-class OneDimViewIterator(ConcreteArrayIterator):
-    def __init__(self, array, start, strides, shape):
-        self.array = array
-        self.offset = start
-        self.index = 0
-        assert len(strides) == len(shape)
-        if len(shape) == 0:
-            self.skip = array.dtype.elsize
-            self.size = 1
-        else:
-            assert len(shape) == 1
-            self.skip = strides[0]
-            self.size = shape[0]
-
-    def next(self):
-        self.offset += self.skip
-        self.index += 1
-
-    def next_skip_x(self, x):
-        self.offset += self.skip * x
-        self.index += x
-
-    def done(self):
-        return self.index >= self.size
-
-    def reset(self):
-        self.offset %= self.size
-
-    def get_index(self, d):
-        return self.index
-
-
-class MultiDimViewIterator(ConcreteArrayIterator):
+class MultiDimViewIterator(ArrayIterator):
     def __init__(self, array, start, strides, backstrides, shape):
         self.indexes = [0] * len(shape)
         self.array = array
@@ -232,9 +173,6 @@
     def reset(self):
         self.offset %= self.size
 
-    def get_index(self, d):
-        return self.indexes[d]
-
 
 class AxisIterator(ArrayIterator):
     def __init__(self, array, shape, dim, cumulative):
@@ -258,12 +196,6 @@
         self.dim = dim
         self.array = array
 
-    def setitem(self, elem):
-        self.array.setitem(self.offset, elem)
-
-    def getitem(self):
-        return self.array.getitem(self.offset)
-
     @jit.unroll_safe
     def next(self):
         for i in range(len(self.shape) - 1, -1, -1):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -332,7 +332,7 @@
 
 def nonzero(res, arr, box):
     res_iter = res.create_iter()
-    arr_iter = arr.create_iter(require_index=True)
+    arr_iter = arr.create_iter()
     shapelen = len(arr.shape)
     dtype = arr.dtype
     dims = range(shapelen)
@@ -340,7 +340,7 @@
         nonzero_driver.jit_merge_point(shapelen=shapelen, dims=dims, 
dtype=dtype)
         if arr_iter.getitem_bool():
             for d in dims:
-                res_iter.setitem(box(arr_iter.get_index(d)))
+                res_iter.setitem(box(arr_iter.indices[d]))
                 res_iter.next()
         arr_iter.next()
     return res
@@ -436,8 +436,6 @@
         arr_iter.next_skip_x(step)
         length -= 1
         val_iter.next()
-        # WTF numpy?
-        val_iter.reset()
 
 fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
                                   greens = ['itemsize', 'dtype'],
diff --git a/pypy/module/micronumpy/ndarray.py 
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -280,11 +280,10 @@
         s.append(suffix)
         return s.build()
 
-    def create_iter(self, shape=None, backward_broadcast=False, 
require_index=False):
+    def create_iter(self, shape=None, backward_broadcast=False):
         assert isinstance(self.implementation, BaseConcreteArray)
         return self.implementation.create_iter(
-            shape=shape, backward_broadcast=backward_broadcast,
-            require_index=require_index)
+            shape=shape, backward_broadcast=backward_broadcast)
 
     def create_axis_iter(self, shape, dim, cum):
         return self.implementation.create_axis_iter(shape, dim, cum)
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to