Author: Brian Kearns <[email protected]>
Branch:
Changeset: r70741:7ee3a18d1aca
Date: 2014-04-17 21:35 -0400
http://bitbucket.org/pypy/pypy/changeset/7ee3a18d1aca/
Log: cleanup nditer
diff --git a/pypy/module/micronumpy/iterators.py
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -172,14 +172,13 @@
size = 1
self.slicesize = support.product(shape)
self.gap = strides
-
ArrayIter.__init__(self, arr.implementation, size, shape, strides,
backstrides)
def getslice(self):
from pypy.module.micronumpy.concrete import SliceArray
- retVal = SliceArray(self.offset, self.gap, self.backstrides,
- [self.slicesize], self.arr.implementation, self.arr, self.dtype)
- return retVal
+ return SliceArray(self.offset, self.gap, self.backstrides,
+ [self.slicesize], self.arr.implementation,
+ self.arr, self.dtype)
def AxisIter(array, shape, axis, cumulative):
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -1,14 +1,14 @@
from pypy.interpreter.baseobjspace import W_Root
from pypy.interpreter.typedef import TypeDef, GetSetProperty
from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
-from pypy.interpreter.error import OperationError
+from pypy.interpreter.error import OperationError, oefmt
+from pypy.module.micronumpy import ufuncs, support
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
-from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
- shape_agreement,
shape_agreement_multiple)
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
from pypy.module.micronumpy.concrete import SliceArray
from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy import ufuncs, support
+from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
+from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
+ shape_agreement,
shape_agreement_multiple)
class AbstractIterator(object):
@@ -21,8 +21,10 @@
def getitem(self, space, array):
raise NotImplementedError("Abstract Class")
+
class IteratorMixin(object):
_mixin_ = True
+
def __init__(self, it, op_flags):
self.it = it
self.op_flags = op_flags
@@ -39,28 +41,33 @@
def setitem(self, space, array, val):
xxx
+
class BoxIterator(IteratorMixin, AbstractIterator):
index = 0
+
class ExternalLoopIterator(IteratorMixin, AbstractIterator):
index = 1
+
def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
ret = []
if space.is_w(w_op_flags, space.w_None):
for i in range(n):
ret.append(OpFlag())
elif not space.isinstance_w(w_op_flags, space.w_tuple) and not \
- space.isinstance_w(w_op_flags, space.w_list):
- raise OperationError(space.w_ValueError, space.wrap(
- '%s must be a tuple or array of per-op flag-tuples' % name))
+ space.isinstance_w(w_op_flags, space.w_list):
+ raise oefmt(space.w_ValueError,
+ '%s must be a tuple or array of per-op flag-tuples',
+ name)
else:
w_lst = space.listview(w_op_flags)
if space.isinstance_w(w_lst[0], space.w_tuple) or \
space.isinstance_w(w_lst[0], space.w_list):
if len(w_lst) != n:
- raise OperationError(space.w_ValueError, space.wrap(
- '%s must be a tuple or array of per-op flag-tuples' % name))
+ raise oefmt(space.w_ValueError,
+ '%s must be a tuple or array of per-op
flag-tuples',
+ name)
for item in w_lst:
ret.append(parse_one_arg(space, space.listview(item)))
else:
@@ -69,6 +76,7 @@
ret.append(op_flag)
return ret
+
class OpFlag(object):
def __init__(self):
self.rw = 'r'
@@ -80,21 +88,26 @@
self.allocate = False
self.get_it_item = (get_readonly_item, get_readonly_slice)
+
def get_readonly_item(space, array, it):
return space.wrap(it.getitem())
+
def get_readwrite_item(space, array, it):
#create a single-value view (since scalars are not views)
- res = SliceArray(it.array.start + it.offset, [0], [0], [1,], it.array,
array)
+ res = SliceArray(it.array.start + it.offset, [0], [0], [1], it.array,
array)
#it.dtype.setitem(res, 0, it.getitem())
return W_NDimArray(res)
+
def get_readonly_slice(space, array, it):
return W_NDimArray(it.getslice().readonly())
+
def get_readwrite_slice(space, array, it):
return W_NDimArray(it.getslice())
+
def parse_op_flag(space, lst):
op_flag = OpFlag()
for w_item in lst:
@@ -121,16 +134,16 @@
op_flag.allocate = True
elif item == 'no_subtype':
raise OperationError(space.w_NotImplementedError, space.wrap(
- '"no_subtype" op_flag not implemented yet'))
+ '"no_subtype" op_flag not implemented yet'))
elif item == 'arraymask':
raise OperationError(space.w_NotImplementedError, space.wrap(
- '"arraymask" op_flag not implemented yet'))
+ '"arraymask" op_flag not implemented yet'))
elif item == 'writemask':
raise OperationError(space.w_NotImplementedError, space.wrap(
- '"writemask" op_flag not implemented yet'))
+ '"writemask" op_flag not implemented yet'))
else:
raise OperationError(space.w_ValueError, space.wrap(
- 'op_flags must be a tuple or array of per-op flag-tuples'))
+ 'op_flags must be a tuple or array of per-op flag-tuples'))
if op_flag.rw == 'r':
op_flag.get_it_item = (get_readonly_item, get_readonly_slice)
elif op_flag.rw == 'rw':
@@ -140,20 +153,22 @@
op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
return op_flag
+
def parse_func_flags(space, nditer, w_flags):
if space.is_w(w_flags, space.w_None):
return
elif not space.isinstance_w(w_flags, space.w_tuple) and not \
- space.isinstance_w(w_flags, space.w_list):
+ space.isinstance_w(w_flags, space.w_list):
raise OperationError(space.w_ValueError, space.wrap(
- 'Iter global flags must be a list or tuple of strings'))
+ 'Iter global flags must be a list or tuple of strings'))
lst = space.listview(w_flags)
for w_item in lst:
if not space.isinstance_w(w_item, space.w_str) and not \
- space.isinstance_w(w_item, space.w_unicode):
+ space.isinstance_w(w_item, space.w_unicode):
typename = space.type(w_item).getname(space)
- raise OperationError(space.w_TypeError, space.wrap(
- 'expected string or Unicode object, %s found' % typename))
+ raise oefmt(space.w_TypeError,
+ 'expected string or Unicode object, %s found',
+ typename)
item = space.str_w(w_item)
if item == 'external_loop':
raise OperationError(space.w_NotImplementedError, space.wrap(
@@ -187,21 +202,24 @@
elif item == 'zerosize_ok':
nditer.zerosize_ok = True
else:
- raise OperationError(space.w_ValueError, space.wrap(
- 'Unexpected iterator global flag "%s"' % item))
+ raise oefmt(space.w_ValueError,
+ 'Unexpected iterator global flag "%s"',
+ item)
if nditer.tracked_index and nditer.external_loop:
- raise OperationError(space.w_ValueError, space.wrap(
- 'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
- 'multi-index is being tracked'))
+ raise OperationError(space.w_ValueError, space.wrap(
+ 'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
+ 'multi-index is being tracked'))
+
def is_backward(imp, order):
if order == 'K' or (order == 'C' and imp.order == 'C'):
return False
- elif order =='F' and imp.order == 'C':
+ elif order == 'F' and imp.order == 'C':
return True
else:
raise NotImplementedError('not implemented yet')
+
def get_iter(space, order, arr, shape, dtype):
imp = arr.implementation
backward = is_backward(imp, order)
@@ -223,11 +241,13 @@
shape, backward)
return ArrayIter(imp, imp.get_size(), shape, r[0], r[1])
+
def get_external_loop_iter(space, order, arr, shape):
imp = arr.implementation
backward = is_backward(imp, order)
return SliceIterator(arr, imp.strides, imp.backstrides, shape,
order=order, backward=backward)
+
def convert_to_array_or_none(space, w_elem):
'''
None will be passed through, all others will be converted
@@ -263,10 +283,10 @@
ret += self.index[i] * self.shape[i - 1]
return ret
+
class W_NDIter(W_Root):
-
def __init__(self, space, w_seq, w_flags, w_op_flags, w_op_dtypes,
w_casting,
- w_op_axes, w_itershape, w_buffersize, order):
+ w_op_axes, w_itershape, w_buffersize, order):
self.order = order
self.external_loop = False
self.buffered = False
@@ -288,7 +308,7 @@
w_seq_as_list = space.listview(w_seq)
self.seq = [convert_to_array_or_none(space, w_elem) for w_elem in
w_seq_as_list]
else:
- self.seq =[convert_to_array(space, w_seq)]
+ self.seq = [convert_to_array(space, w_seq)]
parse_func_flags(space, self, w_flags)
self.op_flags = parse_op_arg(space, 'op_flags', w_op_flags,
@@ -308,9 +328,9 @@
self.dtypes = []
# handle None or writable operands, calculate my shape
- self.iters=[]
- outargs = [i for i in range(len(self.seq)) \
- if self.seq[i] is None or self.op_flags[i].rw == 'w']
+ self.iters = []
+ outargs = [i for i in range(len(self.seq))
+ if self.seq[i] is None or self.op_flags[i].rw == 'w']
if len(outargs) > 0:
out_shape = shape_agreement_multiple(space, [self.seq[i] for i in
outargs])
else:
@@ -326,13 +346,13 @@
for i in range(len(self.seq)):
if self.seq[i] is None:
self.op_flags[i].get_it_item = (get_readwrite_item,
- get_readwrite_slice)
+ get_readwrite_slice)
self.op_flags[i].allocate = True
continue
if self.op_flags[i].rw == 'w':
continue
- out_dtype = ufuncs.find_binop_result_dtype(space,
- self.seq[i].get_dtype(),
out_dtype)
+ out_dtype = ufuncs.find_binop_result_dtype(
+ space, self.seq[i].get_dtype(), out_dtype)
for i in outargs:
if self.seq[i] is None:
# XXX can we postpone allocation to later?
@@ -372,13 +392,17 @@
# create an iterator for each operand
if self.external_loop:
for i in range(len(self.seq)):
-
self.iters.append(ExternalLoopIterator(get_external_loop_iter(space, self.order,
- self.seq[i], iter_shape), self.op_flags[i]))
+ self.iters.append(ExternalLoopIterator(
+ get_external_loop_iter(
+ space, self.order, self.seq[i], iter_shape),
+ self.op_flags[i]))
else:
for i in range(len(self.seq)):
- self.iters.append(BoxIterator(get_iter(space, self.order,
- self.seq[i], iter_shape, self.dtypes[i]),
- self.op_flags[i]))
+ self.iters.append(BoxIterator(
+ get_iter(
+ space, self.order, self.seq[i], iter_shape,
self.dtypes[i]),
+ self.op_flags[i]))
+
def set_op_axes(self, space, w_op_axes):
if space.len_w(w_op_axes) != len(self.seq):
raise OperationError(space.w_ValueError, space.wrap("op_axes must
be a tuple/list matching the number of ops"))
@@ -435,7 +459,7 @@
for i in range(len(self.iters)):
res.append(self.iters[i].getitem(space, self.seq[i]))
self.iters[i].next()
- if len(res) <2:
+ if len(res) < 2:
return res[0]
return space.newtuple(res)
@@ -551,14 +575,14 @@
'not implemented yet'))
-@unwrap_spec(w_flags = WrappedDefault(None), w_op_flags=WrappedDefault(None),
- w_op_dtypes = WrappedDefault(None), order=str,
+@unwrap_spec(w_flags=WrappedDefault(None), w_op_flags=WrappedDefault(None),
+ w_op_dtypes=WrappedDefault(None), order=str,
w_casting=WrappedDefault(None), w_op_axes=WrappedDefault(None),
w_itershape=WrappedDefault(None),
w_buffersize=WrappedDefault(None))
def nditer(space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting,
w_op_axes,
- w_itershape, w_buffersize, order='K'):
+ w_itershape, w_buffersize, order='K'):
return W_NDIter(space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting,
w_op_axes,
- w_itershape, w_buffersize, order)
+ w_itershape, w_buffersize, order)
W_NDIter.typedef = TypeDef(
'nditer',
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit