Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-reintroduce-jit-drivers
Changeset: r57669:8415d52651ee
Date: 2012-09-29 20:13 +0200
http://bitbucket.org/pypy/pypy/changeset/8415d52651ee/
Log: more jitdrivers
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -514,7 +514,7 @@
if self.get_size() == 0:
raise OperationError(space.w_ValueError,
space.wrap("Can't call %s on zero-size arrays" % op_name))
- return space.wrap(loop.argmin_argmax(op_name, self))
+ return space.wrap(getattr(loop, op_name)(self))
return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
descr_argmax = _reduce_argmax_argmin_impl("max")
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -4,7 +4,6 @@
over all the array elements.
"""
-from pypy.rlib.objectmodel import specialize
from pypy.rlib.rstring import StringBuilder
from pypy.rlib import jit
from pypy.rpython.lltypesystem import lltype, rffi
@@ -182,23 +181,41 @@
out_iter.next()
return out
[email protected](0)
-def argmin_argmax(op_name, arr):
- result = 0
- idx = 1
- dtype = arr.get_dtype()
- iter = arr.create_iter(arr.get_shape())
- cur_best = iter.getitem()
- iter.next()
- while not iter.done():
- w_val = iter.getitem()
- new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
- if dtype.itemtype.ne(new_best, cur_best):
- result = idx
- cur_best = new_best
+
+def _new_argmin_argmax(op_name):
+ arg_driver = jit.JitDriver(name='numpy_' + op_name,
+ greens = ['shapelen', 'dtype'],
+ reds = ['result', 'idx', 'cur_best', 'arr',
+ 'iter'])
+
+ def argmin_argmax(arr):
+ result = 0
+ idx = 1
+ dtype = arr.get_dtype()
+ iter = arr.create_iter(arr.get_shape())
+ cur_best = iter.getitem()
iter.next()
- idx += 1
- return result
+ shapelen = len(arr.get_shape())
+ while not iter.done():
+ arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
+ result=result, idx=idx,
+ cur_best=cur_best, arr=arr, iter=iter)
+ w_val = iter.getitem()
+ new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
+ if dtype.itemtype.ne(new_best, cur_best):
+ result = idx
+ cur_best = new_best
+ iter.next()
+ idx += 1
+ return result
+ return argmin_argmax
+argmin = _new_argmin_argmax('argmin')
+argmax = _new_argmin_argmax('argmax')
+
+# note that shapelen == 2 always
+dot_driver = jit.JitDriver(name = 'numpy_dot',
+ greens = ['dtype'],
+ reds = ['outi', 'lefti', 'righti', 'result'])
def multidim_dot(space, left, right, result, dtype, right_critical_dim):
''' assumes left, right are concrete arrays
@@ -225,6 +242,8 @@
lefti = left.create_dot_iter(broadcast_shape, left_skip)
righti = right.create_dot_iter(broadcast_shape, right_skip)
while not outi.done():
+ dot_driver.jit_merge_point(dtype=dtype, outi=outi, lefti=lefti,
+ righti=righti, result=result)
lval = lefti.getitem().convert_to(dtype)
rval = righti.getitem().convert_to(dtype)
outval = outi.getitem().convert_to(dtype)
@@ -236,21 +255,45 @@
lefti.next()
return result
+count_all_true_driver = jit.JitDriver(name = 'numpy_count',
+ greens = ['shapelen', 'dtype'],
+ reds = ['s', 'iter'])
+
def count_all_true(arr):
s = 0
if arr.is_scalar():
return arr.get_dtype().itemtype.bool(arr.get_scalar_value())
iter = arr.create_iter()
+ shapelen = len(arr.get_shape())
+ dtype = arr.get_dtype()
while not iter.done():
+ count_all_true_driver.jit_merge_point(shapelen=shapelen, iter=iter,
+ s=s, dtype=dtype)
s += iter.getitem_bool()
iter.next()
return s
+getitem_filter_driver = jit.JitDriver(name = 'numpy_getitem_bool',
+ greens = ['shapelen', 'arr_dtype',
+ 'index_dtype'],
+ reds = ['res', 'index_iter', 'res_iter',
+ 'arr_iter'])
+
def getitem_filter(res, arr, index):
res_iter = res.create_iter()
index_iter = index.create_iter()
arr_iter = arr.create_iter()
+ shapelen = len(arr.get_shape())
+ arr_dtype = arr.get_dtype()
+ index_dtype = index.get_dtype()
+ # XXX length of shape of index as well?
while not index_iter.done():
+ getitem_filter_driver.jit_merge_point(shapelen=shapelen,
+ index_dtype=index_dtype,
+ arr_dtype=arr_dtype,
+ res=res, index_iter=index_iter,
+ res_iter=res_iter,
+ arr_iter=arr_iter)
if index_iter.getitem_bool():
res_iter.setitem(arr_iter.getitem())
res_iter.next()
@@ -258,31 +301,63 @@
arr_iter.next()
return res
+setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool',
+ greens = ['shapelen', 'arr_dtype',
+ 'index_dtype'],
+ reds = ['index_iter', 'value_iter',
+ 'arr_iter'])
+
def setitem_filter(arr, index, value):
arr_iter = arr.create_iter()
index_iter = index.create_iter()
value_iter = value.create_iter()
+ shapelen = len(arr.get_shape())
+ index_dtype = index.get_dtype()
+ arr_dtype = arr.get_dtype()
while not index_iter.done():
+ setitem_filter_driver.jit_merge_point(shapelen=shapelen,
+ index_dtype=index_dtype,
+ arr_dtype=arr_dtype,
+ index_iter=index_iter,
+ value_iter=value_iter,
+ arr_iter=arr_iter)
if index_iter.getitem_bool():
arr_iter.setitem(value_iter.getitem())
value_iter.next()
arr_iter.next()
index_iter.next()
+flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem',
+ greens = ['dtype'],
+ reds = ['step', 'ri', 'res',
+ 'base_iter'])
+
def flatiter_getitem(res, base_iter, step):
ri = res.create_iter()
+ dtype = res.get_dtype()
while not ri.done():
+ flatiter_getitem_driver.jit_merge_point(dtype=dtype,
+ base_iter=base_iter,
+ ri=ri, res=res, step=step)
ri.setitem(base_iter.getitem())
base_iter.next_skip_x(step)
ri.next()
return res
+flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem',
+ greens = ['dtype'],
+ reds = ['length', 'step', 'arr_iter',
+ 'val_iter'])
+
def flatiter_setitem(arr, val, start, step, length):
dtype = arr.get_dtype()
arr_iter = arr.create_iter()
val_iter = val.create_iter()
arr_iter.next_skip_x(start)
while length > 0:
+ flatiter_setitem_driver.jit_merge_point(dtype=dtype, length=length,
+ step=step, arr_iter=arr_iter,
+ val_iter=val_iter)
arr_iter.setitem(val_iter.getitem().convert_to(dtype))
# need to repeat i_nput values until all assignments are done
arr_iter.next_skip_x(step)
@@ -291,10 +366,15 @@
# WTF numpy?
val_iter.reset()
+fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
+ greens = ['dtype'],
+ reds = ['s', 'ai', 'i'])
+
def fromstring_loop(a, dtype, itemsize, s):
i = 0
ai = a.create_iter()
while not ai.done():
+ fromstring_driver.jit_merge_point(dtype=dtype, s=s, ai=ai, i=i)
val = dtype.itemtype.runpack_str(s[i*itemsize:i*itemsize + itemsize])
ai.setitem(val)
ai.next()
@@ -345,9 +425,21 @@
def get_index(self, space):
return [space.wrap(i) for i in self.indexes]
+getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
+ greens = ['shapelen', 'indexlen', 'dtype'],
+ reds = ['arr', 'res', 'iter', 'indexes_w',
+ 'prefix_w'])
+
def getitem_array_int(space, arr, res, iter_shape, indexes_w, prefix_w):
+ shapelen = len(iter_shape)
+ indexlen = len(indexes_w)
+ dtype = arr.get_dtype()
iter = PureShapeIterator(iter_shape, indexes_w)
while not iter.done():
+ getitem_int_driver.jit_merge_point(shapelen=shapelen,
indexlen=indexlen,
+ dtype=dtype, arr=arr, res=res,
+ iter=iter, indexes_w=indexes_w,
+ prefix_w=prefix_w)
# prepare the index
index_w = [None] * len(indexes_w)
for i in range(len(indexes_w)):
@@ -361,10 +453,22 @@
iter.next()
return res
+setitem_int_driver = jit.JitDriver(name = 'numpy_setitem_int',
+ greens = ['shapelen', 'indexlen', 'dtype'],
+ reds = ['arr', 'iter', 'indexes_w',
+ 'prefix_w', 'val_arr'])
+
def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr,
prefix_w):
+ shapelen = len(iter_shape)
+ indexlen = len(indexes_w)
+ dtype = arr.get_dtype()
iter = PureShapeIterator(iter_shape, indexes_w)
while not iter.done():
+ setitem_int_driver.jit_merge_point(shapelen=shapelen,
indexlen=indexlen,
+ dtype=dtype, arr=arr,
+ iter=iter, indexes_w=indexes_w,
+ prefix_w=prefix_w, val_arr=val_arr)
# prepare the index
index_w = [None] * len(indexes_w)
for i in range(len(indexes_w)):
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit