Author: Matti Picus <matti.pi...@gmail.com> Branch: numpypy-nditer Changeset: r70671:0dca5996f880 Date: 2014-04-16 23:10 +0300 http://bitbucket.org/pypy/pypy/changeset/0dca5996f880/
Log: implement op_dtypes diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py --- a/pypy/module/micronumpy/nditer.py +++ b/pypy/module/micronumpy/nditer.py @@ -7,6 +7,7 @@ shape_agreement, shape_agreement_multiple) from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator from pypy.module.micronumpy.concrete import SliceArray +from pypy.module.micronumpy.descriptor import decode_w_dtype from pypy.module.micronumpy import ufuncs @@ -201,8 +202,8 @@ else: raise NotImplementedError('not implemented yet') -def get_iter(space, order, arr, shape): - imp = arr.implementation +def get_iter(space, order, arr, shape, dtype): + imp = arr.implementation.astype(space, dtype) backward = is_backward(imp, order) if (imp.strides[0] < imp.strides[-1] and not backward) or \ (imp.strides[0] > imp.strides[-1] and backward): @@ -291,8 +292,13 @@ if not space.is_none(w_op_axes): self.set_op_axes(space, w_op_axes) if not space.is_none(w_op_dtypes): - raise OperationError(space.w_NotImplementedError, space.wrap( - 'nditer op_dtypes kwarg not implemented yet')) + w_seq_as_list = space.listview(w_op_dtypes) + self.dtypes = [decode_w_dtype(space, w_elem) for w_elem in w_seq_as_list] + if len(self.dtypes) != len(self.seq): + raise OperationError(space.w_ValueError, space.wrap( + "op_dtypes must be a tuple/list matching the number of ops")) + else: + self.dtypes = [] self.iters=[] outargs = [i for i in range(len(self.seq)) \ if self.seq[i] is None or self.op_flags[i].rw == 'w'] @@ -304,7 +310,7 @@ shape=out_shape) if len(outargs) > 0: # Make None operands writeonly and flagged for allocation - out_dtype = None + out_dtype = self.dtypes[0] if len(self.dtypes) > 0 else None for i in range(len(self.seq)): if self.seq[i] is None: self.op_flags[i].get_it_item = (get_readwrite_item, @@ -331,6 +337,19 @@ else: backward = self.order != self.tracked_index self.index_iter = IndexIterator(iter_shape, backward=backward) + if len(self.dtypes) > 0: + # Make sure dtypes make sense + for i in range(len(self.seq)): + selfd = self.dtypes[i] + seq_d = self.seq[i].get_dtype() + if not selfd: + self.dtypes[i] = seq_d + elif selfd != seq_d and not 'r' in self.op_flags[i].tmp_copy: + raise OperationError(space.w_TypeError, space.wrap( + "Iterator operand required copying or buffering")) + else: + #copy them from seq + self.dtypes = [s.get_dtype() for s in self.seq] if self.external_loop: for i in range(len(self.seq)): self.iters.append(ExternalLoopIterator(get_external_loop_iter(space, self.order, @@ -338,7 +357,8 @@ else: for i in range(len(self.seq)): self.iters.append(BoxIterator(get_iter(space, self.order, - self.seq[i], iter_shape), self.op_flags[i])) + self.seq[i], iter_shape, self.dtypes[i]), + self.op_flags[i])) def set_op_axes(self, space, w_op_axes): if space.len_w(w_op_axes) != len(self.seq): diff --git a/pypy/module/micronumpy/test/test_nditer.py b/pypy/module/micronumpy/test/test_nditer.py --- a/pypy/module/micronumpy/test/test_nditer.py +++ b/pypy/module/micronumpy/test/test_nditer.py @@ -140,11 +140,7 @@ def test_op_dtype(self): from numpy import arange, nditer, sqrt, array - import sys a = arange(6).reshape(2,3) - 3 - if '__pypy__' in sys.builtin_module_names: - raises(NotImplementedError, nditer, a, op_dtypes=['complex']) - skip('nditer op_dtypes kwarg not implemented yet') exc = raises(TypeError, nditer, a, op_dtypes=['complex']) assert str(exc.value).startswith("Iterator operand required copying or buffering") r = [] @@ -154,7 +150,7 @@ assert abs((array(r) - [1.73205080757j, 1.41421356237j, 1j, 0j, 1+0j, 1.41421356237+0j]).sum()) < 1e-5 r = [] - for x in nditer(a, flags=['buffered'], + for x in nditer(a, op_flags=['copy'], op_dtypes=['complex128']): r.append(sqrt(x)) assert abs((array(r) - [1.73205080757j, 1.41421356237j, 1j, 0j, _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit