Author: Stefan H. Muller <shmuell...@gmail.com> Branch: pypy-pyarray Changeset: r66348:f0b2849dfe98 Date: 2013-08-11 18:59 +0200 http://bitbucket.org/pypy/pypy/changeset/f0b2849dfe98/
Log: Implement ndarray.nonzero() diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -5,7 +5,7 @@ from pypy.module.micronumpy.base import W_NDimArray, convert_to_array,\ ArrayArgumentException, issequence_w, wrap_impl from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes,\ - interp_arrayops + interp_arrayops, iter from pypy.module.micronumpy.strides import find_shape_and_elems,\ get_shape_from_iterable, to_coords, shape_agreement, \ shape_agreement_multiple @@ -351,6 +351,31 @@ "order not implemented")) return self.descr_reshape(space, [space.wrap(-1)]) + def descr_nonzero(self, space): + impl = self.implementation + arr_iter = iter.MultiDimViewIterator(impl, impl.dtype, 0, + impl.strides, impl.backstrides, impl.shape) + + index_type = interp_dtype.get_dtype_cache(space).w_int64dtype + box = index_type.itemtype.box + + nd = len(impl.shape) + s = loop.count_all_true(self) + w_res = W_NDimArray.from_shape(space, [s, nd], index_type) + res_iter = w_res.create_iter() + + dims = range(nd) + while not arr_iter.done(): + if arr_iter.getitem_bool(): + for d in dims: + res_iter.setitem(box(arr_iter.indexes[d])) + res_iter.next() + arr_iter.next() + + w_res = w_res.implementation.swapaxes(space, w_res, 0, 1) + l_w = [w_res.descr_getitem(space, space.wrap(d)) for d in dims] + return space.newtuple(l_w) + def descr_take(self, space, w_obj, w_axis=None, w_out=None): # if w_axis is None and w_out is Nont this is an equivalent to # fancy indexing @@ -707,7 +732,7 @@ descr_conj = _unaryop_impl('conjugate') - def descr_nonzero(self, space): + def descr___nonzero__(self, space): if self.get_size() > 1: raise OperationError(space.w_ValueError, space.wrap( "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()")) @@ -995,7 +1020,7 @@ __neg__ = interp2app(W_NDimArray.descr_neg), __abs__ = interp2app(W_NDimArray.descr_abs), __invert__ = interp2app(W_NDimArray.descr_invert), - __nonzero__ = interp2app(W_NDimArray.descr_nonzero), + __nonzero__ = interp2app(W_NDimArray.descr___nonzero__), __add__ = interp2app(W_NDimArray.descr_add), __sub__ = interp2app(W_NDimArray.descr_sub), @@ -1070,6 +1095,7 @@ tolist = interp2app(W_NDimArray.descr_tolist), flatten = interp2app(W_NDimArray.descr_flatten), ravel = interp2app(W_NDimArray.descr_ravel), + nonzero = interp2app(W_NDimArray.descr_nonzero), take = interp2app(W_NDimArray.descr_take), compress = interp2app(W_NDimArray.descr_compress), repeat = interp2app(W_NDimArray.descr_repeat), diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -2300,6 +2300,13 @@ assert (arange(6).reshape(2, 3).ravel() == arange(6)).all() assert (arange(6).reshape(2, 3).T.ravel() == [0, 3, 1, 4, 2, 5]).all() + def test_nonzero(self): + from numpypy import array + a = array([[1, 0, 3], [2, 0, 4]]) + nz = a.nonzero() + assert (nz[0] == array([0, 0, 1, 1])).all() + assert (nz[1] == array([0, 2, 0, 2])).all() + def test_take(self): from numpypy import arange try: _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit