Author: Maciej Fijalkowski <[email protected]>
Branch:
Changeset: r44625:2a5057b89b75
Date: 2011-06-01 16:52 +0200
http://bitbucket.org/pypy/pypy/changeset/2a5057b89b75/
Log: merge jit-applevel-hook. This branch provides a hook, used like
that:
import pypyjit pypyjit.set_compile_hook(a_callable)
that will invoke callable each time there is a loop to be compiled.
Refer to function docstring for details
diff --git a/pypy/annotation/annrpython.py b/pypy/annotation/annrpython.py
--- a/pypy/annotation/annrpython.py
+++ b/pypy/annotation/annrpython.py
@@ -228,7 +228,7 @@
# graph -- it's already low-level operations!
for a, s_newarg in zip(graph.getargs(), cells):
s_oldarg = self.binding(a)
- assert s_oldarg.contains(s_newarg)
+ assert annmodel.unionof(s_oldarg, s_newarg) == s_oldarg
else:
assert not self.frozen
for a in cells:
diff --git a/pypy/jit/metainterp/compile.py b/pypy/jit/metainterp/compile.py
--- a/pypy/jit/metainterp/compile.py
+++ b/pypy/jit/metainterp/compile.py
@@ -124,18 +124,21 @@
return old_loop_token
if loop.preamble.operations is not None:
- send_loop_to_backend(metainterp_sd, loop, "loop")
+ send_loop_to_backend(greenkey, jitdriver_sd, metainterp_sd, loop,
+ "loop")
record_loop_or_bridge(metainterp_sd, loop)
token = loop.preamble.token
if full_preamble_needed:
- send_loop_to_backend(metainterp_sd, loop.preamble, "entry bridge")
+ send_loop_to_backend(greenkey, jitdriver_sd, metainterp_sd,
+ loop.preamble, "entry bridge")
insert_loop_token(old_loop_tokens, loop.preamble.token)
jitdriver_sd.warmstate.attach_unoptimized_bridge_from_interp(
greenkey, loop.preamble.token)
record_loop_or_bridge(metainterp_sd, loop.preamble)
return token
else:
- send_loop_to_backend(metainterp_sd, loop, "loop")
+ send_loop_to_backend(greenkey, jitdriver_sd, metainterp_sd, loop,
+ "loop")
insert_loop_token(old_loop_tokens, loop_token)
jitdriver_sd.warmstate.attach_unoptimized_bridge_from_interp(
greenkey, loop.token)
@@ -150,7 +153,9 @@
# XXX do we still need a list?
old_loop_tokens.append(loop_token)
-def send_loop_to_backend(metainterp_sd, loop, type):
+def send_loop_to_backend(greenkey, jitdriver_sd, metainterp_sd, loop, type):
+ jitdriver_sd.on_compile(metainterp_sd.logger_ops, loop.token,
+ loop.operations, type, greenkey)
globaldata = metainterp_sd.globaldata
loop_token = loop.token
loop_token.number = n = globaldata.loopnumbering
@@ -186,8 +191,11 @@
if metainterp_sd.warmrunnerdesc is not None: # for tests
metainterp_sd.warmrunnerdesc.memory_manager.keep_loop_alive(loop.token)
-def send_bridge_to_backend(metainterp_sd, faildescr, inputargs, operations,
- original_loop_token):
+def send_bridge_to_backend(jitdriver_sd, metainterp_sd, faildescr, inputargs,
+ operations, original_loop_token):
+ n = metainterp_sd.cpu.get_fail_descr_number(faildescr)
+ jitdriver_sd.on_compile_bridge(metainterp_sd.logger_ops,
+ original_loop_token, operations, n)
if not we_are_translated():
show_loop(metainterp_sd)
TreeLoop.check_consistency_of(inputargs, operations)
@@ -204,7 +212,6 @@
metainterp_sd.stats.compiled()
metainterp_sd.log("compiled new bridge")
#
- n = metainterp_sd.cpu.get_fail_descr_number(faildescr)
metainterp_sd.logger_ops.log_bridge(inputargs, operations, n, ops_offset)
#
if metainterp_sd.warmrunnerdesc is not None: # for tests
@@ -390,8 +397,9 @@
inputargs = metainterp.history.inputargs
if not we_are_translated():
self._debug_suboperations = new_loop.operations
- send_bridge_to_backend(metainterp.staticdata, self, inputargs,
- new_loop.operations, new_loop.token)
+ send_bridge_to_backend(metainterp.jitdriver_sd, metainterp.staticdata,
+ self, inputargs, new_loop.operations,
+ new_loop.token)
def copy_all_attributes_into(self, res):
# XXX a bit ugly to have to list them all here
@@ -570,7 +578,8 @@
# to every guard in the loop.
new_loop_token = make_loop_token(len(redargs), jitdriver_sd)
new_loop.token = new_loop_token
- send_loop_to_backend(metainterp_sd, new_loop, "entry bridge")
+ send_loop_to_backend(self.original_greenkey, metainterp.jitdriver_sd,
+ metainterp_sd, new_loop, "entry bridge")
# send the new_loop to warmspot.py, to be called directly the next time
jitdriver_sd.warmstate.attach_unoptimized_bridge_from_interp(
self.original_greenkey,
diff --git a/pypy/jit/metainterp/jitdriver.py b/pypy/jit/metainterp/jitdriver.py
--- a/pypy/jit/metainterp/jitdriver.py
+++ b/pypy/jit/metainterp/jitdriver.py
@@ -20,6 +20,7 @@
# self.portal_finishtoken... pypy.jit.metainterp.pyjitpl
# self.index ... pypy.jit.codewriter.call
# self.mainjitcode ... pypy.jit.codewriter.call
+ # self.on_compile ... pypy.jit.metainterp.warmstate
# These attributes are read by the backend in CALL_ASSEMBLER:
# self.assembler_helper_adr
diff --git a/pypy/jit/metainterp/logger.py b/pypy/jit/metainterp/logger.py
--- a/pypy/jit/metainterp/logger.py
+++ b/pypy/jit/metainterp/logger.py
@@ -75,6 +75,40 @@
else:
return '?'
+ def repr_of_resop(self, memo, op, ops_offset=None):
+ if op.getopnum() == rop.DEBUG_MERGE_POINT:
+ loc = op.getarg(0)._get_str()
+ reclev = op.getarg(1).getint()
+ return "debug_merge_point('%s', %s)" % (loc, reclev)
+ if ops_offset is None:
+ offset = -1
+ else:
+ offset = ops_offset.get(op, -1)
+ if offset == -1:
+ s_offset = ""
+ else:
+ s_offset = "+%d: " % offset
+ args = ", ".join([self.repr_of_arg(memo, op.getarg(i)) for i in
range(op.numargs())])
+ if op.result is not None:
+ res = self.repr_of_arg(memo, op.result) + " = "
+ else:
+ res = ""
+ is_guard = op.is_guard()
+ if op.getdescr() is not None:
+ descr = op.getdescr()
+ if is_guard and self.guard_number:
+ index = self.metainterp_sd.cpu.get_fail_descr_number(descr)
+ r = "<Guard%d>" % index
+ else:
+ r = self.repr_of_descr(descr)
+ args += ', descr=' + r
+ if is_guard and op.getfailargs() is not None:
+ fail_args = ' [' + ", ".join([self.repr_of_arg(memo, arg)
+ for arg in op.getfailargs()]) + ']'
+ else:
+ fail_args = ''
+ return s_offset + res + op.getopname() + '(' + args + ')' + fail_args
+
def _log_operations(self, inputargs, operations, ops_offset):
if not have_debug_prints():
return
@@ -86,37 +120,7 @@
debug_print('[' + args + ']')
for i in range(len(operations)):
op = operations[i]
- if op.getopnum() == rop.DEBUG_MERGE_POINT:
- loc = op.getarg(0)._get_str()
- reclev = op.getarg(1).getint()
- debug_print("debug_merge_point('%s', %s)" % (loc, reclev))
- continue
- offset = ops_offset.get(op, -1)
- if offset == -1:
- s_offset = ""
- else:
- s_offset = "+%d: " % offset
- args = ", ".join([self.repr_of_arg(memo, op.getarg(i)) for i in
range(op.numargs())])
- if op.result is not None:
- res = self.repr_of_arg(memo, op.result) + " = "
- else:
- res = ""
- is_guard = op.is_guard()
- if op.getdescr() is not None:
- descr = op.getdescr()
- if is_guard and self.guard_number:
- index = self.metainterp_sd.cpu.get_fail_descr_number(descr)
- r = "<Guard%d>" % index
- else:
- r = self.repr_of_descr(descr)
- args += ', descr=' + r
- if is_guard and op.getfailargs() is not None:
- fail_args = ' [' + ", ".join([self.repr_of_arg(memo, arg)
- for arg in op.getfailargs()]) +
']'
- else:
- fail_args = ''
- debug_print(s_offset + res + op.getopname() +
- '(' + args + ')' + fail_args)
+ debug_print(self.repr_of_resop(memo, operations[i], ops_offset))
if ops_offset and None in ops_offset:
offset = ops_offset[None]
debug_print("+%d: --end of the loop--" % offset)
diff --git a/pypy/jit/metainterp/test/test_jitdriver.py
b/pypy/jit/metainterp/test/test_jitdriver.py
--- a/pypy/jit/metainterp/test/test_jitdriver.py
+++ b/pypy/jit/metainterp/test/test_jitdriver.py
@@ -10,8 +10,59 @@
def getloc2(g):
return "in jitdriver2, with g=%d" % g
+class JitDriverTests(object):
+ def test_on_compile(self):
+ called = {}
+
+ class MyJitDriver(JitDriver):
+ def on_compile(self, logger, looptoken, operations, type, n, m):
+ called[(m, n, type)] = looptoken
-class MultipleJitDriversTests:
+ driver = MyJitDriver(greens = ['n', 'm'], reds = ['i'])
+
+ def loop(n, m):
+ i = 0
+ while i < n + m:
+ driver.can_enter_jit(n=n, m=m, i=i)
+ driver.jit_merge_point(n=n, m=m, i=i)
+ i += 1
+
+ self.meta_interp(loop, [1, 4])
+ assert sorted(called.keys()) == [(4, 1, "entry bridge"), (4, 1,
"loop")]
+ self.meta_interp(loop, [2, 4])
+ assert sorted(called.keys()) == [(4, 1, "entry bridge"), (4, 1,
"loop"),
+ (4, 2, "entry bridge"), (4, 2,
"loop")]
+
+ def test_on_compile_bridge(self):
+ called = {}
+
+ class MyJitDriver(JitDriver):
+ def on_compile(self, logger, looptoken, operations, type, n, m):
+ called[(m, n, type)] = loop
+ def on_compile_bridge(self, logger, orig_token, operations, n):
+ assert 'bridge' not in called
+ called['bridge'] = orig_token
+
+ driver = MyJitDriver(greens = ['n', 'm'], reds = ['i'])
+
+ def loop(n, m):
+ i = 0
+ while i < n + m:
+ driver.can_enter_jit(n=n, m=m, i=i)
+ driver.jit_merge_point(n=n, m=m, i=i)
+ if i >= 4:
+ i += 2
+ i += 1
+
+ self.meta_interp(loop, [1, 10])
+ assert sorted(called.keys()) == ['bridge', (10, 1, "entry bridge"),
+ (10, 1, "loop")]
+
+
+class TestLLtypeSingle(JitDriverTests, LLJitMixin):
+ pass
+
+class MultipleJitDriversTests(object):
def test_simple(self):
myjitdriver1 = JitDriver(greens=[], reds=['n', 'm'],
diff --git a/pypy/jit/metainterp/warmstate.py b/pypy/jit/metainterp/warmstate.py
--- a/pypy/jit/metainterp/warmstate.py
+++ b/pypy/jit/metainterp/warmstate.py
@@ -566,6 +566,19 @@
return can_inline_greenargs(*greenargs)
self.can_inline_greenargs = can_inline_greenargs
self.can_inline_callable = can_inline_callable
+ if hasattr(jd.jitdriver, 'on_compile'):
+ def on_compile(logger, token, operations, type, greenkey):
+ greenargs = unwrap_greenkey(greenkey)
+ return jd.jitdriver.on_compile(logger, token, operations, type,
+ *greenargs)
+ def on_compile_bridge(logger, orig_token, operations, n):
+ return jd.jitdriver.on_compile_bridge(logger, orig_token,
+ operations, n)
+ jd.on_compile = on_compile
+ jd.on_compile_bridge = on_compile_bridge
+ else:
+ jd.on_compile = lambda *args: None
+ jd.on_compile_bridge = lambda *args: None
def get_assembler_token(greenkey, redboxes):
# 'redboxes' is only used to know the types of red arguments
diff --git a/pypy/module/pypyjit/__init__.py b/pypy/module/pypyjit/__init__.py
--- a/pypy/module/pypyjit/__init__.py
+++ b/pypy/module/pypyjit/__init__.py
@@ -7,13 +7,15 @@
interpleveldefs = {
'set_param': 'interp_jit.set_param',
'residual_call': 'interp_jit.residual_call',
+ 'set_compile_hook': 'interp_jit.set_compile_hook',
}
def setup_after_space_initialization(self):
# force the __extend__ hacks to occur early
- import pypy.module.pypyjit.interp_jit
+ from pypy.module.pypyjit.interp_jit import pypyjitdriver
# add the 'defaults' attribute
from pypy.rlib.jit import PARAMETERS
space = self.space
+ pypyjitdriver.space = space
w_obj = space.wrap(PARAMETERS)
space.setattr(space.wrap(self), space.wrap('defaults'), w_obj)
diff --git a/pypy/module/pypyjit/interp_jit.py
b/pypy/module/pypyjit/interp_jit.py
--- a/pypy/module/pypyjit/interp_jit.py
+++ b/pypy/module/pypyjit/interp_jit.py
@@ -12,6 +12,8 @@
from pypy.interpreter.pycode import PyCode, CO_GENERATOR
from pypy.interpreter.pyframe import PyFrame
from pypy.interpreter.pyopcode import ExitFrame
+from pypy.interpreter.gateway import unwrap_spec
+from pypy.interpreter.baseobjspace import ObjSpace, W_Root
from opcode import opmap
from pypy.rlib.objectmodel import we_are_translated
@@ -49,6 +51,44 @@
greens = ['next_instr', 'is_being_profiled', 'pycode']
virtualizables = ['frame']
+ def on_compile(self, logger, looptoken, operations, type, next_instr,
+ is_being_profiled, ll_pycode):
+ from pypy.rpython.annlowlevel import cast_base_ptr_to_instance
+
+ space = self.space
+ cache = space.fromcache(Cache)
+ if space.is_true(cache.w_compile_hook):
+ memo = {}
+ list_w = [space.wrap(logger.repr_of_resop(memo, op))
+ for op in operations]
+ pycode = cast_base_ptr_to_instance(PyCode, ll_pycode)
+ try:
+ space.call_function(cache.w_compile_hook,
+ space.wrap('main'),
+ space.wrap(type),
+ space.newtuple([pycode,
+ space.wrap(next_instr),
+ space.wrap(is_being_profiled)]),
+ space.newlist(list_w))
+ except OperationError, e:
+ e.write_unraisable(space, "jit hook ", cache.w_compile_hook)
+
+ def on_compile_bridge(self, logger, orig_looptoken, operations, n):
+ space = self.space
+ cache = space.fromcache(Cache)
+ if space.is_true(cache.w_compile_hook):
+ memo = {}
+ list_w = [space.wrap(logger.repr_of_resop(memo, op))
+ for op in operations]
+ try:
+ space.call_function(cache.w_compile_hook,
+ space.wrap('main'),
+ space.wrap('bridge'),
+ space.wrap(n),
+ space.newlist(list_w))
+ except OperationError, e:
+ e.write_unraisable(space, "jit hook ", cache.w_compile_hook)
+
pypyjitdriver = PyPyJitDriver(get_printable_location = get_printable_location,
get_jitcell_at = get_jitcell_at,
set_jitcell_at = set_jitcell_at,
@@ -149,3 +189,28 @@
'''For testing. Invokes callable(...), but without letting
the JIT follow the call.'''
return space.call_args(w_callable, __args__)
+
+class Cache(object):
+ def __init__(self, space):
+ self.w_compile_hook = space.w_None
+
+@unwrap_spec(ObjSpace, W_Root)
+def set_compile_hook(space, w_hook):
+ """ set_compile_hook(hook)
+
+ Set a compiling hook that will be called each time a loop is compiled.
+ The hook will be called with the following signature:
+ hook(merge_point_type, loop_type, greenkey or guard_number, operations)
+
+ for now merge point type is always `main`
+
+ loop_type can be either `loop` `entry_bridge` or `bridge`
+ in case loop is not `bridge`, greenkey will be a set of constants
+ for jit merge point. in case it's `main` it'll be a tuple
+ (code, offset, is_being_profiled)
+
+ XXX write down what else
+ """
+ cache = space.fromcache(Cache)
+ cache.w_compile_hook = w_hook
+ return space.w_None
diff --git a/pypy/module/pypyjit/test/test_jit_hook.py
b/pypy/module/pypyjit/test/test_jit_hook.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/pypyjit/test/test_jit_hook.py
@@ -0,0 +1,85 @@
+
+from pypy.conftest import gettestobjspace
+from pypy.interpreter.pycode import PyCode
+from pypy.interpreter.gateway import interp2app
+from pypy.jit.metainterp.history import LoopToken
+from pypy.jit.metainterp.resoperation import ResOperation, rop
+from pypy.jit.metainterp.logger import Logger
+from pypy.rpython.annlowlevel import (cast_instance_to_base_ptr,
+ cast_base_ptr_to_instance)
+from pypy.module.pypyjit.interp_jit import pypyjitdriver
+from pypy.jit.tool.oparser import parse
+from pypy.jit.metainterp.typesystem import llhelper
+
+class MockSD(object):
+ class cpu:
+ ts = llhelper
+
+class AppTestJitHook(object):
+ def setup_class(cls):
+ space = gettestobjspace(usemodules=('pypyjit',))
+ cls.space = space
+ w_f = space.appexec([], """():
+ def f():
+ pass
+ return f
+ """)
+ ll_code = cast_instance_to_base_ptr(w_f.code)
+ logger = Logger(MockSD())
+
+ oplist = parse("""
+ [i1, i2]
+ i3 = int_add(i1, i2)
+ guard_true(i3) []
+ """).operations
+
+ def interp_on_compile():
+ pypyjitdriver.on_compile(logger, LoopToken(), oplist, 'loop',
+ 0, False, ll_code)
+
+ def interp_on_compile_bridge():
+ pypyjitdriver.on_compile_bridge(logger, LoopToken(), oplist, 0)
+
+ cls.w_on_compile = space.wrap(interp2app(interp_on_compile))
+ cls.w_on_compile_bridge =
space.wrap(interp2app(interp_on_compile_bridge))
+
+ def test_on_compile(self):
+ import pypyjit
+ all = []
+
+ def hook(*args):
+ assert args[0] == 'main'
+ assert args[1] in ['loop', 'bridge']
+ all.append(args[2:])
+
+ self.on_compile()
+ pypyjit.set_compile_hook(hook)
+ assert not all
+ self.on_compile()
+ assert len(all) == 1
+ assert all[0][0][0].co_name == 'f'
+ assert all[0][0][1] == 0
+ assert all[0][0][2] == False
+ assert len(all[0][1]) == 2
+ assert 'int_add' in all[0][1][0]
+ self.on_compile_bridge()
+ assert len(all) == 2
+ pypyjit.set_compile_hook(None)
+ self.on_compile()
+ assert len(all) == 2
+
+ def test_on_compile_exception(self):
+ import pypyjit, sys, cStringIO
+
+ def hook(*args):
+ 1/0
+
+ pypyjit.set_compile_hook(hook)
+ s = cStringIO.StringIO()
+ sys.stderr = s
+ try:
+ self.on_compile()
+ finally:
+ sys.stderr = sys.__stderr__
+ assert 'jit hook' in s.getvalue()
+ assert 'ZeroDivisionError' in s.getvalue()
diff --git a/pypy/rlib/jit.py b/pypy/rlib/jit.py
--- a/pypy/rlib/jit.py
+++ b/pypy/rlib/jit.py
@@ -370,6 +370,24 @@
raise
set_user_param._annspecialcase_ = 'specialize:arg(0)'
+
+ def on_compile(self, logger, looptoken, operations, type, *greenargs):
+ """ A hook called when loop is compiled. Overwrite
+ for your own jitdriver if you want to do something special, like
+ call applevel code
+ """
+
+ def on_compile_bridge(self, logger, orig_looptoken, operations, n):
+ """ A hook called when a bridge is compiled. Overwrite
+ for your own jitdriver if you want to do something special
+ """
+
+ # note: if you overwrite this functions with the above signature it'll
+ # work, but the *greenargs is different for each jitdriver, so we
+ # can't share the same methods
+ del on_compile
+ del on_compile_bridge
+
def _make_extregistryentries(self):
# workaround: we cannot declare ExtRegistryEntries for functions
# used as methods of a frozen object, but we can attach the
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit