Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r73868:188aa764d2e7 Date: 2014-10-09 15:18 -0400 http://bitbucket.org/pypy/pypy/changeset/188aa764d2e7/
Log: fix searchsorted with multidim targets diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py --- a/pypy/module/micronumpy/ndarray.py +++ b/pypy/module/micronumpy/ndarray.py @@ -738,8 +738,6 @@ if len(self.get_shape()) > 1: raise oefmt(space.w_ValueError, "a must be a 1-d array") v = convert_to_array(space, w_v) - if len(v.get_shape()) > 1: - raise oefmt(space.w_ValueError, "v must be a 1-d array-like") ret = W_NDimArray.from_shape( space, v.get_shape(), descriptor.get_dtype_cache(space).w_longdtype) app_searchsort(space, self, v, space.wrap(side), ret) diff --git a/pypy/module/micronumpy/selection.py b/pypy/module/micronumpy/selection.py --- a/pypy/module/micronumpy/selection.py +++ b/pypy/module/micronumpy/selection.py @@ -375,9 +375,8 @@ op = operator.lt else: op = operator.le - if v.size < 2: - result[...] = _searchsort(a, op, v) - else: - for i in range(v.size): - result[i] = _searchsort(a, op, v[i]) + v = v.flat + result = result.flat + for i in xrange(len(v)): + result[i] = _searchsort(a, op, v[i]) """, filename=__file__).interphook('searchsort') diff --git a/pypy/module/micronumpy/test/test_selection.py b/pypy/module/micronumpy/test/test_selection.py --- a/pypy/module/micronumpy/test/test_selection.py +++ b/pypy/module/micronumpy/test/test_selection.py @@ -354,25 +354,36 @@ import numpy as np import sys a = np.arange(1, 6) + ret = a.searchsorted(3) assert ret == 2 assert isinstance(ret, np.generic) + ret = a.searchsorted(np.array(3)) assert ret == 2 assert isinstance(ret, np.generic) + ret = a.searchsorted(np.array([3])) assert ret == 2 assert isinstance(ret, np.ndarray) + + ret = a.searchsorted(np.array([[2, 3]])) + assert (ret == [1, 2]).all() + assert ret.shape == (1, 2) + ret = a.searchsorted(3, side='right') assert ret == 3 assert isinstance(ret, np.generic) + exc = raises(ValueError, a.searchsorted, 3, side=None) assert str(exc.value) == "expected nonempty string for keyword 'side'" exc = raises(ValueError, a.searchsorted, 3, side='') assert str(exc.value) == "expected nonempty string for keyword 'side'" exc = raises(ValueError, a.searchsorted, 3, side=2) assert str(exc.value) == "expected nonempty string for keyword 'side'" + ret = a.searchsorted([-10, 10, 2, 3]) assert (ret == [0, 5, 1, 2]).all() + if '__pypy__' in sys.builtin_module_names: raises(NotImplementedError, "a.searchsorted(3, sorter=range(6))") _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit