Author: Brian Kearns <[email protected]>
Branch:
Changeset: r70761:667ad75d7ce9
Date: 2014-04-18 14:51 -0400
http://bitbucket.org/pypy/pypy/changeset/667ad75d7ce9/
Log: simplify nditer
diff --git a/pypy/module/micronumpy/iterators.py
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -164,38 +164,6 @@
self.array.setitem(state.offset, elem)
-class SliceIterator(ArrayIter):
- def __init__(self, arr, strides, backstrides, shape, order="C",
- backward=False, dtype=None):
- if dtype is None:
- dtype = arr.implementation.dtype
- self.dtype = dtype
- self.arr = arr
- if backward:
- self.slicesize = shape[0]
- self.gap = [support.product(shape[1:]) * dtype.elsize]
- strides = strides[1:]
- backstrides = backstrides[1:]
- shape = shape[1:]
- strides.reverse()
- backstrides.reverse()
- shape.reverse()
- size = support.product(shape)
- else:
- shape = [support.product(shape)]
- strides, backstrides = calc_strides(shape, dtype, order)
- size = 1
- self.slicesize = support.product(shape)
- self.gap = strides
- ArrayIter.__init__(self, arr.implementation, size, shape, strides,
backstrides)
-
- def getslice(self):
- from pypy.module.micronumpy.concrete import SliceArray
- return SliceArray(self.offset, self.gap, self.backstrides,
- [self.slicesize], self.arr.implementation,
- self.arr, self.dtype)
-
-
def AxisIter(array, shape, axis, cumulative):
strides = array.get_strides()
backstrides = array.get_backstrides()
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -5,56 +5,33 @@
from pypy.module.micronumpy import ufuncs, support, concrete
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
+from pypy.module.micronumpy.iterators import ArrayIter
from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
shape_agreement,
shape_agreement_multiple)
-class Iterator(object):
- def __init__(self, nditer, index, it, op_flags):
- self.nditer = nditer
- self.index = index
- self.it = it
- self.st = it.reset()
- self.op_flags = op_flags
-
- def done(self):
- return self.it.done(self.st)
-
- def next(self):
- self.st = self.it.next(self.st)
-
- def getitem(self, space, array):
- return self.op_flags.get_it_item[self.index](space, self.nditer,
self.it, self.st)
-
- def setitem(self, space, array, val):
- xxx
-
-
def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
- ret = []
if space.is_w(w_op_flags, space.w_None):
- for i in range(n):
- ret.append(OpFlag())
- elif not space.isinstance_w(w_op_flags, space.w_tuple) and not \
+ w_op_flags = space.newtuple([space.wrap('readonly')])
+ if not space.isinstance_w(w_op_flags, space.w_tuple) and not \
space.isinstance_w(w_op_flags, space.w_list):
raise oefmt(space.w_ValueError,
'%s must be a tuple or array of per-op flag-tuples',
name)
+ ret = []
+ w_lst = space.listview(w_op_flags)
+ if space.isinstance_w(w_lst[0], space.w_tuple) or \
+ space.isinstance_w(w_lst[0], space.w_list):
+ if len(w_lst) != n:
+ raise oefmt(space.w_ValueError,
+ '%s must be a tuple or array of per-op flag-tuples',
+ name)
+ for item in w_lst:
+ ret.append(parse_one_arg(space, space.listview(item)))
else:
- w_lst = space.listview(w_op_flags)
- if space.isinstance_w(w_lst[0], space.w_tuple) or \
- space.isinstance_w(w_lst[0], space.w_list):
- if len(w_lst) != n:
- raise oefmt(space.w_ValueError,
- '%s must be a tuple or array of per-op
flag-tuples',
- name)
- for item in w_lst:
- ret.append(parse_one_arg(space, space.listview(item)))
- else:
- op_flag = parse_one_arg(space, w_lst)
- for i in range(n):
- ret.append(op_flag)
+ op_flag = parse_one_arg(space, w_lst)
+ for i in range(n):
+ ret.append(op_flag)
return ret
@@ -67,29 +44,6 @@
self.native_byte_order = False
self.tmp_copy = ''
self.allocate = False
- self.get_it_item = (get_readonly_item, get_readonly_slice)
-
-
-def get_readonly_item(space, nditer, it, st):
- res = concrete.ConcreteNonWritableArrayWithBase(
- [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer)
- res.start = st.offset
- return W_NDimArray(res)
-
-
-def get_readwrite_item(space, nditer, it, st):
- res = concrete.ConcreteArrayWithBase(
- [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer)
- res.start = st.offset
- return W_NDimArray(res)
-
-
-def get_readonly_slice(space, array, it):
- return W_NDimArray(it.getslice().readonly())
-
-
-def get_readwrite_slice(space, array, it):
- return W_NDimArray(it.getslice())
def parse_op_flag(space, lst):
@@ -128,17 +82,10 @@
else:
raise OperationError(space.w_ValueError, space.wrap(
'op_flags must be a tuple or array of per-op flag-tuples'))
- if op_flag.rw == '':
- raise oefmt(space.w_ValueError,
- "None of the iterator flags READWRITE, READONLY, or "
- "WRITEONLY were specified for an operand")
- elif op_flag.rw == 'r':
- op_flag.get_it_item = (get_readonly_item, get_readonly_slice)
- elif op_flag.rw == 'rw':
- op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
- elif op_flag.rw == 'w':
- # XXX Extra logic needed to make sure writeonly
- op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
+ if op_flag.rw == '':
+ raise oefmt(space.w_ValueError,
+ "None of the iterator flags READWRITE, READONLY, or "
+ "WRITEONLY were specified for an operand")
return op_flag
@@ -230,12 +177,6 @@
return ArrayIter(imp, imp.get_size(), shape, r[0], r[1])
-def get_external_loop_iter(space, order, arr, shape):
- imp = arr.implementation
- backward = is_backward(imp, order)
- return SliceIterator(arr, imp.strides, imp.backstrides, shape,
order=order, backward=backward)
-
-
class IndexIterator(object):
def __init__(self, shape, backward=False):
self.shape = shape
@@ -326,8 +267,6 @@
out_dtype = None
for i in range(len(self.seq)):
if self.seq[i] is None:
- self.op_flags[i].get_it_item = (get_readwrite_item,
- get_readwrite_slice)
self.op_flags[i].allocate = True
continue
if self.op_flags[i].rw == 'w':
@@ -372,20 +311,9 @@
self.dtypes = [s.get_dtype() for s in self.seq]
# create an iterator for each operand
- if self.external_loop:
- for i in range(len(self.seq)):
- self.iters.append(Iterator(
- self, 1,
- get_external_loop_iter(
- space, self.order, self.seq[i], iter_shape),
- self.op_flags[i]))
- else:
- for i in range(len(self.seq)):
- self.iters.append(Iterator(
- self, 0,
- get_iter(
- space, self.order, self.seq[i], iter_shape,
self.dtypes[i]),
- self.op_flags[i]))
+ for i in range(len(self.seq)):
+ it = get_iter(space, self.order, self.seq[i], iter_shape,
self.dtypes[i])
+ self.iters.append((it, it.reset()))
def set_op_axes(self, space, w_op_axes):
if space.len_w(w_op_axes) != len(self.seq):
@@ -417,14 +345,24 @@
def descr_iter(self, space):
return space.wrap(self)
+ def getitem(self, it, st, op_flags):
+ if op_flags.rw == 'r':
+ impl = concrete.ConcreteNonWritableArrayWithBase
+ else:
+ impl = concrete.ConcreteArrayWithBase
+ res = impl([], it.array.dtype, it.array.order, [], [],
+ it.array.storage, self)
+ res.start = st.offset
+ return W_NDimArray(res)
+
def descr_getitem(self, space, w_idx):
idx = space.int_w(w_idx)
try:
- ret = space.wrap(self.iters[idx].getitem(space, self.seq[idx]))
+ it, st = self.iters[idx]
except IndexError:
raise oefmt(space.w_IndexError,
"Iterator operand index %d is out of bounds", idx)
- return ret
+ return self.getitem(it, st, self.op_flags[idx])
def descr_setitem(self, space, w_idx, w_value):
raise oefmt(space.w_NotImplementedError, "not implemented yet")
@@ -433,8 +371,8 @@
space.wrap(len(self.iters))
def descr_next(self, space):
- for it in self.iters:
- if not it.done():
+ for it, st in self.iters:
+ if not it.done(st):
break
else:
self.done = True
@@ -445,9 +383,9 @@
self.index_iter.next()
else:
self.first_next = False
- for i in range(len(self.iters)):
- res.append(self.iters[i].getitem(space, self.seq[i]))
- self.iters[i].next()
+ for i, (it, st) in enumerate(self.iters):
+ res.append(self.getitem(it, st, self.op_flags[i]))
+ self.iters[i] = (it, it.next(st))
if len(res) < 2:
return res[0]
return space.newtuple(res)
@@ -455,10 +393,10 @@
def iternext(self):
if self.index_iter:
self.index_iter.next()
- for i in range(len(self.iters)):
- self.iters[i].next()
- for it in self.iters:
- if not it.done():
+ for i, (it, st) in enumerate(self.iters):
+ self.iters[i] = (it, it.next(st))
+ for it, st in self.iters:
+ if not it.done(st):
break
else:
self.done = True
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit