Author: Antonio Cuni <[email protected]> Branch: gc-hooks Changeset: r94191:9edf064fc152 Date: 2018-03-30 18:31 +0200 http://bitbucket.org/pypy/pypy/changeset/9edf064fc152/
Log: hg merge default diff too long, truncating to 2000 out of 2274 lines diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst --- a/pypy/doc/whatsnew-head.rst +++ b/pypy/doc/whatsnew-head.rst @@ -68,3 +68,14 @@ Optimize `Py*_Check` for `Bool`, `Float`, `Set`. Also refactor and simplify `W_PyCWrapperObject` which is used to call slots from the C-API, greatly improving microbenchmarks in https://github.com/antocuni/cpyext-benchmarks + + +.. branch: fix-sre-problems + +Fix two (unrelated) JIT bugs manifesting in the re module: + +- green fields are broken and were thus disabled, plus their usage removed from + the _sre implementation + +- in rare "trace is too long" situations, the JIT could break behaviour + arbitrarily. diff --git a/pypy/module/_cffi_backend/ccallback.py b/pypy/module/_cffi_backend/ccallback.py --- a/pypy/module/_cffi_backend/ccallback.py +++ b/pypy/module/_cffi_backend/ccallback.py @@ -232,7 +232,9 @@ "different from the 'ffi.h' file seen at compile-time)") def py_invoke(self, ll_res, ll_args): + key_pycode = self.key_pycode jitdriver1.jit_merge_point(callback=self, + key_pycode=key_pycode, ll_res=ll_res, ll_args=ll_args) self.do_invoke(ll_res, ll_args) @@ -294,7 +296,7 @@ return 'cffi_callback ' + key_pycode.get_repr() jitdriver1 = jit.JitDriver(name='cffi_callback', - greens=['callback.key_pycode'], + greens=['key_pycode'], reds=['ll_res', 'll_args', 'callback'], get_printable_location=get_printable_location1) diff --git a/pypy/module/_io/test/test_interp_textio.py b/pypy/module/_io/test/test_interp_textio.py --- a/pypy/module/_io/test/test_interp_textio.py +++ b/pypy/module/_io/test/test_interp_textio.py @@ -7,6 +7,11 @@ from pypy.module._io.interp_bytesio import W_BytesIO from pypy.module._io.interp_textio import W_TextIOWrapper, DecodeBuffer +# workaround suggestion for slowness by David McIver: +# force hypothesis to initialize some lazy stuff +# (which takes a lot of time, which trips the timer otherwise) +st.text().example() + def translate_newlines(text): text = text.replace(u'\r\n', u'\n') text = text.replace(u'\r', u'\n') @@ -29,7 +34,7 @@ @given(data=st_readline(), mode=st.sampled_from(['\r', '\n', '\r\n', ''])) -@settings(deadline=None) +@settings(deadline=None, database=None) def test_readline(space, data, mode): txt, limits = data w_stream = W_BytesIO(space) diff --git a/pypy/module/_sre/interp_sre.py b/pypy/module/_sre/interp_sre.py --- a/pypy/module/_sre/interp_sre.py +++ b/pypy/module/_sre/interp_sre.py @@ -77,15 +77,15 @@ w_import = space.getattr(w_builtin, space.newtext("__import__")) return space.call_function(w_import, space.newtext("re")) -def matchcontext(space, ctx): +def matchcontext(space, ctx, pattern): try: - return rsre_core.match_context(ctx) + return rsre_core.match_context(ctx, pattern) except rsre_core.Error as e: raise OperationError(space.w_RuntimeError, space.newtext(e.msg)) -def searchcontext(space, ctx): +def searchcontext(space, ctx, pattern): try: - return rsre_core.search_context(ctx) + return rsre_core.search_context(ctx, pattern) except rsre_core.Error as e: raise OperationError(space.w_RuntimeError, space.newtext(e.msg)) @@ -114,7 +114,7 @@ pos = len(unicodestr) if endpos > len(unicodestr): endpos = len(unicodestr) - return rsre_core.UnicodeMatchContext(self.code, unicodestr, + return rsre_core.UnicodeMatchContext(unicodestr, pos, endpos, self.flags) elif space.isinstance_w(w_string, space.w_bytes): str = space.bytes_w(w_string) @@ -122,7 +122,7 @@ pos = len(str) if endpos > len(str): endpos = len(str) - return rsre_core.StrMatchContext(self.code, str, + return rsre_core.StrMatchContext(str, pos, endpos, self.flags) else: buf = space.readbuf_w(w_string) @@ -132,7 +132,7 @@ pos = size if endpos > size: endpos = size - return rsre_core.BufMatchContext(self.code, buf, + return rsre_core.BufMatchContext(buf, pos, endpos, self.flags) def getmatch(self, ctx, found): @@ -144,12 +144,12 @@ @unwrap_spec(pos=int, endpos=int) def match_w(self, w_string, pos=0, endpos=sys.maxint): ctx = self.make_ctx(w_string, pos, endpos) - return self.getmatch(ctx, matchcontext(self.space, ctx)) + return self.getmatch(ctx, matchcontext(self.space, ctx, self.code)) @unwrap_spec(pos=int, endpos=int) def search_w(self, w_string, pos=0, endpos=sys.maxint): ctx = self.make_ctx(w_string, pos, endpos) - return self.getmatch(ctx, searchcontext(self.space, ctx)) + return self.getmatch(ctx, searchcontext(self.space, ctx, self.code)) @unwrap_spec(pos=int, endpos=int) def findall_w(self, w_string, pos=0, endpos=sys.maxint): @@ -157,7 +157,7 @@ matchlist_w = [] ctx = self.make_ctx(w_string, pos, endpos) while ctx.match_start <= ctx.end: - if not searchcontext(space, ctx): + if not searchcontext(space, ctx, self.code): break num_groups = self.num_groups w_emptystr = space.newtext("") @@ -182,7 +182,7 @@ # this also works as the implementation of the undocumented # scanner() method. ctx = self.make_ctx(w_string, pos, endpos) - scanner = W_SRE_Scanner(self, ctx) + scanner = W_SRE_Scanner(self, ctx, self.code) return scanner @unwrap_spec(maxsplit=int) @@ -193,7 +193,7 @@ last = 0 ctx = self.make_ctx(w_string) while not maxsplit or n < maxsplit: - if not searchcontext(space, ctx): + if not searchcontext(space, ctx, self.code): break if ctx.match_start == ctx.match_end: # zero-width match if ctx.match_start == ctx.end: # or end of string @@ -274,8 +274,8 @@ else: sublist_w = [] n = last_pos = 0 + pattern = self.code while not count or n < count: - pattern = ctx.pattern sub_jitdriver.jit_merge_point( self=self, use_builder=use_builder, @@ -292,7 +292,7 @@ n=n, last_pos=last_pos, sublist_w=sublist_w ) space = self.space - if not searchcontext(space, ctx): + if not searchcontext(space, ctx, pattern): break if last_pos < ctx.match_start: _sub_append_slice( @@ -388,7 +388,11 @@ srepat.space = space srepat.w_pattern = w_pattern # the original uncompiled pattern srepat.flags = flags - srepat.code = code + # note: we assume that the app-level is caching SRE_Pattern objects, + # so that we don't need to do it here. Creating new SRE_Pattern + # objects all the time would be bad for the JIT, which relies on the + # identity of the CompiledPattern() object. + srepat.code = rsre_core.CompiledPattern(code) srepat.num_groups = groups srepat.w_groupindex = w_groupindex srepat.w_indexgroup = w_indexgroup @@ -611,10 +615,11 @@ # Our version is also directly iterable, to make finditer() easier. class W_SRE_Scanner(W_Root): - def __init__(self, pattern, ctx): + def __init__(self, pattern, ctx, code): self.space = pattern.space self.srepat = pattern self.ctx = ctx + self.code = code # 'self.ctx' is always a fresh context in which no searching # or matching succeeded so far. @@ -624,19 +629,19 @@ def next_w(self): if self.ctx.match_start > self.ctx.end: raise OperationError(self.space.w_StopIteration, self.space.w_None) - if not searchcontext(self.space, self.ctx): + if not searchcontext(self.space, self.ctx, self.code): raise OperationError(self.space.w_StopIteration, self.space.w_None) return self.getmatch(True) def match_w(self): if self.ctx.match_start > self.ctx.end: return self.space.w_None - return self.getmatch(matchcontext(self.space, self.ctx)) + return self.getmatch(matchcontext(self.space, self.ctx, self.code)) def search_w(self): if self.ctx.match_start > self.ctx.end: return self.space.w_None - return self.getmatch(searchcontext(self.space, self.ctx)) + return self.getmatch(searchcontext(self.space, self.ctx, self.code)) def getmatch(self, found): if found: diff --git a/pypy/module/pypyjit/hooks.py b/pypy/module/pypyjit/hooks.py --- a/pypy/module/pypyjit/hooks.py +++ b/pypy/module/pypyjit/hooks.py @@ -7,12 +7,20 @@ WrappedOp, W_JitLoopInfo, wrap_oplist) class PyPyJitIface(JitHookInterface): + def are_hooks_enabled(self): + space = self.space + cache = space.fromcache(Cache) + return (cache.w_compile_hook is not None or + cache.w_abort_hook is not None or + cache.w_trace_too_long_hook is not None) + + def on_abort(self, reason, jitdriver, greenkey, greenkey_repr, logops, operations): space = self.space cache = space.fromcache(Cache) if cache.in_recursion: return - if space.is_true(cache.w_abort_hook): + if cache.w_abort_hook is not None: cache.in_recursion = True oplist_w = wrap_oplist(space, logops, operations) try: @@ -33,7 +41,7 @@ cache = space.fromcache(Cache) if cache.in_recursion: return - if space.is_true(cache.w_trace_too_long_hook): + if cache.w_trace_too_long_hook is not None: cache.in_recursion = True try: try: @@ -62,7 +70,7 @@ cache = space.fromcache(Cache) if cache.in_recursion: return - if space.is_true(cache.w_compile_hook): + if cache.w_compile_hook is not None: w_debug_info = W_JitLoopInfo(space, debug_info, is_bridge, cache.compile_hook_with_ops) cache.in_recursion = True diff --git a/pypy/module/pypyjit/interp_resop.py b/pypy/module/pypyjit/interp_resop.py --- a/pypy/module/pypyjit/interp_resop.py +++ b/pypy/module/pypyjit/interp_resop.py @@ -21,9 +21,10 @@ no = 0 def __init__(self, space): - self.w_compile_hook = space.w_None - self.w_abort_hook = space.w_None - self.w_trace_too_long_hook = space.w_None + self.w_compile_hook = None + self.w_abort_hook = None + self.w_trace_too_long_hook = None + self.compile_hook_with_ops = False def getno(self): self.no += 1 @@ -58,7 +59,8 @@ jit hook won't be called for that. """ cache = space.fromcache(Cache) - assert w_hook is not None + if space.is_w(w_hook, space.w_None): + w_hook = None cache.w_compile_hook = w_hook cache.compile_hook_with_ops = operations cache.in_recursion = NonConstant(False) @@ -77,7 +79,8 @@ as attributes on JitLoopInfo object. """ cache = space.fromcache(Cache) - assert w_hook is not None + if space.is_w(w_hook, space.w_None): + w_hook = None cache.w_abort_hook = w_hook cache.in_recursion = NonConstant(False) @@ -92,14 +95,15 @@ hook(jitdriver_name, greenkey) """ cache = space.fromcache(Cache) - assert w_hook is not None + if space.is_w(w_hook, space.w_None): + w_hook = None cache.w_trace_too_long_hook = w_hook cache.in_recursion = NonConstant(False) def wrap_oplist(space, logops, operations, ops_offset=None): # this function is called from the JIT from rpython.jit.metainterp.resoperation import rop - + l_w = [] jitdrivers_sd = logops.metainterp_sd.jitdrivers_sd for op in operations: @@ -109,22 +113,27 @@ ofs = ops_offset.get(op, 0) num = op.getopnum() name = op.getopname() + repr = logops.repr_of_resop(op) if num == rop.DEBUG_MERGE_POINT: jd_sd = jitdrivers_sd[op.getarg(0).getint()] greenkey = op.getarglist()[3:] repr = jd_sd.warmstate.get_location_str(greenkey) w_greenkey = wrap_greenkey(space, jd_sd.jitdriver, greenkey, repr) l_w.append(DebugMergePoint(space, name, - logops.repr_of_resop(op), + repr, jd_sd.jitdriver.name, op.getarg(1).getint(), op.getarg(2).getint(), w_greenkey)) elif op.is_guard(): - l_w.append(GuardOp(name, ofs, logops.repr_of_resop(op), - op.getdescr().get_jitcounter_hash())) + descr = op.getdescr() + if descr is not None: # can be none in on_abort! + hash = op.getdescr().get_jitcounter_hash() + else: + hash = -1 + l_w.append(GuardOp(name, ofs, repr, hash)) else: - l_w.append(WrappedOp(name, ofs, logops.repr_of_resop(op))) + l_w.append(WrappedOp(name, ofs, repr)) return l_w @unwrap_spec(offset=int, repr='text', name='text') diff --git a/pypy/module/pypyjit/test/test_jit_hook.py b/pypy/module/pypyjit/test/test_jit_hook.py --- a/pypy/module/pypyjit/test/test_jit_hook.py +++ b/pypy/module/pypyjit/test/test_jit_hook.py @@ -65,6 +65,17 @@ if i != 1: offset[op] = i + oplist_no_descrs = parse(""" + [i1, i2, p2] + i3 = int_add(i1, i2) + debug_merge_point(0, 0, 0, 0, 0, ConstPtr(ptr0)) + guard_nonnull(p2) [] + guard_true(i3) [] + """, namespace={'ptr0': code_gcref}).operations + for op in oplist_no_descrs: + if op.is_guard(): + op.setdescr(None) + class FailDescr(BasicFailDescr): def get_jitcounter_hash(self): from rpython.rlib.rarithmetic import r_uint @@ -86,18 +97,23 @@ def interp_on_compile(): di_loop.oplist = cls.oplist - pypy_hooks.after_compile(di_loop) + if pypy_hooks.are_hooks_enabled(): + pypy_hooks.after_compile(di_loop) def interp_on_compile_bridge(): - pypy_hooks.after_compile_bridge(di_bridge) + if pypy_hooks.are_hooks_enabled(): + pypy_hooks.after_compile_bridge(di_bridge) def interp_on_optimize(): - di_loop_optimize.oplist = cls.oplist - pypy_hooks.before_compile(di_loop_optimize) + if pypy_hooks.are_hooks_enabled(): + di_loop_optimize.oplist = cls.oplist + pypy_hooks.before_compile(di_loop_optimize) def interp_on_abort(): - pypy_hooks.on_abort(Counters.ABORT_TOO_LONG, pypyjitdriver, - greenkey, 'blah', Logger(MockSD), []) + if pypy_hooks.are_hooks_enabled(): + pypy_hooks.on_abort(Counters.ABORT_TOO_LONG, pypyjitdriver, + greenkey, 'blah', Logger(MockSD), + cls.oplist_no_descrs) space = cls.space cls.w_on_compile = space.wrap(interp2app(interp_on_compile)) @@ -107,10 +123,12 @@ cls.w_dmp_num = space.wrap(rop.DEBUG_MERGE_POINT) cls.w_on_optimize = space.wrap(interp2app(interp_on_optimize)) cls.orig_oplist = oplist + cls.orig_oplist_no_descrs = oplist_no_descrs cls.w_sorted_keys = space.wrap(sorted(Counters.counter_names)) def setup_method(self, meth): self.__class__.oplist = self.orig_oplist[:] + self.__class__.oplist_no_descrs = self.orig_oplist_no_descrs[:] def test_on_compile(self): import pypyjit @@ -219,7 +237,11 @@ pypyjit.set_abort_hook(hook) self.on_abort() - assert l == [('pypyjit', 'ABORT_TOO_LONG', [])] + assert len(l) == 1 + name, reason, ops = l[0] + assert name == 'pypyjit' + assert reason == 'ABORT_TOO_LONG' + assert len(ops) == 4 def test_creation(self): from pypyjit import ResOperation diff --git a/rpython/jit/codewriter/policy.py b/rpython/jit/codewriter/policy.py --- a/rpython/jit/codewriter/policy.py +++ b/rpython/jit/codewriter/policy.py @@ -11,9 +11,6 @@ self.supports_floats = False self.supports_longlong = False self.supports_singlefloats = False - if jithookiface is None: - from rpython.rlib.jit import JitHookInterface - jithookiface = JitHookInterface() self.jithookiface = jithookiface def set_supports_floats(self, flag): diff --git a/rpython/jit/metainterp/compile.py b/rpython/jit/metainterp/compile.py --- a/rpython/jit/metainterp/compile.py +++ b/rpython/jit/metainterp/compile.py @@ -545,15 +545,17 @@ show_procedures(metainterp_sd, loop) loop.check_consistency() + debug_info = None + hooks = None if metainterp_sd.warmrunnerdesc is not None: hooks = metainterp_sd.warmrunnerdesc.hooks - debug_info = JitDebugInfo(jitdriver_sd, metainterp_sd.logger_ops, - original_jitcell_token, loop.operations, - type, greenkey) - hooks.before_compile(debug_info) - else: - debug_info = None - hooks = None + if hooks.are_hooks_enabled(): + debug_info = JitDebugInfo(jitdriver_sd, metainterp_sd.logger_ops, + original_jitcell_token, loop.operations, + type, greenkey) + hooks.before_compile(debug_info) + else: + hooks = None operations = get_deep_immutable_oplist(loop.operations) metainterp_sd.profiler.start_backend() debug_start("jit-backend") @@ -597,15 +599,17 @@ show_procedures(metainterp_sd) seen = dict.fromkeys(inputargs) TreeLoop.check_consistency_of_branch(operations, seen) + debug_info = None + hooks = None if metainterp_sd.warmrunnerdesc is not None: hooks = metainterp_sd.warmrunnerdesc.hooks - debug_info = JitDebugInfo(jitdriver_sd, metainterp_sd.logger_ops, - original_loop_token, operations, 'bridge', - fail_descr=faildescr) - hooks.before_compile_bridge(debug_info) - else: - hooks = None - debug_info = None + if hooks.are_hooks_enabled(): + debug_info = JitDebugInfo(jitdriver_sd, metainterp_sd.logger_ops, + original_loop_token, operations, 'bridge', + fail_descr=faildescr) + hooks.before_compile_bridge(debug_info) + else: + hooks = None operations = get_deep_immutable_oplist(operations) metainterp_sd.profiler.start_backend() debug_start("jit-backend") diff --git a/rpython/jit/metainterp/history.py b/rpython/jit/metainterp/history.py --- a/rpython/jit/metainterp/history.py +++ b/rpython/jit/metainterp/history.py @@ -701,6 +701,9 @@ def length(self): return self.trace._count - len(self.trace.inputargs) + def trace_tag_overflow(self): + return self.trace.tag_overflow + def get_trace_position(self): return self.trace.cut_point() diff --git a/rpython/jit/metainterp/opencoder.py b/rpython/jit/metainterp/opencoder.py --- a/rpython/jit/metainterp/opencoder.py +++ b/rpython/jit/metainterp/opencoder.py @@ -49,13 +49,6 @@ way up to lltype.Signed for indexes everywhere """ -def frontend_tag_overflow(): - # Minor abstraction leak: raise directly the right exception - # expected by the rest of the machinery - from rpython.jit.metainterp import history - from rpython.rlib.jit import Counters - raise history.SwitchToBlackhole(Counters.ABORT_TOO_LONG) - class BaseTrace(object): pass @@ -293,6 +286,7 @@ self._start = len(inputargs) self._pos = self._start self.inputargs = inputargs + self.tag_overflow = False def append(self, v): model = get_model(self) @@ -300,12 +294,14 @@ # grow by 2X self._ops = self._ops + [rffi.cast(model.STORAGE_TP, 0)] * len(self._ops) if not model.MIN_VALUE <= v <= model.MAX_VALUE: - raise frontend_tag_overflow() + v = 0 # broken value, but that's fine, tracing will stop soon + self.tag_overflow = True self._ops[self._pos] = rffi.cast(model.STORAGE_TP, v) self._pos += 1 - def done(self): + def tracing_done(self): from rpython.rlib.debug import debug_start, debug_stop, debug_print + assert not self.tag_overflow self._bigints_dict = {} self._refs_dict = llhelper.new_ref_dict_3() @@ -317,8 +313,6 @@ debug_print(" ref consts: " + str(self._consts_ptr) + " " + str(len(self._refs))) debug_print(" descrs: " + str(len(self._descrs))) debug_stop("jit-trace-done") - return 0 # completely different than TraceIter.done, but we have to - # share the base class def length(self): return self._pos @@ -379,6 +373,7 @@ def record_op(self, opnum, argboxes, descr=None): pos = self._index + old_pos = self._pos self.append(opnum) expected_arity = oparity[opnum] if expected_arity == -1: @@ -397,6 +392,10 @@ self._count += 1 if opclasses[opnum].type != 'v': self._index += 1 + if self.tag_overflow: + # potentially a broken op is left behind + # clean it up + self._pos = old_pos return pos def _encode_descr(self, descr): @@ -424,10 +423,11 @@ vref_array = self._list_of_boxes(vref_boxes) s = TopSnapshot(combine_uint(jitcode.index, pc), array, vable_array, vref_array) - assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0 # guards have no descr self._snapshots.append(s) - self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1) + if not self.tag_overflow: # otherwise we're broken anyway + assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0 + self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1) return s def create_empty_top_snapshot(self, vable_boxes, vref_boxes): @@ -436,10 +436,11 @@ vref_array = self._list_of_boxes(vref_boxes) s = TopSnapshot(combine_uint(2**16 - 1, 0), [], vable_array, vref_array) - assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0 # guards have no descr self._snapshots.append(s) - self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1) + if not self.tag_overflow: # otherwise we're broken anyway + assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0 + self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1) return s def create_snapshot(self, jitcode, pc, frame, flag): diff --git a/rpython/jit/metainterp/pyjitpl.py b/rpython/jit/metainterp/pyjitpl.py --- a/rpython/jit/metainterp/pyjitpl.py +++ b/rpython/jit/metainterp/pyjitpl.py @@ -2365,7 +2365,9 @@ greenkey = None # we're in the bridge else: greenkey = self.current_merge_points[0][0][:jd_sd.num_green_args] - self.staticdata.warmrunnerdesc.hooks.on_abort(reason, + hooks = self.staticdata.warmrunnerdesc.hooks + if hooks.are_hooks_enabled(): + hooks.on_abort(reason, jd_sd.jitdriver, greenkey, jd_sd.warmstate.get_location_str(greenkey), self.staticdata.logger_ops._make_log_operations( @@ -2374,9 +2376,10 @@ if self.aborted_tracing_jitdriver is not None: jd_sd = self.aborted_tracing_jitdriver greenkey = self.aborted_tracing_greenkey - self.staticdata.warmrunnerdesc.hooks.on_trace_too_long( - jd_sd.jitdriver, greenkey, - jd_sd.warmstate.get_location_str(greenkey)) + if hooks.are_hooks_enabled(): + hooks.on_trace_too_long( + jd_sd.jitdriver, greenkey, + jd_sd.warmstate.get_location_str(greenkey)) # no ops for now self.aborted_tracing_jitdriver = None self.aborted_tracing_greenkey = None @@ -2384,9 +2387,9 @@ def blackhole_if_trace_too_long(self): warmrunnerstate = self.jitdriver_sd.warmstate - if self.history.length() > warmrunnerstate.trace_limit: + if (self.history.length() > warmrunnerstate.trace_limit or + self.history.trace_tag_overflow()): jd_sd, greenkey_of_huge_function = self.find_biggest_function() - self.history.trace.done() self.staticdata.stats.record_aborted(greenkey_of_huge_function) self.portal_trace_positions = None if greenkey_of_huge_function is not None: @@ -2689,7 +2692,9 @@ try_disabling_unroll=False, exported_state=None): num_green_args = self.jitdriver_sd.num_green_args greenkey = original_boxes[:num_green_args] - self.history.trace.done() + if self.history.trace_tag_overflow(): + raise SwitchToBlackhole(Counters.ABORT_TOO_LONG) + self.history.trace.tracing_done() if not self.partial_trace: ptoken = self.get_procedure_token(greenkey) if ptoken is not None and ptoken.target_tokens is not None: @@ -2742,7 +2747,9 @@ self.history.record(rop.JUMP, live_arg_boxes[num_green_args:], None, descr=target_jitcell_token) self.history.ends_with_jump = True - self.history.trace.done() + if self.history.trace_tag_overflow(): + raise SwitchToBlackhole(Counters.ABORT_TOO_LONG) + self.history.trace.tracing_done() try: target_token = compile.compile_trace(self, self.resumekey, live_arg_boxes[num_green_args:]) @@ -2776,7 +2783,9 @@ assert False # FIXME: can we call compile_trace? self.history.record(rop.FINISH, exits, None, descr=token) - self.history.trace.done() + if self.history.trace_tag_overflow(): + raise SwitchToBlackhole(Counters.ABORT_TOO_LONG) + self.history.trace.tracing_done() target_token = compile.compile_trace(self, self.resumekey, exits) if target_token is not token: compile.giveup() @@ -2802,7 +2811,9 @@ sd = self.staticdata token = sd.exit_frame_with_exception_descr_ref self.history.record(rop.FINISH, [valuebox], None, descr=token) - self.history.trace.done() + if self.history.trace_tag_overflow(): + raise SwitchToBlackhole(Counters.ABORT_TOO_LONG) + self.history.trace.tracing_done() target_token = compile.compile_trace(self, self.resumekey, [valuebox]) if target_token is not token: compile.giveup() diff --git a/rpython/jit/metainterp/test/test_ajit.py b/rpython/jit/metainterp/test/test_ajit.py --- a/rpython/jit/metainterp/test/test_ajit.py +++ b/rpython/jit/metainterp/test/test_ajit.py @@ -4661,3 +4661,36 @@ f() # finishes self.meta_interp(f, []) + + def test_trace_too_long_bug(self): + driver = JitDriver(greens=[], reds=['i']) + @unroll_safe + def match(s): + l = len(s) + p = 0 + for i in range(2500): # produces too long trace + c = s[p] + if c != 'a': + return False + p += 1 + if p >= l: + return True + c = s[p] + if c != '\n': + p += 1 + if p >= l: + return True + else: + return False + return True + + def f(i): + while i > 0: + driver.jit_merge_point(i=i) + match('a' * (500 * i)) + i -= 1 + return i + + res = self.meta_interp(f, [10]) + assert res == f(10) + diff --git a/rpython/jit/metainterp/test/test_greenfield.py b/rpython/jit/metainterp/test/test_greenfield.py --- a/rpython/jit/metainterp/test/test_greenfield.py +++ b/rpython/jit/metainterp/test/test_greenfield.py @@ -1,6 +1,17 @@ +import pytest from rpython.jit.metainterp.test.support import LLJitMixin from rpython.rlib.jit import JitDriver, assert_green +pytest.skip("this feature is disabled at the moment!") + +# note why it is disabled: before d721da4573ad +# there was a failing assert when inlining python -> sre -> python: +# https://bitbucket.org/pypy/pypy/issues/2775/ +# this shows, that the interaction of greenfields and virtualizables is broken, +# because greenfields use MetaInterp.virtualizable_boxes, which confuses +# MetaInterp._nonstandard_virtualizable somehow (and makes no sense +# conceptually anyway). to fix greenfields, the two mechanisms would have to be +# disentangled. class GreenFieldsTests: diff --git a/rpython/jit/metainterp/test/test_jitiface.py b/rpython/jit/metainterp/test/test_jitiface.py --- a/rpython/jit/metainterp/test/test_jitiface.py +++ b/rpython/jit/metainterp/test/test_jitiface.py @@ -238,7 +238,7 @@ hashes = Hashes() - class Hooks(object): + class Hooks(JitHookInterface): def before_compile(self, debug_info): pass @@ -279,6 +279,44 @@ self.meta_interp(main, [1, 1], policy=JitPolicy(hooks)) assert len(hashes.t) == 1 + + def test_are_hooks_enabled(self): + reasons = [] + + class MyJitIface(JitHookInterface): + def are_hooks_enabled(self): + return False + + def on_abort(self, reason, jitdriver, greenkey, greenkey_repr, logops, ops): + reasons.append(reason) + + iface = MyJitIface() + + myjitdriver = JitDriver(greens=['foo'], reds=['x', 'total'], + get_printable_location=lambda *args: 'blah') + + class Foo: + _immutable_fields_ = ['a?'] + + def __init__(self, a): + self.a = a + + def f(a, x): + foo = Foo(a) + total = 0 + while x > 0: + myjitdriver.jit_merge_point(foo=foo, x=x, total=total) + total += foo.a + foo.a += 1 + x -= 1 + return total + # + assert f(100, 7) == 721 + res = self.meta_interp(f, [100, 7], policy=JitPolicy(iface)) + assert res == 721 + assert reasons == [] + + class LLJitHookInterfaceTests(JitHookInterfaceTests): # use this for any backend, instead of the super class @@ -320,7 +358,6 @@ # this so far does not work because of the way setup_once is done, # but fine, it's only about untranslated version anyway #self.meta_interp(main, [False], ProfilerClass=Profiler) - class TestJitHookInterface(JitHookInterfaceTests, LLJitMixin): pass diff --git a/rpython/jit/metainterp/test/test_opencoder.py b/rpython/jit/metainterp/test/test_opencoder.py --- a/rpython/jit/metainterp/test/test_opencoder.py +++ b/rpython/jit/metainterp/test/test_opencoder.py @@ -209,5 +209,8 @@ def test_tag_overflow(self): t = Trace([], metainterp_sd) i0 = FakeOp(100000) - py.test.raises(SwitchToBlackhole, t.record_op, rop.FINISH, [i0]) - assert t.unpack() == ([], []) + # if we overflow, we can keep recording + for i in range(10): + t.record_op(rop.FINISH, [i0]) + assert t.unpack() == ([], []) + assert t.tag_overflow diff --git a/rpython/jit/metainterp/warmspot.py b/rpython/jit/metainterp/warmspot.py --- a/rpython/jit/metainterp/warmspot.py +++ b/rpython/jit/metainterp/warmspot.py @@ -220,6 +220,15 @@ stats.check_consistency() # ____________________________________________________________ +# always disabled hooks interface + +from rpython.rlib.jit import JitHookInterface + +class NoHooksInterface(JitHookInterface): + def are_hooks_enabled(self): + return False + +# ____________________________________________________________ class WarmRunnerDesc(object): @@ -259,7 +268,7 @@ else: self.jitcounter = counter.DeterministicJitCounter() # - self.hooks = policy.jithookiface + self.make_hooks(policy.jithookiface) self.make_virtualizable_infos() self.make_driverhook_graphs() self.make_enter_functions() @@ -498,6 +507,12 @@ self.metainterp_sd.opencoder_model = Model self.stats.metainterp_sd = self.metainterp_sd + def make_hooks(self, hooks): + if hooks is None: + # interface not overridden, use a special one that is never enabled + hooks = NoHooksInterface() + self.hooks = hooks + def make_virtualizable_infos(self): vinfos = {} for jd in self.jitdrivers_sd: diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py --- a/rpython/rlib/jit.py +++ b/rpython/rlib/jit.py @@ -653,6 +653,9 @@ self._make_extregistryentries() assert get_jitcell_at is None, "get_jitcell_at no longer used" assert set_jitcell_at is None, "set_jitcell_at no longer used" + for green in self.greens: + if "." in green: + raise ValueError("green fields are buggy! if you need them fixed, please talk to us") self.get_printable_location = get_printable_location self.get_location = get_location self.has_unique_id = (get_unique_id is not None) @@ -1084,7 +1087,8 @@ """ This is the main connector between the JIT and the interpreter. Several methods on this class will be invoked at various stages of JIT running like JIT loops compiled, aborts etc. - An instance of this class will be available as policy.jithookiface. + An instance of this class has to be passed into the JitPolicy constructor + (and will then be available as policy.jithookiface). """ # WARNING: You should make a single prebuilt instance of a subclass # of this class. You can, before translation, initialize some @@ -1094,6 +1098,13 @@ # of the program! A line like ``pypy_hooks.foo = ...`` must not # appear inside your interpreter's RPython code. + def are_hooks_enabled(self): + """ A hook that is called to check whether the interpreter's hooks are + enabled at all. Only if this function returns True, are the other hooks + called. Otherwise, nothing happens. This is done because constructing + some of the hooks' arguments is expensive, so we'd rather not do it.""" + return True + def on_abort(self, reason, jitdriver, greenkey, greenkey_repr, logops, operations): """ A hook called each time a loop is aborted with jitdriver and greenkey where it started, reason is a string why it got aborted diff --git a/rpython/rlib/rsre/rpy/_sre.py b/rpython/rlib/rsre/rpy/_sre.py --- a/rpython/rlib/rsre/rpy/_sre.py +++ b/rpython/rlib/rsre/rpy/_sre.py @@ -1,4 +1,4 @@ -from rpython.rlib.rsre import rsre_char +from rpython.rlib.rsre import rsre_char, rsre_core from rpython.rlib.rarithmetic import intmask VERSION = "2.7.6" @@ -12,7 +12,7 @@ pass def compile(pattern, flags, code, *args): - raise GotIt([intmask(i) for i in code], flags, args) + raise GotIt(rsre_core.CompiledPattern([intmask(i) for i in code]), flags, args) def get_code(regexp, flags=0, allargs=False): diff --git a/rpython/rlib/rsre/rsre_char.py b/rpython/rlib/rsre/rsre_char.py --- a/rpython/rlib/rsre/rsre_char.py +++ b/rpython/rlib/rsre/rsre_char.py @@ -152,17 +152,16 @@ ##### Charset evaluation @jit.unroll_safe -def check_charset(ctx, ppos, char_code): +def check_charset(ctx, pattern, ppos, char_code): """Checks whether a character matches set of arbitrary length. The set starts at pattern[ppos].""" negated = False result = False - pattern = ctx.pattern while True: - opcode = pattern[ppos] + opcode = pattern.pattern[ppos] for i, function in set_dispatch_unroll: if opcode == i: - newresult, ppos = function(ctx, ppos, char_code) + newresult, ppos = function(ctx, pattern, ppos, char_code) result |= newresult break else: @@ -177,50 +176,44 @@ return not result return result -def set_literal(ctx, index, char_code): +def set_literal(ctx, pattern, index, char_code): # <LITERAL> <code> - pat = ctx.pattern - match = pat[index+1] == char_code + match = pattern.pattern[index+1] == char_code return match, index + 2 -def set_category(ctx, index, char_code): +def set_category(ctx, pattern, index, char_code): # <CATEGORY> <code> - pat = ctx.pattern - match = category_dispatch(pat[index+1], char_code) + match = category_dispatch(pattern.pattern[index+1], char_code) return match, index + 2 -def set_charset(ctx, index, char_code): +def set_charset(ctx, pattern, index, char_code): # <CHARSET> <bitmap> (16 bits per code word) - pat = ctx.pattern if CODESIZE == 2: match = char_code < 256 and \ - (pat[index+1+(char_code >> 4)] & (1 << (char_code & 15))) + (pattern.pattern[index+1+(char_code >> 4)] & (1 << (char_code & 15))) return match, index + 17 # skip bitmap else: match = char_code < 256 and \ - (pat[index+1+(char_code >> 5)] & (1 << (char_code & 31))) + (pattern.pattern[index+1+(char_code >> 5)] & (1 << (char_code & 31))) return match, index + 9 # skip bitmap -def set_range(ctx, index, char_code): +def set_range(ctx, pattern, index, char_code): # <RANGE> <lower> <upper> - pat = ctx.pattern - match = int_between(pat[index+1], char_code, pat[index+2] + 1) + match = int_between(pattern.pattern[index+1], char_code, pattern.pattern[index+2] + 1) return match, index + 3 -def set_range_ignore(ctx, index, char_code): +def set_range_ignore(ctx, pattern, index, char_code): # <RANGE_IGNORE> <lower> <upper> # the char_code is already lower cased - pat = ctx.pattern - lower = pat[index + 1] - upper = pat[index + 2] + lower = pattern.pattern[index + 1] + upper = pattern.pattern[index + 2] match1 = int_between(lower, char_code, upper + 1) match2 = int_between(lower, getupper(char_code, ctx.flags), upper + 1) return match1 | match2, index + 3 -def set_bigcharset(ctx, index, char_code): +def set_bigcharset(ctx, pattern, index, char_code): # <BIGCHARSET> <blockcount> <256 blockindices> <blocks> - pat = ctx.pattern - count = pat[index+1] + count = pattern.pattern[index+1] index += 2 if CODESIZE == 2: @@ -238,7 +231,7 @@ return False, index shift = 5 - block = pat[index + (char_code >> (shift + 5))] + block = pattern.pattern[index + (char_code >> (shift + 5))] block_shift = char_code >> 5 if BIG_ENDIAN: @@ -247,23 +240,22 @@ block = (block >> block_shift) & 0xFF index += 256 / CODESIZE - block_value = pat[index+(block * (32 / CODESIZE) + block_value = pattern.pattern[index+(block * (32 / CODESIZE) + ((char_code & 255) >> shift))] match = (block_value & (1 << (char_code & ((8 * CODESIZE) - 1)))) index += count * (32 / CODESIZE) # skip blocks return match, index -def set_unicode_general_category(ctx, index, char_code): +def set_unicode_general_category(ctx, pattern, index, char_code): # Unicode "General category property code" (not used by Python). - # A general category is two letters. 'pat[index+1]' contains both + # A general category is two letters. 'pattern.pattern[index+1]' contains both # the first character, and the second character shifted by 8. # http://en.wikipedia.org/wiki/Unicode_character_property#General_Category # Also supports single-character categories, if the second character is 0. # Negative matches are triggered by bit number 7. assert unicodedb is not None cat = unicodedb.category(char_code) - pat = ctx.pattern - category_code = pat[index + 1] + category_code = pattern.pattern[index + 1] first_character = category_code & 0x7F second_character = (category_code >> 8) & 0x7F negative_match = category_code & 0x80 diff --git a/rpython/rlib/rsre/rsre_core.py b/rpython/rlib/rsre/rsre_core.py --- a/rpython/rlib/rsre/rsre_core.py +++ b/rpython/rlib/rsre/rsre_core.py @@ -83,35 +83,19 @@ def __init__(self, msg): self.msg = msg -class AbstractMatchContext(object): - """Abstract base class""" - _immutable_fields_ = ['pattern[*]', 'flags', 'end'] - match_start = 0 - match_end = 0 - match_marks = None - match_marks_flat = None - fullmatch_only = False - def __init__(self, pattern, match_start, end, flags): - # 'match_start' and 'end' must be known to be non-negative - # and they must not be more than len(string). - check_nonneg(match_start) - check_nonneg(end) +class CompiledPattern(object): + _immutable_fields_ = ['pattern[*]'] + + def __init__(self, pattern): self.pattern = pattern - self.match_start = match_start - self.end = end - self.flags = flags # check we don't get the old value of MAXREPEAT # during the untranslated tests if not we_are_translated(): assert 65535 not in pattern - def reset(self, start): - self.match_start = start - self.match_marks = None - self.match_marks_flat = None - def pat(self, index): + jit.promote(self) check_nonneg(index) result = self.pattern[index] # Check that we only return non-negative integers from this helper. @@ -121,6 +105,29 @@ assert result >= 0 return result +class AbstractMatchContext(object): + """Abstract base class""" + _immutable_fields_ = ['flags', 'end'] + match_start = 0 + match_end = 0 + match_marks = None + match_marks_flat = None + fullmatch_only = False + + def __init__(self, match_start, end, flags): + # 'match_start' and 'end' must be known to be non-negative + # and they must not be more than len(string). + check_nonneg(match_start) + check_nonneg(end) + self.match_start = match_start + self.end = end + self.flags = flags + + def reset(self, start): + self.match_start = start + self.match_marks = None + self.match_marks_flat = None + @not_rpython def str(self, index): """Must be overridden in a concrete subclass. @@ -183,8 +190,8 @@ _immutable_fields_ = ["_buffer"] - def __init__(self, pattern, buf, match_start, end, flags): - AbstractMatchContext.__init__(self, pattern, match_start, end, flags) + def __init__(self, buf, match_start, end, flags): + AbstractMatchContext.__init__(self, match_start, end, flags) self._buffer = buf def str(self, index): @@ -196,7 +203,7 @@ return rsre_char.getlower(c, self.flags) def fresh_copy(self, start): - return BufMatchContext(self.pattern, self._buffer, start, + return BufMatchContext(self._buffer, start, self.end, self.flags) class StrMatchContext(AbstractMatchContext): @@ -204,8 +211,8 @@ _immutable_fields_ = ["_string"] - def __init__(self, pattern, string, match_start, end, flags): - AbstractMatchContext.__init__(self, pattern, match_start, end, flags) + def __init__(self, string, match_start, end, flags): + AbstractMatchContext.__init__(self, match_start, end, flags) self._string = string if not we_are_translated() and isinstance(string, unicode): self.flags |= rsre_char.SRE_FLAG_UNICODE # for rsre_re.py @@ -219,7 +226,7 @@ return rsre_char.getlower(c, self.flags) def fresh_copy(self, start): - return StrMatchContext(self.pattern, self._string, start, + return StrMatchContext(self._string, start, self.end, self.flags) class UnicodeMatchContext(AbstractMatchContext): @@ -227,8 +234,8 @@ _immutable_fields_ = ["_unicodestr"] - def __init__(self, pattern, unicodestr, match_start, end, flags): - AbstractMatchContext.__init__(self, pattern, match_start, end, flags) + def __init__(self, unicodestr, match_start, end, flags): + AbstractMatchContext.__init__(self, match_start, end, flags) self._unicodestr = unicodestr def str(self, index): @@ -240,7 +247,7 @@ return rsre_char.getlower(c, self.flags) def fresh_copy(self, start): - return UnicodeMatchContext(self.pattern, self._unicodestr, start, + return UnicodeMatchContext(self._unicodestr, start, self.end, self.flags) # ____________________________________________________________ @@ -265,16 +272,16 @@ class MatchResult(object): subresult = None - def move_to_next_result(self, ctx): + def move_to_next_result(self, ctx, pattern): # returns either 'self' or None result = self.subresult if result is None: return - if result.move_to_next_result(ctx): + if result.move_to_next_result(ctx, pattern): return self - return self.find_next_result(ctx) + return self.find_next_result(ctx, pattern) - def find_next_result(self, ctx): + def find_next_result(self, ctx, pattern): raise NotImplementedError MATCHED_OK = MatchResult() @@ -287,11 +294,11 @@ self.start_marks = marks @jit.unroll_safe - def find_first_result(self, ctx): + def find_first_result(self, ctx, pattern): ppos = jit.hint(self.ppos, promote=True) - while ctx.pat(ppos): - result = sre_match(ctx, ppos + 1, self.start_ptr, self.start_marks) - ppos += ctx.pat(ppos) + while pattern.pat(ppos): + result = sre_match(ctx, pattern, ppos + 1, self.start_ptr, self.start_marks) + ppos += pattern.pat(ppos) if result is not None: self.subresult = result self.ppos = ppos @@ -300,7 +307,7 @@ class RepeatOneMatchResult(MatchResult): install_jitdriver('RepeatOne', - greens=['nextppos', 'ctx.pattern'], + greens=['nextppos', 'pattern'], reds=['ptr', 'self', 'ctx'], debugprint=(1, 0)) # indices in 'greens' @@ -310,13 +317,14 @@ self.start_ptr = ptr self.start_marks = marks - def find_first_result(self, ctx): + def find_first_result(self, ctx, pattern): ptr = self.start_ptr nextppos = self.nextppos while ptr >= self.minptr: ctx.jitdriver_RepeatOne.jit_merge_point( - self=self, ptr=ptr, ctx=ctx, nextppos=nextppos) - result = sre_match(ctx, nextppos, ptr, self.start_marks) + self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, + pattern=pattern) + result = sre_match(ctx, pattern, nextppos, ptr, self.start_marks) ptr -= 1 if result is not None: self.subresult = result @@ -327,7 +335,7 @@ class MinRepeatOneMatchResult(MatchResult): install_jitdriver('MinRepeatOne', - greens=['nextppos', 'ppos3', 'ctx.pattern'], + greens=['nextppos', 'ppos3', 'pattern'], reds=['ptr', 'self', 'ctx'], debugprint=(2, 0)) # indices in 'greens' @@ -338,39 +346,40 @@ self.start_ptr = ptr self.start_marks = marks - def find_first_result(self, ctx): + def find_first_result(self, ctx, pattern): ptr = self.start_ptr nextppos = self.nextppos ppos3 = self.ppos3 while ptr <= self.maxptr: ctx.jitdriver_MinRepeatOne.jit_merge_point( - self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, ppos3=ppos3) - result = sre_match(ctx, nextppos, ptr, self.start_marks) + self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, ppos3=ppos3, + pattern=pattern) + result = sre_match(ctx, pattern, nextppos, ptr, self.start_marks) if result is not None: self.subresult = result self.start_ptr = ptr return self - if not self.next_char_ok(ctx, ptr, ppos3): + if not self.next_char_ok(ctx, pattern, ptr, ppos3): break ptr += 1 - def find_next_result(self, ctx): + def find_next_result(self, ctx, pattern): ptr = self.start_ptr - if not self.next_char_ok(ctx, ptr, self.ppos3): + if not self.next_char_ok(ctx, pattern, ptr, self.ppos3): return self.start_ptr = ptr + 1 - return self.find_first_result(ctx) + return self.find_first_result(ctx, pattern) - def next_char_ok(self, ctx, ptr, ppos): + def next_char_ok(self, ctx, pattern, ptr, ppos): if ptr == ctx.end: return False - op = ctx.pat(ppos) + op = pattern.pat(ppos) for op1, checkerfn in unroll_char_checker: if op1 == op: - return checkerfn(ctx, ptr, ppos) + return checkerfn(ctx, pattern, ptr, ppos) # obscure case: it should be a single char pattern, but isn't # one of the opcodes in unroll_char_checker (see test_ext_opcode) - return sre_match(ctx, ppos, ptr, self.start_marks) is not None + return sre_match(ctx, pattern, ppos, ptr, self.start_marks) is not None class AbstractUntilMatchResult(MatchResult): @@ -391,17 +400,17 @@ class MaxUntilMatchResult(AbstractUntilMatchResult): install_jitdriver('MaxUntil', - greens=['ppos', 'tailppos', 'match_more', 'ctx.pattern'], + greens=['ppos', 'tailppos', 'match_more', 'pattern'], reds=['ptr', 'marks', 'self', 'ctx'], debugprint=(3, 0, 2)) - def find_first_result(self, ctx): - return self.search_next(ctx, match_more=True) + def find_first_result(self, ctx, pattern): + return self.search_next(ctx, pattern, match_more=True) - def find_next_result(self, ctx): - return self.search_next(ctx, match_more=False) + def find_next_result(self, ctx, pattern): + return self.search_next(ctx, pattern, match_more=False) - def search_next(self, ctx, match_more): + def search_next(self, ctx, pattern, match_more): ppos = self.ppos tailppos = self.tailppos ptr = self.cur_ptr @@ -409,12 +418,13 @@ while True: ctx.jitdriver_MaxUntil.jit_merge_point( ppos=ppos, tailppos=tailppos, match_more=match_more, - ptr=ptr, marks=marks, self=self, ctx=ctx) + ptr=ptr, marks=marks, self=self, ctx=ctx, + pattern=pattern) if match_more: - max = ctx.pat(ppos+2) + max = pattern.pat(ppos+2) if max == rsre_char.MAXREPEAT or self.num_pending < max: # try to match one more 'item' - enum = sre_match(ctx, ppos + 3, ptr, marks) + enum = sre_match(ctx, pattern, ppos + 3, ptr, marks) else: enum = None # 'max' reached, no more matches else: @@ -425,9 +435,9 @@ self.num_pending -= 1 ptr = p.ptr marks = p.marks - enum = p.enum.move_to_next_result(ctx) + enum = p.enum.move_to_next_result(ctx, pattern) # - min = ctx.pat(ppos+1) + min = pattern.pat(ppos+1) if enum is not None: # matched one more 'item'. record it and continue. last_match_length = ctx.match_end - ptr @@ -447,7 +457,7 @@ # 'item' no longer matches. if self.num_pending >= min: # try to match 'tail' if we have enough 'item' - result = sre_match(ctx, tailppos, ptr, marks) + result = sre_match(ctx, pattern, tailppos, ptr, marks) if result is not None: self.subresult = result self.cur_ptr = ptr @@ -457,23 +467,23 @@ class MinUntilMatchResult(AbstractUntilMatchResult): - def find_first_result(self, ctx): - return self.search_next(ctx, resume=False) + def find_first_result(self, ctx, pattern): + return self.search_next(ctx, pattern, resume=False) - def find_next_result(self, ctx): - return self.search_next(ctx, resume=True) + def find_next_result(self, ctx, pattern): + return self.search_next(ctx, pattern, resume=True) - def search_next(self, ctx, resume): + def search_next(self, ctx, pattern, resume): # XXX missing jit support here ppos = self.ppos - min = ctx.pat(ppos+1) - max = ctx.pat(ppos+2) + min = pattern.pat(ppos+1) + max = pattern.pat(ppos+2) ptr = self.cur_ptr marks = self.cur_marks while True: # try to match 'tail' if we have enough 'item' if not resume and self.num_pending >= min: - result = sre_match(ctx, self.tailppos, ptr, marks) + result = sre_match(ctx, pattern, self.tailppos, ptr, marks) if result is not None: self.subresult = result self.cur_ptr = ptr @@ -483,12 +493,12 @@ if max == rsre_char.MAXREPEAT or self.num_pending < max: # try to match one more 'item' - enum = sre_match(ctx, ppos + 3, ptr, marks) + enum = sre_match(ctx, pattern, ppos + 3, ptr, marks) # # zero-width match protection if self.num_pending >= min: while enum is not None and ptr == ctx.match_end: - enum = enum.move_to_next_result(ctx) + enum = enum.move_to_next_result(ctx, pattern) else: enum = None # 'max' reached, no more matches @@ -502,7 +512,7 @@ self.num_pending -= 1 ptr = p.ptr marks = p.marks - enum = p.enum.move_to_next_result(ctx) + enum = p.enum.move_to_next_result(ctx, pattern) # matched one more 'item'. record it and continue self.pending = Pending(ptr, marks, enum, self.pending) @@ -514,13 +524,13 @@ @specializectx @jit.unroll_safe -def sre_match(ctx, ppos, ptr, marks): +def sre_match(ctx, pattern, ppos, ptr, marks): """Returns either None or a MatchResult object. Usually we only need the first result, but there is the case of REPEAT...UNTIL where we need all results; in that case we use the method move_to_next_result() of the MatchResult.""" while True: - op = ctx.pat(ppos) + op = pattern.pat(ppos) ppos += 1 #jit.jit_debug("sre_match", op, ppos, ptr) @@ -563,33 +573,33 @@ elif op == OPCODE_ASSERT: # assert subpattern # <ASSERT> <0=skip> <1=back> <pattern> - ptr1 = ptr - ctx.pat(ppos+1) + ptr1 = ptr - pattern.pat(ppos+1) saved = ctx.fullmatch_only ctx.fullmatch_only = False - stop = ptr1 < 0 or sre_match(ctx, ppos + 2, ptr1, marks) is None + stop = ptr1 < 0 or sre_match(ctx, pattern, ppos + 2, ptr1, marks) is None ctx.fullmatch_only = saved if stop: return marks = ctx.match_marks - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) elif op == OPCODE_ASSERT_NOT: # assert not subpattern # <ASSERT_NOT> <0=skip> <1=back> <pattern> - ptr1 = ptr - ctx.pat(ppos+1) + ptr1 = ptr - pattern.pat(ppos+1) saved = ctx.fullmatch_only ctx.fullmatch_only = False - stop = (ptr1 >= 0 and sre_match(ctx, ppos + 2, ptr1, marks) + stop = (ptr1 >= 0 and sre_match(ctx, pattern, ppos + 2, ptr1, marks) is not None) ctx.fullmatch_only = saved if stop: return - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) elif op == OPCODE_AT: # match at given position (e.g. at beginning, at boundary, etc.) # <AT> <code> - if not sre_at(ctx, ctx.pat(ppos), ptr): + if not sre_at(ctx, pattern.pat(ppos), ptr): return ppos += 1 @@ -597,14 +607,14 @@ # alternation # <BRANCH> <0=skip> code <JUMP> ... <NULL> result = BranchMatchResult(ppos, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) elif op == OPCODE_CATEGORY: # seems to be never produced, but used by some tests from # pypy/module/_sre/test # <CATEGORY> <category> if (ptr == ctx.end or - not rsre_char.category_dispatch(ctx.pat(ppos), ctx.str(ptr))): + not rsre_char.category_dispatch(pattern.pat(ppos), ctx.str(ptr))): return ptr += 1 ppos += 1 @@ -612,7 +622,7 @@ elif op == OPCODE_GROUPREF: # match backreference # <GROUPREF> <groupnum> - startptr, length = get_group_ref(marks, ctx.pat(ppos)) + startptr, length = get_group_ref(marks, pattern.pat(ppos)) if length < 0: return # group was not previously defined if not match_repeated(ctx, ptr, startptr, length): @@ -623,7 +633,7 @@ elif op == OPCODE_GROUPREF_IGNORE: # match backreference # <GROUPREF> <groupnum> - startptr, length = get_group_ref(marks, ctx.pat(ppos)) + startptr, length = get_group_ref(marks, pattern.pat(ppos)) if length < 0: return # group was not previously defined if not match_repeated_ignore(ctx, ptr, startptr, length): @@ -634,44 +644,44 @@ elif op == OPCODE_GROUPREF_EXISTS: # conditional match depending on the existence of a group # <GROUPREF_EXISTS> <group> <skip> codeyes <JUMP> codeno ... - _, length = get_group_ref(marks, ctx.pat(ppos)) + _, length = get_group_ref(marks, pattern.pat(ppos)) if length >= 0: ppos += 2 # jump to 'codeyes' else: - ppos += ctx.pat(ppos+1) # jump to 'codeno' + ppos += pattern.pat(ppos+1) # jump to 'codeno' elif op == OPCODE_IN: # match set member (or non_member) # <IN> <skip> <set> - if ptr >= ctx.end or not rsre_char.check_charset(ctx, ppos+1, + if ptr >= ctx.end or not rsre_char.check_charset(ctx, pattern, ppos+1, ctx.str(ptr)): return - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) ptr += 1 elif op == OPCODE_IN_IGNORE: # match set member (or non_member), ignoring case # <IN> <skip> <set> - if ptr >= ctx.end or not rsre_char.check_charset(ctx, ppos+1, + if ptr >= ctx.end or not rsre_char.check_charset(ctx, pattern, ppos+1, ctx.lowstr(ptr)): return - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) ptr += 1 elif op == OPCODE_INFO: # optimization info block # <INFO> <0=skip> <1=flags> <2=min> ... - if (ctx.end - ptr) < ctx.pat(ppos+2): + if (ctx.end - ptr) < pattern.pat(ppos+2): return - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) elif op == OPCODE_JUMP: - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) elif op == OPCODE_LITERAL: # match literal string # <LITERAL> <code> - if ptr >= ctx.end or ctx.str(ptr) != ctx.pat(ppos): + if ptr >= ctx.end or ctx.str(ptr) != pattern.pat(ppos): return ppos += 1 ptr += 1 @@ -679,7 +689,7 @@ elif op == OPCODE_LITERAL_IGNORE: # match literal string, ignoring case # <LITERAL_IGNORE> <code> - if ptr >= ctx.end or ctx.lowstr(ptr) != ctx.pat(ppos): + if ptr >= ctx.end or ctx.lowstr(ptr) != pattern.pat(ppos): return ppos += 1 ptr += 1 @@ -687,14 +697,14 @@ elif op == OPCODE_MARK: # set mark # <MARK> <gid> - gid = ctx.pat(ppos) + gid = pattern.pat(ppos) marks = Mark(gid, ptr, marks) ppos += 1 elif op == OPCODE_NOT_LITERAL: # match if it's not a literal string # <NOT_LITERAL> <code> - if ptr >= ctx.end or ctx.str(ptr) == ctx.pat(ppos): + if ptr >= ctx.end or ctx.str(ptr) == pattern.pat(ppos): return ppos += 1 ptr += 1 @@ -702,7 +712,7 @@ elif op == OPCODE_NOT_LITERAL_IGNORE: # match if it's not a literal string, ignoring case # <NOT_LITERAL> <code> - if ptr >= ctx.end or ctx.lowstr(ptr) == ctx.pat(ppos): + if ptr >= ctx.end or ctx.lowstr(ptr) == pattern.pat(ppos): return ppos += 1 ptr += 1 @@ -715,22 +725,22 @@ # decode the later UNTIL operator to see if it is actually # a MAX_UNTIL or MIN_UNTIL - untilppos = ppos + ctx.pat(ppos) + untilppos = ppos + pattern.pat(ppos) tailppos = untilppos + 1 - op = ctx.pat(untilppos) + op = pattern.pat(untilppos) if op == OPCODE_MAX_UNTIL: # the hard case: we have to match as many repetitions as # possible, followed by the 'tail'. we do this by # remembering each state for each possible number of # 'item' matching. result = MaxUntilMatchResult(ppos, tailppos, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) elif op == OPCODE_MIN_UNTIL: # first try to match the 'tail', and if it fails, try # to match one more 'item' and try again result = MinUntilMatchResult(ppos, tailppos, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) else: raise Error("missing UNTIL after REPEAT") @@ -743,17 +753,18 @@ # use the MAX_REPEAT operator. # <REPEAT_ONE> <skip> <1=min> <2=max> item <SUCCESS> tail start = ptr - minptr = start + ctx.pat(ppos+1) + minptr = start + pattern.pat(ppos+1) if minptr > ctx.end: return # cannot match - ptr = find_repetition_end(ctx, ppos+3, start, ctx.pat(ppos+2), + ptr = find_repetition_end(ctx, pattern, ppos+3, start, + pattern.pat(ppos+2), marks) # when we arrive here, ptr points to the tail of the target # string. check if the rest of the pattern matches, # and backtrack if not. - nextppos = ppos + ctx.pat(ppos) + nextppos = ppos + pattern.pat(ppos) result = RepeatOneMatchResult(nextppos, minptr, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) elif op == OPCODE_MIN_REPEAT_ONE: # match repeated sequence (minimizing regexp). @@ -763,26 +774,26 @@ # use the MIN_REPEAT operator. # <MIN_REPEAT_ONE> <skip> <1=min> <2=max> item <SUCCESS> tail start = ptr - min = ctx.pat(ppos+1) + min = pattern.pat(ppos+1) if min > 0: minptr = ptr + min if minptr > ctx.end: return # cannot match # count using pattern min as the maximum - ptr = find_repetition_end(ctx, ppos+3, ptr, min, marks) + ptr = find_repetition_end(ctx, pattern, ppos+3, ptr, min, marks) if ptr < minptr: return # did not match minimum number of times maxptr = ctx.end - max = ctx.pat(ppos+2) + max = pattern.pat(ppos+2) if max != rsre_char.MAXREPEAT: maxptr1 = start + max if maxptr1 <= maxptr: maxptr = maxptr1 - nextppos = ppos + ctx.pat(ppos) + nextppos = ppos + pattern.pat(ppos) result = MinRepeatOneMatchResult(nextppos, ppos+3, maxptr, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) else: raise Error("bad pattern code %d" % op) @@ -816,7 +827,7 @@ return True @specializectx -def find_repetition_end(ctx, ppos, ptr, maxcount, marks): +def find_repetition_end(ctx, pattern, ppos, ptr, maxcount, marks): end = ctx.end ptrp1 = ptr + 1 # First get rid of the cases where we don't have room for any match. @@ -826,16 +837,16 @@ # The idea is to be fast for cases like re.search("b+"), where we expect # the common case to be a non-match. It's much faster with the JIT to # have the non-match inlined here rather than detect it in the fre() call. - op = ctx.pat(ppos) + op = pattern.pat(ppos) for op1, checkerfn in unroll_char_checker: if op1 == op: - if checkerfn(ctx, ptr, ppos): + if checkerfn(ctx, pattern, ptr, ppos): break return ptr else: # obscure case: it should be a single char pattern, but isn't # one of the opcodes in unroll_char_checker (see test_ext_opcode) - return general_find_repetition_end(ctx, ppos, ptr, maxcount, marks) + return general_find_repetition_end(ctx, pattern, ppos, ptr, maxcount, marks) # It matches at least once. If maxcount == 1 (relatively common), # then we are done. if maxcount == 1: @@ -846,14 +857,14 @@ end1 = ptr + maxcount if end1 <= end: end = end1 - op = ctx.pat(ppos) + op = pattern.pat(ppos) for op1, fre in unroll_fre_checker: if op1 == op: - return fre(ctx, ptrp1, end, ppos) + return fre(ctx, pattern, ptrp1, end, ppos) raise Error("rsre.find_repetition_end[%d]" % op) @specializectx -def general_find_repetition_end(ctx, ppos, ptr, maxcount, marks): +def general_find_repetition_end(ctx, patern, ppos, ptr, maxcount, marks): # moved into its own JIT-opaque function end = ctx.end if maxcount != rsre_char.MAXREPEAT: @@ -861,63 +872,65 @@ end1 = ptr + maxcount if end1 <= end: end = end1 - while ptr < end and sre_match(ctx, ppos, ptr, marks) is not None: + while ptr < end and sre_match(ctx, patern, ppos, ptr, marks) is not None: ptr += 1 return ptr @specializectx -def match_ANY(ctx, ptr, ppos): # dot wildcard. +def match_ANY(ctx, pattern, ptr, ppos): # dot wildcard. return not rsre_char.is_linebreak(ctx.str(ptr)) -def match_ANY_ALL(ctx, ptr, ppos): +def match_ANY_ALL(ctx, pattern, ptr, ppos): return True # match anything (including a newline) @specializectx -def match_IN(ctx, ptr, ppos): - return rsre_char.check_charset(ctx, ppos+2, ctx.str(ptr)) +def match_IN(ctx, pattern, ptr, ppos): + return rsre_char.check_charset(ctx, pattern, ppos+2, ctx.str(ptr)) @specializectx -def match_IN_IGNORE(ctx, ptr, ppos): - return rsre_char.check_charset(ctx, ppos+2, ctx.lowstr(ptr)) +def match_IN_IGNORE(ctx, pattern, ptr, ppos): + return rsre_char.check_charset(ctx, pattern, ppos+2, ctx.lowstr(ptr)) @specializectx -def match_LITERAL(ctx, ptr, ppos): - return ctx.str(ptr) == ctx.pat(ppos+1) +def match_LITERAL(ctx, pattern, ptr, ppos): + return ctx.str(ptr) == pattern.pat(ppos+1) @specializectx -def match_LITERAL_IGNORE(ctx, ptr, ppos): - return ctx.lowstr(ptr) == ctx.pat(ppos+1) +def match_LITERAL_IGNORE(ctx, pattern, ptr, ppos): + return ctx.lowstr(ptr) == pattern.pat(ppos+1) @specializectx -def match_NOT_LITERAL(ctx, ptr, ppos): - return ctx.str(ptr) != ctx.pat(ppos+1) +def match_NOT_LITERAL(ctx, pattern, ptr, ppos): + return ctx.str(ptr) != pattern.pat(ppos+1) @specializectx -def match_NOT_LITERAL_IGNORE(ctx, ptr, ppos): - return ctx.lowstr(ptr) != ctx.pat(ppos+1) +def match_NOT_LITERAL_IGNORE(ctx, pattern, ptr, ppos): + return ctx.lowstr(ptr) != pattern.pat(ppos+1) def _make_fre(checkerfn): if checkerfn == match_ANY_ALL: - def fre(ctx, ptr, end, ppos): + def fre(ctx, pattern, ptr, end, ppos): return end elif checkerfn == match_IN: install_jitdriver_spec('MatchIn', - greens=['ppos', 'ctx.pattern'], + greens=['ppos', 'pattern'], reds=['ptr', 'end', 'ctx'], debugprint=(1, 0)) @specializectx - def fre(ctx, ptr, end, ppos): + def fre(ctx, pattern, ptr, end, ppos): while True: ctx.jitdriver_MatchIn.jit_merge_point(ctx=ctx, ptr=ptr, - end=end, ppos=ppos) - if ptr < end and checkerfn(ctx, ptr, ppos): + end=end, ppos=ppos, + pattern=pattern) + if ptr < end and checkerfn(ctx, pattern, ptr, ppos): ptr += 1 else: return ptr elif checkerfn == match_IN_IGNORE: install_jitdriver_spec('MatchInIgnore', - greens=['ppos', 'ctx.pattern'], + greens=['ppos', 'pattern'], reds=['ptr', 'end', 'ctx'], debugprint=(1, 0)) @specializectx - def fre(ctx, ptr, end, ppos): + def fre(ctx, pattern, ptr, end, ppos): while True: ctx.jitdriver_MatchInIgnore.jit_merge_point(ctx=ctx, ptr=ptr, - end=end, ppos=ppos) - if ptr < end and checkerfn(ctx, ptr, ppos): + end=end, ppos=ppos, + pattern=pattern) + if ptr < end and checkerfn(ctx, pattern, ptr, ppos): ptr += 1 else: return ptr @@ -925,8 +938,8 @@ # in the other cases, the fre() function is not JITted at all # and is present as a residual call. @specializectx - def fre(ctx, ptr, end, ppos): - while ptr < end and checkerfn(ctx, ptr, ppos): + def fre(ctx, pattern, ptr, end, ppos): + while ptr < end and checkerfn(ctx, pattern, ptr, ppos): ptr += 1 return ptr fre = func_with_new_name(fre, 'fre_' + checkerfn.__name__) @@ -1037,10 +1050,11 @@ return start, end def match(pattern, string, start=0, end=sys.maxint, flags=0, fullmatch=False): + assert isinstance(pattern, CompiledPattern) start, end = _adjust(start, end, len(string)) - ctx = StrMatchContext(pattern, string, start, end, flags) + ctx = StrMatchContext(string, start, end, flags) ctx.fullmatch_only = fullmatch - if match_context(ctx): + if match_context(ctx, pattern): return ctx else: return None @@ -1049,105 +1063,106 @@ return match(pattern, string, start, end, flags, fullmatch=True) def search(pattern, string, start=0, end=sys.maxint, flags=0): + assert isinstance(pattern, CompiledPattern) start, end = _adjust(start, end, len(string)) - ctx = StrMatchContext(pattern, string, start, end, flags) - if search_context(ctx): + ctx = StrMatchContext(string, start, end, flags) + if search_context(ctx, pattern): return ctx else: return None install_jitdriver('Match', - greens=['ctx.pattern'], reds=['ctx'], + greens=['pattern'], reds=['ctx'], debugprint=(0,)) -def match_context(ctx): +def match_context(ctx, pattern): ctx.original_pos = ctx.match_start if ctx.end < ctx.match_start: return False - ctx.jitdriver_Match.jit_merge_point(ctx=ctx) - return sre_match(ctx, 0, ctx.match_start, None) is not None + ctx.jitdriver_Match.jit_merge_point(ctx=ctx, pattern=pattern) + return sre_match(ctx, pattern, 0, ctx.match_start, None) is not None -def search_context(ctx): +def search_context(ctx, pattern): ctx.original_pos = ctx.match_start if ctx.end < ctx.match_start: return False base = 0 charset = False - if ctx.pat(base) == OPCODE_INFO: - flags = ctx.pat(2) + if pattern.pat(base) == OPCODE_INFO: + flags = pattern.pat(2) if flags & rsre_char.SRE_INFO_PREFIX: - if ctx.pat(5) > 1: - return fast_search(ctx) + if pattern.pat(5) > 1: + return fast_search(ctx, pattern) else: charset = (flags & rsre_char.SRE_INFO_CHARSET) - base += 1 + ctx.pat(1) - if ctx.pat(base) == OPCODE_LITERAL: - return literal_search(ctx, base) + base += 1 + pattern.pat(1) + if pattern.pat(base) == OPCODE_LITERAL: + return literal_search(ctx, pattern, base) if charset: - return charset_search(ctx, base) - return regular_search(ctx, base) + return charset_search(ctx, pattern, base) + return regular_search(ctx, pattern, base) install_jitdriver('RegularSearch', - greens=['base', 'ctx.pattern'], + greens=['base', 'pattern'], reds=['start', 'ctx'], debugprint=(1, 0)) -def regular_search(ctx, base): +def regular_search(ctx, pattern, base): start = ctx.match_start while start <= ctx.end: ctx.jitdriver_RegularSearch.jit_merge_point(ctx=ctx, start=start, - base=base) - if sre_match(ctx, base, start, None) is not None: + base=base, pattern=pattern) + if sre_match(ctx, pattern, base, start, None) is not None: ctx.match_start = start return True start += 1 return False install_jitdriver_spec("LiteralSearch", - greens=['base', 'character', 'ctx.pattern'], + greens=['base', 'character', 'pattern'], reds=['start', 'ctx'], debugprint=(2, 0, 1)) @specializectx -def literal_search(ctx, base): +def literal_search(ctx, pattern, base): # pattern starts with a literal character. this is used # for short prefixes, and if fast search is disabled - character = ctx.pat(base + 1) + character = pattern.pat(base + 1) base += 2 start = ctx.match_start while start < ctx.end: ctx.jitdriver_LiteralSearch.jit_merge_point(ctx=ctx, start=start, - base=base, character=character) + base=base, character=character, pattern=pattern) if ctx.str(start) == character: - if sre_match(ctx, base, start + 1, None) is not None: + if sre_match(ctx, pattern, base, start + 1, None) is not None: ctx.match_start = start return True start += 1 return False install_jitdriver_spec("CharsetSearch", - greens=['base', 'ctx.pattern'], + greens=['base', 'pattern'], reds=['start', 'ctx'], debugprint=(1, 0)) @specializectx -def charset_search(ctx, base): +def charset_search(ctx, pattern, base): # pattern starts with a character from a known set start = ctx.match_start while start < ctx.end: ctx.jitdriver_CharsetSearch.jit_merge_point(ctx=ctx, start=start, - base=base) - if rsre_char.check_charset(ctx, 5, ctx.str(start)): - if sre_match(ctx, base, start, None) is not None: + base=base, pattern=pattern) + if rsre_char.check_charset(ctx, pattern, 5, ctx.str(start)): + if sre_match(ctx, pattern, base, start, None) is not None: ctx.match_start = start return True start += 1 return False install_jitdriver_spec('FastSearch', - greens=['i', 'prefix_len', 'ctx.pattern'], + greens=['i', 'prefix_len', 'pattern'], reds=['string_position', 'ctx'], debugprint=(2, 0)) @specializectx -def fast_search(ctx): +def fast_search(ctx, pattern): # skips forward in a string as fast as possible using information from # an optimization info block # <INFO> <1=skip> <2=flags> <3=min> <4=...> @@ -1155,17 +1170,18 @@ string_position = ctx.match_start if string_position >= ctx.end: return False _______________________________________________ pypy-commit mailing list [email protected] https://mail.python.org/mailman/listinfo/pypy-commit
