Author: mattip <[email protected]>
Branch: nditer-external_loop
Changeset: r74289:476916a3e563
Date: 2014-10-24 17:36 +0300
http://bitbucket.org/pypy/pypy/changeset/476916a3e563/

Log:    coalescing almost works

diff --git a/pypy/module/micronumpy/iterators.py 
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -41,6 +41,16 @@
 from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy.flagsobj import _update_contiguous_flags
 
+class OpFlag(object):
+    def __init__(self):
+        self.rw = ''
+        self.broadcast = True
+        self.force_contig = False
+        self.force_align = False
+        self.native_byte_order = False
+        self.tmp_copy = ''
+        self.allocate = False
+
 
 class PureShapeIter(object):
     def __init__(self, shape, idx_w):
@@ -89,11 +99,12 @@
 class ArrayIter(object):
     _immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 
'shape_m1[*]',
                           'strides[*]', 'backstrides[*]', 'factors[*]',
-                          'track_index']
+                          'track_index', 'operand_type']
 
     track_index = True
 
-    def __init__(self, array, size, shape, strides, backstrides):
+    def __init__(self, array, size, shape, strides, backstrides, 
op_flags=OpFlag()):
+        from pypy.module.micronumpy import concrete
         assert len(shape) == len(strides) == len(backstrides)
         _update_contiguous_flags(array)
         self.contiguous = (array.flags & NPY.ARRAY_C_CONTIGUOUS and
@@ -114,6 +125,10 @@
             else:
                 factors[ndim-i-1] = factors[ndim-i] * shape[ndim-i]
         self.factors = factors
+        if op_flags.rw == 'r':
+            self.operand_type = concrete.ConcreteNonWritableArrayWithBase
+        else:
+            self.operand_type = concrete.ConcreteArrayWithBase
 
     @jit.unroll_safe
     def reset(self, state=None):
@@ -223,8 +238,8 @@
     view into the original array
     '''
 
-    def __init__(self, array, size, shape, strides, backstrides):
-        ArrayIter.__init__(self, array, size, shape, strides, backstrides)
+    def __init__(self, array, size, shape, strides, backstrides, op_flags):
+        ArrayIter.__init__(self, array, size, shape, strides, backstrides, 
op_flags)
         self.slice_shape = array.get_shape()[len(shape):]
         self.slice_strides = array.strides[len(shape):]
         self.slice_backstrides = array.backstrides[len(shape):]
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -5,7 +5,7 @@
 from pypy.module.micronumpy import ufuncs, support, concrete
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIter
+from pypy.module.micronumpy.iterators import ArrayIter, SliceIter, OpFlag
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
                                             shape_agreement, 
shape_agreement_multiple)
 
@@ -35,17 +35,6 @@
     return ret
 
 
-class OpFlag(object):
-    def __init__(self):
-        self.rw = ''
-        self.broadcast = True
-        self.force_contig = False
-        self.force_align = False
-        self.native_byte_order = False
-        self.tmp_copy = ''
-        self.allocate = False
-
-
 def parse_op_flag(space, lst):
     op_flag = OpFlag()
     for w_item in lst:
@@ -153,11 +142,11 @@
         raise NotImplementedError('not implemented yet')
 
 
-def get_iter(space, order, arr, shape, dtype):
+def get_iter(space, order, arr, shape, dtype, op_flags):
     imp = arr.implementation
     backward = is_backward(imp, order)
     if arr.is_scalar():
-        return ArrayIter(imp, 1, [], [], [])
+        return ArrayIter(imp, 1, [], [], [], op_flags=op_flags)
     if (imp.strides[0] < imp.strides[-1] and not backward) or \
        (imp.strides[0] > imp.strides[-1] and backward):
         # flip the strides. Is this always true for multidimension?
@@ -172,7 +161,7 @@
         backstrides = imp.backstrides
     r = calculate_broadcast_strides(strides, backstrides, imp.shape,
                                     shape, backward)
-    return ArrayIter(imp, imp.get_size(), shape, r[0], r[1])
+    return ArrayIter(imp, imp.get_size(), shape, r[0], r[1], op_flags=op_flags)
 
 def calculate_ndim(op_in, oa_ndim):
     if oa_ndim >=0:
@@ -208,17 +197,15 @@
             if it.order == 'F':
                 last = out_shape[0]
                 out_shape = out_shape[1:]
-                out_tshape[0] *= last
             else:
                 last = out_shape[-1]
                 out_shape = out_shape[:-1]
-                out_shape[-1] *= last
             for i in range(len(it.iters)):
                 old_iter = it.iters[i][0]
                 shape = [s+1 for s in old_iter.shape_m1]
-                new_iter = SliceIter(old_iter.array, old_iter.size,
+                new_iter = SliceIter(old_iter.array, old_iter.size / last,
                                 shape[:-1], old_iter.strides[:-1],
-                                old_iter.backstrides[:-1])
+                                old_iter.backstrides[:-1], it.op_flags[i])
                 it.iters[i] = (new_iter, new_iter.reset())
             it.shape = out_shape
         else:
@@ -362,7 +349,8 @@
         # create an iterator for each operand
         self.iters = []
         for i in range(len(self.seq)):
-            it = get_iter(space, self.order, self.seq[i], self.shape, 
self.dtypes[i])
+            it = get_iter(space, self.order, self.seq[i], self.shape,
+                          self.dtypes[i], self.op_flags[i])
             it.contiguous = False
             self.iters.append((it, it.reset()))
 
@@ -400,11 +388,8 @@
     def descr_iter(self, space):
         return space.wrap(self)
 
-    def getitem(self, it, st, op_flags):
-        if op_flags.rw == 'r':
-            impl = concrete.ConcreteNonWritableArrayWithBase
-        else:
-            impl = concrete.ConcreteArrayWithBase
+    def getitem(self, it, st):
+        impl = it.operand_type
         res = impl([], it.array.dtype, it.array.order, [], [],
                    it.array.storage, self)
         res.start = st.offset
@@ -417,7 +402,7 @@
         except IndexError:
             raise oefmt(space.w_IndexError,
                         "Iterator operand index %d is out of bounds", idx)
-        return self.getitem(it, st, self.op_flags[idx])
+        return self.getitem(it, st)
 
     def descr_setitem(self, space, w_idx, w_value):
         raise oefmt(space.w_NotImplementedError, "not implemented yet")
@@ -426,6 +411,7 @@
         space.wrap(len(self.iters))
 
     def descr_next(self, space):
+        import pdb;pdb.set_trace()
         for it, st in self.iters:
             if not it.done(st):
                 break
@@ -439,7 +425,7 @@
             else:
                 self.first_next = False
         for i, (it, st) in enumerate(self.iters):
-            res.append(self.getitem(it, st, self.op_flags[i]))
+            res.append(self.getitem(it, st))
             self.iters[i] = (it, it.next(st))
         if len(res) < 2:
             return res[0]
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to