Author: Antonio Cuni <[email protected]>
Branch:
Changeset: r58966:29f51cb83169
Date: 2012-11-16 22:48 +0100
http://bitbucket.org/pypy/pypy/changeset/29f51cb83169/
Log: merge again the autoreds branch, which now uses an approach which
seems to work for the upcoming space.iteriterable
diff --git a/pypy/jit/metainterp/test/test_warmspot.py
b/pypy/jit/metainterp/test/test_warmspot.py
--- a/pypy/jit/metainterp/test/test_warmspot.py
+++ b/pypy/jit/metainterp/test/test_warmspot.py
@@ -383,7 +383,36 @@
assert res == expected
self.check_resops(int_sub=2, int_mul=0, int_add=2)
- def test_inline_in_portal(self):
+ def test_inline_jit_merge_point(self):
+ # test that the machinery to inline jit_merge_points in callers
+ # works. The final user does not need to mess manually with the
+ # _inline_jit_merge_point_ attribute and similar, it is all nicely
+ # handled by @JitDriver.inline() (see next tests)
+ myjitdriver = JitDriver(greens = ['a'], reds = 'auto')
+
+ def jit_merge_point(a, b):
+ myjitdriver.jit_merge_point(a=a)
+
+ def add(a, b):
+ jit_merge_point(a, b)
+ return a+b
+ add._inline_jit_merge_point_ = jit_merge_point
+ myjitdriver.inline_jit_merge_point = True
+
+ def calc(n):
+ res = 0
+ while res < 1000:
+ res = add(n, res)
+ return res
+
+ def f():
+ return calc(1) + calc(3)
+
+ res = self.meta_interp(f, [])
+ assert res == 1000 + 1002
+ self.check_resops(int_add=4)
+
+ def test_jitdriver_inline(self):
myjitdriver = JitDriver(greens = [], reds = 'auto')
class MyRange(object):
def __init__(self, n):
@@ -393,35 +422,102 @@
def __iter__(self):
return self
- @myjitdriver.inline_in_portal
+ def jit_merge_point(self):
+ myjitdriver.jit_merge_point()
+
+ @myjitdriver.inline(jit_merge_point)
def next(self):
- myjitdriver.jit_merge_point()
if self.cur == self.n:
raise StopIteration
self.cur += 1
return self.cur
- def one():
+ def f(n):
res = 0
- for i in MyRange(10):
+ for i in MyRange(n):
res += i
return res
- def two():
+ expected = f(21)
+ res = self.meta_interp(f, [21])
+ assert res == expected
+ self.check_resops(int_eq=2, int_add=4)
+ self.check_trace_count(1)
+
+ def test_jitdriver_inline_twice(self):
+ myjitdriver = JitDriver(greens = [], reds = 'auto')
+
+ def jit_merge_point(a, b):
+ myjitdriver.jit_merge_point()
+
+ @myjitdriver.inline(jit_merge_point)
+ def add(a, b):
+ return a+b
+
+ def one(n):
res = 0
- for i in MyRange(13):
- res += i * 2
+ while res < 1000:
+ res = add(n, res)
return res
- def f(n, m):
- res = one() * 100
- res += two()
+ def two(n):
+ res = 0
+ while res < 2000:
+ res = add(n, res)
return res
- expected = f(21, 5)
- res = self.meta_interp(f, [21, 5])
+
+ def f(n):
+ return one(n) + two(n)
+
+ res = self.meta_interp(f, [1])
+ assert res == 3000
+ self.check_resops(int_add=4)
+ self.check_trace_count(2)
+
+ def test_jitdriver_inline_exception(self):
+ # this simulates what happens in a real case scenario: inside the next
+ # we have a call which we cannot inline (e.g. space.next in the case
+ # of W_InterpIterable), but we need to put it in a try/except block.
+ # With the first "inline_in_portal" approach, this case crashed
+ myjitdriver = JitDriver(greens = [], reds = 'auto')
+
+ def inc(x, n):
+ if x == n:
+ raise OverflowError
+ return x+1
+ inc._dont_inline_ = True
+
+ class MyRange(object):
+ def __init__(self, n):
+ self.cur = 0
+ self.n = n
+
+ def __iter__(self):
+ return self
+
+ def jit_merge_point(self):
+ myjitdriver.jit_merge_point()
+
+ @myjitdriver.inline(jit_merge_point)
+ def next(self):
+ try:
+ self.cur = inc(self.cur, self.n)
+ except OverflowError:
+ raise StopIteration
+ return self.cur
+
+ def f(n):
+ res = 0
+ for i in MyRange(n):
+ res += i
+ return res
+
+ expected = f(21)
+ res = self.meta_interp(f, [21])
assert res == expected
- self.check_resops(int_eq=4, int_add=8)
- self.check_trace_count(2)
+ self.check_resops(int_eq=2, int_add=4)
+ self.check_trace_count(1)
+
class TestLLWarmspot(WarmspotTests, LLJitMixin):
CPUClass = runner.LLtypeCPU
diff --git a/pypy/jit/metainterp/warmspot.py b/pypy/jit/metainterp/warmspot.py
--- a/pypy/jit/metainterp/warmspot.py
+++ b/pypy/jit/metainterp/warmspot.py
@@ -244,27 +244,80 @@
def inline_inlineable_portals(self):
"""
- Find all the graphs which have been decorated with
- @jitdriver.inline_in_portal and inline them in the callers, making
- them JIT portals. Then, create a fresh copy of the jitdriver for each
- of those new portals, because they cannot share the same one. See
- test_ajit::test_inline_in_portal.
+ Find all the graphs which have been decorated with @jitdriver.inline
+ and inline them in the callers, making them JIT portals. Then, create
+ a fresh copy of the jitdriver for each of those new portals, because
+ they cannot share the same one. See
+ test_ajit::test_inline_jit_merge_point
"""
- from pypy.translator.backendopt import inline
- lltype_to_classdef =
self.translator.rtyper.lltype_to_classdef_mapping()
- raise_analyzer = inline.RaiseAnalyzer(self.translator)
- callgraph = inline.inlinable_static_callers(self.translator.graphs)
+ from pypy.translator.backendopt.inline import (
+ get_funcobj, inlinable_static_callers, auto_inlining)
+
+ jmp_calls = {}
+ def get_jmp_call(graph, _inline_jit_merge_point_):
+ # there might be multiple calls to the @inlined function: the
+ # first time we see it, we remove the call to the jit_merge_point
+ # and we remember the corresponding op. Then, we create a new call
+ # to it every time we need a new one (i.e., for each callsite
+ # which becomes a new portal)
+ try:
+ op, jmp_graph = jmp_calls[graph]
+ except KeyError:
+ op, jmp_graph = fish_jmp_call(graph, _inline_jit_merge_point_)
+ jmp_calls[graph] = op, jmp_graph
+ #
+ # clone the op
+ newargs = op.args[:]
+ newresult = Variable()
+ newresult.concretetype = op.result.concretetype
+ op = SpaceOperation(op.opname, newargs, newresult)
+ return op, jmp_graph
+
+ def fish_jmp_call(graph, _inline_jit_merge_point_):
+ # graph is function which has been decorated with
+ # @jitdriver.inline, so its very first op is a call to the
+ # function which contains the actual jit_merge_point: fish it!
+ jmp_block, op_jmp_call = next(callee.iterblockops())
+ msg = ("The first operation of an _inline_jit_merge_point_ graph
must be "
+ "a direct_call to the function passed to
@jitdriver.inline()")
+ assert op_jmp_call.opname == 'direct_call', msg
+ jmp_funcobj = get_funcobj(op_jmp_call.args[0].value)
+ assert jmp_funcobj._callable is _inline_jit_merge_point_, msg
+ jmp_block.operations.remove(op_jmp_call)
+ return op_jmp_call, jmp_funcobj.graph
+
+ # find all the graphs which call an @inline_in_portal function
+ callgraph = inlinable_static_callers(self.translator.graphs,
store_calls=True)
+ new_callgraph = []
new_portals = set()
- for caller, callee in callgraph:
+ for caller, block, op_call, callee in callgraph:
func = getattr(callee, 'func', None)
- _inline_in_portal_ = getattr(func, '_inline_in_portal_', False)
- if _inline_in_portal_:
- count = inline.inline_function(self.translator, callee, caller,
- lltype_to_classdef,
raise_analyzer)
- assert count > 0, ('The function has been decorated with '
- '@inline_in_portal, but it is not possible '
- 'to inline it')
+ _inline_jit_merge_point_ = getattr(func,
'_inline_jit_merge_point_', None)
+ if _inline_jit_merge_point_:
+ _inline_jit_merge_point_._always_inline_ = True
+ op_jmp_call, jmp_graph = get_jmp_call(callee,
_inline_jit_merge_point_)
+ #
+ # now we move the op_jmp_call from callee to caller, just
+ # before op_call. We assume that the args passed to
+ # op_jmp_call are the very same which are received by callee
+ # (i.e., the one passed to op_call)
+ assert len(op_call.args) == len(op_jmp_call.args)
+ op_jmp_call.args[1:] = op_call.args[1:]
+ idx = block.operations.index(op_call)
+ block.operations.insert(idx, op_jmp_call)
+ #
+ # finally, we signal that we want to inline op_jmp_call into
+ # caller, so that finally the actuall call to
+ # driver.jit_merge_point will be seen there
+ new_callgraph.append((caller, jmp_graph))
new_portals.add(caller)
+
+ # inline them!
+ inline_threshold = 0.1 # we rely on the _always_inline_ set above
+ auto_inlining(self.translator, inline_threshold, new_callgraph)
+
+ # make a fresh copy of the JitDriver in all newly created
+ # jit_merge_points
self.clone_inlined_jit_merge_points(new_portals)
def clone_inlined_jit_merge_points(self, graphs):
@@ -277,7 +330,10 @@
for graph, block, pos in find_jit_merge_points(graphs):
op = block.operations[pos]
v_driver = op.args[1]
- new_driver = v_driver.value.clone()
+ driver = v_driver.value
+ if not driver.inline_jit_merge_point:
+ continue
+ new_driver = driver.clone()
c_new_driver = Constant(new_driver, v_driver.concretetype)
op.args[1] = c_new_driver
@@ -320,6 +376,7 @@
alive_v.add(op1.result)
greens_v = op.args[2:]
reds_v = alive_v - set(greens_v)
+ reds_v = [v for v in reds_v if v.concretetype is not
lltype.Void]
reds_v = support.sort_vars(reds_v)
op.args.extend(reds_v)
if jitdriver.numreds is None:
diff --git a/pypy/rlib/jit.py b/pypy/rlib/jit.py
--- a/pypy/rlib/jit.py
+++ b/pypy/rlib/jit.py
@@ -6,6 +6,7 @@
from pypy.rlib.objectmodel import CDefinedIntSymbolic, keepalive_until_here,
specialize
from pypy.rlib.unroll import unrolling_iterable
from pypy.rpython.extregistry import ExtRegistryEntry
+from pypy.tool.sourcetools import rpython_wrapper
DEBUG_ELIDABLE_FUNCTIONS = False
@@ -443,7 +444,7 @@
active = True # if set to False, this JitDriver is ignored
virtualizables = []
name = 'jitdriver'
- inlined_in_portal = False
+ inline_jit_merge_point = False
def __init__(self, greens=None, reds=None, virtualizables=None,
get_jitcell_at=None, set_jitcell_at=None,
@@ -551,14 +552,26 @@
# special-cased by ExtRegistryEntry
pass
- def inline_in_portal(self, func):
- assert self.autoreds, "inline_in_portal works only with reds='auto'"
- func._inline_in_portal_ = True
- self.inlined_in_portal = True
- return func
+ def inline(self, call_jit_merge_point):
+ assert self.autoreds, "@inline works only with reds='auto'"
+ self.inline_jit_merge_point = True
+ def decorate(func):
+ template = """
+ def {name}({arglist}):
+ {call_jit_merge_point}({arglist})
+ return {original}({arglist})
+ """
+ templateargs = {'call_jit_merge_point':
call_jit_merge_point.__name__}
+ globaldict = {call_jit_merge_point.__name__: call_jit_merge_point}
+ result = rpython_wrapper(func, template, templateargs,
**globaldict)
+ result._inline_jit_merge_point_ = call_jit_merge_point
+ return result
+
+ return decorate
+
def clone(self):
- assert self.inlined_in_portal, 'JitDriver.clone works only after
@inline_in_portal'
+ assert self.inline_jit_merge_point, 'JitDriver.clone works only after
@inline'
newdriver = object.__new__(self.__class__)
newdriver.__dict__ = self.__dict__.copy()
return newdriver
diff --git a/pypy/rlib/objectmodel.py b/pypy/rlib/objectmodel.py
--- a/pypy/rlib/objectmodel.py
+++ b/pypy/rlib/objectmodel.py
@@ -8,6 +8,7 @@
import types
import math
import inspect
+from pypy.tool.sourcetools import rpython_wrapper
# specialize is a decorator factory for attaching _annspecialcase_
# attributes to functions: for example
@@ -170,9 +171,16 @@
f.func_name, srcargs[i], expected_type)
raise TypeError, msg
#
- # we cannot simply wrap the function using *args, **kwds, because it's
- # not RPython. Instead, we generate a function with exactly the same
- # argument list
+ template = """
+ def {name}({arglist}):
+ if not we_are_translated():
+ typecheck({arglist}) # pypy.rlib.objectmodel
+ return {original}({arglist})
+ """
+ result = rpython_wrapper(f, template,
+ typecheck=typecheck,
+ we_are_translated=we_are_translated)
+ #
srcargs, srcvarargs, srckeywords, defaults = inspect.getargspec(f)
if kwds:
types = tuple([kwds.get(arg) for arg in srcargs])
@@ -181,28 +189,11 @@
assert len(srcargs) == len(types), (
'not enough types provided: expected %d, got %d' %
(len(types), len(srcargs)))
- assert not srcvarargs, '*args not supported by enforceargs'
- assert not srckeywords, '**kwargs not supported by enforceargs'
- #
- arglist = ', '.join(srcargs)
- src = py.code.Source("""
- def %(name)s(%(arglist)s):
- if not we_are_translated():
- typecheck(%(arglist)s) # pypy.rlib.objectmodel
- return %(name)s_original(%(arglist)s)
- """ % dict(name=f.func_name, arglist=arglist))
- #
- mydict = {f.func_name + '_original': f,
- 'typecheck': typecheck,
- 'we_are_translated': we_are_translated}
- exec src.compile() in mydict
- result = mydict[f.func_name]
- result.func_defaults = f.func_defaults
- result.func_dict.update(f.func_dict)
result._annenforceargs_ = types
return result
return decorator
+
# ____________________________________________________________
class Symbolic(object):
diff --git a/pypy/rlib/test/test_jit.py b/pypy/rlib/test/test_jit.py
--- a/pypy/rlib/test/test_jit.py
+++ b/pypy/rlib/test/test_jit.py
@@ -37,16 +37,35 @@
assert driver.reds == ['a', 'b']
assert driver.numreds == 2
+def test_jitdriver_inline():
+ driver = JitDriver(greens=[], reds='auto')
+ calls = []
+ def foo(a, b):
+ calls.append(('foo', a, b))
+
+ @driver.inline(foo)
+ def bar(a, b):
+ calls.append(('bar', a, b))
+ return a+b
+
+ assert bar._inline_jit_merge_point_ is foo
+ assert driver.inline_jit_merge_point
+ assert bar(40, 2) == 42
+ assert calls == [
+ ('foo', 40, 2),
+ ('bar', 40, 2),
+ ]
+
def test_jitdriver_clone():
- def foo():
- pass
+ def bar(): pass
+ def foo(): pass
driver = JitDriver(greens=[], reds=[])
- py.test.raises(AssertionError, "driver.inline_in_portal(foo)")
+ py.test.raises(AssertionError, "driver.inline(bar)(foo)")
#
driver = JitDriver(greens=[], reds='auto')
py.test.raises(AssertionError, "driver.clone()")
- foo = driver.inline_in_portal(foo)
- assert foo._inline_in_portal_ == True
+ foo = driver.inline(bar)(foo)
+ assert foo._inline_jit_merge_point_ == bar
#
driver.foo = 'bar'
driver2 = driver.clone()
diff --git a/pypy/tool/sourcetools.py b/pypy/tool/sourcetools.py
--- a/pypy/tool/sourcetools.py
+++ b/pypy/tool/sourcetools.py
@@ -268,3 +268,30 @@
except AttributeError:
firstlineno = -1
return "(%s:%d)%s" % (mod or '?', firstlineno, name or 'UNKNOWN')
+
+
+def rpython_wrapper(f, template, templateargs=None, **globaldict):
+ """
+ We cannot simply wrap the function using *args, **kwds, because it's not
+ RPython. Instead, we generate a function from ``template`` with exactly
+ the same argument list.
+ """
+ if templateargs is None:
+ templateargs = {}
+ srcargs, srcvarargs, srckeywords, defaults = inspect.getargspec(f)
+ assert not srcvarargs, '*args not supported by enforceargs'
+ assert not srckeywords, '**kwargs not supported by enforceargs'
+ #
+ arglist = ', '.join(srcargs)
+ templateargs.update(name=f.func_name,
+ arglist=arglist,
+ original=f.func_name+'_original')
+ src = template.format(**templateargs)
+ src = py.code.Source(src)
+ #
+ globaldict[f.func_name + '_original'] = f
+ exec src.compile() in globaldict
+ result = globaldict[f.func_name]
+ result.func_defaults = f.func_defaults
+ result.func_dict.update(f.func_dict)
+ return result
diff --git a/pypy/tool/test/test_sourcetools.py
b/pypy/tool/test/test_sourcetools.py
--- a/pypy/tool/test/test_sourcetools.py
+++ b/pypy/tool/test/test_sourcetools.py
@@ -1,4 +1,4 @@
-from pypy.tool.sourcetools import func_with_new_name, func_renamer
+from pypy.tool.sourcetools import func_with_new_name, func_renamer,
rpython_wrapper
def test_rename():
def f(x, y=5):
@@ -34,3 +34,25 @@
bar3 = func_with_new_name(bar, 'bar3')
assert bar3.func_doc == 'new doc'
assert bar2.func_doc != bar3.func_doc
+
+
+def test_rpython_wrapper():
+ calls = []
+
+ def bar(a, b):
+ calls.append(('bar', a, b))
+ return a+b
+
+ template = """
+ def {name}({arglist}):
+ calls.append(('decorated', {arglist}))
+ return {original}({arglist})
+ """
+ bar = rpython_wrapper(bar, template, calls=calls)
+ assert bar(40, 2) == 42
+ assert calls == [
+ ('decorated', 40, 2),
+ ('bar', 40, 2),
+ ]
+
+
diff --git a/pypy/translator/backendopt/inline.py
b/pypy/translator/backendopt/inline.py
--- a/pypy/translator/backendopt/inline.py
+++ b/pypy/translator/backendopt/inline.py
@@ -614,9 +614,15 @@
return (0.9999 * measure_median_execution_cost(graph) +
count), True
-def inlinable_static_callers(graphs):
+def inlinable_static_callers(graphs, store_calls=False):
ok_to_call = set(graphs)
result = []
+ def add(parentgraph, block, op, graph):
+ if store_calls:
+ result.append((parentgraph, block, op, graph))
+ else:
+ result.append((parentgraph, graph))
+ #
for parentgraph in graphs:
for block in parentgraph.iterblocks():
for op in block.operations:
@@ -627,12 +633,12 @@
if getattr(getattr(funcobj, '_callable', None),
'_dont_inline_', False):
continue
- result.append((parentgraph, graph))
+ add(parentgraph, block, op, graph)
if op.opname == "oosend":
meth = get_meth_from_oosend(op)
graph = getattr(meth, 'graph', None)
if graph is not None and graph in ok_to_call:
- result.append((parentgraph, graph))
+ add(parentgraph, block, op, graph)
return result
def instrument_inline_candidates(graphs, threshold):
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit