Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r69521:aab294de242a Date: 2014-02-27 11:29 -0500 http://bitbucket.org/pypy/pypy/changeset/aab294de242a/
Log: kill duplicate code in AxisIterator, add AllButAxisIterator diff --git a/pypy/module/micronumpy/concrete.py b/pypy/module/micronumpy/concrete.py --- a/pypy/module/micronumpy/concrete.py +++ b/pypy/module/micronumpy/concrete.py @@ -283,9 +283,10 @@ self.get_backstrides(), self.get_shape(), shape, backward_broadcast) - return iter.ArrayIterator(self, shape, r[0], r[1]) - return iter.ArrayIterator(self, self.shape, self.strides, - self.backstrides) + return iter.ArrayIterator(self, support.product(shape), shape, + r[0], r[1]) + return iter.ArrayIterator(self, self.get_size(), self.shape, + self.strides, self.backstrides) def create_axis_iter(self, shape, dim, cum): return iter.AxisIterator(self, shape, dim, cum) @@ -293,7 +294,8 @@ def create_dot_iter(self, shape, skip): r = calculate_dot_strides(self.get_strides(), self.get_backstrides(), shape, skip) - return iter.ArrayIterator(self, shape, r[0], r[1]) + return iter.ArrayIterator(self, support.product(shape), shape, + r[0], r[1]) def swapaxes(self, space, orig_arr, axis1, axis2): shape = self.get_shape()[:] diff --git a/pypy/module/micronumpy/iter.py b/pypy/module/micronumpy/iter.py --- a/pypy/module/micronumpy/iter.py +++ b/pypy/module/micronumpy/iter.py @@ -82,10 +82,11 @@ _immutable_fields_ = ['array', 'start', 'size', 'ndim_m1', 'shape_m1', 'strides', 'backstrides'] - def __init__(self, array, shape, strides, backstrides): + def __init__(self, array, size, shape, strides, backstrides): + assert len(shape) == len(strides) == len(backstrides) self.array = array self.start = array.start - self.size = support.product(shape) + self.size = size self.ndim_m1 = len(shape) - 1 self.shape_m1 = [s - 1 for s in shape] self.strides = strides @@ -141,44 +142,25 @@ self.array.setitem(self.offset, elem) -class AxisIterator(ArrayIterator): - def __init__(self, array, shape, dim, cumulative): - self.shape = shape - strides = array.get_strides() - backstrides = array.get_backstrides() - if cumulative: - self.strides = strides - self.backstrides = backstrides - elif len(shape) == len(strides): +def AxisIterator(array, shape, axis, cumulative): + strides = array.get_strides() + backstrides = array.get_backstrides() + if not cumulative: + if len(shape) == len(strides): # keepdims = True - self.strides = strides[:dim] + [0] + strides[dim + 1:] - self.backstrides = backstrides[:dim] + [0] + backstrides[dim + 1:] + strides = strides[:axis] + [0] + strides[axis + 1:] + backstrides = backstrides[:axis] + [0] + backstrides[axis + 1:] else: - self.strides = strides[:dim] + [0] + strides[dim:] - self.backstrides = backstrides[:dim] + [0] + backstrides[dim:] - self.first_line = True - self.indices = [0] * len(shape) - self._done = array.get_size() == 0 - self.offset = array.start - self.dim = dim - self.array = array + strides = strides[:axis] + [0] + strides[axis:] + backstrides = backstrides[:axis] + [0] + backstrides[axis:] + return ArrayIterator(array, support.product(shape), shape, strides, backstrides) - @jit.unroll_safe - def next(self): - for i in range(len(self.shape) - 1, -1, -1): - if self.indices[i] < self.shape[i] - 1: - if i == self.dim: - self.first_line = False - self.indices[i] += 1 - self.offset += self.strides[i] - break - else: - if i == self.dim: - self.first_line = True - self.indices[i] = 0 - self.offset -= self.backstrides[i] - else: - self._done = True - def done(self): - return self._done +def AllButAxisIterator(array, axis): + size = array.get_size() + shape = array.get_shape()[:] + backstrides = array.backstrides[:] + if size: + size /= shape[axis] + shape[axis] = backstrides[axis] = 0 + return ArrayIterator(array, size, shape, array.strides, backstrides) diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -215,16 +215,14 @@ while not out_iter.done(): axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func, dtype=dtype) - if arr_iter.done(): - w_val = identity + assert not arr_iter.done() + w_val = arr_iter.getitem().convert_to(space, dtype) + if out_iter.indices[axis] == 0: + if identity is not None: + w_val = func(dtype, identity, w_val) else: - w_val = arr_iter.getitem().convert_to(space, dtype) - if out_iter.first_line: - if identity is not None: - w_val = func(dtype, identity, w_val) - else: - cur = temp_iter.getitem() - w_val = func(dtype, cur, w_val) + cur = temp_iter.getitem() + w_val = func(dtype, cur, w_val) out_iter.setitem(w_val) if cumulative: temp_iter.setitem(w_val) diff --git a/pypy/module/micronumpy/sort.py b/pypy/module/micronumpy/sort.py --- a/pypy/module/micronumpy/sort.py +++ b/pypy/module/micronumpy/sort.py @@ -11,7 +11,7 @@ from rpython.rtyper.lltypesystem import rffi, lltype from pypy.module.micronumpy import descriptor, types, constants as NPY from pypy.module.micronumpy.base import W_NDimArray -from pypy.module.micronumpy.iter import AxisIterator +from pypy.module.micronumpy.iter import AllButAxisIterator INT_SIZE = rffi.sizeof(lltype.Signed) @@ -146,21 +146,20 @@ if axis < 0 or axis >= len(shape): raise OperationError(space.w_IndexError, space.wrap( "Wrong axis %d" % axis)) - iterable_shape = shape[:axis] + [0] + shape[axis + 1:] - iter = AxisIterator(arr, iterable_shape, axis, False) + arr_iter = AllButAxisIterator(arr, axis) index_impl = index_arr.implementation - index_iter = AxisIterator(index_impl, iterable_shape, axis, False) + index_iter = AllButAxisIterator(index_impl, axis) stride_size = arr.strides[axis] index_stride_size = index_impl.strides[axis] axis_size = arr.shape[axis] - while not iter.done(): + while not arr_iter.done(): for i in range(axis_size): raw_storage_setitem(storage, i * index_stride_size + index_iter.offset, i) r = Repr(index_stride_size, stride_size, axis_size, - arr.get_storage(), storage, index_iter.offset, iter.offset) + arr.get_storage(), storage, index_iter.offset, arr_iter.offset) ArgSort(r).sort() - iter.next() + arr_iter.next() index_iter.next() return index_arr @@ -292,14 +291,13 @@ if axis < 0 or axis >= len(shape): raise OperationError(space.w_IndexError, space.wrap( "Wrong axis %d" % axis)) - iterable_shape = shape[:axis] + [0] + shape[axis + 1:] - iter = AxisIterator(arr, iterable_shape, axis, False) + arr_iter = AllButAxisIterator(arr, axis) stride_size = arr.strides[axis] axis_size = arr.shape[axis] - while not iter.done(): - r = Repr(stride_size, axis_size, arr.get_storage(), iter.offset) + while not arr_iter.done(): + r = Repr(stride_size, axis_size, arr.get_storage(), arr_iter.offset) ArgSort(r).sort() - iter.next() + arr_iter.next() return sort diff --git a/pypy/module/micronumpy/test/test_iter.py b/pypy/module/micronumpy/test/test_iter.py --- a/pypy/module/micronumpy/test/test_iter.py +++ b/pypy/module/micronumpy/test/test_iter.py @@ -1,8 +1,8 @@ +from pypy.module.micronumpy import support from pypy.module.micronumpy.iter import ArrayIterator class MockArray(object): - size = 1 start = 0 @@ -14,7 +14,8 @@ strides = [5, 1] backstrides = [x * (y - 1) for x,y in zip(strides, shape)] assert backstrides == [10, 4] - i = ArrayIterator(MockArray, shape, strides, backstrides) + i = ArrayIterator(MockArray, support.product(shape), shape, + strides, backstrides) i.next() i.next() i.next() @@ -32,7 +33,8 @@ strides = [1, 3] backstrides = [x * (y - 1) for x,y in zip(strides, shape)] assert backstrides == [2, 12] - i = ArrayIterator(MockArray, shape, strides, backstrides) + i = ArrayIterator(MockArray, support.product(shape), shape, + strides, backstrides) i.next() i.next() i.next() @@ -52,7 +54,8 @@ strides = [5, 1] backstrides = [x * (y - 1) for x,y in zip(strides, shape)] assert backstrides == [10, 4] - i = ArrayIterator(MockArray, shape, strides, backstrides) + i = ArrayIterator(MockArray, support.product(shape), shape, + strides, backstrides) i.next_skip_x(2) i.next_skip_x(2) i.next_skip_x(2) @@ -75,7 +78,8 @@ strides = [1, 3] backstrides = [x * (y - 1) for x,y in zip(strides, shape)] assert backstrides == [2, 12] - i = ArrayIterator(MockArray, shape, strides, backstrides) + i = ArrayIterator(MockArray, support.product(shape), shape, + strides, backstrides) i.next_skip_x(2) i.next_skip_x(2) i.next_skip_x(2) diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -237,6 +237,10 @@ dtype = out.get_dtype() else: out = W_NDimArray.from_shape(space, shape, dtype, w_instance=obj) + if obj.get_size() == 0: + if self.identity is not None: + out.fill(space, self.identity.convert_to(space, dtype)) + return out return loop.do_axis_reduce(space, shape, self.func, obj, dtype, axis, out, self.identity, cumulative, temp) if cumulative: _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit