Author: Maciej Fijalkowski <fij...@gmail.com> Branch: numpy-reintroduce-jit-drivers Changeset: r57669:8415d52651ee Date: 2012-09-29 20:13 +0200 http://bitbucket.org/pypy/pypy/changeset/8415d52651ee/
Log: more jitdrivers diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -514,7 +514,7 @@ if self.get_size() == 0: raise OperationError(space.w_ValueError, space.wrap("Can't call %s on zero-size arrays" % op_name)) - return space.wrap(loop.argmin_argmax(op_name, self)) + return space.wrap(getattr(loop, op_name)(self)) return func_with_new_name(impl, "reduce_arg%s_impl" % op_name) descr_argmax = _reduce_argmax_argmin_impl("max") diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -4,7 +4,6 @@ over all the array elements. """ -from pypy.rlib.objectmodel import specialize from pypy.rlib.rstring import StringBuilder from pypy.rlib import jit from pypy.rpython.lltypesystem import lltype, rffi @@ -182,23 +181,41 @@ out_iter.next() return out -@specialize.arg(0) -def argmin_argmax(op_name, arr): - result = 0 - idx = 1 - dtype = arr.get_dtype() - iter = arr.create_iter(arr.get_shape()) - cur_best = iter.getitem() - iter.next() - while not iter.done(): - w_val = iter.getitem() - new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val) - if dtype.itemtype.ne(new_best, cur_best): - result = idx - cur_best = new_best + +def _new_argmin_argmax(op_name): + arg_driver = jit.JitDriver(name='numpy_' + op_name, + greens = ['shapelen', 'dtype'], + reds = ['result', 'idx', 'cur_best', 'arr', + 'iter']) + + def argmin_argmax(arr): + result = 0 + idx = 1 + dtype = arr.get_dtype() + iter = arr.create_iter(arr.get_shape()) + cur_best = iter.getitem() iter.next() - idx += 1 - return result + shapelen = len(arr.get_shape()) + while not iter.done(): + arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype, + result=result, idx=idx, + cur_best=cur_best, arr=arr, iter=iter) + w_val = iter.getitem() + new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val) + if dtype.itemtype.ne(new_best, cur_best): + result = idx + cur_best = new_best + iter.next() + idx += 1 + return result + return argmin_argmax +argmin = _new_argmin_argmax('argmin') +argmax = _new_argmin_argmax('argmax') + +# note that shapelen == 2 always +dot_driver = jit.JitDriver(name = 'numpy_dot', + greens = ['dtype'], + reds = ['outi', 'lefti', 'righti', 'result']) def multidim_dot(space, left, right, result, dtype, right_critical_dim): ''' assumes left, right are concrete arrays @@ -225,6 +242,8 @@ lefti = left.create_dot_iter(broadcast_shape, left_skip) righti = right.create_dot_iter(broadcast_shape, right_skip) while not outi.done(): + dot_driver.jit_merge_point(dtype=dtype, outi=outi, lefti=lefti, + righti=righti, result=result) lval = lefti.getitem().convert_to(dtype) rval = righti.getitem().convert_to(dtype) outval = outi.getitem().convert_to(dtype) @@ -236,21 +255,45 @@ lefti.next() return result +count_all_true_driver = jit.JitDriver(name = 'numpy_count', + greens = ['shapelen', 'dtype'], + reds = ['s', 'iter']) + def count_all_true(arr): s = 0 if arr.is_scalar(): return arr.get_dtype().itemtype.bool(arr.get_scalar_value()) iter = arr.create_iter() + shapelen = len(arr.get_shape()) + dtype = arr.get_dtype() while not iter.done(): + count_all_true_driver.jit_merge_point(shapelen=shapelen, iter=iter, + s=s, dtype=dtype) s += iter.getitem_bool() iter.next() return s +getitem_filter_driver = jit.JitDriver(name = 'numpy_getitem_bool', + greens = ['shapelen', 'arr_dtype', + 'index_dtype'], + reds = ['res', 'index_iter', 'res_iter', + 'arr_iter']) + def getitem_filter(res, arr, index): res_iter = res.create_iter() index_iter = index.create_iter() arr_iter = arr.create_iter() + shapelen = len(arr.get_shape()) + arr_dtype = arr.get_dtype() + index_dtype = index.get_dtype() + # XXX length of shape of index as well? while not index_iter.done(): + getitem_filter_driver.jit_merge_point(shapelen=shapelen, + index_dtype=index_dtype, + arr_dtype=arr_dtype, + res=res, index_iter=index_iter, + res_iter=res_iter, + arr_iter=arr_iter) if index_iter.getitem_bool(): res_iter.setitem(arr_iter.getitem()) res_iter.next() @@ -258,31 +301,63 @@ arr_iter.next() return res +setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool', + greens = ['shapelen', 'arr_dtype', + 'index_dtype'], + reds = ['index_iter', 'value_iter', + 'arr_iter']) + def setitem_filter(arr, index, value): arr_iter = arr.create_iter() index_iter = index.create_iter() value_iter = value.create_iter() + shapelen = len(arr.get_shape()) + index_dtype = index.get_dtype() + arr_dtype = arr.get_dtype() while not index_iter.done(): + setitem_filter_driver.jit_merge_point(shapelen=shapelen, + index_dtype=index_dtype, + arr_dtype=arr_dtype, + index_iter=index_iter, + value_iter=value_iter, + arr_iter=arr_iter) if index_iter.getitem_bool(): arr_iter.setitem(value_iter.getitem()) value_iter.next() arr_iter.next() index_iter.next() +flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem', + greens = ['dtype'], + reds = ['step', 'ri', 'res', + 'base_iter']) + def flatiter_getitem(res, base_iter, step): ri = res.create_iter() + dtype = res.get_dtype() while not ri.done(): + flatiter_getitem_driver.jit_merge_point(dtype=dtype, + base_iter=base_iter, + ri=ri, res=res, step=step) ri.setitem(base_iter.getitem()) base_iter.next_skip_x(step) ri.next() return res +flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem', + greens = ['dtype'], + reds = ['length', 'step', 'arr_iter', + 'val_iter']) + def flatiter_setitem(arr, val, start, step, length): dtype = arr.get_dtype() arr_iter = arr.create_iter() val_iter = val.create_iter() arr_iter.next_skip_x(start) while length > 0: + flatiter_setitem_driver.jit_merge_point(dtype=dtype, length=length, + step=step, arr_iter=arr_iter, + val_iter=val_iter) arr_iter.setitem(val_iter.getitem().convert_to(dtype)) # need to repeat i_nput values until all assignments are done arr_iter.next_skip_x(step) @@ -291,10 +366,15 @@ # WTF numpy? val_iter.reset() +fromstring_driver = jit.JitDriver(name = 'numpy_fromstring', + greens = ['dtype'], + reds = ['s', 'ai', 'i']) + def fromstring_loop(a, dtype, itemsize, s): i = 0 ai = a.create_iter() while not ai.done(): + fromstring_driver.jit_merge_point(dtype=dtype, s=s, ai=ai, i=i) val = dtype.itemtype.runpack_str(s[i*itemsize:i*itemsize + itemsize]) ai.setitem(val) ai.next() @@ -345,9 +425,21 @@ def get_index(self, space): return [space.wrap(i) for i in self.indexes] +getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int', + greens = ['shapelen', 'indexlen', 'dtype'], + reds = ['arr', 'res', 'iter', 'indexes_w', + 'prefix_w']) + def getitem_array_int(space, arr, res, iter_shape, indexes_w, prefix_w): + shapelen = len(iter_shape) + indexlen = len(indexes_w) + dtype = arr.get_dtype() iter = PureShapeIterator(iter_shape, indexes_w) while not iter.done(): + getitem_int_driver.jit_merge_point(shapelen=shapelen, indexlen=indexlen, + dtype=dtype, arr=arr, res=res, + iter=iter, indexes_w=indexes_w, + prefix_w=prefix_w) # prepare the index index_w = [None] * len(indexes_w) for i in range(len(indexes_w)): @@ -361,10 +453,22 @@ iter.next() return res +setitem_int_driver = jit.JitDriver(name = 'numpy_setitem_int', + greens = ['shapelen', 'indexlen', 'dtype'], + reds = ['arr', 'iter', 'indexes_w', + 'prefix_w', 'val_arr']) + def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr, prefix_w): + shapelen = len(iter_shape) + indexlen = len(indexes_w) + dtype = arr.get_dtype() iter = PureShapeIterator(iter_shape, indexes_w) while not iter.done(): + setitem_int_driver.jit_merge_point(shapelen=shapelen, indexlen=indexlen, + dtype=dtype, arr=arr, + iter=iter, indexes_w=indexes_w, + prefix_w=prefix_w, val_arr=val_arr) # prepare the index index_w = [None] * len(indexes_w) for i in range(len(indexes_w)): _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit