Author: mattip
Branch: numpypy-axisops
Changeset: r50845:3cfc0b93cb23
Date: 2011-12-25 00:10 +0200
http://bitbucket.org/pypy/pypy/changeset/3cfc0b93cb23/
Log: test, implement improved AxisIterator
diff --git a/pypy/module/micronumpy/interp_iter.py
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -106,16 +106,24 @@
# ------ other iterators that are not part of the computation frame ----------
class AxisIterator(object):
- """ This object will return offsets of each start of the last stride
+ """ This object will return offsets of each start of a stride on the
+ desired dimension, starting at the desired index
"""
- def __init__(self, arr):
+ def __init__(self, arr, dim=-1, start=[]):
self.arr = arr
- self.indices = [0] * (len(arr.shape) - 1)
+ self.indices = [0] * len(arr.shape)
self.done = False
self.offset = arr.start
-
+ self.dim = len(arr.shape) - 1
+ if dim >= 0:
+ self.dim = dim
+ if len(start) == len(arr.shape):
+ for i in range(len(start)):
+ self.offset += arr.strides[i] * start[i]
def next(self):
- for i in range(len(self.arr.shape) - 2, -1, -1):
+ for i in range(len(self.arr.shape) - 1, -1, -1):
+ if i == self.dim:
+ continue
if self.indices[i] < self.arr.shape[i] - 1:
self.indices[i] += 1
self.offset += self.arr.strides[i]
diff --git a/pypy/module/micronumpy/test/test_iterators.py
b/pypy/module/micronumpy/test/test_iterators.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/test/test_iterators.py
@@ -0,0 +1,51 @@
+
+from pypy.module.micronumpy.interp_iter import AxisIterator
+from pypy.module.micronumpy.interp_numarray import W_NDimArray
+
+class MockDtype(object):
+ def malloc(self, size):
+ return None
+
+class TestAxisIteratorDirect(object):
+ def test_axis_iterator(self):
+ a = W_NDimArray(7*5*3, [7, 5, 3], MockDtype(), 'C')
+ i = AxisIterator(a)
+ ret = []
+ while not i.done:
+ ret.append(i.offset)
+ i.next()
+ assert ret == [3*v for v in range(7*5)]
+ i = AxisIterator(a,2)
+ ret = []
+ while not i.done:
+ ret.append(i.offset)
+ i.next()
+ assert ret == [3*v for v in range(7*5)]
+ i = AxisIterator(a,1)
+ ret = []
+ while not i.done:
+ ret.append(i.offset)
+ i.next()
+ assert ret == [ 0, 1, 2, 15, 16, 17, 30, 31, 32, 45, 46, 47,
+ 60, 61, 62, 75, 76, 77, 90, 91, 92]
+ def test_axis_iterator_with_start(self):
+ a = W_NDimArray(7*5*3, [7, 5, 3], MockDtype(), 'C')
+ i = AxisIterator(a, start=[0, 0, 0])
+ ret = []
+ while not i.done:
+ ret.append(i.offset)
+ i.next()
+ assert ret == [3*v for v in range(7*5)]
+ i = AxisIterator(a, start=[1, 1, 0])
+ ret = []
+ while not i.done:
+ ret.append(i.offset)
+ i.next()
+ assert ret == [3*v+18 for v in range(7*5)]
+ i = AxisIterator(a, 1, [2, 0, 2])
+ ret = []
+ while not i.done:
+ ret.append(i.offset)
+ i.next()
+ assert ret == [v + 32 for v in [ 0, 1, 2, 15, 16, 17, 30, 31, 32,
+ 45, 46, 47, 60, 61, 62, 75, 76, 77, 90, 91, 92]]
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit