Author: Armin Rigo <ar...@tunes.org>
Branch: conditional_call_value_2
Changeset: r86998:1e11cddce2d6
Date: 2016-09-11 15:39 +0200
http://bitbucket.org/pypy/pypy/changeset/1e11cddce2d6/

Log:    more in-progress

diff --git a/rpython/jit/backend/x86/test/test_call.py 
b/rpython/jit/backend/x86/test/test_call.py
--- a/rpython/jit/backend/x86/test/test_call.py
+++ b/rpython/jit/backend/x86/test/test_call.py
@@ -1,7 +1,7 @@
 from rpython.jit.backend.x86.test.test_basic import Jit386Mixin
 from rpython.jit.metainterp.test import test_call
 
-class TestCall(Jit386Mixin, test_call.TestCall):
+class TestCall(Jit386Mixin, test_call.CallTest):
     # for the individual tests see
     # ====> ../../../metainterp/test/test_call.py
     pass
diff --git a/rpython/jit/metainterp/executor.py 
b/rpython/jit/metainterp/executor.py
--- a/rpython/jit/metainterp/executor.py
+++ b/rpython/jit/metainterp/executor.py
@@ -96,21 +96,21 @@
 do_call_may_force_f = do_call_f
 do_call_may_force_n = do_call_n
 
-def do_cond_call_i(cpu, metainterp, argboxes, descr):
+def do_cond_call_pure_i(cpu, metainterp, argboxes, descr):
     cond = argboxes[0].getint()
     specialval = argboxes[1].getint()
     if cond == specialval:
         return do_call_i(cpu, metainterp, argboxes[2:], descr)
     return cond
 
-def do_cond_call_r(cpu, metainterp, argboxes, descr):
+def do_cond_call_pure_r(cpu, metainterp, argboxes, descr):
     cond = argboxes[0].getref_base()
     specialval = argboxes[1].getref_base()
     if cond == specialval:
         return do_call_r(cpu, metainterp, argboxes[2:], descr)
     return cond
 
-def do_cond_call_n(cpu, metainterp, argboxes, descr):
+def do_cond_call(cpu, metainterp, argboxes, descr):
     cond = argboxes[0].getint()
     specialval = argboxes[1].getint()
     assert specialval == 1       # cond_call_n is only used with that case
diff --git a/rpython/jit/metainterp/optimizeopt/rewrite.py 
b/rpython/jit/metainterp/optimizeopt/rewrite.py
--- a/rpython/jit/metainterp/optimizeopt/rewrite.py
+++ b/rpython/jit/metainterp/optimizeopt/rewrite.py
@@ -516,28 +516,41 @@
     optimize_CALL_LOOPINVARIANT_F = optimize_CALL_LOOPINVARIANT_I
     optimize_CALL_LOOPINVARIANT_N = optimize_CALL_LOOPINVARIANT_I
 
-    def optimize_COND_CALL_N(self, op):
+    def optimize_COND_CALL(self, op):
         arg0 = self.get_box_replacement(op.getarg(0))
         arg1 = self.get_box_replacement(op.getarg(1))
-        if arg0.type == 'i':
+        equal = -1    # unknown
+        if isinstance(arg0, Const):
+            equal = arg0.same_constant(arg1)
+        elif arg0.type == 'i':
             b1 = self.getintbound(arg0)
             b2 = self.getintbound(arg1)
-            drop = b1.known_gt(b2) or b1.known_lt(b2)
-        elif arg0.type == 'r' and arg1.same_constant(CONST_NULL):
-            drop = self.getnullness(arg0) == INFO_NONNULL
-        else:
-            drop = False
-        if drop:
+            if b1.known_gt(b2) or b1.known_lt(b2):
+                equal = 0    # different
+        elif arg0.type == 'r':
+            if arg1.same_constant(CONST_NULL):
+                if self.getnullness(arg0) == INFO_NONNULL:
+                    equal = 0    # different
+            else:
+                info0 = self.getptrinfo(arg0)
+                if info0 and info0.is_virtual():
+                    equal = 0    # a virtual can't be equal to a constant
+        #
+        if equal == 1:
+            if op.type == 'v':
+                opnum = rop.CALL_N
+            else:
+                opnum = OpHelpers.call_pure_for_descr(op.getdescr())
+            op = self.replace_op_with(op, opnum, args=op.getarglist()[2:])
+            self.send_extra_operation(op)
+        elif equal == 0:
             if op.type != 'v':
                 self.make_equal_to(op, arg0)
             self.last_emitted_operation = REMOVED
-            return
-        if arg0.same_box(arg1):
-            opnum = OpHelpers.call_for_type(op.type)
-            op = self.replace_op_with(op, opnum, args=op.getarglist()[2:])
-        self.emit_operation(op)
-    optimize_COND_CALL_I = optimize_COND_CALL_N
-    optimize_COND_CALL_R = optimize_COND_CALL_N
+        else:
+            self.emit_operation(op)
+    optimize_COND_CALL_PURE_I = optimize_COND_CALL
+    optimize_COND_CALL_PURE_R = optimize_COND_CALL
 
     def _optimize_nullness(self, op, box, expect_nonnull):
         info = self.getnullness(box)
diff --git a/rpython/jit/metainterp/optimizeopt/test/test_optimizeopt.py 
b/rpython/jit/metainterp/optimizeopt/test/test_optimizeopt.py
--- a/rpython/jit/metainterp/optimizeopt/test/test_optimizeopt.py
+++ b/rpython/jit/metainterp/optimizeopt/test/test_optimizeopt.py
@@ -7592,7 +7592,7 @@
         ops = """
         [i0]
         p1 = new_with_vtable(descr=nodesize)
-        cond_call_n(i0, 1, 123, p1, descr=clear_vable)
+        cond_call(i0, 1, 123, p1, descr=clear_vable)
         jump(i0)
         """
         expected = """
@@ -8652,7 +8652,7 @@
     def test_cond_call_with_a_constant(self):
         ops = """
         [p1]
-        cond_call_n(1, 1, 123, p1, descr=plaincalldescr)
+        cond_call(1, 1, 123, p1, descr=plaincalldescr)
         jump(p1)
         """
         expected = """
@@ -8665,7 +8665,7 @@
     def test_cond_call_with_a_constant_2(self):
         ops = """
         [p1]
-        cond_call_n(0, 1, 123, p1, descr=plaincalldescr)
+        cond_call(0, 1, 123, p1, descr=plaincalldescr)
         jump(p1)
         """
         expected = """
@@ -8677,7 +8677,7 @@
     def test_cond_call_with_a_constant_i(self):
         ops = """
         [p1]
-        i2 = cond_call_i(12, 12, 123, p1, descr=plaincalldescr)
+        i2 = cond_call_pure_i(12, 12, 123, p1, descr=plaincalldescr)
         escape_n(i2)
         jump(p1)
         """
@@ -8692,7 +8692,7 @@
     def test_cond_call_with_a_constant_i2(self):
         ops = """
         [p1]
-        i2 = cond_call_i(12, 45, 123, p1, descr=plaincalldescr)
+        i2 = cond_call_pure_i(12, 45, 123, p1, descr=plaincalldescr)
         escape_n(i2)
         jump(p1)
         """
@@ -8708,7 +8708,7 @@
         [p1, i1]
         i0 = int_gt(i1, 100)
         guard_true(i0) []
-        i2 = cond_call_i(i1, 45, 123, p1, descr=plaincalldescr)
+        i2 = cond_call_pure_i(i1, 45, 123, p1, descr=plaincalldescr)
         i3 = escape_i(i2)
         jump(p1, i3)
         """
@@ -8724,7 +8724,7 @@
     def test_cond_call_r1(self):
         ops = """
         [p1]
-        p2 = cond_call_r(p1, NULL, 123, 45, descr=plaincalldescr)
+        p2 = cond_call_pure_r(p1, NULL, 123, p1, descr=plain_r_calldescr)
         jump(p2)
         """
         self.optimize_loop(ops, ops)
@@ -8733,7 +8733,7 @@
         ops = """
         [p1]
         guard_nonnull(p1) []
-        p2 = cond_call_r(p1, NULL, 123, 45, descr=plaincalldescr)
+        p2 = cond_call_pure_r(p1, NULL, 123, p1, descr=plain_r_calldescr)
         p3 = escape_r(p2)
         jump(p3)
         """
@@ -8745,6 +8745,24 @@
         """
         self.optimize_loop(ops, expected)
 
+    def test_cond_call_r3(self):
+        ops = """
+        [p0]
+        p4 = escape_r(4)
+        p1 = same_as_r(ConstPtr(myptr))
+        p2 = cond_call_pure_r(p1, ConstPtr(myptr), 123, p4, 
descr=plain_r_calldescr)
+        p3 = escape_r(p2)
+        jump(p3)
+        """
+        expected = """
+        [p0]
+        p4 = escape_r(4)
+        p2 = call_r(123, p4, descr=plain_r_calldescr)
+        p3 = escape_r(p2)
+        jump(p3)
+        """
+        self.optimize_loop(ops, expected)
+
     def test_hippyvm_unroll_bug(self):
         ops = """
         [p0, i1, i2]
diff --git a/rpython/jit/metainterp/optimizeopt/test/test_util.py 
b/rpython/jit/metainterp/optimizeopt/test/test_util.py
--- a/rpython/jit/metainterp/optimizeopt/test/test_util.py
+++ b/rpython/jit/metainterp/optimizeopt/test/test_util.py
@@ -436,6 +436,10 @@
                     oopspecindex=EffectInfo.OS_INT_PY_MOD)
     int_py_mod_descr = cpu.calldescrof(FUNC, FUNC.ARGS, FUNC.RESULT, ei)
 
+    FUNC = lltype.FuncType([], llmemory.GCREF)
+    ei = EffectInfo([], [], [], [], [], [], EffectInfo.EF_ELIDABLE_CAN_RAISE)
+    plain_r_calldescr = cpu.calldescrof(FUNC, FUNC.ARGS, FUNC.RESULT, ei)
+
     namespace = locals()
 
 
diff --git a/rpython/jit/metainterp/optimizeopt/virtualize.py 
b/rpython/jit/metainterp/optimizeopt/virtualize.py
--- a/rpython/jit/metainterp/optimizeopt/virtualize.py
+++ b/rpython/jit/metainterp/optimizeopt/virtualize.py
@@ -94,7 +94,7 @@
     optimize_CALL_MAY_FORCE_F = optimize_CALL_MAY_FORCE_I
     optimize_CALL_MAY_FORCE_N = optimize_CALL_MAY_FORCE_I
 
-    def optimize_COND_CALL_N(self, op):
+    def optimize_COND_CALL(self, op):
         effectinfo = op.getdescr().get_extra_info()
         oopspecindex = effectinfo.oopspecindex
         if oopspecindex == EffectInfo.OS_JIT_FORCE_VIRTUALIZABLE:
diff --git a/rpython/jit/metainterp/pyjitpl.py 
b/rpython/jit/metainterp/pyjitpl.py
--- a/rpython/jit/metainterp/pyjitpl.py
+++ b/rpython/jit/metainterp/pyjitpl.py
@@ -1062,14 +1062,14 @@
                                      funcbox, argboxes, calldescr, pc):
         return self.do_conditional_call(condbox, specialvalbox,
                                         funcbox, argboxes, calldescr, pc,
-                                        rop.COND_CALL_I)
+                                        rop.COND_CALL_PURE_I)
 
     @arguments("box", "box", "box", "boxes2", "descr", "orgpc")
     def opimpl_conditional_call_ir_r(self, condbox, specialvalbox,
                                      funcbox, argboxes, calldescr, pc):
         return self.do_conditional_call(condbox, specialvalbox,
                                         funcbox, argboxes, calldescr, pc,
-                                        rop.COND_CALL_R)
+                                        rop.COND_CALL_PURE_R)
 
     @arguments("box", "box", "box", "boxes2", "descr", "orgpc")
     def opimpl_conditional_call_ir_v(self, condbox, specialvalbox,
@@ -1730,7 +1730,7 @@
                 assert False
 
     def do_conditional_call(self, condbox, specialvalbox,
-                            funcbox, argboxes, descr, pc, rop_num):
+                            funcbox, argboxes, descr, pc, opnum):
         if (isinstance(condbox, Const) and
                 not condbox.same_constant(specialvalbox)):
             return condbox  # so that the heapcache can keep argboxes virtual
@@ -1739,7 +1739,8 @@
         assert not effectinfo.check_forces_virtual_or_virtualizable()
         exc = effectinfo.check_can_raise()
         pure = effectinfo.check_is_elidable()
-        return self.execute_varargs(rop_num,
+        assert pure == (opnum != rop.COND_CALL_N)
+        return self.execute_varargs(opnum,
                                     [condbox, specialvalbox] + allboxes,
                                     descr, exc, pure)
 
@@ -3081,6 +3082,10 @@
         """ Patch a CALL into a CALL_PURE.
         """
         resbox_as_const = executor.constant_from_op(op)
+        is_cond = (op.opnum == rop.COND_CALL_PURE_I or
+                   op.opnum == rop.COND_CALL_PURE_R)
+        if is_cond:
+            argboxes = argboxes[2:]
         for argbox in argboxes:
             if not isinstance(argbox, Const):
                 break
@@ -3093,6 +3098,8 @@
         # be either removed later by optimizeopt or turned back into CALL.
         arg_consts = [executor.constant_from_op(a) for a in argboxes]
         self.call_pure_results[arg_consts] = resbox_as_const
+        if is_cond:
+            return op    # there is no COND_CALL_I/R
         opnum = OpHelpers.call_pure_for_descr(descr)
         self.history.cut(patch_pos)
         newop = self.history.record_nospec(opnum, argboxes, descr)
diff --git a/rpython/jit/metainterp/resoperation.py 
b/rpython/jit/metainterp/resoperation.py
--- a/rpython/jit/metainterp/resoperation.py
+++ b/rpython/jit/metainterp/resoperation.py
@@ -1149,7 +1149,7 @@
     '_CANRAISE_FIRST', # ----- start of can_raise operations -----
     '_CALL_FIRST',
     'CALL/*d/rfin',
-    'COND_CALL/*d/rin',
+    'COND_CALL/*d/n',
     # a conditional call, with first argument as a condition
     'CALL_ASSEMBLER/*d/rfin',  # call already compiled assembler
     'CALL_MAY_FORCE/*d/rfin',
@@ -1157,6 +1157,7 @@
     'CALL_RELEASE_GIL/*d/fin',
     # release the GIL and "close the stack" for asmgcc
     'CALL_PURE/*d/rfin',             # removed before it's passed to the 
backend
+    'COND_CALL_PURE/*d/ri',
     'CHECK_MEMORY_ERROR/1/n',   # after a CALL: NULL => propagate MemoryError
     'CALL_MALLOC_NURSERY/1/r',  # nursery malloc, const number of bytes, zeroed
     'CALL_MALLOC_NURSERY_VARSIZE/3d/r',
diff --git a/rpython/jit/metainterp/test/test_call.py 
b/rpython/jit/metainterp/test/test_call.py
--- a/rpython/jit/metainterp/test/test_call.py
+++ b/rpython/jit/metainterp/test/test_call.py
@@ -2,7 +2,7 @@
 from rpython.jit.metainterp.test.support import LLJitMixin
 from rpython.rlib import jit
 
-class TestCall(LLJitMixin):
+class CallTest(object):
     def test_indirect_call(self):
         @jit.dont_look_inside
         def f1(x):
@@ -54,19 +54,21 @@
         self.check_resops(guard_no_exception=0)
 
     def test_cond_call_i(self):
+        @jit.elidable   # not really, for tests
         def f(l, n):
             l.append(n)
             return 1000
 
         def main(n):
             l = []
-            x = jit.conditional_call_value(n, 10, f, l, n)
+            x = jit.conditional_call_elidable(n, 10, f, l, n)
             return x + len(l)
 
         assert self.interp_operations(main, [10]) == 1001
         assert self.interp_operations(main, [5]) == 5
 
     def test_cond_call_r(self):
+        @jit.elidable
         def f(n):
             return [n]
 
@@ -75,8 +77,24 @@
                 l = []
             else:
                 l = None
-            l = jit.conditional_call_value(l, None, f, n)
+            l = jit.conditional_call_elidable(l, None, f, n)
             return len(l)
 
         assert self.interp_operations(main, [10]) == 0
         assert self.interp_operations(main, [5]) == 1
+
+    def test_cond_call_constant_in_pyjitpl(self):
+        @jit.elidable
+        def f(a, b):
+            return a + b
+        def main(n):
+            # this is completely constant-folded because the arguments
+            # to f() are constants.
+            return jit.conditional_call_elidable(n, 23, f, 40, 2)
+
+        assert main(12) == 12    # because 12 != 23
+        assert self.interp_operations(main, [12]) == 42   # == f(40, 2)
+
+
+class TestCall(LLJitMixin, CallTest):
+    pass
diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py
--- a/rpython/rlib/jit.py
+++ b/rpython/rlib/jit.py
@@ -1181,11 +1181,20 @@
 
 def _jit_conditional_call(value, ignored, function, *args):
     """NOT_RPYTHON"""
-def _jit_conditional_call_value(value, special_constant, function, *args):
+def _jit_conditional_call_elidable(value, special_constant, function, *args):
     """NOT_RPYTHON"""
 
 @specialize.call_location()
 def conditional_call(condition, function, *args):
+    """Does the same as:
+        
+         if condition:
+             function(*args)
+     
+    but is better for the JIT, in case the condition is often false
+    but could be true occasionally.  It allows the JIT to always produce
+    bridge-free code.
+    """
     if we_are_jitted():
         _jit_conditional_call(condition, True, function, *args)
     else:
@@ -1194,11 +1203,25 @@
 conditional_call._always_inline_ = True
 
 @specialize.call_location()
-def conditional_call_value(value, special_constant, function, *args):
+def conditional_call_elidable(value, special_constant, function, *args):
+    """Does the same as:
+
+        return ONE OF function(*args) OR (value if value != special_constant)
+
+    whichever is better for the JIT.  Usually it first checks if 'value'
+    is equal to 'special_constant', and only if it is, it calls
+    'function(*args)'.  The 'function' must be marked as @elidable.  An
+    example of an "unusual" case is if, say, all arguments are constant.
+    In this case the JIT knows the result of the call in advance, and
+    so it always uses the 'function(*args)' path without comparing
+    'value' and 'special_constant' at all.
+    """
     if we_are_jitted():
-        return _jit_conditional_call_value(value, special_constant,
-                                           function, *args)
+        return _jit_conditional_call_elidable(value, special_constant,
+                                              function, *args)
     else:
+        if not we_are_translated():
+            assert function._elidable_function_ # must call an elidable 
function
         if not we_are_translated() or isinstance(value, int):
             if value == special_constant:
                 value = function(*args)
@@ -1206,16 +1229,20 @@
             if value is special_constant:
                 value = function(*args)
         return value
-conditional_call_value._always_inline_ = True
+conditional_call_elidable._always_inline_ = True
 
 class ConditionalCallEntry(ExtRegistryEntry):
-    _about_ = _jit_conditional_call_value, _jit_conditional_call
+    _about_ = _jit_conditional_call_elidable, _jit_conditional_call
 
     def compute_result_annotation(self, *args_s):
         from rpython.annotator import model as annmodel
         self.bookkeeper.emulate_pbc_call(self.bookkeeper.position_key,
                                          args_s[2], args_s[3:])
-        if self.instance is _jit_conditional_call_value:
+        if self.instance is _jit_conditional_call_elidable:
+            function = args_s[2].const
+            assert getattr(function, '_elidable_function_', False), (
+                "jit.conditional_call_elidable() must call an elidable "
+                "function, but got %r" % (function,))
             return args_s[0]
 
     def specialize_call(self, hop):
diff --git a/rpython/rlib/test/test_jit.py b/rpython/rlib/test/test_jit.py
--- a/rpython/rlib/test/test_jit.py
+++ b/rpython/rlib/test/test_jit.py
@@ -3,8 +3,8 @@
 from rpython.conftest import option
 from rpython.annotator.model import UnionError
 from rpython.rlib.jit import (hint, we_are_jitted, JitDriver, elidable_promote,
-    JitHintError, oopspec, isconstant, conditional_call, 
conditional_call_value,
-    elidable, unroll_safe, dont_look_inside,
+    JitHintError, oopspec, isconstant, conditional_call,
+    elidable, unroll_safe, dont_look_inside, conditional_call_elidable,
     enter_portal_frame, leave_portal_frame)
 from rpython.rlib.rarithmetic import r_uint
 from rpython.rtyper.test.tool import BaseRtypingTest
@@ -310,23 +310,27 @@
         t = Translation(g, [])
         t.compile_c() # does not crash
 
-    def test_conditional_call_value(self):
+    def test_conditional_call_elidable(self):
+        @elidable
         def g(m):
             return m + 42
         def f(n, m):
-            return conditional_call_value(n, -1, g, m)
+            return conditional_call_elidable(n, -1, g, m)
 
+        assert f(10, 200) == 10
+        assert f(-1, 200) == 242
         res = self.interpret(f, [10, 200])
         assert res == 10
         res = self.interpret(f, [-1, 200])
         assert res == 242
 
-    def test_compiled_conditional_call_value(self):
+    def test_compiled_conditional_call_elidable(self):
         from rpython.translator.c.test.test_genc import compile
+        @elidable
         def g(m):
             return m + 42
         def f(n, m):
-            return conditional_call_value(n, -1, g, m)
+            return conditional_call_elidable(n, -1, g, m)
         fn = compile(f, [int, int], backendopt=False)
         assert fn(10, 200) == 10
         assert fn(-1, 200) == 242
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to