Author: mattip
Branch: numpypy-axisops
Changeset: r50845:3cfc0b93cb23
Date: 2011-12-25 00:10 +0200
http://bitbucket.org/pypy/pypy/changeset/3cfc0b93cb23/

Log:    test, implement improved AxisIterator

diff --git a/pypy/module/micronumpy/interp_iter.py 
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -106,16 +106,24 @@
 # ------ other iterators that are not part of the computation frame ----------
 
 class AxisIterator(object):
-    """ This object will return offsets of each start of the last stride
+    """ This object will return offsets of each start of a stride on the 
+        desired dimension, starting at the desired index
     """
-    def __init__(self, arr):
+    def __init__(self, arr, dim=-1, start=[]):
         self.arr = arr
-        self.indices = [0] * (len(arr.shape) - 1)
+        self.indices = [0] * len(arr.shape)
         self.done = False
         self.offset = arr.start
-
+        self.dim = len(arr.shape) - 1
+        if dim >= 0:
+            self.dim = dim
+        if len(start) == len(arr.shape):
+            for i in range(len(start)):
+                self.offset += arr.strides[i] * start[i]
     def next(self):
-        for i in range(len(self.arr.shape) - 2, -1, -1):
+        for i in range(len(self.arr.shape) - 1, -1, -1):
+            if i == self.dim:
+                continue
             if self.indices[i] < self.arr.shape[i] - 1:
                 self.indices[i] += 1
                 self.offset += self.arr.strides[i]
diff --git a/pypy/module/micronumpy/test/test_iterators.py 
b/pypy/module/micronumpy/test/test_iterators.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/test/test_iterators.py
@@ -0,0 +1,51 @@
+
+from pypy.module.micronumpy.interp_iter import AxisIterator
+from pypy.module.micronumpy.interp_numarray import W_NDimArray
+
+class MockDtype(object):
+    def malloc(self, size):
+        return None
+
+class TestAxisIteratorDirect(object):
+    def test_axis_iterator(self):
+        a = W_NDimArray(7*5*3, [7, 5, 3], MockDtype(), 'C')
+        i = AxisIterator(a)
+        ret = []
+        while not i.done:
+            ret.append(i.offset)
+            i.next()
+        assert ret == [3*v for v in range(7*5)]
+        i = AxisIterator(a,2)
+        ret = []
+        while not i.done:
+            ret.append(i.offset)
+            i.next()
+        assert ret == [3*v for v in range(7*5)]
+        i = AxisIterator(a,1)
+        ret = []
+        while not i.done:
+            ret.append(i.offset)
+            i.next()
+        assert ret == [ 0,  1,  2, 15, 16, 17, 30, 31, 32, 45, 46, 47,
+                       60, 61, 62, 75, 76, 77, 90, 91, 92] 
+    def test_axis_iterator_with_start(self):
+        a = W_NDimArray(7*5*3, [7, 5, 3], MockDtype(), 'C')
+        i = AxisIterator(a, start=[0, 0, 0])
+        ret = []
+        while not i.done:
+            ret.append(i.offset)
+            i.next()
+        assert ret == [3*v for v in range(7*5)]
+        i = AxisIterator(a, start=[1, 1, 0])
+        ret = []
+        while not i.done:
+            ret.append(i.offset)
+            i.next()
+        assert ret == [3*v+18 for v in range(7*5)]
+        i = AxisIterator(a, 1, [2, 0, 2])
+        ret = []
+        while not i.done:
+            ret.append(i.offset)
+            i.next()
+        assert ret == [v + 32 for v in [ 0,  1,  2, 15, 16, 17, 30, 31, 32,
+                            45, 46, 47, 60, 61, 62, 75, 76, 77, 90, 91, 92]]
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to