Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r73877:0462e4a83ff1 Date: 2014-10-09 18:03 -0400 http://bitbucket.org/pypy/pypy/changeset/0462e4a83ff1/
Log: implement searchsorted in rpython with jitdriver diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py --- a/pypy/module/micronumpy/compile.py +++ b/pypy/module/micronumpy/compile.py @@ -36,7 +36,7 @@ SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any", "unegative", "flat", "tostring","count_nonzero", "argsort"] -TWO_ARG_FUNCTIONS = ["dot", 'take'] +TWO_ARG_FUNCTIONS = ["dot", 'take', 'searchsorted'] TWO_ARG_FUNCTIONS_OR_NONE = ['view', 'astype'] THREE_ARG_FUNCTIONS = ['where'] @@ -109,6 +109,9 @@ if stop < 0: stop += size + 1 if step < 0: + start, stop = stop, start + start -= 1 + stop -= 1 lgt = (stop - start + 1) / step + 1 else: lgt = (stop - start - 1) / step + 1 @@ -475,7 +478,6 @@ class SliceConstant(Node): def __init__(self, start, stop, step): - # no negative support for now self.start = start self.stop = stop self.step = step @@ -582,6 +584,9 @@ w_res = arr.descr_dot(interp.space, arg) elif self.name == 'take': w_res = arr.descr_take(interp.space, arg) + elif self.name == "searchsorted": + w_res = arr.descr_searchsorted(interp.space, arg, + interp.space.wrap('left')) else: assert False # unreachable code elif self.name in THREE_ARG_FUNCTIONS: diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -700,3 +700,43 @@ out_iter.setitem(out_state, arr.getitem_index(space, indexes)) iter.next() out_state = out_iter.next(out_state) + +def _new_binsearch(side, op_name): + binsearch_driver = jit.JitDriver(name='numpy_binsearch_' + side, + greens=['dtype'], + reds='auto') + + def binsearch(space, arr, key, ret): + assert len(arr.get_shape()) == 1 + dtype = key.get_dtype() + op = getattr(dtype.itemtype, op_name) + key_iter, key_state = key.create_iter() + ret_iter, ret_state = ret.create_iter() + ret_iter.track_index = False + size = arr.get_size() + min_idx = 0 + max_idx = size + last_key_val = key_iter.getitem(key_state) + while not key_iter.done(key_state): + key_val = key_iter.getitem(key_state) + if dtype.itemtype.lt(last_key_val, key_val): + max_idx = size + else: + min_idx = 0 + max_idx = max_idx + 1 if max_idx < size else size + last_key_val = key_val + while min_idx < max_idx: + binsearch_driver.jit_merge_point(dtype=dtype) + mid_idx = min_idx + ((max_idx - min_idx) >> 1) + mid_val = arr.getitem(space, [mid_idx]).convert_to(space, dtype) + if op(mid_val, key_val): + min_idx = mid_idx + 1 + else: + max_idx = mid_idx + ret_iter.setitem(ret_state, ret.get_dtype().box(min_idx)) + ret_state = ret_iter.next(ret_state) + key_state = key_iter.next(key_state) + return binsearch + +binsearch_left = _new_binsearch('left', 'lt') +binsearch_right = _new_binsearch('right', 'le') diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py --- a/pypy/module/micronumpy/ndarray.py +++ b/pypy/module/micronumpy/ndarray.py @@ -20,7 +20,6 @@ from pypy.module.micronumpy.flagsobj import W_FlagsObject from pypy.module.micronumpy.strides import get_shape_from_iterable, \ shape_agreement, shape_agreement_multiple -from .selection import app_searchsort def _match_dot_shapes(space, left, right): @@ -740,7 +739,11 @@ v = convert_to_array(space, w_v) ret = W_NDimArray.from_shape( space, v.get_shape(), descriptor.get_dtype_cache(space).w_longdtype) - app_searchsort(space, self, v, space.wrap(side), ret) + if side == NPY.SEARCHLEFT: + binsearch = loop.binsearch_left + else: + binsearch = loop.binsearch_right + binsearch(space, self, v, ret) if ret.is_scalar(): return ret.get_scalar_value() return ret diff --git a/pypy/module/micronumpy/selection.py b/pypy/module/micronumpy/selection.py --- a/pypy/module/micronumpy/selection.py +++ b/pypy/module/micronumpy/selection.py @@ -1,5 +1,4 @@ from pypy.interpreter.error import oefmt -from pypy.interpreter.gateway import applevel from rpython.rlib.listsort import make_timsort_class from rpython.rlib.objectmodel import specialize from rpython.rlib.rarithmetic import widen @@ -354,39 +353,3 @@ cache[cls] = make_sort_function(space, cls, it) self.cache = cache self._lookup = specialize.memo()(lambda tp: cache[tp[0]]) - - -app_searchsort = applevel(r""" - import operator - - def searchsort(arr, val, side, res): - val = val.flat - res = res.flat - if side == 0: - op = operator.lt - else: - op = operator.le - - size = arr.size - imin = 0 - imax = size - try: - last = val[0] - except IndexError: - return - for i in xrange(len(val)): - key = val[i] - if last < key: - imax = size - else: - imin = 0 - imax = imax + 1 if imax < size else size - last = key - while imin < imax: - imid = imin + ((imax - imin) >> 1) - if op(arr[imid], key): - imin = imid + 1 - else: - imax = imid - res[i] = imin -""", filename=__file__).interphook('searchsort') diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py --- a/pypy/module/micronumpy/test/test_compile.py +++ b/pypy/module/micronumpy/test/test_compile.py @@ -330,3 +330,12 @@ results = interp.results[0] assert isinstance(results, W_NDimArray) assert results.get_dtype().is_int() + + def test_searchsorted(self): + interp = self.run(''' + a = [1, 4, 5, 6, 9] + b = |30| -> ::-1 + c = searchsorted(a, b) + c -> -1 + ''') + assert interp.results[0].value == 0 diff --git a/pypy/module/micronumpy/test/test_selection.py b/pypy/module/micronumpy/test/test_selection.py --- a/pypy/module/micronumpy/test/test_selection.py +++ b/pypy/module/micronumpy/test/test_selection.py @@ -382,6 +382,9 @@ assert ret == 3 assert isinstance(ret, np.generic) + assert a.searchsorted(3.1) == 3 + assert a.searchsorted(3.9) == 3 + exc = raises(ValueError, a.searchsorted, 3, side=None) assert str(exc.value) == "expected nonempty string for keyword 'side'" exc = raises(ValueError, a.searchsorted, 3, side='') diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py --- a/pypy/module/micronumpy/test/test_zjit.py +++ b/pypy/module/micronumpy/test/test_zjit.py @@ -51,7 +51,9 @@ w_res = i.getitem(s) if isinstance(w_res, boxes.W_Float64Box): return w_res.value - if isinstance(w_res, boxes.W_Int64Box): + elif isinstance(w_res, boxes.W_Int64Box): + return float(w_res.value) + elif isinstance(w_res, boxes.W_LongBox): return float(w_res.value) elif isinstance(w_res, boxes.W_BoolBox): return float(w_res.value) @@ -660,3 +662,30 @@ 'raw_load': 2, 'raw_store': 1, }) + + def define_searchsorted(): + return """ + a = [1, 4, 5, 6, 9] + b = |30| -> ::-1 + c = searchsorted(a, b) + c -> -1 + """ + + def test_searchsorted(self): + result = self.run("searchsorted") + assert result == 0 + self.check_trace_count(6) + self.check_simple_loop({ + 'float_lt': 1, + 'guard_false': 2, + 'guard_not_invalidated': 1, + 'guard_true': 2, + 'int_add': 3, + 'int_ge': 1, + 'int_lt': 2, + 'int_mul': 1, + 'int_rshift': 1, + 'int_sub': 1, + 'jump': 1, + 'raw_load': 1, + }) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit