Author: Antonio Cuni <[email protected]>
Branch: 
Changeset: r65996:891ccccca71d
Date: 2013-08-07 16:17 +0200
http://bitbucket.org/pypy/pypy/changeset/891ccccca71d/

Log:    don't include 'identity' in the greens of numpy_axis_reduce: it is
        useless because it is used only at the first iteration of the loop,
        and bad because we get a different instance of W_*Box every time we
        run it, which means that we compile the same loop again and again

diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -215,8 +215,7 @@
 
 axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
                                     greens=['shapelen',
-                                            'func', 'dtype',
-                                            'identity'],
+                                            'func', 'dtype'],
                                     reds='auto')
 
 def do_axis_reduce(shape, func, arr, dtype, axis, out, identity, cumultative,
@@ -232,8 +231,7 @@
     shapelen = len(shape)
     while not out_iter.done():
         axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
-                                            dtype=dtype, identity=identity,
-                                            )
+                                            dtype=dtype)
         w_val = arr_iter.getitem().convert_to(dtype)
         if out_iter.first_line:
             if identity is not None:
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -6,7 +6,7 @@
 import py
 from rpython.jit.metainterp import pyjitpl
 from rpython.jit.metainterp.test.support import LLJitMixin
-from rpython.jit.metainterp.warmspot import reset_stats
+from rpython.jit.metainterp.warmspot import reset_stats, get_stats
 from pypy.module.micronumpy import interp_boxes
 from pypy.module.micronumpy.compile import FakeSpace, Parser, InterpreterState
 from pypy.module.micronumpy.base import W_NDimArray
@@ -35,9 +35,10 @@
         cls.code_mapping = d
         cls.codes = allcodes
 
-    def run(self, name):
+    def compile_graph(self):
+        if self.graph is not None:
+            return
         space = FakeSpace()
-        i = self.code_mapping[name]
         codes = self.codes
 
         def f(i):
@@ -57,14 +58,18 @@
             raise TypeError(w_res)
 
         if self.graph is None:
-            interp, graph = self.meta_interp(f, [i],
+            interp, graph = self.meta_interp(f, [0],
                                              listops=True,
                                              backendopt=True,
                                              graph_and_interp_only=True)
             self.__class__.interp = interp
             self.__class__.graph = graph
+
+    def run(self, name):
+        self.compile_graph()
         reset_stats()
         pyjitpl._warmrunnerdesc.memory_manager.alive_loops.clear()
+        i = self.code_mapping[name]
         retval = self.interp.eval_graph(self.graph, [i])
         py.test.skip("don't run for now")
         return retval
@@ -134,6 +139,29 @@
                                 'int_add': 3,
                                 })
 
+    def test_reduce_compile_only_once(self):
+        self.compile_graph()
+        reset_stats()
+        pyjitpl._warmrunnerdesc.memory_manager.alive_loops.clear()
+        i = self.code_mapping['sum']
+        # run it twice
+        retval = self.interp.eval_graph(self.graph, [i])
+        retval = self.interp.eval_graph(self.graph, [i])
+        # check that we got only one loop
+        assert len(get_stats().loops) == 1
+
+    def test_reduce_axis_compile_only_once(self):
+        self.compile_graph()
+        reset_stats()
+        pyjitpl._warmrunnerdesc.memory_manager.alive_loops.clear()
+        i = self.code_mapping['axissum']
+        # run it twice
+        retval = self.interp.eval_graph(self.graph, [i])
+        retval = self.interp.eval_graph(self.graph, [i])
+        # check that we got only one loop
+        assert len(get_stats().loops) == 1
+
+
     def define_prod():
         return """
         a = |30|
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to