Author: Brian Kearns <[email protected]>
Branch:
Changeset: r69521:aab294de242a
Date: 2014-02-27 11:29 -0500
http://bitbucket.org/pypy/pypy/changeset/aab294de242a/
Log: kill duplicate code in AxisIterator, add AllButAxisIterator
diff --git a/pypy/module/micronumpy/concrete.py
b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -283,9 +283,10 @@
self.get_backstrides(),
self.get_shape(), shape,
backward_broadcast)
- return iter.ArrayIterator(self, shape, r[0], r[1])
- return iter.ArrayIterator(self, self.shape, self.strides,
- self.backstrides)
+ return iter.ArrayIterator(self, support.product(shape), shape,
+ r[0], r[1])
+ return iter.ArrayIterator(self, self.get_size(), self.shape,
+ self.strides, self.backstrides)
def create_axis_iter(self, shape, dim, cum):
return iter.AxisIterator(self, shape, dim, cum)
@@ -293,7 +294,8 @@
def create_dot_iter(self, shape, skip):
r = calculate_dot_strides(self.get_strides(), self.get_backstrides(),
shape, skip)
- return iter.ArrayIterator(self, shape, r[0], r[1])
+ return iter.ArrayIterator(self, support.product(shape), shape,
+ r[0], r[1])
def swapaxes(self, space, orig_arr, axis1, axis2):
shape = self.get_shape()[:]
diff --git a/pypy/module/micronumpy/iter.py b/pypy/module/micronumpy/iter.py
--- a/pypy/module/micronumpy/iter.py
+++ b/pypy/module/micronumpy/iter.py
@@ -82,10 +82,11 @@
_immutable_fields_ = ['array', 'start', 'size', 'ndim_m1', 'shape_m1',
'strides', 'backstrides']
- def __init__(self, array, shape, strides, backstrides):
+ def __init__(self, array, size, shape, strides, backstrides):
+ assert len(shape) == len(strides) == len(backstrides)
self.array = array
self.start = array.start
- self.size = support.product(shape)
+ self.size = size
self.ndim_m1 = len(shape) - 1
self.shape_m1 = [s - 1 for s in shape]
self.strides = strides
@@ -141,44 +142,25 @@
self.array.setitem(self.offset, elem)
-class AxisIterator(ArrayIterator):
- def __init__(self, array, shape, dim, cumulative):
- self.shape = shape
- strides = array.get_strides()
- backstrides = array.get_backstrides()
- if cumulative:
- self.strides = strides
- self.backstrides = backstrides
- elif len(shape) == len(strides):
+def AxisIterator(array, shape, axis, cumulative):
+ strides = array.get_strides()
+ backstrides = array.get_backstrides()
+ if not cumulative:
+ if len(shape) == len(strides):
# keepdims = True
- self.strides = strides[:dim] + [0] + strides[dim + 1:]
- self.backstrides = backstrides[:dim] + [0] + backstrides[dim + 1:]
+ strides = strides[:axis] + [0] + strides[axis + 1:]
+ backstrides = backstrides[:axis] + [0] + backstrides[axis + 1:]
else:
- self.strides = strides[:dim] + [0] + strides[dim:]
- self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
- self.first_line = True
- self.indices = [0] * len(shape)
- self._done = array.get_size() == 0
- self.offset = array.start
- self.dim = dim
- self.array = array
+ strides = strides[:axis] + [0] + strides[axis:]
+ backstrides = backstrides[:axis] + [0] + backstrides[axis:]
+ return ArrayIterator(array, support.product(shape), shape, strides,
backstrides)
- @jit.unroll_safe
- def next(self):
- for i in range(len(self.shape) - 1, -1, -1):
- if self.indices[i] < self.shape[i] - 1:
- if i == self.dim:
- self.first_line = False
- self.indices[i] += 1
- self.offset += self.strides[i]
- break
- else:
- if i == self.dim:
- self.first_line = True
- self.indices[i] = 0
- self.offset -= self.backstrides[i]
- else:
- self._done = True
- def done(self):
- return self._done
+def AllButAxisIterator(array, axis):
+ size = array.get_size()
+ shape = array.get_shape()[:]
+ backstrides = array.backstrides[:]
+ if size:
+ size /= shape[axis]
+ shape[axis] = backstrides[axis] = 0
+ return ArrayIterator(array, size, shape, array.strides, backstrides)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -215,16 +215,14 @@
while not out_iter.done():
axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
dtype=dtype)
- if arr_iter.done():
- w_val = identity
+ assert not arr_iter.done()
+ w_val = arr_iter.getitem().convert_to(space, dtype)
+ if out_iter.indices[axis] == 0:
+ if identity is not None:
+ w_val = func(dtype, identity, w_val)
else:
- w_val = arr_iter.getitem().convert_to(space, dtype)
- if out_iter.first_line:
- if identity is not None:
- w_val = func(dtype, identity, w_val)
- else:
- cur = temp_iter.getitem()
- w_val = func(dtype, cur, w_val)
+ cur = temp_iter.getitem()
+ w_val = func(dtype, cur, w_val)
out_iter.setitem(w_val)
if cumulative:
temp_iter.setitem(w_val)
diff --git a/pypy/module/micronumpy/sort.py b/pypy/module/micronumpy/sort.py
--- a/pypy/module/micronumpy/sort.py
+++ b/pypy/module/micronumpy/sort.py
@@ -11,7 +11,7 @@
from rpython.rtyper.lltypesystem import rffi, lltype
from pypy.module.micronumpy import descriptor, types, constants as NPY
from pypy.module.micronumpy.base import W_NDimArray
-from pypy.module.micronumpy.iter import AxisIterator
+from pypy.module.micronumpy.iter import AllButAxisIterator
INT_SIZE = rffi.sizeof(lltype.Signed)
@@ -146,21 +146,20 @@
if axis < 0 or axis >= len(shape):
raise OperationError(space.w_IndexError, space.wrap(
"Wrong axis %d" % axis))
- iterable_shape = shape[:axis] + [0] + shape[axis + 1:]
- iter = AxisIterator(arr, iterable_shape, axis, False)
+ arr_iter = AllButAxisIterator(arr, axis)
index_impl = index_arr.implementation
- index_iter = AxisIterator(index_impl, iterable_shape, axis, False)
+ index_iter = AllButAxisIterator(index_impl, axis)
stride_size = arr.strides[axis]
index_stride_size = index_impl.strides[axis]
axis_size = arr.shape[axis]
- while not iter.done():
+ while not arr_iter.done():
for i in range(axis_size):
raw_storage_setitem(storage, i * index_stride_size +
index_iter.offset, i)
r = Repr(index_stride_size, stride_size, axis_size,
- arr.get_storage(), storage, index_iter.offset,
iter.offset)
+ arr.get_storage(), storage, index_iter.offset,
arr_iter.offset)
ArgSort(r).sort()
- iter.next()
+ arr_iter.next()
index_iter.next()
return index_arr
@@ -292,14 +291,13 @@
if axis < 0 or axis >= len(shape):
raise OperationError(space.w_IndexError, space.wrap(
"Wrong axis %d" % axis))
- iterable_shape = shape[:axis] + [0] + shape[axis + 1:]
- iter = AxisIterator(arr, iterable_shape, axis, False)
+ arr_iter = AllButAxisIterator(arr, axis)
stride_size = arr.strides[axis]
axis_size = arr.shape[axis]
- while not iter.done():
- r = Repr(stride_size, axis_size, arr.get_storage(),
iter.offset)
+ while not arr_iter.done():
+ r = Repr(stride_size, axis_size, arr.get_storage(),
arr_iter.offset)
ArgSort(r).sort()
- iter.next()
+ arr_iter.next()
return sort
diff --git a/pypy/module/micronumpy/test/test_iter.py
b/pypy/module/micronumpy/test/test_iter.py
--- a/pypy/module/micronumpy/test/test_iter.py
+++ b/pypy/module/micronumpy/test/test_iter.py
@@ -1,8 +1,8 @@
+from pypy.module.micronumpy import support
from pypy.module.micronumpy.iter import ArrayIterator
class MockArray(object):
- size = 1
start = 0
@@ -14,7 +14,8 @@
strides = [5, 1]
backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
assert backstrides == [10, 4]
- i = ArrayIterator(MockArray, shape, strides, backstrides)
+ i = ArrayIterator(MockArray, support.product(shape), shape,
+ strides, backstrides)
i.next()
i.next()
i.next()
@@ -32,7 +33,8 @@
strides = [1, 3]
backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
assert backstrides == [2, 12]
- i = ArrayIterator(MockArray, shape, strides, backstrides)
+ i = ArrayIterator(MockArray, support.product(shape), shape,
+ strides, backstrides)
i.next()
i.next()
i.next()
@@ -52,7 +54,8 @@
strides = [5, 1]
backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
assert backstrides == [10, 4]
- i = ArrayIterator(MockArray, shape, strides, backstrides)
+ i = ArrayIterator(MockArray, support.product(shape), shape,
+ strides, backstrides)
i.next_skip_x(2)
i.next_skip_x(2)
i.next_skip_x(2)
@@ -75,7 +78,8 @@
strides = [1, 3]
backstrides = [x * (y - 1) for x,y in zip(strides, shape)]
assert backstrides == [2, 12]
- i = ArrayIterator(MockArray, shape, strides, backstrides)
+ i = ArrayIterator(MockArray, support.product(shape), shape,
+ strides, backstrides)
i.next_skip_x(2)
i.next_skip_x(2)
i.next_skip_x(2)
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -237,6 +237,10 @@
dtype = out.get_dtype()
else:
out = W_NDimArray.from_shape(space, shape, dtype,
w_instance=obj)
+ if obj.get_size() == 0:
+ if self.identity is not None:
+ out.fill(space, self.identity.convert_to(space, dtype))
+ return out
return loop.do_axis_reduce(space, shape, self.func, obj, dtype,
axis, out,
self.identity, cumulative, temp)
if cumulative:
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit