Author: Antonio Cuni <[email protected]>
Branch: autoreds
Changeset: r58957:9db0af7d3894
Date: 2012-11-16 17:09 +0100
http://bitbucket.org/pypy/pypy/changeset/9db0af7d3894/
Log: add support for calling the @jitdriver.inline()d function multiple
times
diff --git a/pypy/jit/metainterp/test/test_warmspot.py
b/pypy/jit/metainterp/test/test_warmspot.py
--- a/pypy/jit/metainterp/test/test_warmspot.py
+++ b/pypy/jit/metainterp/test/test_warmspot.py
@@ -412,7 +412,6 @@
assert res == 1000 + 1002
self.check_resops(int_add=4)
-
def test_jitdriver_inline(self):
myjitdriver = JitDriver(greens = [], reds = 'auto')
class MyRange(object):
@@ -445,6 +444,36 @@
self.check_resops(int_eq=2, int_add=4)
self.check_trace_count(1)
+ def test_jitdriver_inline_twice(self):
+ myjitdriver = JitDriver(greens = [], reds = 'auto')
+
+ def jit_merge_point(a, b):
+ myjitdriver.jit_merge_point()
+
+ @myjitdriver.inline(jit_merge_point)
+ def add(a, b):
+ return a+b
+
+ def one(n):
+ res = 0
+ while res < 1000:
+ res = add(n, res)
+ return res
+
+ def two(n):
+ res = 0
+ while res < 2000:
+ res = add(n, res)
+ return res
+
+ def f(n):
+ return one(n) + two(n)
+
+ res = self.meta_interp(f, [1])
+ assert res == 3000
+ self.check_resops(int_add=4)
+ self.check_trace_count(2)
+
class TestLLWarmspot(WarmspotTests, LLJitMixin):
CPUClass = runner.LLtypeCPU
diff --git a/pypy/jit/metainterp/warmspot.py b/pypy/jit/metainterp/warmspot.py
--- a/pypy/jit/metainterp/warmspot.py
+++ b/pypy/jit/metainterp/warmspot.py
@@ -253,6 +253,39 @@
from pypy.translator.backendopt.inline import (
get_funcobj, inlinable_static_callers, auto_inlining)
+ jmp_calls = {}
+ def get_jmp_call(graph, _inline_jit_merge_point_):
+ # there might be multiple calls to the @inlined function: the
+ # first time we see it, we remove the call to the jit_merge_point
+ # and we remember the corresponding op. Then, we create a new call
+ # to it every time we need a new one (i.e., for each callsite
+ # which becomes a new portal)
+ try:
+ op, jmp_graph = jmp_calls[graph]
+ except KeyError:
+ op, jmp_graph = fish_jmp_call(graph, _inline_jit_merge_point_)
+ jmp_calls[graph] = op, jmp_graph
+ #
+ # clone the op
+ newargs = op.args[:]
+ newresult = Variable()
+ newresult.concretetype = op.result.concretetype
+ op = SpaceOperation(op.opname, newargs, newresult)
+ return op, jmp_graph
+
+ def fish_jmp_call(graph, _inline_jit_merge_point_):
+ # graph is function which has been decorated with
+ # @jitdriver.inline, so its very first op is a call to the
+ # function which contains the actual jit_merge_point: fish it!
+ jmp_block, op_jmp_call = next(callee.iterblockops())
+ msg = ("The first operation of an _inline_jit_merge_point_ graph
must be "
+ "a direct_call to the function passed to
@jitdriver.inline()")
+ assert op_jmp_call.opname == 'direct_call', msg
+ jmp_funcobj = get_funcobj(op_jmp_call.args[0].value)
+ assert jmp_funcobj._callable is _inline_jit_merge_point_, msg
+ jmp_block.operations.remove(op_jmp_call)
+ return op_jmp_call, jmp_funcobj.graph
+
# find all the graphs which call an @inline_in_portal function
callgraph = inlinable_static_callers(self.translator.graphs,
store_calls=True)
new_callgraph = []
@@ -261,23 +294,13 @@
func = getattr(callee, 'func', None)
_inline_jit_merge_point_ = getattr(func,
'_inline_jit_merge_point_', None)
if _inline_jit_merge_point_:
- # we are calling a function which has been decorated with
- # @jitdriver.inline: the very first op of the callee graph is
- # a call to the function which contains the actual
- # jit_merge_point: fish it!
- jmp_block, op_jmp_call = next(callee.iterblockops())
- msg = ("The first operation of an _inline_jit_merge_point_
graph must be "
- "a direct_call to the function passed to
@jitdriver.inline()")
- assert op_jmp_call.opname == 'direct_call', msg
- jmp_funcobj = get_funcobj(op_jmp_call.args[0].value)
- assert jmp_funcobj._callable is _inline_jit_merge_point_, msg
+ op_jmp_call, jmp_graph = get_jmp_call(callee,
_inline_jit_merge_point_)
#
# now we move the op_jmp_call from callee to caller, just
# before op_call. We assume that the args passed to
# op_jmp_call are the very same which are received by callee
# (i.e., the one passed to op_call)
assert len(op_call.args) == len(op_jmp_call.args)
- jmp_block.operations.remove(op_jmp_call)
op_jmp_call.args[1:] = op_call.args[1:]
idx = block.operations.index(op_call)
block.operations.insert(idx, op_jmp_call)
@@ -285,7 +308,7 @@
# finally, we signal that we want to inline op_jmp_call into
# caller, so that finally the actuall call to
# driver.jit_merge_point will be seen there
- new_callgraph.append((caller, jmp_funcobj.graph))
+ new_callgraph.append((caller, jmp_graph))
new_portals.add(caller)
# inline them!
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit