Author: Maciej Fijalkowski <fij...@gmail.com> Branch: missing-ndarray-attributes Changeset: r58584:47b57e79e2fa Date: 2012-10-29 14:59 +0100 http://bitbucket.org/pypy/pypy/changeset/47b57e79e2fa/
Log: implement ndarray.choose diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py --- a/pypy/module/micronumpy/arrayimpl/concrete.py +++ b/pypy/module/micronumpy/arrayimpl/concrete.py @@ -13,12 +13,6 @@ from pypy.module.micronumpy.arrayimpl.sort import argsort_array from pypy.rlib.debug import make_sure_not_resized -def int_w(space, w_obj): - try: - return space.int_w(space.index(w_obj)) - except OperationError: - return space.int_w(space.int(w_obj)) - class BaseConcreteArray(base.BaseArrayImplementation): start = 0 parent = None @@ -85,7 +79,7 @@ for i, w_index in enumerate(view_w): if space.isinstance_w(w_index, space.w_slice): raise IndexError - idx = int_w(space, w_index) + idx = support.int_w(space, w_index) if idx < 0: idx = self.get_shape()[i] + idx if idx < 0 or idx >= self.get_shape()[i]: @@ -159,7 +153,7 @@ return self._lookup_by_index(space, view_w) if shape_len > 1: raise IndexError - idx = int_w(space, w_idx) + idx = support.int_w(space, w_idx) return self._lookup_by_index(space, [space.wrap(idx)]) @jit.unroll_safe diff --git a/pypy/module/micronumpy/constants.py b/pypy/module/micronumpy/constants.py new file mode 100644 --- /dev/null +++ b/pypy/module/micronumpy/constants.py @@ -0,0 +1,4 @@ + +MODE_WRAP, MODE_RAISE, MODE_CLIP = range(3) + +MODES = {'wrap': MODE_WRAP, 'raise': MODE_RAISE, 'clip': MODE_CLIP} diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py --- a/pypy/module/micronumpy/interp_arrayops.py +++ b/pypy/module/micronumpy/interp_arrayops.py @@ -3,6 +3,7 @@ from pypy.module.micronumpy import loop, interp_ufuncs from pypy.module.micronumpy.iter import Chunk, Chunks from pypy.module.micronumpy.strides import shape_agreement +from pypy.module.micronumpy.constants import MODES from pypy.interpreter.error import OperationError, operationerrfmt from pypy.interpreter.gateway import unwrap_spec @@ -153,3 +154,28 @@ def count_nonzero(space, w_obj): return space.wrap(loop.count_all_true(convert_to_array(space, w_obj))) + +def choose(space, arr, w_choices, out, mode): + choices = [convert_to_array(space, w_item) for w_item + in space.listview(w_choices)] + if not choices: + raise OperationError(space.w_ValueError, + space.wrap("choices list cannot be empty")) + # find the shape agreement + shape = arr.get_shape() + for choice in choices: + shape = shape_agreement(space, shape, choice) + if out is not None: + shape = shape_agreement(space, shape, out) + # find the correct dtype + dtype = choices[0].get_dtype() + for choice in choices[1:]: + dtype = interp_ufuncs.find_binop_result_dtype(space, + dtype, choice.get_dtype()) + if out is None: + out = W_NDimArray.from_shape(shape, dtype) + if mode not in MODES: + raise OperationError(space.w_ValueError, + space.wrap("mode %s not known" % (mode,))) + loop.choose(space, arr, choices, shape, dtype, out, MODES[mode]) + return out diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -4,7 +4,8 @@ from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault from pypy.module.micronumpy.base import W_NDimArray, convert_to_array,\ ArrayArgumentException -from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes +from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes,\ + interp_arrayops from pypy.module.micronumpy.strides import find_shape_and_elems,\ get_shape_from_iterable, to_coords, shape_agreement from pypy.module.micronumpy.interp_flatiter import W_FlatIterator @@ -402,9 +403,11 @@ return res @unwrap_spec(mode=str) - def descr_choose(self, space, w_choices, w_out=None, mode='raise'): - raise OperationError(space.w_NotImplementedError, space.wrap( - "choose not implemented yet")) + def descr_choose(self, space, w_choices, mode='raise', w_out=None): + if w_out is not None and not isinstance(w_out, W_NDimArray): + raise OperationError(space.w_TypeError, space.wrap( + "return arrays must be of ArrayType")) + return interp_arrayops.choose(space, self, w_choices, w_out, mode) def descr_clip(self, space, w_min, w_max, w_out=None): raise OperationError(space.w_NotImplementedError, space.wrap( diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -4,11 +4,14 @@ over all the array elements. """ +from pypy.interpreter.error import OperationError from pypy.rlib.rstring import StringBuilder from pypy.rlib import jit from pypy.rpython.lltypesystem import lltype, rffi from pypy.module.micronumpy.base import W_NDimArray from pypy.module.micronumpy.iter import PureShapeIterator +from pypy.module.micronumpy import constants +from pypy.module.micronumpy.support import int_w call2_driver = jit.JitDriver(name='numpy_call2', greens = ['shapelen', 'func', 'calc_dtype', @@ -486,3 +489,36 @@ to_iter.setitem(dtype.itemtype.byteswap(from_iter.getitem())) to_iter.next() from_iter.next() + +choose_driver = jit.JitDriver(greens = ['shapelen', 'mode', 'dtype'], + reds = ['shape', 'iterators', 'arr_iter', + 'out_iter']) + +def choose(space, arr, choices, shape, dtype, out, mode): + shapelen = len(shape) + iterators = [a.create_iter(shape) for a in choices] + arr_iter = arr.create_iter(shape) + out_iter = out.create_iter(shape) + while not arr_iter.done(): + choose_driver.jit_merge_point(shapelen=shapelen, dtype=dtype, + mode=mode, shape=shape, + iterators=iterators, arr_iter=arr_iter, + out_iter=out_iter) + index = int_w(space, arr_iter.getitem()) + if index < 0 or index >= len(iterators): + if mode == constants.MODE_RAISE: + raise OperationError(space.w_ValueError, space.wrap( + "invalid entry in choice array")) + elif mode == constants.MODE_WRAP: + index = index % (len(iterators)) + else: + assert mode == constants.MODE_CLIP + if index < 0: + index = 0 + else: + index = len(iterators) - 1 + out_iter.setitem(iterators[index].getitem().convert_to(dtype)) + for iter in iterators: + iter.next() + out_iter.next() + arr_iter.next() diff --git a/pypy/module/micronumpy/support.py b/pypy/module/micronumpy/support.py --- a/pypy/module/micronumpy/support.py +++ b/pypy/module/micronumpy/support.py @@ -1,5 +1,11 @@ from pypy.rlib import jit +from pypy.interpreter.error import OperationError +def int_w(space, w_obj): + try: + return space.int_w(space.index(w_obj)) + except OperationError: + return space.int_w(space.int(w_obj)) @jit.unroll_safe def product(s): diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py --- a/pypy/module/micronumpy/test/test_arrayops.py +++ b/pypy/module/micronumpy/test/test_arrayops.py @@ -108,8 +108,16 @@ a, b, c = array([1, 2, 3]), [4, 5, 6], 13 raises(ValueError, "array([3, 1, 0]).choose([a, b, c])") raises(ValueError, "array([3, 1, 0]).choose([a, b, c], 'raises')") + raises(ValueError, "array([3, 1, 0]).choose([])") + raises(ValueError, "array([-1, -2, -3]).choose([a, b, c])") r = array([4, 1, 0]).choose([a, b, c], mode='clip') assert (r == [13, 5, 3]).all() r = array([4, 1, 0]).choose([a, b, c], mode='wrap') assert (r == [4, 5, 3]).all() - + + + def test_choose_dtype(self): + from _numpypy import array + a, b, c = array([1.2, 2, 3]), [4, 5, 6], 13 + r = array([2, 1, 0]).choose([a, b, c]) + assert r.dtype == float _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit