Author: mattip <[email protected]>
Branch: nditer-external_loop
Changeset: r74289:476916a3e563
Date: 2014-10-24 17:36 +0300
http://bitbucket.org/pypy/pypy/changeset/476916a3e563/
Log: coalescing almost works
diff --git a/pypy/module/micronumpy/iterators.py
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -41,6 +41,16 @@
from pypy.module.micronumpy.base import W_NDimArray
from pypy.module.micronumpy.flagsobj import _update_contiguous_flags
+class OpFlag(object):
+ def __init__(self):
+ self.rw = ''
+ self.broadcast = True
+ self.force_contig = False
+ self.force_align = False
+ self.native_byte_order = False
+ self.tmp_copy = ''
+ self.allocate = False
+
class PureShapeIter(object):
def __init__(self, shape, idx_w):
@@ -89,11 +99,12 @@
class ArrayIter(object):
_immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1',
'shape_m1[*]',
'strides[*]', 'backstrides[*]', 'factors[*]',
- 'track_index']
+ 'track_index', 'operand_type']
track_index = True
- def __init__(self, array, size, shape, strides, backstrides):
+ def __init__(self, array, size, shape, strides, backstrides,
op_flags=OpFlag()):
+ from pypy.module.micronumpy import concrete
assert len(shape) == len(strides) == len(backstrides)
_update_contiguous_flags(array)
self.contiguous = (array.flags & NPY.ARRAY_C_CONTIGUOUS and
@@ -114,6 +125,10 @@
else:
factors[ndim-i-1] = factors[ndim-i] * shape[ndim-i]
self.factors = factors
+ if op_flags.rw == 'r':
+ self.operand_type = concrete.ConcreteNonWritableArrayWithBase
+ else:
+ self.operand_type = concrete.ConcreteArrayWithBase
@jit.unroll_safe
def reset(self, state=None):
@@ -223,8 +238,8 @@
view into the original array
'''
- def __init__(self, array, size, shape, strides, backstrides):
- ArrayIter.__init__(self, array, size, shape, strides, backstrides)
+ def __init__(self, array, size, shape, strides, backstrides, op_flags):
+ ArrayIter.__init__(self, array, size, shape, strides, backstrides,
op_flags)
self.slice_shape = array.get_shape()[len(shape):]
self.slice_strides = array.strides[len(shape):]
self.slice_backstrides = array.backstrides[len(shape):]
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -5,7 +5,7 @@
from pypy.module.micronumpy import ufuncs, support, concrete
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIter
+from pypy.module.micronumpy.iterators import ArrayIter, SliceIter, OpFlag
from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
shape_agreement,
shape_agreement_multiple)
@@ -35,17 +35,6 @@
return ret
-class OpFlag(object):
- def __init__(self):
- self.rw = ''
- self.broadcast = True
- self.force_contig = False
- self.force_align = False
- self.native_byte_order = False
- self.tmp_copy = ''
- self.allocate = False
-
-
def parse_op_flag(space, lst):
op_flag = OpFlag()
for w_item in lst:
@@ -153,11 +142,11 @@
raise NotImplementedError('not implemented yet')
-def get_iter(space, order, arr, shape, dtype):
+def get_iter(space, order, arr, shape, dtype, op_flags):
imp = arr.implementation
backward = is_backward(imp, order)
if arr.is_scalar():
- return ArrayIter(imp, 1, [], [], [])
+ return ArrayIter(imp, 1, [], [], [], op_flags=op_flags)
if (imp.strides[0] < imp.strides[-1] and not backward) or \
(imp.strides[0] > imp.strides[-1] and backward):
# flip the strides. Is this always true for multidimension?
@@ -172,7 +161,7 @@
backstrides = imp.backstrides
r = calculate_broadcast_strides(strides, backstrides, imp.shape,
shape, backward)
- return ArrayIter(imp, imp.get_size(), shape, r[0], r[1])
+ return ArrayIter(imp, imp.get_size(), shape, r[0], r[1], op_flags=op_flags)
def calculate_ndim(op_in, oa_ndim):
if oa_ndim >=0:
@@ -208,17 +197,15 @@
if it.order == 'F':
last = out_shape[0]
out_shape = out_shape[1:]
- out_tshape[0] *= last
else:
last = out_shape[-1]
out_shape = out_shape[:-1]
- out_shape[-1] *= last
for i in range(len(it.iters)):
old_iter = it.iters[i][0]
shape = [s+1 for s in old_iter.shape_m1]
- new_iter = SliceIter(old_iter.array, old_iter.size,
+ new_iter = SliceIter(old_iter.array, old_iter.size / last,
shape[:-1], old_iter.strides[:-1],
- old_iter.backstrides[:-1])
+ old_iter.backstrides[:-1], it.op_flags[i])
it.iters[i] = (new_iter, new_iter.reset())
it.shape = out_shape
else:
@@ -362,7 +349,8 @@
# create an iterator for each operand
self.iters = []
for i in range(len(self.seq)):
- it = get_iter(space, self.order, self.seq[i], self.shape,
self.dtypes[i])
+ it = get_iter(space, self.order, self.seq[i], self.shape,
+ self.dtypes[i], self.op_flags[i])
it.contiguous = False
self.iters.append((it, it.reset()))
@@ -400,11 +388,8 @@
def descr_iter(self, space):
return space.wrap(self)
- def getitem(self, it, st, op_flags):
- if op_flags.rw == 'r':
- impl = concrete.ConcreteNonWritableArrayWithBase
- else:
- impl = concrete.ConcreteArrayWithBase
+ def getitem(self, it, st):
+ impl = it.operand_type
res = impl([], it.array.dtype, it.array.order, [], [],
it.array.storage, self)
res.start = st.offset
@@ -417,7 +402,7 @@
except IndexError:
raise oefmt(space.w_IndexError,
"Iterator operand index %d is out of bounds", idx)
- return self.getitem(it, st, self.op_flags[idx])
+ return self.getitem(it, st)
def descr_setitem(self, space, w_idx, w_value):
raise oefmt(space.w_NotImplementedError, "not implemented yet")
@@ -426,6 +411,7 @@
space.wrap(len(self.iters))
def descr_next(self, space):
+ import pdb;pdb.set_trace()
for it, st in self.iters:
if not it.done(st):
break
@@ -439,7 +425,7 @@
else:
self.first_next = False
for i, (it, st) in enumerate(self.iters):
- res.append(self.getitem(it, st, self.op_flags[i]))
+ res.append(self.getitem(it, st))
self.iters[i] = (it, it.next(st))
if len(res) < 2:
return res[0]
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit