Author: Maciej Fijalkowski <fij...@gmail.com>
Branch: numpy-reintroduce-jit-drivers
Changeset: r57669:8415d52651ee
Date: 2012-09-29 20:13 +0200
http://bitbucket.org/pypy/pypy/changeset/8415d52651ee/

Log:    more jitdrivers

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -514,7 +514,7 @@
             if self.get_size() == 0:
                 raise OperationError(space.w_ValueError,
                     space.wrap("Can't call %s on zero-size arrays" % op_name))
-            return space.wrap(loop.argmin_argmax(op_name, self))
+            return space.wrap(getattr(loop, op_name)(self))
         return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
 
     descr_argmax = _reduce_argmax_argmin_impl("max")
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -4,7 +4,6 @@
 over all the array elements.
 """
 
-from pypy.rlib.objectmodel import specialize
 from pypy.rlib.rstring import StringBuilder
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
@@ -182,23 +181,41 @@
         out_iter.next()
     return out
 
-@specialize.arg(0)
-def argmin_argmax(op_name, arr):
-    result = 0
-    idx = 1
-    dtype = arr.get_dtype()
-    iter = arr.create_iter(arr.get_shape())
-    cur_best = iter.getitem()
-    iter.next()
-    while not iter.done():
-        w_val = iter.getitem()
-        new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
-        if dtype.itemtype.ne(new_best, cur_best):
-            result = idx
-            cur_best = new_best
+
+def _new_argmin_argmax(op_name):
+    arg_driver = jit.JitDriver(name='numpy_' + op_name,
+                               greens = ['shapelen', 'dtype'],
+                               reds = ['result', 'idx', 'cur_best', 'arr',
+                                       'iter'])
+    
+    def argmin_argmax(arr):
+        result = 0
+        idx = 1
+        dtype = arr.get_dtype()
+        iter = arr.create_iter(arr.get_shape())
+        cur_best = iter.getitem()
         iter.next()
-        idx += 1
-    return result
+        shapelen = len(arr.get_shape())
+        while not iter.done():
+            arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
+                                       result=result, idx=idx,
+                                       cur_best=cur_best, arr=arr, iter=iter)
+            w_val = iter.getitem()
+            new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
+            if dtype.itemtype.ne(new_best, cur_best):
+                result = idx
+                cur_best = new_best
+            iter.next()
+            idx += 1
+        return result
+    return argmin_argmax
+argmin = _new_argmin_argmax('argmin')
+argmax = _new_argmin_argmax('argmax')
+
+# note that shapelen == 2 always
+dot_driver = jit.JitDriver(name = 'numpy_dot',
+                           greens = ['dtype'],
+                           reds = ['outi', 'lefti', 'righti', 'result'])
 
 def multidim_dot(space, left, right, result, dtype, right_critical_dim):
     ''' assumes left, right are concrete arrays
@@ -225,6 +242,8 @@
     lefti = left.create_dot_iter(broadcast_shape, left_skip)
     righti = right.create_dot_iter(broadcast_shape, right_skip)
     while not outi.done():
+        dot_driver.jit_merge_point(dtype=dtype, outi=outi, lefti=lefti,
+                                   righti=righti, result=result)
         lval = lefti.getitem().convert_to(dtype) 
         rval = righti.getitem().convert_to(dtype) 
         outval = outi.getitem().convert_to(dtype) 
@@ -236,21 +255,45 @@
         lefti.next()
     return result
 
+count_all_true_driver = jit.JitDriver(name = 'numpy_count',
+                                      greens = ['shapelen', 'dtype'],
+                                      reds = ['s', 'iter'])
+
 def count_all_true(arr):
     s = 0
     if arr.is_scalar():
         return arr.get_dtype().itemtype.bool(arr.get_scalar_value())
     iter = arr.create_iter()
+    shapelen = len(arr.get_shape())
+    dtype = arr.get_dtype()
     while not iter.done():
+        count_all_true_driver.jit_merge_point(shapelen=shapelen, iter=iter,
+                                              s=s, dtype=dtype)
         s += iter.getitem_bool()
         iter.next()
     return s
 
+getitem_filter_driver = jit.JitDriver(name = 'numpy_getitem_bool',
+                                      greens = ['shapelen', 'arr_dtype',
+                                                'index_dtype'],
+                                      reds = ['res', 'index_iter', 'res_iter',
+                                              'arr_iter'])
+
 def getitem_filter(res, arr, index):
     res_iter = res.create_iter()
     index_iter = index.create_iter()
     arr_iter = arr.create_iter()
+    shapelen = len(arr.get_shape())
+    arr_dtype = arr.get_dtype()
+    index_dtype = index.get_dtype()
+    # XXX length of shape of index as well?
     while not index_iter.done():
+        getitem_filter_driver.jit_merge_point(shapelen=shapelen,
+                                              index_dtype=index_dtype,
+                                              arr_dtype=arr_dtype,
+                                              res=res, index_iter=index_iter,
+                                              res_iter=res_iter,
+                                              arr_iter=arr_iter)
         if index_iter.getitem_bool():
             res_iter.setitem(arr_iter.getitem())
             res_iter.next()
@@ -258,31 +301,63 @@
         arr_iter.next()
     return res
 
+setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool',
+                                      greens = ['shapelen', 'arr_dtype',
+                                                'index_dtype'],
+                                      reds = ['index_iter', 'value_iter',
+                                              'arr_iter'])
+
 def setitem_filter(arr, index, value):
     arr_iter = arr.create_iter()
     index_iter = index.create_iter()
     value_iter = value.create_iter()
+    shapelen = len(arr.get_shape())
+    index_dtype = index.get_dtype()
+    arr_dtype = arr.get_dtype()
     while not index_iter.done():
+        setitem_filter_driver.jit_merge_point(shapelen=shapelen,
+                                              index_dtype=index_dtype,
+                                              arr_dtype=arr_dtype,
+                                              index_iter=index_iter,
+                                              value_iter=value_iter,
+                                              arr_iter=arr_iter)
         if index_iter.getitem_bool():
             arr_iter.setitem(value_iter.getitem())
             value_iter.next()
         arr_iter.next()
         index_iter.next()
 
+flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem',
+                                        greens = ['dtype'],
+                                        reds = ['step', 'ri', 'res',
+                                                'base_iter'])
+
 def flatiter_getitem(res, base_iter, step):
     ri = res.create_iter()
+    dtype = res.get_dtype()
     while not ri.done():
+        flatiter_getitem_driver.jit_merge_point(dtype=dtype,
+                                                base_iter=base_iter,
+                                                ri=ri, res=res, step=step)
         ri.setitem(base_iter.getitem())
         base_iter.next_skip_x(step)
         ri.next()
     return res
 
+flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem',
+                                        greens = ['dtype'],
+                                        reds = ['length', 'step', 'arr_iter',
+                                                'val_iter'])
+
 def flatiter_setitem(arr, val, start, step, length):
     dtype = arr.get_dtype()
     arr_iter = arr.create_iter()
     val_iter = val.create_iter()
     arr_iter.next_skip_x(start)
     while length > 0:
+        flatiter_setitem_driver.jit_merge_point(dtype=dtype, length=length,
+                                                step=step, arr_iter=arr_iter,
+                                                val_iter=val_iter)
         arr_iter.setitem(val_iter.getitem().convert_to(dtype))
         # need to repeat i_nput values until all assignments are done
         arr_iter.next_skip_x(step)
@@ -291,10 +366,15 @@
         # WTF numpy?
         val_iter.reset()
 
+fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
+                                  greens = ['dtype'],
+                                  reds = ['s', 'ai', 'i'])
+
 def fromstring_loop(a, dtype, itemsize, s):
     i = 0
     ai = a.create_iter()
     while not ai.done():
+        fromstring_driver.jit_merge_point(dtype=dtype, s=s, ai=ai, i=i)
         val = dtype.itemtype.runpack_str(s[i*itemsize:i*itemsize + itemsize])
         ai.setitem(val)
         ai.next()
@@ -345,9 +425,21 @@
     def get_index(self, space):
         return [space.wrap(i) for i in self.indexes]
 
+getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
+                                   greens = ['shapelen', 'indexlen', 'dtype'],
+                                   reds = ['arr', 'res', 'iter', 'indexes_w',
+                                           'prefix_w'])
+
 def getitem_array_int(space, arr, res, iter_shape, indexes_w, prefix_w):
+    shapelen = len(iter_shape)
+    indexlen = len(indexes_w)
+    dtype = arr.get_dtype()
     iter = PureShapeIterator(iter_shape, indexes_w)
     while not iter.done():
+        getitem_int_driver.jit_merge_point(shapelen=shapelen, 
indexlen=indexlen,
+                                           dtype=dtype, arr=arr, res=res,
+                                           iter=iter, indexes_w=indexes_w,
+                                           prefix_w=prefix_w)
         # prepare the index
         index_w = [None] * len(indexes_w)
         for i in range(len(indexes_w)):
@@ -361,10 +453,22 @@
         iter.next()
     return res
 
+setitem_int_driver = jit.JitDriver(name = 'numpy_setitem_int',
+                                   greens = ['shapelen', 'indexlen', 'dtype'],
+                                   reds = ['arr', 'iter', 'indexes_w',
+                                           'prefix_w', 'val_arr'])
+
 def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr,
                       prefix_w):
+    shapelen = len(iter_shape)
+    indexlen = len(indexes_w)
+    dtype = arr.get_dtype()
     iter = PureShapeIterator(iter_shape, indexes_w)
     while not iter.done():
+        setitem_int_driver.jit_merge_point(shapelen=shapelen, 
indexlen=indexlen,
+                                           dtype=dtype, arr=arr,
+                                           iter=iter, indexes_w=indexes_w,
+                                           prefix_w=prefix_w, val_arr=val_arr)
         # prepare the index
         index_w = [None] * len(indexes_w)
         for i in range(len(indexes_w)):
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to