Author: mattip <[email protected]>
Branch: nditer-external_loop
Changeset: r74288:8eac2bffd5c2
Date: 2014-10-24 15:59 +0300
http://bitbucket.org/pypy/pypy/changeset/8eac2bffd5c2/
Log: start to handle external_loop with SliceIter whose getitem() returns
an ndarray
diff --git a/pypy/module/micronumpy/iterators.py
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -216,3 +216,34 @@
size /= shape[axis]
shape[axis] = backstrides[axis] = 0
return ArrayIter(array, size, shape, array.strides, backstrides)
+
+class SliceIter(ArrayIter):
+ '''
+ used with external loops, getitem and setitem return a SliceArray
+ view into the original array
+ '''
+
+ def __init__(self, array, size, shape, strides, backstrides):
+ ArrayIter.__init__(self, array, size, shape, strides, backstrides)
+ self.slice_shape = array.get_shape()[len(shape):]
+ self.slice_strides = array.strides[len(shape):]
+ self.slice_backstrides = array.backstrides[len(shape):]
+
+ def getitem(self, state):
+ from pypy.module.micronumpy.concrete import SliceArray
+ assert state.iterator is self
+ return SliceArray(state.offset, self.slice_strides,
+ self.slice_backstrides, self.slice_shape, self.array,
+ self.array)
+
+ def getitem_bool(self, state):
+ # XXX cannot be called
+ assert False
+
+ def setitem(self, state, elem):
+ assert state.iterator is self
+ slice = SliceArray(state.offset, self.slice_strides,
+ self.slice_backstrides, self.slice_shape, self.array,
+ self.array)
+ # TODO: implement
+ assert False
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -5,7 +5,7 @@
from pypy.module.micronumpy import ufuncs, support, concrete
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy.iterators import ArrayIter
+from pypy.module.micronumpy.iterators import ArrayIter, SliceIter
from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
shape_agreement,
shape_agreement_multiple)
@@ -186,15 +186,43 @@
ndim = max(ndim, op.ndims())
return ndim
-def coalexce_axes(it, space):
+def coalesce_axes(it, space):
# Copy logic from npyiter_coalesce_axes, used in ufunc iterators
# and in nditer's with 'external_loop' flag
+ out_shape = it.shape[:]
for idim in range(it.ndim - 1):
- can_coalesce = 1
- for op in it.seq:
- stride = op.implementation.get_strides()
- shape = op.get_shape()
- pass
+ can_coalesce = True
+ for op_it, _ in it.iters:
+ if op_it is None:
+ continue
+ assert isinstance(op_it, ArrayIter)
+ if len(op_it.shape_m1) < 2:
+ can_coalesce = False
+ continue
+ if len(op_it.shape_m1) != len(it.shape):
+ can_coalesce = False
+ break
+ if op_it.strides[-1] * op_it.shape_m1[-1] != op_it.backstrides[-1]:
+ can_coalesce = False
+ if can_coalesce:
+ if it.order == 'F':
+ last = out_shape[0]
+ out_shape = out_shape[1:]
+ out_tshape[0] *= last
+ else:
+ last = out_shape[-1]
+ out_shape = out_shape[:-1]
+ out_shape[-1] *= last
+ for i in range(len(it.iters)):
+ old_iter = it.iters[i][0]
+ shape = [s+1 for s in old_iter.shape_m1]
+ new_iter = SliceIter(old_iter.array, old_iter.size,
+ shape[:-1], old_iter.strides[:-1],
+ old_iter.backstrides[:-1])
+ it.iters[i] = (new_iter, new_iter.reset())
+ it.shape = out_shape
+ else:
+ return
class IndexIterator(object):
def __init__(self, shape, backward=False):
@@ -272,7 +300,6 @@
self.dtypes = []
# handle None or writable operands, calculate my shape
- self.iters = []
outargs = [i for i in range(len(self.seq))
if self.seq[i] is None or self.op_flags[i].rw == 'w']
if len(outargs) > 0:
@@ -332,15 +359,16 @@
#copy them from seq
self.dtypes = [s.get_dtype() for s in self.seq]
- if self.external_loop:
- coalexce_axes(self, space)
-
# create an iterator for each operand
+ self.iters = []
for i in range(len(self.seq)):
it = get_iter(space, self.order, self.seq[i], self.shape,
self.dtypes[i])
it.contiguous = False
self.iters.append((it, it.reset()))
+ if self.external_loop:
+ coalesce_axes(self, space)
+
def set_op_axes(self, space, w_op_axes):
if space.len_w(w_op_axes) != len(self.seq):
raise oefmt(space.w_ValueError,
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit