Author: mattip Branch: numpypy-axisops Changeset: r50845:3cfc0b93cb23 Date: 2011-12-25 00:10 +0200 http://bitbucket.org/pypy/pypy/changeset/3cfc0b93cb23/
Log: test, implement improved AxisIterator diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py --- a/pypy/module/micronumpy/interp_iter.py +++ b/pypy/module/micronumpy/interp_iter.py @@ -106,16 +106,24 @@ # ------ other iterators that are not part of the computation frame ---------- class AxisIterator(object): - """ This object will return offsets of each start of the last stride + """ This object will return offsets of each start of a stride on the + desired dimension, starting at the desired index """ - def __init__(self, arr): + def __init__(self, arr, dim=-1, start=[]): self.arr = arr - self.indices = [0] * (len(arr.shape) - 1) + self.indices = [0] * len(arr.shape) self.done = False self.offset = arr.start - + self.dim = len(arr.shape) - 1 + if dim >= 0: + self.dim = dim + if len(start) == len(arr.shape): + for i in range(len(start)): + self.offset += arr.strides[i] * start[i] def next(self): - for i in range(len(self.arr.shape) - 2, -1, -1): + for i in range(len(self.arr.shape) - 1, -1, -1): + if i == self.dim: + continue if self.indices[i] < self.arr.shape[i] - 1: self.indices[i] += 1 self.offset += self.arr.strides[i] diff --git a/pypy/module/micronumpy/test/test_iterators.py b/pypy/module/micronumpy/test/test_iterators.py new file mode 100644 --- /dev/null +++ b/pypy/module/micronumpy/test/test_iterators.py @@ -0,0 +1,51 @@ + +from pypy.module.micronumpy.interp_iter import AxisIterator +from pypy.module.micronumpy.interp_numarray import W_NDimArray + +class MockDtype(object): + def malloc(self, size): + return None + +class TestAxisIteratorDirect(object): + def test_axis_iterator(self): + a = W_NDimArray(7*5*3, [7, 5, 3], MockDtype(), 'C') + i = AxisIterator(a) + ret = [] + while not i.done: + ret.append(i.offset) + i.next() + assert ret == [3*v for v in range(7*5)] + i = AxisIterator(a,2) + ret = [] + while not i.done: + ret.append(i.offset) + i.next() + assert ret == [3*v for v in range(7*5)] + i = AxisIterator(a,1) + ret = [] + while not i.done: + ret.append(i.offset) + i.next() + assert ret == [ 0, 1, 2, 15, 16, 17, 30, 31, 32, 45, 46, 47, + 60, 61, 62, 75, 76, 77, 90, 91, 92] + def test_axis_iterator_with_start(self): + a = W_NDimArray(7*5*3, [7, 5, 3], MockDtype(), 'C') + i = AxisIterator(a, start=[0, 0, 0]) + ret = [] + while not i.done: + ret.append(i.offset) + i.next() + assert ret == [3*v for v in range(7*5)] + i = AxisIterator(a, start=[1, 1, 0]) + ret = [] + while not i.done: + ret.append(i.offset) + i.next() + assert ret == [3*v+18 for v in range(7*5)] + i = AxisIterator(a, 1, [2, 0, 2]) + ret = [] + while not i.done: + ret.append(i.offset) + i.next() + assert ret == [v + 32 for v in [ 0, 1, 2, 15, 16, 17, 30, 31, 32, + 45, 46, 47, 60, 61, 62, 75, 76, 77, 90, 91, 92]] _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit