Author: Armin Rigo <[email protected]>
Branch: stm-thread-2
Changeset: r59967:20fe85f53d29
Date: 2013-01-11 20:06 +0100
http://bitbucket.org/pypy/pypy/changeset/20fe85f53d29/

Log:    Finish it, maybe. Hard to test! Tested so far by staring at the
        graphs :-(

diff --git a/pypy/rlib/rstm.py b/pypy/rlib/rstm.py
--- a/pypy/rlib/rstm.py
+++ b/pypy/rlib/rstm.py
@@ -82,7 +82,6 @@
         llcontainer = rffi.cast(CONTAINERP, llcontainer)
         try:
             res = func(llcontainer, retry_counter)
-            llcontainer.got_exception = lltype.nullptr(rclass.OBJECT)
         except Exception, e:
             res = 0     # stop perform_transaction() and returns
             lle = cast_instance_to_base_ptr(e)
diff --git a/pypy/translator/stm/jitdriver.py b/pypy/translator/stm/jitdriver.py
--- a/pypy/translator/stm/jitdriver.py
+++ b/pypy/translator/stm/jitdriver.py
@@ -8,6 +8,7 @@
 from pypy.rpython.annlowlevel import (MixLevelHelperAnnotator,
                                       cast_base_ptr_to_instance)
 from pypy.rlib import rstm
+from pypy.tool.sourcetools import compile2
 
 
 def find_jit_merge_point(graph):
@@ -83,17 +84,18 @@
         assert op_jitmarker.opname == 'jit_marker'
         assert op_jitmarker.args[0].value == 'jit_merge_point'
         jitdriver = op_jitmarker.args[1].value
-
-        assert not jitdriver.greens and not jitdriver.reds   # XXX
         assert not jitdriver.autoreds    # XXX
 
     def split_after_jit_merge_point(self, (portalblock, portalopindex)):
-        split_block(None, portalblock, portalopindex + 1)
+        link = split_block(None, portalblock, portalopindex + 1)
+        self.TYPES = [v.concretetype for v in link.args]
 
     def make_container_type(self):
+        args = [('a%d' % i, self.TYPES[i]) for i in range(len(self.TYPES))]
         self.CONTAINER = lltype.GcStruct('StmArgs',
                                          ('result_value', self.RESTYPE),
-                                         ('got_exception', rclass.OBJECTPTR))
+                                         ('got_exception', rclass.OBJECTPTR),
+                                         *args)
         self.CONTAINERP = lltype.Ptr(self.CONTAINER)
 
     def add_call_should_break_transaction(self, block):
@@ -130,7 +132,8 @@
         #
         # fill in blockf with a call to invoke_stm()
         v = varoftype(self.RESTYPE)
-        op = SpaceOperation('direct_call', [self.c_invoke_stm_func], v)
+        op = SpaceOperation('direct_call',
+                            [self.c_invoke_stm_func] + blockf.inputargs, v)
         blockf.operations.append(op)
         blockf.closeblock(Link([v], main_graph.returnblock))
         #
@@ -141,16 +144,28 @@
         callback = self.callback_function
         perform_transaction = rstm.make_perform_transaction(callback,
                                                             self.CONTAINERP)
-        #
-        def ll_invoke_stm():
+        irange = range(len(self.TYPES))
+        source = """if 1:
+        def ll_invoke_stm(%s):
             p = lltype.malloc(CONTAINER)
+            %s
             perform_transaction(p)
             if p.got_exception:
                 raise cast_base_ptr_to_instance(Exception, p.got_exception)
             return p.result_value
+"""     % (', '.join(['a%d' % i for i in irange]),
+           '\n            '.join(['p.a%d = a%d' % (i, i) for i in irange]))
+        d = {'CONTAINER': CONTAINER,
+             'lltype': lltype,
+             'perform_transaction': perform_transaction,
+             'cast_base_ptr_to_instance': cast_base_ptr_to_instance,
+             }
+        exec compile2(source) in d
+        ll_invoke_stm = d['ll_invoke_stm']
         #
         mix = self.mixlevelannotator
-        c_func = mix.constfunc(ll_invoke_stm, [],
+        c_func = mix.constfunc(ll_invoke_stm,
+                               map(lltype_to_annotation, self.TYPES),
                                lltype_to_annotation(self.RESTYPE))
         self.c_invoke_stm_func = c_func
 
@@ -159,6 +174,7 @@
         callback_graph = copygraph(self.main_graph)
         callback_graph.name += '_stm'
         self.callback_graph = callback_graph
+        self.stmtransformer.translator.graphs.append(callback_graph)
         #for v1, v2 in zip(
         #    self.main_graph.getargs() + [self.main_graph.getreturnvar()],
         #    callback_graph.getargs() + [callback_graph.getreturnvar()]):
@@ -179,17 +195,31 @@
         del block1.operations[i]
         [link] = block1.exits
         callback_graph.startblock = blockst
-        blockst.closeblock(Link([], link.target))
+        #
+        # fill in the operations of blockst: getfields reading all live vars
+        a_vars = []
+        for i in range(len(self.TYPES)):
+            c_a_i = Constant('a%d' % i, lltype.Void)
+            v_a_i = varoftype(self.TYPES[i])
+            blockst.operations.append(
+                SpaceOperation('getfield', [v_p, c_a_i], v_a_i))
+            a_vars.append(v_a_i)
+        blockst.closeblock(Link(a_vars, link.target))
         #
         # hack at the regular return block, to set the result into
         # 'p.result_value', clear 'p.got_exception', and return 0
         blockr = callback_graph.returnblock
         c_result_value = Constant('result_value', lltype.Void)
+        c_got_exception = Constant('got_exception', lltype.Void)
+        c_null = Constant(lltype.nullptr(self.CONTAINER.got_exception.TO),
+                          self.CONTAINER.got_exception)
         blockr.operations = [
             SpaceOperation('setfield',
                            [v_p, c_result_value, blockr.inputargs[0]],
                            varoftype(lltype.Void)),
-            #...
+            SpaceOperation('setfield',
+                           [v_p, c_got_exception, c_null],
+                           varoftype(lltype.Void)),
             ]
         v = varoftype(lltype.Signed)
         annotator.setbinding(v, s_Int)
@@ -201,7 +231,14 @@
         #
         # add 'should_break_transaction()' at the end of the loop
         blockf = self.add_call_should_break_transaction(block1)
-        # ...store stuff...
+        # store the variables again into v_p
+        for i in range(len(self.TYPES)):
+            c_a_i = Constant('a%d' % i, lltype.Void)
+            v_a_i = blockf.inputargs[i]
+            assert v_a_i.concretetype == self.TYPES[i]
+            blockf.operations.append(
+                SpaceOperation('setfield', [v_p, c_a_i, v_a_i],
+                               varoftype(lltype.Void)))
         blockf.closeblock(Link([Constant(1, lltype.Signed)], newblockr))
         #
         SSA_to_SSI(callback_graph)   # to pass 'p' everywhere
diff --git a/pypy/translator/stm/test/test_jitdriver.py 
b/pypy/translator/stm/test/test_jitdriver.py
--- a/pypy/translator/stm/test/test_jitdriver.py
+++ b/pypy/translator/stm/test/test_jitdriver.py
@@ -20,3 +20,18 @@
 
         res = self.interpret(f1, [])
         assert res == 'X'
+
+    def test_loop_args(self):
+        class X:
+            counter = 100
+        x = X()
+        myjitdriver = JitDriver(greens=['a'], reds=['b', 'c'])
+
+        def f1(a, b, c):
+            while x.counter > 0:
+                myjitdriver.jit_merge_point(a=a, b=b, c=c)
+                x.counter -= (ord(a) + rffi.cast(lltype.Signed, b) + c)
+            return 'X'
+
+        res = self.interpret(f1, ['\x03', rffi.cast(rffi.SHORT, 4), 2])
+        assert res == 'X'
diff --git a/pypy/translator/stm/test/transform2_support.py 
b/pypy/translator/stm/test/transform2_support.py
--- a/pypy/translator/stm/test/transform2_support.py
+++ b/pypy/translator/stm/test/transform2_support.py
@@ -56,6 +56,9 @@
         if option.view:
             self.translator.view()
         #
+        if self.do_jit_driver:
+            import py
+            py.test.skip("XXX how to test?")
         result = interp.eval_graph(self.graph, args)
         return result
 
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to