Author: Romain Guillebert <romain...@gmail.com> Branch: numpypy-nditer Changeset: r64813:8c3a4fc396d3 Date: 2013-06-06 17:21 +0200 http://bitbucket.org/pypy/pypy/changeset/8c3a4fc396d3/
Log: Refactor the way nditer iterates diff --git a/pypy/module/micronumpy/interp_nditer.py b/pypy/module/micronumpy/interp_nditer.py --- a/pypy/module/micronumpy/interp_nditer.py +++ b/pypy/module/micronumpy/interp_nditer.py @@ -5,10 +5,41 @@ from pypy.module.micronumpy.base import W_NDimArray, convert_to_array from pypy.module.micronumpy.strides import (calculate_broadcast_strides, shape_agreement_multiple) -from pypy.module.micronumpy.iter import MultiDimViewIterator +from pypy.module.micronumpy.iter import MultiDimViewIterator, SliceIterator from pypy.module.micronumpy import support from pypy.module.micronumpy.arrayimpl.concrete import SliceArray +class AbstractIterator(object): + def done(self): + raise NotImplementedError("Abstract Class") + + def next(self): + raise NotImplementedError("Abstract Class") + + def getitem(self, array): + raise NotImplementedError("Abstract Class") + +class IteratorMixin(object): + _mixin_ = True + def __init__(self, it, op_flags): + self.it = it + self.op_flags = op_flags + + def done(self): + return self.it.done() + + def next(self): + self.it.next() + + def getitem(self, space, array): + return self.op_flags.get_it_item(space, array, self.it) + +class BoxIterator(IteratorMixin): + pass + +class SliceIterator(IteratorMixin): + pass + def parse_op_arg(space, name, w_op_flags, n, parse_one_arg): ret = [] if space.is_w(w_op_flags, space.w_None): @@ -53,6 +84,13 @@ #it.dtype.setitem(res, 0, it.getitem()) return W_NDimArray(res) +def get_readonly_slice(space, array, it): + #XXX Not readonly + return W_NDimArray(it.getslice()) + +def get_readwrite_slice(space, array, it): + return W_NDimArray(it.getslice()) + def parse_op_flag(space, lst): op_flag = OpFlag() for w_item in lst: @@ -191,11 +229,11 @@ self.iters=[] self.shape = iter_shape = shape_agreement_multiple(space, self.seq) if self.external_loop: - xxx find longest contiguous shape + #XXX find longest contiguous shape iter_shape = iter_shape[1:] for i in range(len(self.seq)): - self.iters.append(get_iter(space, self.order, - self.seq[i].implementation, iter_shape)) + self.iters.append(BoxIterator(get_iter(space, self.order, + self.seq[i].implementation, iter_shape), self.op_flags[i])) def descr_iter(self, space): return space.wrap(self) @@ -220,8 +258,7 @@ raise OperationError(space.w_StopIteration, space.w_None) res = [] for i in range(len(self.iters)): - res.append(self.op_flags[i].get_it_item(space, self.seq[i], - self.iters[i])) + res.append(self.iters[i].getitem(space, self.seq[i])) self.iters[i].next() if len(res) <2: return res[0] diff --git a/pypy/module/micronumpy/iter.py b/pypy/module/micronumpy/iter.py --- a/pypy/module/micronumpy/iter.py +++ b/pypy/module/micronumpy/iter.py @@ -32,13 +32,13 @@ shape dimension which is back 25 and forward 1, which is x.strides[1] * (x.shape[1] - 1) + x.strides[0] -so if we precalculate the overflow backstride as +so if we precalculate the overflow backstride as [x.strides[i] * (x.shape[i] - 1) for i in range(len(x.shape))] we can go faster. All the calculations happen in next() next_skip_x() tries to do the iteration for a number of steps at once, -but then we cannot gaurentee that we only overflow one single shape +but then we cannot gaurentee that we only overflow one single shape dimension, perhaps we could overflow times in one big step. """ @@ -266,6 +266,30 @@ def reset(self): self.offset %= self.size +class SliceIterator(object): + def __init__(self, arr, stride, backstride, shape, dtype=None): + self.step = 0 + self.arr = arr + self.stride = stride + self.backstride = backstride + self.shape = shape + if dtype is None: + dtype = arr.implementation.dtype + self.dtype = dtype + self._done = False + + def done(): + return self._done + + def next(): + self.step += self.arr.implementation.dtype.get_size() + if self.step == self.backstride - self.implementation.dtype.get_size(): + self._done = True + + def getslice(self): + from pypy.module.micronumpy.arrayimpl.concrete import SliceArray + return SliceArray(self.step, [self.stride], [self.backstride], self.shape, self.arr.implementation, self.arr, self.dtype) + class AxisIterator(base.BaseArrayIterator): def __init__(self, array, shape, dim, cumultative): self.shape = shape @@ -288,7 +312,7 @@ self.dim = dim self.array = array self.dtype = array.dtype - + def setitem(self, elem): self.dtype.setitem(self.array, self.offset, elem) _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit