Author: mattip <[email protected]>
Branch: nditer-external_loop
Changeset: r74288:8eac2bffd5c2
Date: 2014-10-24 15:59 +0300
http://bitbucket.org/pypy/pypy/changeset/8eac2bffd5c2/

Log:    start to handle external_loop with SliceIter whose getitem() returns
        an ndarray

diff --git a/pypy/module/micronumpy/iterators.py 
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -216,3 +216,34 @@
         size /= shape[axis]
     shape[axis] = backstrides[axis] = 0
     return ArrayIter(array, size, shape, array.strides, backstrides)
+
+class SliceIter(ArrayIter):
+    '''
+    used with external loops, getitem and setitem return a SliceArray
+    view into the original array
+    '''
+
+    def __init__(self, array, size, shape, strides, backstrides):
+        ArrayIter.__init__(self, array, size, shape, strides, backstrides)
+        self.slice_shape = array.get_shape()[len(shape):]
+        self.slice_strides = array.strides[len(shape):]
+        self.slice_backstrides = array.backstrides[len(shape):]
+
+    def getitem(self, state):
+        from pypy.module.micronumpy.concrete import SliceArray
+        assert state.iterator is self
+        return SliceArray(state.offset, self.slice_strides,
+                 self.slice_backstrides, self.slice_shape, self.array,
+                 self.array)
+
+    def getitem_bool(self, state):
+        # XXX cannot be called
+        assert False
+
+    def setitem(self, state, elem):
+        assert state.iterator is self
+        slice = SliceArray(state.offset, self.slice_strides,
+                 self.slice_backstrides, self.slice_shape, self.array,
+                 self.array)
+        # TODO: implement
+        assert False
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -5,7 +5,7 @@
 from pypy.module.micronumpy import ufuncs, support, concrete
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy.iterators import ArrayIter
+from pypy.module.micronumpy.iterators import ArrayIter, SliceIter
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
                                             shape_agreement, 
shape_agreement_multiple)
 
@@ -186,15 +186,43 @@
             ndim = max(ndim, op.ndims())
     return ndim
 
-def coalexce_axes(it, space):
+def coalesce_axes(it, space):
     # Copy logic from npyiter_coalesce_axes, used in ufunc iterators
     # and in nditer's with 'external_loop' flag
+    out_shape = it.shape[:]
     for idim in range(it.ndim - 1):
-        can_coalesce = 1
-        for op in it.seq:
-            stride = op.implementation.get_strides()
-            shape = op.get_shape()
-            pass
+        can_coalesce = True
+        for op_it, _ in it.iters:
+            if op_it is None:
+                continue
+            assert isinstance(op_it, ArrayIter)
+            if len(op_it.shape_m1) < 2:
+                can_coalesce = False
+                continue
+            if len(op_it.shape_m1) != len(it.shape):
+                can_coalesce = False
+                break
+            if op_it.strides[-1] * op_it.shape_m1[-1] != op_it.backstrides[-1]:
+                can_coalesce = False
+        if can_coalesce:
+            if it.order == 'F':
+                last = out_shape[0]
+                out_shape = out_shape[1:]
+                out_tshape[0] *= last
+            else:
+                last = out_shape[-1]
+                out_shape = out_shape[:-1]
+                out_shape[-1] *= last
+            for i in range(len(it.iters)):
+                old_iter = it.iters[i][0]
+                shape = [s+1 for s in old_iter.shape_m1]
+                new_iter = SliceIter(old_iter.array, old_iter.size,
+                                shape[:-1], old_iter.strides[:-1],
+                                old_iter.backstrides[:-1])
+                it.iters[i] = (new_iter, new_iter.reset())
+            it.shape = out_shape
+        else:
+            return
 
 class IndexIterator(object):
     def __init__(self, shape, backward=False):
@@ -272,7 +300,6 @@
             self.dtypes = []
 
         # handle None or writable operands, calculate my shape
-        self.iters = []
         outargs = [i for i in range(len(self.seq))
                    if self.seq[i] is None or self.op_flags[i].rw == 'w']
         if len(outargs) > 0:
@@ -332,15 +359,16 @@
             #copy them from seq
             self.dtypes = [s.get_dtype() for s in self.seq]
 
-        if self.external_loop:
-            coalexce_axes(self, space)
-
         # create an iterator for each operand
+        self.iters = []
         for i in range(len(self.seq)):
             it = get_iter(space, self.order, self.seq[i], self.shape, 
self.dtypes[i])
             it.contiguous = False
             self.iters.append((it, it.reset()))
 
+        if self.external_loop:
+            coalesce_axes(self, space)
+
     def set_op_axes(self, space, w_op_axes):
         if space.len_w(w_op_axes) != len(self.seq):
             raise oefmt(space.w_ValueError,
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to