Author: Armin Rigo <[email protected]>
Branch: conditional_call_value
Changeset: r79454:f010addba075
Date: 2015-09-05 12:33 +0200
http://bitbucket.org/pypy/pypy/changeset/f010addba075/
Log: in-progress: tweak the user API
diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py
--- a/rpython/rlib/jit.py
+++ b/rpython/rlib/jit.py
@@ -1106,67 +1106,64 @@
return hop.genop('jit_record_known_class', [v_inst, v_cls],
resulttype=lltype.Void)
-def _jit_conditional_call(condition, function, *args):
- pass
[email protected]_location()
-def conditional_call(condition, function, *args):
+def conditional_call(condition, function, *args, **kwds):
+ default = kwds.pop('default', None)
+ assert not kwds
+ if condition:
+ return function(*args)
+ return default
+
+def _ll_cond_call(condition, ll_default, ll_function, *ll_args):
if we_are_jitted():
- _jit_conditional_call(condition, function, *args)
+ from rpython.rtyper.lltypesystem import lltype
+ from rpython.rtyper.lltypesystem.lloperation import llop
+ RESTYPE = lltype.typeOf(ll_default)
+ return llop.jit_conditional_call(RESTYPE, condition, ll_default,
+ ll_function, *ll_args)
else:
if condition:
- return function(*args)
-conditional_call._always_inline_ = True
+ return ll_function(*ll_args)
+ return ll_default
+_ll_cond_call._always_inline_ = True
class ConditionalCallEntry(ExtRegistryEntry):
- _about_ = _jit_conditional_call
+ _about_ = conditional_call
- def compute_result_annotation(self, *args_s):
- self.bookkeeper.emulate_pbc_call(self.bookkeeper.position_key,
- args_s[1], args_s[2:])
+ def compute_result_annotation(self, *args_s, **kwds_s):
+ from rpython.annotator import model as annmodel
- def specialize_call(self, hop):
+ s_res = self.bookkeeper.emulate_pbc_call(self.bookkeeper.position_key,
+ args_s[1], args_s[2:])
+ if 's_default' in kwds_s:
+ assert kwds_s.keys() == ['s_default']
+ return annmodel.unionof(s_res, kwds_s['s_default'])
+ else:
+ assert not kwds_s
+ return None
+
+ def specialize_call(self, hop, i_default=None):
from rpython.rtyper.lltypesystem import lltype
- args_v = hop.inputargs(lltype.Bool, lltype.Void, *hop.args_r[2:])
+ end = len(hop.args_r) - (i_default is not None)
+ inputargs = [lltype.Bool, lltype.Void] + hop.args_r[2:end]
+ if i_default is not None:
+ assert i_default == end
+ inputargs.append(hop.r_result)
+
+ args_v = hop.inputargs(*inputargs)
args_v[1] = hop.args_r[1].get_concrete_llfn(hop.args_s[1],
- hop.args_s[2:],
hop.spaceop)
+ hop.args_s[2:end],
+ hop.spaceop)
+ if i_default is not None:
+ v_default = args_v.pop()
+ else:
+ v_default = hop.inputconst(lltype.Void, None)
+ args_v.insert(1, v_default)
+
hop.exception_is_here()
- return hop.genop('jit_conditional_call', args_v)
+ return hop.gendirectcall(_ll_cond_call, *args_v)
-def _jit_conditional_call_value(condition, function, default_value, *args):
- return default_value
-
[email protected]_location()
-def conditional_call_value(condition, function, default_value, *args):
- if we_are_jitted():
- return _jit_conditional_call_value(condition, function, default_value,
- *args)
- else:
- if condition:
- return function(*args)
- return default_value
-conditional_call._always_inline_ = True
-
-class ConditionalCallValueEntry(ExtRegistryEntry):
- _about_ = _jit_conditional_call_value
-
- def compute_result_annotation(self, *args_s):
- s_result = self.bookkeeper.emulate_pbc_call(
- self.bookkeeper.position_key, args_s[1], args_s[3:],
- callback = self.bookkeeper.position_key)
- return s_result
-
- def specialize_call(self, hop):
- from rpython.rtyper.lltypesystem import lltype
-
- args_v = hop.inputargs(lltype.Bool, lltype.Void, *hop.args_r[2:])
- args_v[1] = hop.args_r[1].get_concrete_llfn(hop.args_s[1],
- hop.args_s[3:],
hop.spaceop)
- hop.exception_is_here()
- resulttype = hop.r_result
- return hop.genop('jit_conditional_call_value', args_v,
- resulttype=resulttype)
class Counters(object):
counters="""
diff --git a/rpython/rlib/test/test_jit.py b/rpython/rlib/test/test_jit.py
--- a/rpython/rlib/test/test_jit.py
+++ b/rpython/rlib/test/test_jit.py
@@ -300,3 +300,38 @@
mix = MixLevelHelperAnnotator(t.rtyper)
mix.getgraph(later, [annmodel.s_Bool], annmodel.s_None)
mix.finish()
+
+ def test_conditional_call_value(self):
+ def g(m):
+ return m + 42
+ def f(n, m):
+ return conditional_call(n >= 0, g, m, default=678)
+
+ res = self.interpret(f, [10, 20])
+ assert res == 20 + 42
+ res = self.interpret(f, [-10, 20])
+ assert res == 678
+
+ def test_conditional_call_void(self):
+ class X:
+ pass
+ glob = X()
+ #
+ def g(m):
+ glob.x += m
+ #
+ def h():
+ glob.x += 2
+ #
+ def f(n, m):
+ glob.x = 0
+ conditional_call(n >= 0, g, m)
+ conditional_call(n >= 5, h)
+ return glob.x
+
+ res = self.interpret(f, [10, 20])
+ assert res == 22
+ res = self.interpret(f, [2, 20])
+ assert res == 20
+ res = self.interpret(f, [-2, 20])
+ assert res == 0
diff --git a/rpython/rtyper/llinterp.py b/rpython/rtyper/llinterp.py
--- a/rpython/rtyper/llinterp.py
+++ b/rpython/rtyper/llinterp.py
@@ -548,9 +548,6 @@
def op_jit_conditional_call(self, *args):
raise NotImplementedError("should not be called while not jitted")
- def op_jit_conditional_call_value(self, *args):
- raise NotImplementedError("should not be called while not jitted")
-
def op_get_exception_addr(self, *args):
raise NotImplementedError
diff --git a/rpython/rtyper/lltypesystem/lloperation.py
b/rpython/rtyper/lltypesystem/lloperation.py
--- a/rpython/rtyper/lltypesystem/lloperation.py
+++ b/rpython/rtyper/lltypesystem/lloperation.py
@@ -451,8 +451,7 @@
'jit_force_quasi_immutable': LLOp(canrun=True),
'jit_record_known_class' : LLOp(canrun=True),
'jit_ffi_save_result': LLOp(canrun=True),
- 'jit_conditional_call': LLOp(),
- 'jit_conditional_call_value': LLOp(),
+ 'jit_conditional_call': LLOp(),
'get_exception_addr': LLOp(),
'get_exc_value_addr': LLOp(),
'do_malloc_fixedsize':LLOp(canmallocgc=True),
diff --git a/rpython/rtyper/lltypesystem/rstr.py
b/rpython/rtyper/lltypesystem/rstr.py
--- a/rpython/rtyper/lltypesystem/rstr.py
+++ b/rpython/rtyper/lltypesystem/rstr.py
@@ -374,7 +374,8 @@
if not s:
return 0
x = s.hash
- return jit.conditional_call_value(x == 0, LLHelpers._ll_strhash, x, s)
+ return jit.conditional_call(x == 0, LLHelpers._ll_strhash, s,
+ default=x)
@staticmethod
def ll_length(s):
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit