Author: mattip <matti.pi...@gmail.com> Branch: nditer-external_loop Changeset: r74285:4267ee5fd6ed Date: 2014-10-20 22:03 +0200 http://bitbucket.org/pypy/pypy/changeset/4267ee5fd6ed/
Log: start to implement external_loop diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py --- a/pypy/module/micronumpy/nditer.py +++ b/pypy/module/micronumpy/nditer.py @@ -174,6 +174,26 @@ shape, backward) return ArrayIter(imp, imp.get_size(), shape, r[0], r[1]) +def calculate_ndim(op_in, oa_ndim): + if oa_ndim >=0: + return oa_ndim + else: + ndim = 0 + for op in op_in: + if op is None: + continue + assert isinstance(op, W_NDimArray) + ndim = max(ndim, op.ndims()) + return ndim + +def coalexce_axes(iter, space): + # Copy logic from npyiter_coalesce_axes, used in ufunc iterators + # and in nditer's with 'external_loop' flag + import pdb;pdb.set_trace() + for idim in range(iter.ndim - 1): + can_coalesce = 1 + for op in self.ops: + pass class IndexIterator(object): def __init__(self, shape, backward=False): @@ -203,6 +223,7 @@ class W_NDIter(W_Root): + _immutable_fields_ = ['ndim', ] def __init__(self, space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting, w_op_axes, w_itershape, w_buffersize, order): self.order = order @@ -234,8 +255,10 @@ self.op_flags = parse_op_arg(space, 'op_flags', w_op_flags, len(self.seq), parse_op_flag) # handle w_op_axes + oa_ndim = -1 if not space.is_none(w_op_axes): - self.set_op_axes(space, w_op_axes) + oa_ndim = self.set_op_axes(space, w_op_axes) + self.ndim = calculate_ndim(self.seq, oa_ndim) # handle w_op_dtypes part 1: creating self.dtypes list from input if not space.is_none(w_op_dtypes): @@ -255,7 +278,7 @@ out_shape = shape_agreement_multiple(space, [self.seq[i] for i in outargs]) else: out_shape = None - self.shape = iter_shape = shape_agreement_multiple(space, self.seq, + self.shape = shape_agreement_multiple(space, self.seq, shape=out_shape) if len(outargs) > 0: # Make None operands writeonly and flagged for allocation @@ -274,11 +297,11 @@ for i in outargs: if self.seq[i] is None: # XXX can we postpone allocation to later? - self.seq[i] = W_NDimArray.from_shape(space, iter_shape, out_dtype) + self.seq[i] = W_NDimArray.from_shape(space, self.shape, out_dtype) else: if not self.op_flags[i].broadcast: # Raises if ooutput cannot be broadcast - shape_agreement(space, iter_shape, self.seq[i], False) + shape_agreement(space, self.shape, self.seq[i], False) if self.tracked_index != "": if self.order == "K": @@ -287,7 +310,7 @@ backward = False else: backward = self.order != self.tracked_index - self.index_iter = IndexIterator(iter_shape, backward=backward) + self.index_iter = IndexIterator(self.shape, backward=backward) # handle w_op_dtypes part 2: copy where needed if possible if len(self.dtypes) > 0: @@ -308,9 +331,12 @@ #copy them from seq self.dtypes = [s.get_dtype() for s in self.seq] + if self.external_loop: + coalexce_axes(self, space) + # create an iterator for each operand for i in range(len(self.seq)): - it = get_iter(space, self.order, self.seq[i], iter_shape, self.dtypes[i]) + it = get_iter(space, self.order, self.seq[i], self.shape, self.dtypes[i]) it.contiguous = False self.iters.append((it, it.reset())) @@ -319,18 +345,18 @@ raise oefmt(space.w_ValueError, "op_axes must be a tuple/list matching the number of ops") op_axes = space.listview(w_op_axes) - l = -1 + oa_ndim = -1 for w_axis in op_axes: if not space.is_none(w_axis): axis_len = space.len_w(w_axis) - if l == -1: - l = axis_len - elif axis_len != l: + if oa_ndim == -1: + oa_ndim = axis_len + elif axis_len != oa_ndim: raise oefmt(space.w_ValueError, "Each entry of op_axes must have the same size") self.op_axes.append([space.int_w(x) if not space.is_none(x) else -1 for x in space.listview(w_axis)]) - if l == -1: + if oa_ndim == -1: raise oefmt(space.w_ValueError, "If op_axes is provided, at least one list of axes " "must be contained within it") @@ -340,6 +366,7 @@ # ValueError: Iterator input op_axes[0][3] (==3) is not a valid axis of op[0], which has 2 dimensions # - no repeat axis # ValueError: The 'op_axes' provided to the iterator constructor for operand 1 contained duplicate value 0 + return oa_ndim def descr_iter(self, space): return space.wrap(self) @@ -475,7 +502,7 @@ raise oefmt(space.w_NotImplementedError, "not implemented yet") def descr_get_ndim(self, space): - raise oefmt(space.w_NotImplementedError, "not implemented yet") + return space.wrap(self.ndim) def descr_get_nop(self, space): raise oefmt(space.w_NotImplementedError, "not implemented yet") diff --git a/pypy/module/micronumpy/test/test_nditer.py b/pypy/module/micronumpy/test/test_nditer.py --- a/pypy/module/micronumpy/test/test_nditer.py +++ b/pypy/module/micronumpy/test/test_nditer.py @@ -63,9 +63,6 @@ from numpy import arange, nditer, array a = arange(24).reshape(2, 3, 4) import sys - if '__pypy__' in sys.builtin_module_names: - raises(NotImplementedError, nditer, a, flags=['external_loop']) - skip('nditer external_loop not implmented') r = [] n = 0 for x in nditer(a, flags=['external_loop']): @@ -222,9 +219,6 @@ def test_outarg(self): from numpy import nditer, zeros, arange import sys - if '__pypy__' in sys.builtin_module_names: - raises(NotImplementedError, nditer, [1, 2], flags=['external_loop']) - skip('nditer external_loop not implmented') def square1(a): it = nditer([a, None]) @@ -233,6 +227,9 @@ return it.operands[1] assert (square1([1, 2, 3]) == [1, 4, 9]).all() + if '__pypy__' in sys.builtin_module_names: + raises(NotImplementedError, nditer, [1, 2], flags=['buffered']) + skip('nditer buffered not implmented') def square2(a, out=None): it = nditer([a, out], flags=['external_loop', 'buffered'], op_flags=[['readonly'], @@ -252,9 +249,6 @@ from numpy import nditer, arange a = arange(3) import sys - if '__pypy__' in sys.builtin_module_names: - raises(NotImplementedError, nditer, a, flags=['external_loop']) - skip('nditer external_loop not implmented') b = arange(8).reshape(2,4) it = nditer([a, b, None], flags=['external_loop'], op_axes=[[0, -1, -1], [-1, 0, 1], None]) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit