Author: Stefan H. Muller <shmuell...@gmail.com> Branch: pypy-pyarray Changeset: r66350:b70301c90922 Date: 2013-08-12 23:49 +0200 http://bitbucket.org/pypy/pypy/changeset/b70301c90922/
Log: Put split nonzero() between scalar.py, concrete.py and loops.py. - Separate implementations for 1D and ND case. Try to reunify in next commit. diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py --- a/pypy/module/micronumpy/arrayimpl/concrete.py +++ b/pypy/module/micronumpy/arrayimpl/concrete.py @@ -279,6 +279,22 @@ return W_NDimArray.new_slice(space, self.start, strides, backstrides, shape, self, orig_arr) + def nonzero(self, space, index_type): + s = loop.count_all_true_concrete(self) + box = index_type.itemtype.box + nd = len(self.shape) + + if nd == 1: + w_res = W_NDimArray.from_shape(space, [s], index_type) + loop.nonzero_onedim(w_res, self, box) + return space.newtuple([w_res]) + else: + w_res = W_NDimArray.from_shape(space, [s, nd], index_type) + loop.nonzero_multidim(w_res, self, box) + w_res = w_res.implementation.swapaxes(space, w_res, 0, 1) + l_w = [w_res.descr_getitem(space, space.wrap(d)) for d in range(nd)] + return space.newtuple(l_w) + def get_storage_as_int(self, space): return rffi.cast(lltype.Signed, self.storage) + self.start diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py b/pypy/module/micronumpy/arrayimpl/scalar.py --- a/pypy/module/micronumpy/arrayimpl/scalar.py +++ b/pypy/module/micronumpy/arrayimpl/scalar.py @@ -155,6 +155,13 @@ def swapaxes(self, space, orig_array, axis1, axis2): raise Exception("should not be called") + def nonzero(self, space, index_type): + s = self.dtype.itemtype.bool(self.value) + w_res = W_NDimArray.from_shape(space, [s], index_type) + if s == 1: + w_res.implementation.setitem(0, index_type.itemtype.box(0)) + return space.newtuple([w_res]) + def fill(self, w_value): self.value = w_value diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -332,6 +332,10 @@ return self return self.implementation.swapaxes(space, self, axis1, axis2) + def descr_nonzero(self, space): + index_type = interp_dtype.get_dtype_cache(space).w_int64dtype + return self.implementation.nonzero(space, index_type) + def descr_tolist(self, space): if len(self.get_shape()) == 0: return self.get_scalar_value().item(space) @@ -351,37 +355,6 @@ "order not implemented")) return self.descr_reshape(space, [space.wrap(-1)]) - def descr_nonzero(self, space): - s = loop.count_all_true(self) - index_type = interp_dtype.get_dtype_cache(space).w_int64dtype - box = index_type.itemtype.box - - if self.is_scalar(): - w_res = W_NDimArray.from_shape(space, [s], index_type) - if s == 1: - w_res.implementation.setitem(0, box(0)) - return space.newtuple([w_res]) - - impl = self.implementation - arr_iter = iter.MultiDimViewIterator(impl, impl.dtype, 0, - impl.strides, impl.backstrides, impl.shape) - - nd = len(impl.shape) - w_res = W_NDimArray.from_shape(space, [s, nd], index_type) - res_iter = w_res.create_iter() - - dims = range(nd) - while not arr_iter.done(): - if arr_iter.getitem_bool(): - for d in dims: - res_iter.setitem(box(arr_iter.indexes[d])) - res_iter.next() - arr_iter.next() - - w_res = w_res.implementation.swapaxes(space, w_res, 0, 1) - l_w = [w_res.descr_getitem(space, space.wrap(d)) for d in dims] - return space.newtuple(l_w) - def descr_take(self, space, w_obj, w_axis=None, w_out=None): # if w_axis is None and w_out is Nont this is an equivalent to # fancy indexing @@ -1101,11 +1074,11 @@ tolist = interp2app(W_NDimArray.descr_tolist), flatten = interp2app(W_NDimArray.descr_flatten), ravel = interp2app(W_NDimArray.descr_ravel), - nonzero = interp2app(W_NDimArray.descr_nonzero), take = interp2app(W_NDimArray.descr_take), compress = interp2app(W_NDimArray.descr_compress), repeat = interp2app(W_NDimArray.descr_repeat), swapaxes = interp2app(W_NDimArray.descr_swapaxes), + nonzero = interp2app(W_NDimArray.descr_nonzero), flat = GetSetProperty(W_NDimArray.descr_get_flatiter), item = interp2app(W_NDimArray.descr_item), real = GetSetProperty(W_NDimArray.descr_get_real, diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -9,7 +9,8 @@ from rpython.rlib import jit from rpython.rtyper.lltypesystem import lltype, rffi from pypy.module.micronumpy.base import W_NDimArray -from pypy.module.micronumpy.iter import PureShapeIterator +from pypy.module.micronumpy.iter import PureShapeIterator, OneDimViewIterator, \ + MultiDimViewIterator from pypy.module.micronumpy import constants from pypy.module.micronumpy.support import int_w @@ -323,19 +324,61 @@ greens = ['shapelen', 'dtype'], reds = 'auto') -def count_all_true(arr): +def count_all_true_concrete(impl): s = 0 - if arr.is_scalar(): - return arr.get_dtype().itemtype.bool(arr.get_scalar_value()) - iter = arr.create_iter() - shapelen = len(arr.get_shape()) - dtype = arr.get_dtype() + iter = impl.create_iter() + shapelen = len(impl.shape) + dtype = impl.dtype while not iter.done(): count_all_true_driver.jit_merge_point(shapelen=shapelen, dtype=dtype) s += iter.getitem_bool() iter.next() return s +def count_all_true(arr): + if arr.is_scalar(): + return arr.get_dtype().itemtype.bool(arr.get_scalar_value()) + else: + return count_all_true_concrete(arr.implementation) + +nonzero_driver_onedim = jit.JitDriver(name = 'numpy_nonzero_onedim', + greens = ['shapelen', 'dtype'], + reds = 'auto') + +def nonzero_onedim(res, arr, box): + res_iter = res.create_iter() + arr_iter = OneDimViewIterator(arr, arr.dtype, 0, + arr.strides, arr.shape) + shapelen = 1 + dtype = arr.dtype + while not arr_iter.done(): + nonzero_driver_onedim.jit_merge_point(shapelen=shapelen, dtype=dtype) + if arr_iter.getitem_bool(): + res_iter.setitem(box(arr_iter.index)) + res_iter.next() + arr_iter.next() + return res + +nonzero_driver_multidim = jit.JitDriver(name = 'numpy_nonzero_onedim', + greens = ['shapelen', 'dims', 'dtype'], + reds = 'auto') + +def nonzero_multidim(res, arr, box): + res_iter = res.create_iter() + arr_iter = MultiDimViewIterator(arr, arr.dtype, 0, + arr.strides, arr.backstrides, arr.shape) + shapelen = len(arr.shape) + dtype = arr.dtype + dims = range(shapelen) + while not arr_iter.done(): + nonzero_driver_multidim.jit_merge_point(shapelen=shapelen, dims=dims, dtype=dtype) + if arr_iter.getitem_bool(): + for d in dims: + res_iter.setitem(box(arr_iter.indexes[d])) + res_iter.next() + arr_iter.next() + return res + getitem_filter_driver = jit.JitDriver(name = 'numpy_getitem_bool', greens = ['shapelen', 'arr_dtype', 'index_dtype'], diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -1348,7 +1348,7 @@ for i in xrange(5): assert c[i] == func(b[i], 3) - def test_nonzero(self): + def test___nonzero__(self): from numpypy import array a = array([1, 2]) raises(ValueError, bool, a) @@ -2306,11 +2306,14 @@ assert nz[0].size == 0 nz = array(2).nonzero() - assert (nz[0] == array([0])).all() + assert (nz[0] == [0]).all() + + nz = array([1, 0, 3]).nonzero() + assert (nz[0] == [0, 2]).all() nz = array([[1, 0, 3], [2, 0, 4]]).nonzero() - assert (nz[0] == array([0, 0, 1, 1])).all() - assert (nz[1] == array([0, 2, 0, 2])).all() + assert (nz[0] == [0, 0, 1, 1]).all() + assert (nz[1] == [0, 2, 0, 2]).all() def test_take(self): from numpypy import arange _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit