Author: Maciej Fijalkowski <[email protected]>
Branch: 
Changeset: r44625:2a5057b89b75
Date: 2011-06-01 16:52 +0200
http://bitbucket.org/pypy/pypy/changeset/2a5057b89b75/

Log:    merge jit-applevel-hook. This branch provides a hook, used like
        that:

        import pypyjit pypyjit.set_compile_hook(a_callable)

        that will invoke callable each time there is a loop to be compiled.
        Refer to function docstring for details

diff --git a/pypy/annotation/annrpython.py b/pypy/annotation/annrpython.py
--- a/pypy/annotation/annrpython.py
+++ b/pypy/annotation/annrpython.py
@@ -228,7 +228,7 @@
             # graph -- it's already low-level operations!
             for a, s_newarg in zip(graph.getargs(), cells):
                 s_oldarg = self.binding(a)
-                assert s_oldarg.contains(s_newarg)
+                assert annmodel.unionof(s_oldarg, s_newarg) == s_oldarg
         else:
             assert not self.frozen
             for a in cells:
diff --git a/pypy/jit/metainterp/compile.py b/pypy/jit/metainterp/compile.py
--- a/pypy/jit/metainterp/compile.py
+++ b/pypy/jit/metainterp/compile.py
@@ -124,18 +124,21 @@
         return old_loop_token
 
     if loop.preamble.operations is not None:
-        send_loop_to_backend(metainterp_sd, loop, "loop")
+        send_loop_to_backend(greenkey, jitdriver_sd, metainterp_sd, loop,
+                             "loop")
         record_loop_or_bridge(metainterp_sd, loop)
         token = loop.preamble.token
         if full_preamble_needed:
-            send_loop_to_backend(metainterp_sd, loop.preamble, "entry bridge")
+            send_loop_to_backend(greenkey, jitdriver_sd, metainterp_sd,
+                                 loop.preamble, "entry bridge")
             insert_loop_token(old_loop_tokens, loop.preamble.token)
             jitdriver_sd.warmstate.attach_unoptimized_bridge_from_interp(
                 greenkey, loop.preamble.token)
             record_loop_or_bridge(metainterp_sd, loop.preamble)
         return token
     else:
-        send_loop_to_backend(metainterp_sd, loop, "loop")
+        send_loop_to_backend(greenkey, jitdriver_sd, metainterp_sd, loop,
+                             "loop")
         insert_loop_token(old_loop_tokens, loop_token)
         jitdriver_sd.warmstate.attach_unoptimized_bridge_from_interp(
             greenkey, loop.token)
@@ -150,7 +153,9 @@
     # XXX do we still need a list?
     old_loop_tokens.append(loop_token)
 
-def send_loop_to_backend(metainterp_sd, loop, type):
+def send_loop_to_backend(greenkey, jitdriver_sd, metainterp_sd, loop, type):
+    jitdriver_sd.on_compile(metainterp_sd.logger_ops, loop.token,
+                            loop.operations, type, greenkey)
     globaldata = metainterp_sd.globaldata
     loop_token = loop.token
     loop_token.number = n = globaldata.loopnumbering
@@ -186,8 +191,11 @@
     if metainterp_sd.warmrunnerdesc is not None:    # for tests
         metainterp_sd.warmrunnerdesc.memory_manager.keep_loop_alive(loop.token)
 
-def send_bridge_to_backend(metainterp_sd, faildescr, inputargs, operations,
-                           original_loop_token):
+def send_bridge_to_backend(jitdriver_sd, metainterp_sd, faildescr, inputargs,
+                           operations, original_loop_token):
+    n = metainterp_sd.cpu.get_fail_descr_number(faildescr)
+    jitdriver_sd.on_compile_bridge(metainterp_sd.logger_ops,
+                                   original_loop_token, operations, n)
     if not we_are_translated():
         show_loop(metainterp_sd)
         TreeLoop.check_consistency_of(inputargs, operations)
@@ -204,7 +212,6 @@
         metainterp_sd.stats.compiled()
     metainterp_sd.log("compiled new bridge")
     #
-    n = metainterp_sd.cpu.get_fail_descr_number(faildescr)
     metainterp_sd.logger_ops.log_bridge(inputargs, operations, n, ops_offset)
     #
     if metainterp_sd.warmrunnerdesc is not None:    # for tests
@@ -390,8 +397,9 @@
         inputargs = metainterp.history.inputargs
         if not we_are_translated():
             self._debug_suboperations = new_loop.operations
-        send_bridge_to_backend(metainterp.staticdata, self, inputargs,
-                               new_loop.operations, new_loop.token)
+        send_bridge_to_backend(metainterp.jitdriver_sd, metainterp.staticdata,
+                               self, inputargs, new_loop.operations,
+                               new_loop.token)
 
     def copy_all_attributes_into(self, res):
         # XXX a bit ugly to have to list them all here
@@ -570,7 +578,8 @@
         # to every guard in the loop.
         new_loop_token = make_loop_token(len(redargs), jitdriver_sd)
         new_loop.token = new_loop_token
-        send_loop_to_backend(metainterp_sd, new_loop, "entry bridge")
+        send_loop_to_backend(self.original_greenkey, metainterp.jitdriver_sd,
+                             metainterp_sd, new_loop, "entry bridge")
         # send the new_loop to warmspot.py, to be called directly the next time
         jitdriver_sd.warmstate.attach_unoptimized_bridge_from_interp(
             self.original_greenkey,
diff --git a/pypy/jit/metainterp/jitdriver.py b/pypy/jit/metainterp/jitdriver.py
--- a/pypy/jit/metainterp/jitdriver.py
+++ b/pypy/jit/metainterp/jitdriver.py
@@ -20,6 +20,7 @@
     #    self.portal_finishtoken... pypy.jit.metainterp.pyjitpl
     #    self.index             ... pypy.jit.codewriter.call
     #    self.mainjitcode       ... pypy.jit.codewriter.call
+    #    self.on_compile        ... pypy.jit.metainterp.warmstate
 
     # These attributes are read by the backend in CALL_ASSEMBLER:
     #    self.assembler_helper_adr
diff --git a/pypy/jit/metainterp/logger.py b/pypy/jit/metainterp/logger.py
--- a/pypy/jit/metainterp/logger.py
+++ b/pypy/jit/metainterp/logger.py
@@ -75,6 +75,40 @@
         else:
             return '?'
 
+    def repr_of_resop(self, memo, op, ops_offset=None):
+        if op.getopnum() == rop.DEBUG_MERGE_POINT:
+            loc = op.getarg(0)._get_str()
+            reclev = op.getarg(1).getint()
+            return "debug_merge_point('%s', %s)" % (loc, reclev)
+        if ops_offset is None:
+            offset = -1
+        else:
+            offset = ops_offset.get(op, -1)
+        if offset == -1:
+            s_offset = ""
+        else:
+            s_offset = "+%d: " % offset
+        args = ", ".join([self.repr_of_arg(memo, op.getarg(i)) for i in 
range(op.numargs())])
+        if op.result is not None:
+            res = self.repr_of_arg(memo, op.result) + " = "
+        else:
+            res = ""
+        is_guard = op.is_guard()
+        if op.getdescr() is not None:
+            descr = op.getdescr()
+            if is_guard and self.guard_number:
+                index = self.metainterp_sd.cpu.get_fail_descr_number(descr)
+                r = "<Guard%d>" % index
+            else:
+                r = self.repr_of_descr(descr)
+            args += ', descr=' +  r
+        if is_guard and op.getfailargs() is not None:
+            fail_args = ' [' + ", ".join([self.repr_of_arg(memo, arg)
+                                          for arg in op.getfailargs()]) + ']'
+        else:
+            fail_args = ''
+        return s_offset + res + op.getopname() + '(' + args + ')' + fail_args
+
     def _log_operations(self, inputargs, operations, ops_offset):
         if not have_debug_prints():
             return
@@ -86,37 +120,7 @@
             debug_print('[' + args + ']')
         for i in range(len(operations)):
             op = operations[i]
-            if op.getopnum() == rop.DEBUG_MERGE_POINT:
-                loc = op.getarg(0)._get_str()
-                reclev = op.getarg(1).getint()
-                debug_print("debug_merge_point('%s', %s)" % (loc, reclev))
-                continue
-            offset = ops_offset.get(op, -1)
-            if offset == -1:
-                s_offset = ""
-            else:
-                s_offset = "+%d: " % offset
-            args = ", ".join([self.repr_of_arg(memo, op.getarg(i)) for i in 
range(op.numargs())])
-            if op.result is not None:
-                res = self.repr_of_arg(memo, op.result) + " = "
-            else:
-                res = ""
-            is_guard = op.is_guard()
-            if op.getdescr() is not None:
-                descr = op.getdescr()
-                if is_guard and self.guard_number:
-                    index = self.metainterp_sd.cpu.get_fail_descr_number(descr)
-                    r = "<Guard%d>" % index
-                else:
-                    r = self.repr_of_descr(descr)
-                args += ', descr=' +  r
-            if is_guard and op.getfailargs() is not None:
-                fail_args = ' [' + ", ".join([self.repr_of_arg(memo, arg)
-                                              for arg in op.getfailargs()]) + 
']'
-            else:
-                fail_args = ''
-            debug_print(s_offset + res + op.getopname() +
-                        '(' + args + ')' + fail_args)
+            debug_print(self.repr_of_resop(memo, operations[i], ops_offset))
         if ops_offset and None in ops_offset:
             offset = ops_offset[None]
             debug_print("+%d: --end of the loop--" % offset)
diff --git a/pypy/jit/metainterp/test/test_jitdriver.py 
b/pypy/jit/metainterp/test/test_jitdriver.py
--- a/pypy/jit/metainterp/test/test_jitdriver.py
+++ b/pypy/jit/metainterp/test/test_jitdriver.py
@@ -10,8 +10,59 @@
 def getloc2(g):
     return "in jitdriver2, with g=%d" % g
 
+class JitDriverTests(object):
+    def test_on_compile(self):
+        called = {}
+        
+        class MyJitDriver(JitDriver):
+            def on_compile(self, logger, looptoken, operations, type, n, m):
+                called[(m, n, type)] = looptoken
 
-class MultipleJitDriversTests:
+        driver = MyJitDriver(greens = ['n', 'm'], reds = ['i'])
+
+        def loop(n, m):
+            i = 0
+            while i < n + m:
+                driver.can_enter_jit(n=n, m=m, i=i)
+                driver.jit_merge_point(n=n, m=m, i=i)
+                i += 1
+
+        self.meta_interp(loop, [1, 4])
+        assert sorted(called.keys()) == [(4, 1, "entry bridge"), (4, 1, 
"loop")]
+        self.meta_interp(loop, [2, 4])
+        assert sorted(called.keys()) == [(4, 1, "entry bridge"), (4, 1, 
"loop"),
+                                         (4, 2, "entry bridge"), (4, 2, 
"loop")]
+
+    def test_on_compile_bridge(self):
+        called = {}
+        
+        class MyJitDriver(JitDriver):
+            def on_compile(self, logger, looptoken, operations, type, n, m):
+                called[(m, n, type)] = loop
+            def on_compile_bridge(self, logger, orig_token, operations, n):
+                assert 'bridge' not in called
+                called['bridge'] = orig_token
+
+        driver = MyJitDriver(greens = ['n', 'm'], reds = ['i'])
+
+        def loop(n, m):
+            i = 0
+            while i < n + m:
+                driver.can_enter_jit(n=n, m=m, i=i)
+                driver.jit_merge_point(n=n, m=m, i=i)
+                if i >= 4:
+                    i += 2
+                i += 1
+
+        self.meta_interp(loop, [1, 10])
+        assert sorted(called.keys()) == ['bridge', (10, 1, "entry bridge"),
+                                         (10, 1, "loop")]
+
+
+class TestLLtypeSingle(JitDriverTests, LLJitMixin):
+    pass
+
+class MultipleJitDriversTests(object):
 
     def test_simple(self):
         myjitdriver1 = JitDriver(greens=[], reds=['n', 'm'],
diff --git a/pypy/jit/metainterp/warmstate.py b/pypy/jit/metainterp/warmstate.py
--- a/pypy/jit/metainterp/warmstate.py
+++ b/pypy/jit/metainterp/warmstate.py
@@ -566,6 +566,19 @@
             return can_inline_greenargs(*greenargs)
         self.can_inline_greenargs = can_inline_greenargs
         self.can_inline_callable = can_inline_callable
+        if hasattr(jd.jitdriver, 'on_compile'):
+            def on_compile(logger, token, operations, type, greenkey):
+                greenargs = unwrap_greenkey(greenkey)
+                return jd.jitdriver.on_compile(logger, token, operations, type,
+                                               *greenargs)
+            def on_compile_bridge(logger, orig_token, operations, n):
+                return jd.jitdriver.on_compile_bridge(logger, orig_token,
+                                                      operations, n)
+            jd.on_compile = on_compile
+            jd.on_compile_bridge = on_compile_bridge
+        else:
+            jd.on_compile = lambda *args: None
+            jd.on_compile_bridge = lambda *args: None
 
         def get_assembler_token(greenkey, redboxes):
             # 'redboxes' is only used to know the types of red arguments
diff --git a/pypy/module/pypyjit/__init__.py b/pypy/module/pypyjit/__init__.py
--- a/pypy/module/pypyjit/__init__.py
+++ b/pypy/module/pypyjit/__init__.py
@@ -7,13 +7,15 @@
     interpleveldefs = {
         'set_param':    'interp_jit.set_param',
         'residual_call': 'interp_jit.residual_call',
+        'set_compile_hook': 'interp_jit.set_compile_hook',
     }
 
     def setup_after_space_initialization(self):
         # force the __extend__ hacks to occur early
-        import pypy.module.pypyjit.interp_jit
+        from pypy.module.pypyjit.interp_jit import pypyjitdriver
         # add the 'defaults' attribute
         from pypy.rlib.jit import PARAMETERS
         space = self.space
+        pypyjitdriver.space = space
         w_obj = space.wrap(PARAMETERS)
         space.setattr(space.wrap(self), space.wrap('defaults'), w_obj)
diff --git a/pypy/module/pypyjit/interp_jit.py 
b/pypy/module/pypyjit/interp_jit.py
--- a/pypy/module/pypyjit/interp_jit.py
+++ b/pypy/module/pypyjit/interp_jit.py
@@ -12,6 +12,8 @@
 from pypy.interpreter.pycode import PyCode, CO_GENERATOR
 from pypy.interpreter.pyframe import PyFrame
 from pypy.interpreter.pyopcode import ExitFrame
+from pypy.interpreter.gateway import unwrap_spec
+from pypy.interpreter.baseobjspace import ObjSpace, W_Root
 from opcode import opmap
 from pypy.rlib.objectmodel import we_are_translated
 
@@ -49,6 +51,44 @@
     greens = ['next_instr', 'is_being_profiled', 'pycode']
     virtualizables = ['frame']
 
+    def on_compile(self, logger, looptoken, operations, type, next_instr,
+                   is_being_profiled, ll_pycode):
+        from pypy.rpython.annlowlevel import cast_base_ptr_to_instance
+        
+        space = self.space
+        cache = space.fromcache(Cache)
+        if space.is_true(cache.w_compile_hook):
+            memo = {}
+            list_w = [space.wrap(logger.repr_of_resop(memo, op))
+                      for op in operations]
+            pycode = cast_base_ptr_to_instance(PyCode, ll_pycode)
+            try:
+                space.call_function(cache.w_compile_hook,
+                                    space.wrap('main'),
+                                    space.wrap(type),
+                                    space.newtuple([pycode,
+                                    space.wrap(next_instr),
+                                    space.wrap(is_being_profiled)]),
+                                    space.newlist(list_w))
+            except OperationError, e:
+                e.write_unraisable(space, "jit hook ", cache.w_compile_hook)
+
+    def on_compile_bridge(self, logger, orig_looptoken, operations, n):
+        space = self.space
+        cache = space.fromcache(Cache)
+        if space.is_true(cache.w_compile_hook):
+            memo = {}
+            list_w = [space.wrap(logger.repr_of_resop(memo, op))
+                      for op in operations]
+            try:
+                space.call_function(cache.w_compile_hook,
+                                    space.wrap('main'),
+                                    space.wrap('bridge'),
+                                    space.wrap(n),
+                                    space.newlist(list_w))
+            except OperationError, e:
+                e.write_unraisable(space, "jit hook ", cache.w_compile_hook)
+
 pypyjitdriver = PyPyJitDriver(get_printable_location = get_printable_location,
                               get_jitcell_at = get_jitcell_at,
                               set_jitcell_at = set_jitcell_at,
@@ -149,3 +189,28 @@
     '''For testing.  Invokes callable(...), but without letting
     the JIT follow the call.'''
     return space.call_args(w_callable, __args__)
+
+class Cache(object):
+    def __init__(self, space):
+        self.w_compile_hook = space.w_None
+
+@unwrap_spec(ObjSpace, W_Root)
+def set_compile_hook(space, w_hook):
+    """ set_compile_hook(hook)
+
+    Set a compiling hook that will be called each time a loop is compiled.
+    The hook will be called with the following signature:
+    hook(merge_point_type, loop_type, greenkey or guard_number, operations)
+
+    for now merge point type is always `main`
+
+    loop_type can be either `loop` `entry_bridge` or `bridge`
+    in case loop is not `bridge`, greenkey will be a set of constants
+    for jit merge point. in case it's `main` it'll be a tuple
+    (code, offset, is_being_profiled)
+
+    XXX write down what else
+    """
+    cache = space.fromcache(Cache)
+    cache.w_compile_hook = w_hook
+    return space.w_None
diff --git a/pypy/module/pypyjit/test/test_jit_hook.py 
b/pypy/module/pypyjit/test/test_jit_hook.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/pypyjit/test/test_jit_hook.py
@@ -0,0 +1,85 @@
+
+from pypy.conftest import gettestobjspace    
+from pypy.interpreter.pycode import PyCode
+from pypy.interpreter.gateway import interp2app
+from pypy.jit.metainterp.history import LoopToken
+from pypy.jit.metainterp.resoperation import ResOperation, rop
+from pypy.jit.metainterp.logger import Logger
+from pypy.rpython.annlowlevel import (cast_instance_to_base_ptr,
+                                      cast_base_ptr_to_instance)
+from pypy.module.pypyjit.interp_jit import pypyjitdriver
+from pypy.jit.tool.oparser import parse
+from pypy.jit.metainterp.typesystem import llhelper
+
+class MockSD(object):
+    class cpu:
+        ts = llhelper
+
+class AppTestJitHook(object):
+    def setup_class(cls):
+        space = gettestobjspace(usemodules=('pypyjit',))
+        cls.space = space
+        w_f = space.appexec([], """():
+        def f():
+            pass
+        return f
+        """)
+        ll_code = cast_instance_to_base_ptr(w_f.code)
+        logger = Logger(MockSD())
+
+        oplist = parse("""
+        [i1, i2]
+        i3 = int_add(i1, i2)
+        guard_true(i3) []
+        """).operations
+
+        def interp_on_compile():
+            pypyjitdriver.on_compile(logger, LoopToken(), oplist, 'loop',
+                                     0, False, ll_code)
+
+        def interp_on_compile_bridge():
+            pypyjitdriver.on_compile_bridge(logger, LoopToken(), oplist, 0)
+        
+        cls.w_on_compile = space.wrap(interp2app(interp_on_compile))
+        cls.w_on_compile_bridge = 
space.wrap(interp2app(interp_on_compile_bridge))
+
+    def test_on_compile(self):
+        import pypyjit
+        all = []
+
+        def hook(*args):
+            assert args[0] == 'main'
+            assert args[1] in ['loop', 'bridge']
+            all.append(args[2:])
+        
+        self.on_compile()
+        pypyjit.set_compile_hook(hook)
+        assert not all
+        self.on_compile()
+        assert len(all) == 1
+        assert all[0][0][0].co_name == 'f'
+        assert all[0][0][1] == 0
+        assert all[0][0][2] == False
+        assert len(all[0][1]) == 2
+        assert 'int_add' in all[0][1][0]
+        self.on_compile_bridge()
+        assert len(all) == 2
+        pypyjit.set_compile_hook(None)
+        self.on_compile()
+        assert len(all) == 2
+
+    def test_on_compile_exception(self):
+        import pypyjit, sys, cStringIO
+
+        def hook(*args):
+            1/0
+
+        pypyjit.set_compile_hook(hook)
+        s = cStringIO.StringIO()
+        sys.stderr = s
+        try:
+            self.on_compile()
+        finally:
+            sys.stderr = sys.__stderr__
+        assert 'jit hook' in s.getvalue()
+        assert 'ZeroDivisionError' in s.getvalue()
diff --git a/pypy/rlib/jit.py b/pypy/rlib/jit.py
--- a/pypy/rlib/jit.py
+++ b/pypy/rlib/jit.py
@@ -370,6 +370,24 @@
                             raise
     set_user_param._annspecialcase_ = 'specialize:arg(0)'
 
+    
+    def on_compile(self, logger, looptoken, operations, type, *greenargs):
+        """ A hook called when loop is compiled. Overwrite
+        for your own jitdriver if you want to do something special, like
+        call applevel code
+        """
+
+    def on_compile_bridge(self, logger, orig_looptoken, operations, n):
+        """ A hook called when a bridge is compiled. Overwrite
+        for your own jitdriver if you want to do something special
+        """
+
+    # note: if you overwrite this functions with the above signature it'll
+    #       work, but the *greenargs is different for each jitdriver, so we
+    #       can't share the same methods
+    del on_compile
+    del on_compile_bridge
+
     def _make_extregistryentries(self):
         # workaround: we cannot declare ExtRegistryEntries for functions
         # used as methods of a frozen object, but we can attach the
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to