Author: Matti Picus <[email protected]>
Branch: numpypy-nditer
Changeset: r64057:cc555ec312e2
Date: 2013-05-13 22:55 +0300
http://bitbucket.org/pypy/pypy/changeset/cc555ec312e2/
Log: shuffle things around to handle op_flags, readwrite almost works
diff --git a/pypy/module/micronumpy/interp_nditer.py
b/pypy/module/micronumpy/interp_nditer.py
--- a/pypy/module/micronumpy/interp_nditer.py
+++ b/pypy/module/micronumpy/interp_nditer.py
@@ -2,19 +2,118 @@
from pypy.interpreter.typedef import TypeDef, GetSetProperty,
make_weakref_descr
from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
from pypy.interpreter.error import OperationError
-from pypy.module.micronumpy.base import convert_to_array
+from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy.strides import calculate_broadcast_strides
from pypy.module.micronumpy.iter import MultiDimViewIterator
from pypy.module.micronumpy import support
+def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
+ ret = []
+ if space.is_w(w_op_flags, space.w_None):
+ for i in range(n):
+ ret.append(OpFlag())
+ elif not space.isinstance_w(w_op_flags, space.w_tuple) and not \
+ space.isinstance_w(w_op_flags, space.w_list):
+ raise OperationError(space.w_ValueError, space.wrap(
+ '%s must be a tuple or array of per-op flag-tuples' % name))
+ else:
+ w_lst = space.listview(w_op_flags)
+ if space.isinstance_w(w_lst[0], space.w_tuple) or \
+ space.isinstance_w(w_lst[0], space.w_list):
+ if len(w_lst) != n:
+ raise OperationError(space.w_ValueError, space.wrap(
+ '%s must be a tuple or array of per-op flag-tuples' % name))
+ for item in space.listview(w_lst):
+ ret.append(parse_one_arg(space, item))
+ else:
+ op_flag = parse_one_arg(space, w_lst)
+ for i in range(n):
+ ret.append(op_flag)
+ return ret
-def handle_sequence_args(space, cls, w_seq, w_op_flags, w_op_types, w_op_axes):
- '''
- Make sure that len(args) == 1 or len(w_seq)
- and set attribs on cls appropriately
- '''
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+class OpFlag(object):
+ def __init__(self):
+ self.rw = 'r'
+ self.broadcast = True
+ self.force_contig = False
+ self.force_align = False
+ self.native_byte_order = False
+ self.tmp_copy = ''
+ self.allocate = False
+ self.get_it_item = get_readonly_item
+
+def get_readonly_item(space, it):
+ return space.wrap(it.getitem())
+
+def get_readwrite_item(space, it):
+ res = W_NDimArray.from_shape([1], it.dtype, it.array.order)
+ it.dtype.setitem(res.implementation, 0, it.getitem())
+ return res
+
+def parse_op_flag(space, lst):
+ op_flag = OpFlag()
+ for w_item in lst:
+ item = space.str_w(w_item)
+ if item == 'readonly':
+ op_flag.rw = 'r'
+ elif item == 'readwrite':
+ op_flag.rw = 'rw'
+ elif item == 'writeonly':
+ op_flag.rw = 'w'
+ elif item == 'no_broadcast':
+ op_flag.broadcast = False
+ elif item == 'contig':
+ op_flag.force_contig = True
+ elif item == 'aligned':
+ op_flag.force_align = True
+ elif item == 'nbo':
+ op_flag.native_byte_order = True
+ elif item == 'copy':
+ op_flag.tmp_copy = 'r'
+ elif item == 'updateifcopy':
+ op_flag.tmp_copy = 'rw'
+ elif item == 'allocate':
+ op_flag.allocate = True
+ elif item == 'no_subtype':
+ raise OperationError(space.w_NotImplementedError, space.wrap(
+ '"no_subtype" op_flag not implemented yet'))
+ elif item == 'arraymask':
+ raise OperationError(space.w_NotImplementedError, space.wrap(
+ '"arraymask" op_flag not implemented yet'))
+ elif item == 'writemask':
+ raise OperationError(space.w_NotImplementedError, space.wrap(
+ '"writemask" op_flag not implemented yet'))
+ else:
+ raise OperationError(space.w_ValueError, space.wrap(
+ 'op_flags must be a tuple or array of per-op flag-tuples'))
+ if op_flag.rw == 'r':
+ op_flag.get_it_item = get_readonly_item
+ elif op_flag.rw == 'rw':
+ op_flag.get_it_item = get_readwrite_item
+ return op_flag
+
+def get_iter(space, order, imp, backward):
+ if order == 'K' or (order == 'C' and imp.order == 'C'):
+ backward = False
+ elif order =='F' and imp.order == 'C':
+ backward = True
+ else:
+ raise OperationError(space.w_NotImplementedError, space.wrap(
+ 'not implemented yet'))
+ if (imp.strides[0] < imp.strides[-1] and not backward) or \
+ (imp.strides[0] > imp.strides[-1] and backward):
+ # flip the strides. Is this always true for multidimension?
+ strides = [s for s in imp.strides[::-1]]
+ backstrides = [s for s in imp.backstrides[::-1]]
+ shape = [s for s in imp.shape[::-1]]
+ else:
+ strides = imp.strides
+ backstrides = imp.backstrides
+ shape = imp.shape
+ shape1d = [support.product(imp.shape),]
+ r = calculate_broadcast_strides(strides, backstrides, shape,
+ shape1d, backward)
+ return MultiDimViewIterator(imp, imp.dtype, imp.start, r[0], r[1], shape)
class W_NDIter(W_Root):
@@ -22,33 +121,18 @@
def __init__(self, space, w_seq, w_flags, w_op_flags, w_op_dtypes,
w_casting,
w_op_axes, w_itershape, w_buffersize, order):
self.order = order
- if space.isinstance_w(w_seq, space.w_tuple) or
space.isinstance_w(w_seq, space.w_list):
- handle_sequence_args(space, self, w_seq, w_op_flags, w_op_dtypes,
w_op_axes)
+ if space.isinstance_w(w_seq, space.w_tuple) or \
+ space.isinstance_w(w_seq, space.w_list):
+ w_seq_as_list = space.listview(w_seq)
+ self.seq = [convert_to_array(space, w_elem) for w_elem in
w_seq_as_list]
else:
self.seq =[convert_to_array(space, w_seq)]
- if order == 'K' or (order == 'C' and self.seq[0].get_order() ==
'C'):
- backward = False
- elif order =='F' and self.seq[0].get_order() == 'C':
- backward = True
- else:
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
- imp = self.seq[0].implementation
- if (imp.strides[0] < imp.strides[-1] and not backward) or \
- (imp.strides[0] > imp.strides[-1] and backward):
- # flip the strides. Is this always true for multidimension?
- strides = [s for s in imp.strides[::-1]]
- backstrides = [s for s in imp.backstrides[::-1]]
- shape = [s for s in imp.shape[::-1]]
- else:
- strides = imp.strides
- backstrides = imp.backstrides
- shape = imp.shape
- shape1d = [support.product(imp.shape),]
- r = calculate_broadcast_strides(strides, backstrides, shape,
- shape1d, backward)
- self.iters = [MultiDimViewIterator(imp, imp.dtype, imp.start,
r[0], r[1],
- shape)]
+ self.op_flags = parse_op_arg(space, 'op_flags', w_op_flags,
+ len(self.seq), parse_op_flag)
+ self.iters=[]
+ for i in range(len(self.seq)):
+ self.iters.append(get_iter(space, self.order,
+ self.seq[i].implementation, self.op_flags[i]))
def descr_iter(self, space):
return space.wrap(self)
@@ -72,9 +156,9 @@
else:
raise OperationError(space.w_StopIteration, space.w_None)
res = []
- for it in self.iters:
- res.append(space.wrap(it.getitem()))
- it.next()
+ for i in range(len(self.iters)):
+ res.append(self.op_flags[i].get_it_item(space, self.iters[i]))
+ self.iters[i].next()
if len(res) <2:
return res[0]
return space.newtuple(res)
diff --git a/pypy/module/micronumpy/test/test_nditer.py
b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -39,6 +39,7 @@
from numpypy import arange, nditer
a = arange(6).reshape(2,3)
for x in nditer(a, op_flags=['readwrite']):
+ print x,x.shape
x[...] = 2 * x
assert (a == [[0, 2, 4], [6, 8, 10]]).all()
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit