Author: Carl Friedrich Bolz-Tereick <cfb...@gmx.de> Branch: Changeset: r94164:ac140c11bea3 Date: 2018-03-28 14:54 +0200 http://bitbucket.org/pypy/pypy/changeset/ac140c11bea3/
Log: merge fix-sre-problems: - stop switching to the blackhole interpreter in random places, which leads to arbitrary misbehaviour - disable greenfields, because their interaction with virtualizables is broken diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst --- a/pypy/doc/whatsnew-head.rst +++ b/pypy/doc/whatsnew-head.rst @@ -68,3 +68,14 @@ Optimize `Py*_Check` for `Bool`, `Float`, `Set`. Also refactor and simplify `W_PyCWrapperObject` which is used to call slots from the C-API, greatly improving microbenchmarks in https://github.com/antocuni/cpyext-benchmarks + + +.. branch: fix-sre-problems + +Fix two (unrelated) JIT bugs manifesting in the re module: + +- green fields are broken and were thus disabled, plus their usage removed from + the _sre implementation + +- in rare "trace is too long" situations, the JIT could break behaviour + arbitrarily. diff --git a/pypy/module/_cffi_backend/ccallback.py b/pypy/module/_cffi_backend/ccallback.py --- a/pypy/module/_cffi_backend/ccallback.py +++ b/pypy/module/_cffi_backend/ccallback.py @@ -232,7 +232,9 @@ "different from the 'ffi.h' file seen at compile-time)") def py_invoke(self, ll_res, ll_args): + key_pycode = self.key_pycode jitdriver1.jit_merge_point(callback=self, + key_pycode=key_pycode, ll_res=ll_res, ll_args=ll_args) self.do_invoke(ll_res, ll_args) @@ -294,7 +296,7 @@ return 'cffi_callback ' + key_pycode.get_repr() jitdriver1 = jit.JitDriver(name='cffi_callback', - greens=['callback.key_pycode'], + greens=['key_pycode'], reds=['ll_res', 'll_args', 'callback'], get_printable_location=get_printable_location1) diff --git a/pypy/module/_sre/interp_sre.py b/pypy/module/_sre/interp_sre.py --- a/pypy/module/_sre/interp_sre.py +++ b/pypy/module/_sre/interp_sre.py @@ -77,15 +77,15 @@ w_import = space.getattr(w_builtin, space.newtext("__import__")) return space.call_function(w_import, space.newtext("re")) -def matchcontext(space, ctx): +def matchcontext(space, ctx, pattern): try: - return rsre_core.match_context(ctx) + return rsre_core.match_context(ctx, pattern) except rsre_core.Error as e: raise OperationError(space.w_RuntimeError, space.newtext(e.msg)) -def searchcontext(space, ctx): +def searchcontext(space, ctx, pattern): try: - return rsre_core.search_context(ctx) + return rsre_core.search_context(ctx, pattern) except rsre_core.Error as e: raise OperationError(space.w_RuntimeError, space.newtext(e.msg)) @@ -114,7 +114,7 @@ pos = len(unicodestr) if endpos > len(unicodestr): endpos = len(unicodestr) - return rsre_core.UnicodeMatchContext(self.code, unicodestr, + return rsre_core.UnicodeMatchContext(unicodestr, pos, endpos, self.flags) elif space.isinstance_w(w_string, space.w_bytes): str = space.bytes_w(w_string) @@ -122,7 +122,7 @@ pos = len(str) if endpos > len(str): endpos = len(str) - return rsre_core.StrMatchContext(self.code, str, + return rsre_core.StrMatchContext(str, pos, endpos, self.flags) else: buf = space.readbuf_w(w_string) @@ -132,7 +132,7 @@ pos = size if endpos > size: endpos = size - return rsre_core.BufMatchContext(self.code, buf, + return rsre_core.BufMatchContext(buf, pos, endpos, self.flags) def getmatch(self, ctx, found): @@ -144,12 +144,12 @@ @unwrap_spec(pos=int, endpos=int) def match_w(self, w_string, pos=0, endpos=sys.maxint): ctx = self.make_ctx(w_string, pos, endpos) - return self.getmatch(ctx, matchcontext(self.space, ctx)) + return self.getmatch(ctx, matchcontext(self.space, ctx, self.code)) @unwrap_spec(pos=int, endpos=int) def search_w(self, w_string, pos=0, endpos=sys.maxint): ctx = self.make_ctx(w_string, pos, endpos) - return self.getmatch(ctx, searchcontext(self.space, ctx)) + return self.getmatch(ctx, searchcontext(self.space, ctx, self.code)) @unwrap_spec(pos=int, endpos=int) def findall_w(self, w_string, pos=0, endpos=sys.maxint): @@ -157,7 +157,7 @@ matchlist_w = [] ctx = self.make_ctx(w_string, pos, endpos) while ctx.match_start <= ctx.end: - if not searchcontext(space, ctx): + if not searchcontext(space, ctx, self.code): break num_groups = self.num_groups w_emptystr = space.newtext("") @@ -182,7 +182,7 @@ # this also works as the implementation of the undocumented # scanner() method. ctx = self.make_ctx(w_string, pos, endpos) - scanner = W_SRE_Scanner(self, ctx) + scanner = W_SRE_Scanner(self, ctx, self.code) return scanner @unwrap_spec(maxsplit=int) @@ -193,7 +193,7 @@ last = 0 ctx = self.make_ctx(w_string) while not maxsplit or n < maxsplit: - if not searchcontext(space, ctx): + if not searchcontext(space, ctx, self.code): break if ctx.match_start == ctx.match_end: # zero-width match if ctx.match_start == ctx.end: # or end of string @@ -274,8 +274,8 @@ else: sublist_w = [] n = last_pos = 0 + pattern = self.code while not count or n < count: - pattern = ctx.pattern sub_jitdriver.jit_merge_point( self=self, use_builder=use_builder, @@ -292,7 +292,7 @@ n=n, last_pos=last_pos, sublist_w=sublist_w ) space = self.space - if not searchcontext(space, ctx): + if not searchcontext(space, ctx, pattern): break if last_pos < ctx.match_start: _sub_append_slice( @@ -388,7 +388,11 @@ srepat.space = space srepat.w_pattern = w_pattern # the original uncompiled pattern srepat.flags = flags - srepat.code = code + # note: we assume that the app-level is caching SRE_Pattern objects, + # so that we don't need to do it here. Creating new SRE_Pattern + # objects all the time would be bad for the JIT, which relies on the + # identity of the CompiledPattern() object. + srepat.code = rsre_core.CompiledPattern(code) srepat.num_groups = groups srepat.w_groupindex = w_groupindex srepat.w_indexgroup = w_indexgroup @@ -611,10 +615,11 @@ # Our version is also directly iterable, to make finditer() easier. class W_SRE_Scanner(W_Root): - def __init__(self, pattern, ctx): + def __init__(self, pattern, ctx, code): self.space = pattern.space self.srepat = pattern self.ctx = ctx + self.code = code # 'self.ctx' is always a fresh context in which no searching # or matching succeeded so far. @@ -624,19 +629,19 @@ def next_w(self): if self.ctx.match_start > self.ctx.end: raise OperationError(self.space.w_StopIteration, self.space.w_None) - if not searchcontext(self.space, self.ctx): + if not searchcontext(self.space, self.ctx, self.code): raise OperationError(self.space.w_StopIteration, self.space.w_None) return self.getmatch(True) def match_w(self): if self.ctx.match_start > self.ctx.end: return self.space.w_None - return self.getmatch(matchcontext(self.space, self.ctx)) + return self.getmatch(matchcontext(self.space, self.ctx, self.code)) def search_w(self): if self.ctx.match_start > self.ctx.end: return self.space.w_None - return self.getmatch(searchcontext(self.space, self.ctx)) + return self.getmatch(searchcontext(self.space, self.ctx, self.code)) def getmatch(self, found): if found: diff --git a/rpython/jit/metainterp/history.py b/rpython/jit/metainterp/history.py --- a/rpython/jit/metainterp/history.py +++ b/rpython/jit/metainterp/history.py @@ -701,6 +701,9 @@ def length(self): return self.trace._count - len(self.trace.inputargs) + def trace_tag_overflow(self): + return self.trace.tag_overflow + def get_trace_position(self): return self.trace.cut_point() diff --git a/rpython/jit/metainterp/opencoder.py b/rpython/jit/metainterp/opencoder.py --- a/rpython/jit/metainterp/opencoder.py +++ b/rpython/jit/metainterp/opencoder.py @@ -49,13 +49,6 @@ way up to lltype.Signed for indexes everywhere """ -def frontend_tag_overflow(): - # Minor abstraction leak: raise directly the right exception - # expected by the rest of the machinery - from rpython.jit.metainterp import history - from rpython.rlib.jit import Counters - raise history.SwitchToBlackhole(Counters.ABORT_TOO_LONG) - class BaseTrace(object): pass @@ -293,6 +286,7 @@ self._start = len(inputargs) self._pos = self._start self.inputargs = inputargs + self.tag_overflow = False def append(self, v): model = get_model(self) @@ -300,12 +294,14 @@ # grow by 2X self._ops = self._ops + [rffi.cast(model.STORAGE_TP, 0)] * len(self._ops) if not model.MIN_VALUE <= v <= model.MAX_VALUE: - raise frontend_tag_overflow() + v = 0 # broken value, but that's fine, tracing will stop soon + self.tag_overflow = True self._ops[self._pos] = rffi.cast(model.STORAGE_TP, v) self._pos += 1 - def done(self): + def tracing_done(self): from rpython.rlib.debug import debug_start, debug_stop, debug_print + assert not self.tag_overflow self._bigints_dict = {} self._refs_dict = llhelper.new_ref_dict_3() @@ -317,8 +313,6 @@ debug_print(" ref consts: " + str(self._consts_ptr) + " " + str(len(self._refs))) debug_print(" descrs: " + str(len(self._descrs))) debug_stop("jit-trace-done") - return 0 # completely different than TraceIter.done, but we have to - # share the base class def length(self): return self._pos @@ -379,6 +373,7 @@ def record_op(self, opnum, argboxes, descr=None): pos = self._index + old_pos = self._pos self.append(opnum) expected_arity = oparity[opnum] if expected_arity == -1: @@ -397,6 +392,10 @@ self._count += 1 if opclasses[opnum].type != 'v': self._index += 1 + if self.tag_overflow: + # potentially a broken op is left behind + # clean it up + self._pos = old_pos return pos def _encode_descr(self, descr): @@ -424,10 +423,11 @@ vref_array = self._list_of_boxes(vref_boxes) s = TopSnapshot(combine_uint(jitcode.index, pc), array, vable_array, vref_array) - assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0 # guards have no descr self._snapshots.append(s) - self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1) + if not self.tag_overflow: # otherwise we're broken anyway + assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0 + self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1) return s def create_empty_top_snapshot(self, vable_boxes, vref_boxes): @@ -436,10 +436,11 @@ vref_array = self._list_of_boxes(vref_boxes) s = TopSnapshot(combine_uint(2**16 - 1, 0), [], vable_array, vref_array) - assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0 # guards have no descr self._snapshots.append(s) - self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1) + if not self.tag_overflow: # otherwise we're broken anyway + assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0 + self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1) return s def create_snapshot(self, jitcode, pc, frame, flag): diff --git a/rpython/jit/metainterp/pyjitpl.py b/rpython/jit/metainterp/pyjitpl.py --- a/rpython/jit/metainterp/pyjitpl.py +++ b/rpython/jit/metainterp/pyjitpl.py @@ -2384,9 +2384,9 @@ def blackhole_if_trace_too_long(self): warmrunnerstate = self.jitdriver_sd.warmstate - if self.history.length() > warmrunnerstate.trace_limit: + if (self.history.length() > warmrunnerstate.trace_limit or + self.history.trace_tag_overflow()): jd_sd, greenkey_of_huge_function = self.find_biggest_function() - self.history.trace.done() self.staticdata.stats.record_aborted(greenkey_of_huge_function) self.portal_trace_positions = None if greenkey_of_huge_function is not None: @@ -2689,7 +2689,9 @@ try_disabling_unroll=False, exported_state=None): num_green_args = self.jitdriver_sd.num_green_args greenkey = original_boxes[:num_green_args] - self.history.trace.done() + if self.history.trace_tag_overflow(): + raise SwitchToBlackhole(Counters.ABORT_TOO_LONG) + self.history.trace.tracing_done() if not self.partial_trace: ptoken = self.get_procedure_token(greenkey) if ptoken is not None and ptoken.target_tokens is not None: @@ -2742,7 +2744,9 @@ self.history.record(rop.JUMP, live_arg_boxes[num_green_args:], None, descr=target_jitcell_token) self.history.ends_with_jump = True - self.history.trace.done() + if self.history.trace_tag_overflow(): + raise SwitchToBlackhole(Counters.ABORT_TOO_LONG) + self.history.trace.tracing_done() try: target_token = compile.compile_trace(self, self.resumekey, live_arg_boxes[num_green_args:]) @@ -2776,7 +2780,9 @@ assert False # FIXME: can we call compile_trace? self.history.record(rop.FINISH, exits, None, descr=token) - self.history.trace.done() + if self.history.trace_tag_overflow(): + raise SwitchToBlackhole(Counters.ABORT_TOO_LONG) + self.history.trace.tracing_done() target_token = compile.compile_trace(self, self.resumekey, exits) if target_token is not token: compile.giveup() @@ -2802,7 +2808,9 @@ sd = self.staticdata token = sd.exit_frame_with_exception_descr_ref self.history.record(rop.FINISH, [valuebox], None, descr=token) - self.history.trace.done() + if self.history.trace_tag_overflow(): + raise SwitchToBlackhole(Counters.ABORT_TOO_LONG) + self.history.trace.tracing_done() target_token = compile.compile_trace(self, self.resumekey, [valuebox]) if target_token is not token: compile.giveup() diff --git a/rpython/jit/metainterp/test/test_ajit.py b/rpython/jit/metainterp/test/test_ajit.py --- a/rpython/jit/metainterp/test/test_ajit.py +++ b/rpython/jit/metainterp/test/test_ajit.py @@ -4661,3 +4661,36 @@ f() # finishes self.meta_interp(f, []) + + def test_trace_too_long_bug(self): + driver = JitDriver(greens=[], reds=['i']) + @unroll_safe + def match(s): + l = len(s) + p = 0 + for i in range(2500): # produces too long trace + c = s[p] + if c != 'a': + return False + p += 1 + if p >= l: + return True + c = s[p] + if c != '\n': + p += 1 + if p >= l: + return True + else: + return False + return True + + def f(i): + while i > 0: + driver.jit_merge_point(i=i) + match('a' * (500 * i)) + i -= 1 + return i + + res = self.meta_interp(f, [10]) + assert res == f(10) + diff --git a/rpython/jit/metainterp/test/test_greenfield.py b/rpython/jit/metainterp/test/test_greenfield.py --- a/rpython/jit/metainterp/test/test_greenfield.py +++ b/rpython/jit/metainterp/test/test_greenfield.py @@ -1,6 +1,17 @@ +import pytest from rpython.jit.metainterp.test.support import LLJitMixin from rpython.rlib.jit import JitDriver, assert_green +pytest.skip("this feature is disabled at the moment!") + +# note why it is disabled: before d721da4573ad +# there was a failing assert when inlining python -> sre -> python: +# https://bitbucket.org/pypy/pypy/issues/2775/ +# this shows, that the interaction of greenfields and virtualizables is broken, +# because greenfields use MetaInterp.virtualizable_boxes, which confuses +# MetaInterp._nonstandard_virtualizable somehow (and makes no sense +# conceptually anyway). to fix greenfields, the two mechanisms would have to be +# disentangled. class GreenFieldsTests: diff --git a/rpython/jit/metainterp/test/test_opencoder.py b/rpython/jit/metainterp/test/test_opencoder.py --- a/rpython/jit/metainterp/test/test_opencoder.py +++ b/rpython/jit/metainterp/test/test_opencoder.py @@ -209,5 +209,8 @@ def test_tag_overflow(self): t = Trace([], metainterp_sd) i0 = FakeOp(100000) - py.test.raises(SwitchToBlackhole, t.record_op, rop.FINISH, [i0]) - assert t.unpack() == ([], []) + # if we overflow, we can keep recording + for i in range(10): + t.record_op(rop.FINISH, [i0]) + assert t.unpack() == ([], []) + assert t.tag_overflow diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py --- a/rpython/rlib/jit.py +++ b/rpython/rlib/jit.py @@ -653,6 +653,9 @@ self._make_extregistryentries() assert get_jitcell_at is None, "get_jitcell_at no longer used" assert set_jitcell_at is None, "set_jitcell_at no longer used" + for green in self.greens: + if "." in green: + raise ValueError("green fields are buggy! if you need them fixed, please talk to us") self.get_printable_location = get_printable_location self.get_location = get_location self.has_unique_id = (get_unique_id is not None) diff --git a/rpython/rlib/rsre/rpy/_sre.py b/rpython/rlib/rsre/rpy/_sre.py --- a/rpython/rlib/rsre/rpy/_sre.py +++ b/rpython/rlib/rsre/rpy/_sre.py @@ -1,4 +1,4 @@ -from rpython.rlib.rsre import rsre_char +from rpython.rlib.rsre import rsre_char, rsre_core from rpython.rlib.rarithmetic import intmask VERSION = "2.7.6" @@ -12,7 +12,7 @@ pass def compile(pattern, flags, code, *args): - raise GotIt([intmask(i) for i in code], flags, args) + raise GotIt(rsre_core.CompiledPattern([intmask(i) for i in code]), flags, args) def get_code(regexp, flags=0, allargs=False): diff --git a/rpython/rlib/rsre/rsre_char.py b/rpython/rlib/rsre/rsre_char.py --- a/rpython/rlib/rsre/rsre_char.py +++ b/rpython/rlib/rsre/rsre_char.py @@ -152,17 +152,16 @@ ##### Charset evaluation @jit.unroll_safe -def check_charset(ctx, ppos, char_code): +def check_charset(ctx, pattern, ppos, char_code): """Checks whether a character matches set of arbitrary length. The set starts at pattern[ppos].""" negated = False result = False - pattern = ctx.pattern while True: - opcode = pattern[ppos] + opcode = pattern.pattern[ppos] for i, function in set_dispatch_unroll: if opcode == i: - newresult, ppos = function(ctx, ppos, char_code) + newresult, ppos = function(ctx, pattern, ppos, char_code) result |= newresult break else: @@ -177,50 +176,44 @@ return not result return result -def set_literal(ctx, index, char_code): +def set_literal(ctx, pattern, index, char_code): # <LITERAL> <code> - pat = ctx.pattern - match = pat[index+1] == char_code + match = pattern.pattern[index+1] == char_code return match, index + 2 -def set_category(ctx, index, char_code): +def set_category(ctx, pattern, index, char_code): # <CATEGORY> <code> - pat = ctx.pattern - match = category_dispatch(pat[index+1], char_code) + match = category_dispatch(pattern.pattern[index+1], char_code) return match, index + 2 -def set_charset(ctx, index, char_code): +def set_charset(ctx, pattern, index, char_code): # <CHARSET> <bitmap> (16 bits per code word) - pat = ctx.pattern if CODESIZE == 2: match = char_code < 256 and \ - (pat[index+1+(char_code >> 4)] & (1 << (char_code & 15))) + (pattern.pattern[index+1+(char_code >> 4)] & (1 << (char_code & 15))) return match, index + 17 # skip bitmap else: match = char_code < 256 and \ - (pat[index+1+(char_code >> 5)] & (1 << (char_code & 31))) + (pattern.pattern[index+1+(char_code >> 5)] & (1 << (char_code & 31))) return match, index + 9 # skip bitmap -def set_range(ctx, index, char_code): +def set_range(ctx, pattern, index, char_code): # <RANGE> <lower> <upper> - pat = ctx.pattern - match = int_between(pat[index+1], char_code, pat[index+2] + 1) + match = int_between(pattern.pattern[index+1], char_code, pattern.pattern[index+2] + 1) return match, index + 3 -def set_range_ignore(ctx, index, char_code): +def set_range_ignore(ctx, pattern, index, char_code): # <RANGE_IGNORE> <lower> <upper> # the char_code is already lower cased - pat = ctx.pattern - lower = pat[index + 1] - upper = pat[index + 2] + lower = pattern.pattern[index + 1] + upper = pattern.pattern[index + 2] match1 = int_between(lower, char_code, upper + 1) match2 = int_between(lower, getupper(char_code, ctx.flags), upper + 1) return match1 | match2, index + 3 -def set_bigcharset(ctx, index, char_code): +def set_bigcharset(ctx, pattern, index, char_code): # <BIGCHARSET> <blockcount> <256 blockindices> <blocks> - pat = ctx.pattern - count = pat[index+1] + count = pattern.pattern[index+1] index += 2 if CODESIZE == 2: @@ -238,7 +231,7 @@ return False, index shift = 5 - block = pat[index + (char_code >> (shift + 5))] + block = pattern.pattern[index + (char_code >> (shift + 5))] block_shift = char_code >> 5 if BIG_ENDIAN: @@ -247,23 +240,22 @@ block = (block >> block_shift) & 0xFF index += 256 / CODESIZE - block_value = pat[index+(block * (32 / CODESIZE) + block_value = pattern.pattern[index+(block * (32 / CODESIZE) + ((char_code & 255) >> shift))] match = (block_value & (1 << (char_code & ((8 * CODESIZE) - 1)))) index += count * (32 / CODESIZE) # skip blocks return match, index -def set_unicode_general_category(ctx, index, char_code): +def set_unicode_general_category(ctx, pattern, index, char_code): # Unicode "General category property code" (not used by Python). - # A general category is two letters. 'pat[index+1]' contains both + # A general category is two letters. 'pattern.pattern[index+1]' contains both # the first character, and the second character shifted by 8. # http://en.wikipedia.org/wiki/Unicode_character_property#General_Category # Also supports single-character categories, if the second character is 0. # Negative matches are triggered by bit number 7. assert unicodedb is not None cat = unicodedb.category(char_code) - pat = ctx.pattern - category_code = pat[index + 1] + category_code = pattern.pattern[index + 1] first_character = category_code & 0x7F second_character = (category_code >> 8) & 0x7F negative_match = category_code & 0x80 diff --git a/rpython/rlib/rsre/rsre_core.py b/rpython/rlib/rsre/rsre_core.py --- a/rpython/rlib/rsre/rsre_core.py +++ b/rpython/rlib/rsre/rsre_core.py @@ -83,35 +83,19 @@ def __init__(self, msg): self.msg = msg -class AbstractMatchContext(object): - """Abstract base class""" - _immutable_fields_ = ['pattern[*]', 'flags', 'end'] - match_start = 0 - match_end = 0 - match_marks = None - match_marks_flat = None - fullmatch_only = False - def __init__(self, pattern, match_start, end, flags): - # 'match_start' and 'end' must be known to be non-negative - # and they must not be more than len(string). - check_nonneg(match_start) - check_nonneg(end) +class CompiledPattern(object): + _immutable_fields_ = ['pattern[*]'] + + def __init__(self, pattern): self.pattern = pattern - self.match_start = match_start - self.end = end - self.flags = flags # check we don't get the old value of MAXREPEAT # during the untranslated tests if not we_are_translated(): assert 65535 not in pattern - def reset(self, start): - self.match_start = start - self.match_marks = None - self.match_marks_flat = None - def pat(self, index): + jit.promote(self) check_nonneg(index) result = self.pattern[index] # Check that we only return non-negative integers from this helper. @@ -121,6 +105,29 @@ assert result >= 0 return result +class AbstractMatchContext(object): + """Abstract base class""" + _immutable_fields_ = ['flags', 'end'] + match_start = 0 + match_end = 0 + match_marks = None + match_marks_flat = None + fullmatch_only = False + + def __init__(self, match_start, end, flags): + # 'match_start' and 'end' must be known to be non-negative + # and they must not be more than len(string). + check_nonneg(match_start) + check_nonneg(end) + self.match_start = match_start + self.end = end + self.flags = flags + + def reset(self, start): + self.match_start = start + self.match_marks = None + self.match_marks_flat = None + @not_rpython def str(self, index): """Must be overridden in a concrete subclass. @@ -183,8 +190,8 @@ _immutable_fields_ = ["_buffer"] - def __init__(self, pattern, buf, match_start, end, flags): - AbstractMatchContext.__init__(self, pattern, match_start, end, flags) + def __init__(self, buf, match_start, end, flags): + AbstractMatchContext.__init__(self, match_start, end, flags) self._buffer = buf def str(self, index): @@ -196,7 +203,7 @@ return rsre_char.getlower(c, self.flags) def fresh_copy(self, start): - return BufMatchContext(self.pattern, self._buffer, start, + return BufMatchContext(self._buffer, start, self.end, self.flags) class StrMatchContext(AbstractMatchContext): @@ -204,8 +211,8 @@ _immutable_fields_ = ["_string"] - def __init__(self, pattern, string, match_start, end, flags): - AbstractMatchContext.__init__(self, pattern, match_start, end, flags) + def __init__(self, string, match_start, end, flags): + AbstractMatchContext.__init__(self, match_start, end, flags) self._string = string if not we_are_translated() and isinstance(string, unicode): self.flags |= rsre_char.SRE_FLAG_UNICODE # for rsre_re.py @@ -219,7 +226,7 @@ return rsre_char.getlower(c, self.flags) def fresh_copy(self, start): - return StrMatchContext(self.pattern, self._string, start, + return StrMatchContext(self._string, start, self.end, self.flags) class UnicodeMatchContext(AbstractMatchContext): @@ -227,8 +234,8 @@ _immutable_fields_ = ["_unicodestr"] - def __init__(self, pattern, unicodestr, match_start, end, flags): - AbstractMatchContext.__init__(self, pattern, match_start, end, flags) + def __init__(self, unicodestr, match_start, end, flags): + AbstractMatchContext.__init__(self, match_start, end, flags) self._unicodestr = unicodestr def str(self, index): @@ -240,7 +247,7 @@ return rsre_char.getlower(c, self.flags) def fresh_copy(self, start): - return UnicodeMatchContext(self.pattern, self._unicodestr, start, + return UnicodeMatchContext(self._unicodestr, start, self.end, self.flags) # ____________________________________________________________ @@ -265,16 +272,16 @@ class MatchResult(object): subresult = None - def move_to_next_result(self, ctx): + def move_to_next_result(self, ctx, pattern): # returns either 'self' or None result = self.subresult if result is None: return - if result.move_to_next_result(ctx): + if result.move_to_next_result(ctx, pattern): return self - return self.find_next_result(ctx) + return self.find_next_result(ctx, pattern) - def find_next_result(self, ctx): + def find_next_result(self, ctx, pattern): raise NotImplementedError MATCHED_OK = MatchResult() @@ -287,11 +294,11 @@ self.start_marks = marks @jit.unroll_safe - def find_first_result(self, ctx): + def find_first_result(self, ctx, pattern): ppos = jit.hint(self.ppos, promote=True) - while ctx.pat(ppos): - result = sre_match(ctx, ppos + 1, self.start_ptr, self.start_marks) - ppos += ctx.pat(ppos) + while pattern.pat(ppos): + result = sre_match(ctx, pattern, ppos + 1, self.start_ptr, self.start_marks) + ppos += pattern.pat(ppos) if result is not None: self.subresult = result self.ppos = ppos @@ -300,7 +307,7 @@ class RepeatOneMatchResult(MatchResult): install_jitdriver('RepeatOne', - greens=['nextppos', 'ctx.pattern'], + greens=['nextppos', 'pattern'], reds=['ptr', 'self', 'ctx'], debugprint=(1, 0)) # indices in 'greens' @@ -310,13 +317,14 @@ self.start_ptr = ptr self.start_marks = marks - def find_first_result(self, ctx): + def find_first_result(self, ctx, pattern): ptr = self.start_ptr nextppos = self.nextppos while ptr >= self.minptr: ctx.jitdriver_RepeatOne.jit_merge_point( - self=self, ptr=ptr, ctx=ctx, nextppos=nextppos) - result = sre_match(ctx, nextppos, ptr, self.start_marks) + self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, + pattern=pattern) + result = sre_match(ctx, pattern, nextppos, ptr, self.start_marks) ptr -= 1 if result is not None: self.subresult = result @@ -327,7 +335,7 @@ class MinRepeatOneMatchResult(MatchResult): install_jitdriver('MinRepeatOne', - greens=['nextppos', 'ppos3', 'ctx.pattern'], + greens=['nextppos', 'ppos3', 'pattern'], reds=['ptr', 'self', 'ctx'], debugprint=(2, 0)) # indices in 'greens' @@ -338,39 +346,40 @@ self.start_ptr = ptr self.start_marks = marks - def find_first_result(self, ctx): + def find_first_result(self, ctx, pattern): ptr = self.start_ptr nextppos = self.nextppos ppos3 = self.ppos3 while ptr <= self.maxptr: ctx.jitdriver_MinRepeatOne.jit_merge_point( - self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, ppos3=ppos3) - result = sre_match(ctx, nextppos, ptr, self.start_marks) + self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, ppos3=ppos3, + pattern=pattern) + result = sre_match(ctx, pattern, nextppos, ptr, self.start_marks) if result is not None: self.subresult = result self.start_ptr = ptr return self - if not self.next_char_ok(ctx, ptr, ppos3): + if not self.next_char_ok(ctx, pattern, ptr, ppos3): break ptr += 1 - def find_next_result(self, ctx): + def find_next_result(self, ctx, pattern): ptr = self.start_ptr - if not self.next_char_ok(ctx, ptr, self.ppos3): + if not self.next_char_ok(ctx, pattern, ptr, self.ppos3): return self.start_ptr = ptr + 1 - return self.find_first_result(ctx) + return self.find_first_result(ctx, pattern) - def next_char_ok(self, ctx, ptr, ppos): + def next_char_ok(self, ctx, pattern, ptr, ppos): if ptr == ctx.end: return False - op = ctx.pat(ppos) + op = pattern.pat(ppos) for op1, checkerfn in unroll_char_checker: if op1 == op: - return checkerfn(ctx, ptr, ppos) + return checkerfn(ctx, pattern, ptr, ppos) # obscure case: it should be a single char pattern, but isn't # one of the opcodes in unroll_char_checker (see test_ext_opcode) - return sre_match(ctx, ppos, ptr, self.start_marks) is not None + return sre_match(ctx, pattern, ppos, ptr, self.start_marks) is not None class AbstractUntilMatchResult(MatchResult): @@ -391,17 +400,17 @@ class MaxUntilMatchResult(AbstractUntilMatchResult): install_jitdriver('MaxUntil', - greens=['ppos', 'tailppos', 'match_more', 'ctx.pattern'], + greens=['ppos', 'tailppos', 'match_more', 'pattern'], reds=['ptr', 'marks', 'self', 'ctx'], debugprint=(3, 0, 2)) - def find_first_result(self, ctx): - return self.search_next(ctx, match_more=True) + def find_first_result(self, ctx, pattern): + return self.search_next(ctx, pattern, match_more=True) - def find_next_result(self, ctx): - return self.search_next(ctx, match_more=False) + def find_next_result(self, ctx, pattern): + return self.search_next(ctx, pattern, match_more=False) - def search_next(self, ctx, match_more): + def search_next(self, ctx, pattern, match_more): ppos = self.ppos tailppos = self.tailppos ptr = self.cur_ptr @@ -409,12 +418,13 @@ while True: ctx.jitdriver_MaxUntil.jit_merge_point( ppos=ppos, tailppos=tailppos, match_more=match_more, - ptr=ptr, marks=marks, self=self, ctx=ctx) + ptr=ptr, marks=marks, self=self, ctx=ctx, + pattern=pattern) if match_more: - max = ctx.pat(ppos+2) + max = pattern.pat(ppos+2) if max == rsre_char.MAXREPEAT or self.num_pending < max: # try to match one more 'item' - enum = sre_match(ctx, ppos + 3, ptr, marks) + enum = sre_match(ctx, pattern, ppos + 3, ptr, marks) else: enum = None # 'max' reached, no more matches else: @@ -425,9 +435,9 @@ self.num_pending -= 1 ptr = p.ptr marks = p.marks - enum = p.enum.move_to_next_result(ctx) + enum = p.enum.move_to_next_result(ctx, pattern) # - min = ctx.pat(ppos+1) + min = pattern.pat(ppos+1) if enum is not None: # matched one more 'item'. record it and continue. last_match_length = ctx.match_end - ptr @@ -447,7 +457,7 @@ # 'item' no longer matches. if self.num_pending >= min: # try to match 'tail' if we have enough 'item' - result = sre_match(ctx, tailppos, ptr, marks) + result = sre_match(ctx, pattern, tailppos, ptr, marks) if result is not None: self.subresult = result self.cur_ptr = ptr @@ -457,23 +467,23 @@ class MinUntilMatchResult(AbstractUntilMatchResult): - def find_first_result(self, ctx): - return self.search_next(ctx, resume=False) + def find_first_result(self, ctx, pattern): + return self.search_next(ctx, pattern, resume=False) - def find_next_result(self, ctx): - return self.search_next(ctx, resume=True) + def find_next_result(self, ctx, pattern): + return self.search_next(ctx, pattern, resume=True) - def search_next(self, ctx, resume): + def search_next(self, ctx, pattern, resume): # XXX missing jit support here ppos = self.ppos - min = ctx.pat(ppos+1) - max = ctx.pat(ppos+2) + min = pattern.pat(ppos+1) + max = pattern.pat(ppos+2) ptr = self.cur_ptr marks = self.cur_marks while True: # try to match 'tail' if we have enough 'item' if not resume and self.num_pending >= min: - result = sre_match(ctx, self.tailppos, ptr, marks) + result = sre_match(ctx, pattern, self.tailppos, ptr, marks) if result is not None: self.subresult = result self.cur_ptr = ptr @@ -483,12 +493,12 @@ if max == rsre_char.MAXREPEAT or self.num_pending < max: # try to match one more 'item' - enum = sre_match(ctx, ppos + 3, ptr, marks) + enum = sre_match(ctx, pattern, ppos + 3, ptr, marks) # # zero-width match protection if self.num_pending >= min: while enum is not None and ptr == ctx.match_end: - enum = enum.move_to_next_result(ctx) + enum = enum.move_to_next_result(ctx, pattern) else: enum = None # 'max' reached, no more matches @@ -502,7 +512,7 @@ self.num_pending -= 1 ptr = p.ptr marks = p.marks - enum = p.enum.move_to_next_result(ctx) + enum = p.enum.move_to_next_result(ctx, pattern) # matched one more 'item'. record it and continue self.pending = Pending(ptr, marks, enum, self.pending) @@ -514,13 +524,13 @@ @specializectx @jit.unroll_safe -def sre_match(ctx, ppos, ptr, marks): +def sre_match(ctx, pattern, ppos, ptr, marks): """Returns either None or a MatchResult object. Usually we only need the first result, but there is the case of REPEAT...UNTIL where we need all results; in that case we use the method move_to_next_result() of the MatchResult.""" while True: - op = ctx.pat(ppos) + op = pattern.pat(ppos) ppos += 1 #jit.jit_debug("sre_match", op, ppos, ptr) @@ -563,33 +573,33 @@ elif op == OPCODE_ASSERT: # assert subpattern # <ASSERT> <0=skip> <1=back> <pattern> - ptr1 = ptr - ctx.pat(ppos+1) + ptr1 = ptr - pattern.pat(ppos+1) saved = ctx.fullmatch_only ctx.fullmatch_only = False - stop = ptr1 < 0 or sre_match(ctx, ppos + 2, ptr1, marks) is None + stop = ptr1 < 0 or sre_match(ctx, pattern, ppos + 2, ptr1, marks) is None ctx.fullmatch_only = saved if stop: return marks = ctx.match_marks - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) elif op == OPCODE_ASSERT_NOT: # assert not subpattern # <ASSERT_NOT> <0=skip> <1=back> <pattern> - ptr1 = ptr - ctx.pat(ppos+1) + ptr1 = ptr - pattern.pat(ppos+1) saved = ctx.fullmatch_only ctx.fullmatch_only = False - stop = (ptr1 >= 0 and sre_match(ctx, ppos + 2, ptr1, marks) + stop = (ptr1 >= 0 and sre_match(ctx, pattern, ppos + 2, ptr1, marks) is not None) ctx.fullmatch_only = saved if stop: return - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) elif op == OPCODE_AT: # match at given position (e.g. at beginning, at boundary, etc.) # <AT> <code> - if not sre_at(ctx, ctx.pat(ppos), ptr): + if not sre_at(ctx, pattern.pat(ppos), ptr): return ppos += 1 @@ -597,14 +607,14 @@ # alternation # <BRANCH> <0=skip> code <JUMP> ... <NULL> result = BranchMatchResult(ppos, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) elif op == OPCODE_CATEGORY: # seems to be never produced, but used by some tests from # pypy/module/_sre/test # <CATEGORY> <category> if (ptr == ctx.end or - not rsre_char.category_dispatch(ctx.pat(ppos), ctx.str(ptr))): + not rsre_char.category_dispatch(pattern.pat(ppos), ctx.str(ptr))): return ptr += 1 ppos += 1 @@ -612,7 +622,7 @@ elif op == OPCODE_GROUPREF: # match backreference # <GROUPREF> <groupnum> - startptr, length = get_group_ref(marks, ctx.pat(ppos)) + startptr, length = get_group_ref(marks, pattern.pat(ppos)) if length < 0: return # group was not previously defined if not match_repeated(ctx, ptr, startptr, length): @@ -623,7 +633,7 @@ elif op == OPCODE_GROUPREF_IGNORE: # match backreference # <GROUPREF> <groupnum> - startptr, length = get_group_ref(marks, ctx.pat(ppos)) + startptr, length = get_group_ref(marks, pattern.pat(ppos)) if length < 0: return # group was not previously defined if not match_repeated_ignore(ctx, ptr, startptr, length): @@ -634,44 +644,44 @@ elif op == OPCODE_GROUPREF_EXISTS: # conditional match depending on the existence of a group # <GROUPREF_EXISTS> <group> <skip> codeyes <JUMP> codeno ... - _, length = get_group_ref(marks, ctx.pat(ppos)) + _, length = get_group_ref(marks, pattern.pat(ppos)) if length >= 0: ppos += 2 # jump to 'codeyes' else: - ppos += ctx.pat(ppos+1) # jump to 'codeno' + ppos += pattern.pat(ppos+1) # jump to 'codeno' elif op == OPCODE_IN: # match set member (or non_member) # <IN> <skip> <set> - if ptr >= ctx.end or not rsre_char.check_charset(ctx, ppos+1, + if ptr >= ctx.end or not rsre_char.check_charset(ctx, pattern, ppos+1, ctx.str(ptr)): return - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) ptr += 1 elif op == OPCODE_IN_IGNORE: # match set member (or non_member), ignoring case # <IN> <skip> <set> - if ptr >= ctx.end or not rsre_char.check_charset(ctx, ppos+1, + if ptr >= ctx.end or not rsre_char.check_charset(ctx, pattern, ppos+1, ctx.lowstr(ptr)): return - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) ptr += 1 elif op == OPCODE_INFO: # optimization info block # <INFO> <0=skip> <1=flags> <2=min> ... - if (ctx.end - ptr) < ctx.pat(ppos+2): + if (ctx.end - ptr) < pattern.pat(ppos+2): return - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) elif op == OPCODE_JUMP: - ppos += ctx.pat(ppos) + ppos += pattern.pat(ppos) elif op == OPCODE_LITERAL: # match literal string # <LITERAL> <code> - if ptr >= ctx.end or ctx.str(ptr) != ctx.pat(ppos): + if ptr >= ctx.end or ctx.str(ptr) != pattern.pat(ppos): return ppos += 1 ptr += 1 @@ -679,7 +689,7 @@ elif op == OPCODE_LITERAL_IGNORE: # match literal string, ignoring case # <LITERAL_IGNORE> <code> - if ptr >= ctx.end or ctx.lowstr(ptr) != ctx.pat(ppos): + if ptr >= ctx.end or ctx.lowstr(ptr) != pattern.pat(ppos): return ppos += 1 ptr += 1 @@ -687,14 +697,14 @@ elif op == OPCODE_MARK: # set mark # <MARK> <gid> - gid = ctx.pat(ppos) + gid = pattern.pat(ppos) marks = Mark(gid, ptr, marks) ppos += 1 elif op == OPCODE_NOT_LITERAL: # match if it's not a literal string # <NOT_LITERAL> <code> - if ptr >= ctx.end or ctx.str(ptr) == ctx.pat(ppos): + if ptr >= ctx.end or ctx.str(ptr) == pattern.pat(ppos): return ppos += 1 ptr += 1 @@ -702,7 +712,7 @@ elif op == OPCODE_NOT_LITERAL_IGNORE: # match if it's not a literal string, ignoring case # <NOT_LITERAL> <code> - if ptr >= ctx.end or ctx.lowstr(ptr) == ctx.pat(ppos): + if ptr >= ctx.end or ctx.lowstr(ptr) == pattern.pat(ppos): return ppos += 1 ptr += 1 @@ -715,22 +725,22 @@ # decode the later UNTIL operator to see if it is actually # a MAX_UNTIL or MIN_UNTIL - untilppos = ppos + ctx.pat(ppos) + untilppos = ppos + pattern.pat(ppos) tailppos = untilppos + 1 - op = ctx.pat(untilppos) + op = pattern.pat(untilppos) if op == OPCODE_MAX_UNTIL: # the hard case: we have to match as many repetitions as # possible, followed by the 'tail'. we do this by # remembering each state for each possible number of # 'item' matching. result = MaxUntilMatchResult(ppos, tailppos, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) elif op == OPCODE_MIN_UNTIL: # first try to match the 'tail', and if it fails, try # to match one more 'item' and try again result = MinUntilMatchResult(ppos, tailppos, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) else: raise Error("missing UNTIL after REPEAT") @@ -743,17 +753,18 @@ # use the MAX_REPEAT operator. # <REPEAT_ONE> <skip> <1=min> <2=max> item <SUCCESS> tail start = ptr - minptr = start + ctx.pat(ppos+1) + minptr = start + pattern.pat(ppos+1) if minptr > ctx.end: return # cannot match - ptr = find_repetition_end(ctx, ppos+3, start, ctx.pat(ppos+2), + ptr = find_repetition_end(ctx, pattern, ppos+3, start, + pattern.pat(ppos+2), marks) # when we arrive here, ptr points to the tail of the target # string. check if the rest of the pattern matches, # and backtrack if not. - nextppos = ppos + ctx.pat(ppos) + nextppos = ppos + pattern.pat(ppos) result = RepeatOneMatchResult(nextppos, minptr, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) elif op == OPCODE_MIN_REPEAT_ONE: # match repeated sequence (minimizing regexp). @@ -763,26 +774,26 @@ # use the MIN_REPEAT operator. # <MIN_REPEAT_ONE> <skip> <1=min> <2=max> item <SUCCESS> tail start = ptr - min = ctx.pat(ppos+1) + min = pattern.pat(ppos+1) if min > 0: minptr = ptr + min if minptr > ctx.end: return # cannot match # count using pattern min as the maximum - ptr = find_repetition_end(ctx, ppos+3, ptr, min, marks) + ptr = find_repetition_end(ctx, pattern, ppos+3, ptr, min, marks) if ptr < minptr: return # did not match minimum number of times maxptr = ctx.end - max = ctx.pat(ppos+2) + max = pattern.pat(ppos+2) if max != rsre_char.MAXREPEAT: maxptr1 = start + max if maxptr1 <= maxptr: maxptr = maxptr1 - nextppos = ppos + ctx.pat(ppos) + nextppos = ppos + pattern.pat(ppos) result = MinRepeatOneMatchResult(nextppos, ppos+3, maxptr, ptr, marks) - return result.find_first_result(ctx) + return result.find_first_result(ctx, pattern) else: raise Error("bad pattern code %d" % op) @@ -816,7 +827,7 @@ return True @specializectx -def find_repetition_end(ctx, ppos, ptr, maxcount, marks): +def find_repetition_end(ctx, pattern, ppos, ptr, maxcount, marks): end = ctx.end ptrp1 = ptr + 1 # First get rid of the cases where we don't have room for any match. @@ -826,16 +837,16 @@ # The idea is to be fast for cases like re.search("b+"), where we expect # the common case to be a non-match. It's much faster with the JIT to # have the non-match inlined here rather than detect it in the fre() call. - op = ctx.pat(ppos) + op = pattern.pat(ppos) for op1, checkerfn in unroll_char_checker: if op1 == op: - if checkerfn(ctx, ptr, ppos): + if checkerfn(ctx, pattern, ptr, ppos): break return ptr else: # obscure case: it should be a single char pattern, but isn't # one of the opcodes in unroll_char_checker (see test_ext_opcode) - return general_find_repetition_end(ctx, ppos, ptr, maxcount, marks) + return general_find_repetition_end(ctx, pattern, ppos, ptr, maxcount, marks) # It matches at least once. If maxcount == 1 (relatively common), # then we are done. if maxcount == 1: @@ -846,14 +857,14 @@ end1 = ptr + maxcount if end1 <= end: end = end1 - op = ctx.pat(ppos) + op = pattern.pat(ppos) for op1, fre in unroll_fre_checker: if op1 == op: - return fre(ctx, ptrp1, end, ppos) + return fre(ctx, pattern, ptrp1, end, ppos) raise Error("rsre.find_repetition_end[%d]" % op) @specializectx -def general_find_repetition_end(ctx, ppos, ptr, maxcount, marks): +def general_find_repetition_end(ctx, patern, ppos, ptr, maxcount, marks): # moved into its own JIT-opaque function end = ctx.end if maxcount != rsre_char.MAXREPEAT: @@ -861,63 +872,65 @@ end1 = ptr + maxcount if end1 <= end: end = end1 - while ptr < end and sre_match(ctx, ppos, ptr, marks) is not None: + while ptr < end and sre_match(ctx, patern, ppos, ptr, marks) is not None: ptr += 1 return ptr @specializectx -def match_ANY(ctx, ptr, ppos): # dot wildcard. +def match_ANY(ctx, pattern, ptr, ppos): # dot wildcard. return not rsre_char.is_linebreak(ctx.str(ptr)) -def match_ANY_ALL(ctx, ptr, ppos): +def match_ANY_ALL(ctx, pattern, ptr, ppos): return True # match anything (including a newline) @specializectx -def match_IN(ctx, ptr, ppos): - return rsre_char.check_charset(ctx, ppos+2, ctx.str(ptr)) +def match_IN(ctx, pattern, ptr, ppos): + return rsre_char.check_charset(ctx, pattern, ppos+2, ctx.str(ptr)) @specializectx -def match_IN_IGNORE(ctx, ptr, ppos): - return rsre_char.check_charset(ctx, ppos+2, ctx.lowstr(ptr)) +def match_IN_IGNORE(ctx, pattern, ptr, ppos): + return rsre_char.check_charset(ctx, pattern, ppos+2, ctx.lowstr(ptr)) @specializectx -def match_LITERAL(ctx, ptr, ppos): - return ctx.str(ptr) == ctx.pat(ppos+1) +def match_LITERAL(ctx, pattern, ptr, ppos): + return ctx.str(ptr) == pattern.pat(ppos+1) @specializectx -def match_LITERAL_IGNORE(ctx, ptr, ppos): - return ctx.lowstr(ptr) == ctx.pat(ppos+1) +def match_LITERAL_IGNORE(ctx, pattern, ptr, ppos): + return ctx.lowstr(ptr) == pattern.pat(ppos+1) @specializectx -def match_NOT_LITERAL(ctx, ptr, ppos): - return ctx.str(ptr) != ctx.pat(ppos+1) +def match_NOT_LITERAL(ctx, pattern, ptr, ppos): + return ctx.str(ptr) != pattern.pat(ppos+1) @specializectx -def match_NOT_LITERAL_IGNORE(ctx, ptr, ppos): - return ctx.lowstr(ptr) != ctx.pat(ppos+1) +def match_NOT_LITERAL_IGNORE(ctx, pattern, ptr, ppos): + return ctx.lowstr(ptr) != pattern.pat(ppos+1) def _make_fre(checkerfn): if checkerfn == match_ANY_ALL: - def fre(ctx, ptr, end, ppos): + def fre(ctx, pattern, ptr, end, ppos): return end elif checkerfn == match_IN: install_jitdriver_spec('MatchIn', - greens=['ppos', 'ctx.pattern'], + greens=['ppos', 'pattern'], reds=['ptr', 'end', 'ctx'], debugprint=(1, 0)) @specializectx - def fre(ctx, ptr, end, ppos): + def fre(ctx, pattern, ptr, end, ppos): while True: ctx.jitdriver_MatchIn.jit_merge_point(ctx=ctx, ptr=ptr, - end=end, ppos=ppos) - if ptr < end and checkerfn(ctx, ptr, ppos): + end=end, ppos=ppos, + pattern=pattern) + if ptr < end and checkerfn(ctx, pattern, ptr, ppos): ptr += 1 else: return ptr elif checkerfn == match_IN_IGNORE: install_jitdriver_spec('MatchInIgnore', - greens=['ppos', 'ctx.pattern'], + greens=['ppos', 'pattern'], reds=['ptr', 'end', 'ctx'], debugprint=(1, 0)) @specializectx - def fre(ctx, ptr, end, ppos): + def fre(ctx, pattern, ptr, end, ppos): while True: ctx.jitdriver_MatchInIgnore.jit_merge_point(ctx=ctx, ptr=ptr, - end=end, ppos=ppos) - if ptr < end and checkerfn(ctx, ptr, ppos): + end=end, ppos=ppos, + pattern=pattern) + if ptr < end and checkerfn(ctx, pattern, ptr, ppos): ptr += 1 else: return ptr @@ -925,8 +938,8 @@ # in the other cases, the fre() function is not JITted at all # and is present as a residual call. @specializectx - def fre(ctx, ptr, end, ppos): - while ptr < end and checkerfn(ctx, ptr, ppos): + def fre(ctx, pattern, ptr, end, ppos): + while ptr < end and checkerfn(ctx, pattern, ptr, ppos): ptr += 1 return ptr fre = func_with_new_name(fre, 'fre_' + checkerfn.__name__) @@ -1037,10 +1050,11 @@ return start, end def match(pattern, string, start=0, end=sys.maxint, flags=0, fullmatch=False): + assert isinstance(pattern, CompiledPattern) start, end = _adjust(start, end, len(string)) - ctx = StrMatchContext(pattern, string, start, end, flags) + ctx = StrMatchContext(string, start, end, flags) ctx.fullmatch_only = fullmatch - if match_context(ctx): + if match_context(ctx, pattern): return ctx else: return None @@ -1049,105 +1063,106 @@ return match(pattern, string, start, end, flags, fullmatch=True) def search(pattern, string, start=0, end=sys.maxint, flags=0): + assert isinstance(pattern, CompiledPattern) start, end = _adjust(start, end, len(string)) - ctx = StrMatchContext(pattern, string, start, end, flags) - if search_context(ctx): + ctx = StrMatchContext(string, start, end, flags) + if search_context(ctx, pattern): return ctx else: return None install_jitdriver('Match', - greens=['ctx.pattern'], reds=['ctx'], + greens=['pattern'], reds=['ctx'], debugprint=(0,)) -def match_context(ctx): +def match_context(ctx, pattern): ctx.original_pos = ctx.match_start if ctx.end < ctx.match_start: return False - ctx.jitdriver_Match.jit_merge_point(ctx=ctx) - return sre_match(ctx, 0, ctx.match_start, None) is not None + ctx.jitdriver_Match.jit_merge_point(ctx=ctx, pattern=pattern) + return sre_match(ctx, pattern, 0, ctx.match_start, None) is not None -def search_context(ctx): +def search_context(ctx, pattern): ctx.original_pos = ctx.match_start if ctx.end < ctx.match_start: return False base = 0 charset = False - if ctx.pat(base) == OPCODE_INFO: - flags = ctx.pat(2) + if pattern.pat(base) == OPCODE_INFO: + flags = pattern.pat(2) if flags & rsre_char.SRE_INFO_PREFIX: - if ctx.pat(5) > 1: - return fast_search(ctx) + if pattern.pat(5) > 1: + return fast_search(ctx, pattern) else: charset = (flags & rsre_char.SRE_INFO_CHARSET) - base += 1 + ctx.pat(1) - if ctx.pat(base) == OPCODE_LITERAL: - return literal_search(ctx, base) + base += 1 + pattern.pat(1) + if pattern.pat(base) == OPCODE_LITERAL: + return literal_search(ctx, pattern, base) if charset: - return charset_search(ctx, base) - return regular_search(ctx, base) + return charset_search(ctx, pattern, base) + return regular_search(ctx, pattern, base) install_jitdriver('RegularSearch', - greens=['base', 'ctx.pattern'], + greens=['base', 'pattern'], reds=['start', 'ctx'], debugprint=(1, 0)) -def regular_search(ctx, base): +def regular_search(ctx, pattern, base): start = ctx.match_start while start <= ctx.end: ctx.jitdriver_RegularSearch.jit_merge_point(ctx=ctx, start=start, - base=base) - if sre_match(ctx, base, start, None) is not None: + base=base, pattern=pattern) + if sre_match(ctx, pattern, base, start, None) is not None: ctx.match_start = start return True start += 1 return False install_jitdriver_spec("LiteralSearch", - greens=['base', 'character', 'ctx.pattern'], + greens=['base', 'character', 'pattern'], reds=['start', 'ctx'], debugprint=(2, 0, 1)) @specializectx -def literal_search(ctx, base): +def literal_search(ctx, pattern, base): # pattern starts with a literal character. this is used # for short prefixes, and if fast search is disabled - character = ctx.pat(base + 1) + character = pattern.pat(base + 1) base += 2 start = ctx.match_start while start < ctx.end: ctx.jitdriver_LiteralSearch.jit_merge_point(ctx=ctx, start=start, - base=base, character=character) + base=base, character=character, pattern=pattern) if ctx.str(start) == character: - if sre_match(ctx, base, start + 1, None) is not None: + if sre_match(ctx, pattern, base, start + 1, None) is not None: ctx.match_start = start return True start += 1 return False install_jitdriver_spec("CharsetSearch", - greens=['base', 'ctx.pattern'], + greens=['base', 'pattern'], reds=['start', 'ctx'], debugprint=(1, 0)) @specializectx -def charset_search(ctx, base): +def charset_search(ctx, pattern, base): # pattern starts with a character from a known set start = ctx.match_start while start < ctx.end: ctx.jitdriver_CharsetSearch.jit_merge_point(ctx=ctx, start=start, - base=base) - if rsre_char.check_charset(ctx, 5, ctx.str(start)): - if sre_match(ctx, base, start, None) is not None: + base=base, pattern=pattern) + if rsre_char.check_charset(ctx, pattern, 5, ctx.str(start)): + if sre_match(ctx, pattern, base, start, None) is not None: ctx.match_start = start return True start += 1 return False install_jitdriver_spec('FastSearch', - greens=['i', 'prefix_len', 'ctx.pattern'], + greens=['i', 'prefix_len', 'pattern'], reds=['string_position', 'ctx'], debugprint=(2, 0)) @specializectx -def fast_search(ctx): +def fast_search(ctx, pattern): # skips forward in a string as fast as possible using information from # an optimization info block # <INFO> <1=skip> <2=flags> <3=min> <4=...> @@ -1155,17 +1170,18 @@ string_position = ctx.match_start if string_position >= ctx.end: return False - prefix_len = ctx.pat(5) + prefix_len = pattern.pat(5) assert prefix_len >= 0 i = 0 while True: ctx.jitdriver_FastSearch.jit_merge_point(ctx=ctx, - string_position=string_position, i=i, prefix_len=prefix_len) + string_position=string_position, i=i, prefix_len=prefix_len, + pattern=pattern) char_ord = ctx.str(string_position) - if char_ord != ctx.pat(7 + i): + if char_ord != pattern.pat(7 + i): if i > 0: overlap_offset = prefix_len + (7 - 1) - i = ctx.pat(overlap_offset + i) + i = pattern.pat(overlap_offset + i) continue else: i += 1 @@ -1173,22 +1189,22 @@ # found a potential match start = string_position + 1 - prefix_len assert start >= 0 - prefix_skip = ctx.pat(6) + prefix_skip = pattern.pat(6) ptr = start + prefix_skip - #flags = ctx.pat(2) + #flags = pattern.pat(2) #if flags & rsre_char.SRE_INFO_LITERAL: # # matched all of pure literal pattern # ctx.match_start = start # ctx.match_end = ptr # ctx.match_marks = None # return True - pattern_offset = ctx.pat(1) + 1 + pattern_offset = pattern.pat(1) + 1 ppos_start = pattern_offset + 2 * prefix_skip - if sre_match(ctx, ppos_start, ptr, None) is not None: + if sre_match(ctx, pattern, ppos_start, ptr, None) is not None: ctx.match_start = start return True overlap_offset = prefix_len + (7 - 1) - i = ctx.pat(overlap_offset + i) + i = pattern.pat(overlap_offset + i) string_position += 1 if string_position >= ctx.end: return False diff --git a/rpython/rlib/rsre/test/test_char.py b/rpython/rlib/rsre/test/test_char.py --- a/rpython/rlib/rsre/test/test_char.py +++ b/rpython/rlib/rsre/test/test_char.py @@ -1,10 +1,16 @@ -from rpython.rlib.rsre import rsre_char +from rpython.rlib.rsre import rsre_char, rsre_core from rpython.rlib.rsre.rsre_char import SRE_FLAG_LOCALE, SRE_FLAG_UNICODE def setup_module(mod): from rpython.rlib.unicodedata import unicodedb rsre_char.set_unicode_db(unicodedb) + +def check_charset(pattern, idx, char): + p = rsre_core.CompiledPattern(pattern) + return rsre_char.check_charset(Ctx(p), p, idx, char) + + UPPER_PI = 0x3a0 LOWER_PI = 0x3c0 INDIAN_DIGIT = 0x966 @@ -157,12 +163,12 @@ pat_neg = [70, ord(cat) | 0x80, 0] for c in positive: assert unicodedb.category(ord(c)).startswith(cat) - assert rsre_char.check_charset(Ctx(pat_pos), 0, ord(c)) - assert not rsre_char.check_charset(Ctx(pat_neg), 0, ord(c)) + assert check_charset(pat_pos, 0, ord(c)) + assert not check_charset(pat_neg, 0, ord(c)) for c in negative: assert not unicodedb.category(ord(c)).startswith(cat) - assert not rsre_char.check_charset(Ctx(pat_pos), 0, ord(c)) - assert rsre_char.check_charset(Ctx(pat_neg), 0, ord(c)) + assert not check_charset(pat_pos, 0, ord(c)) + assert check_charset(pat_neg, 0, ord(c)) def cat2num(cat): return ord(cat[0]) | (ord(cat[1]) << 8) @@ -173,17 +179,16 @@ pat_neg = [70, cat2num(cat) | 0x80, 0] for c in positive: assert unicodedb.category(ord(c)) == cat - assert rsre_char.check_charset(Ctx(pat_pos), 0, ord(c)) - assert not rsre_char.check_charset(Ctx(pat_neg), 0, ord(c)) + assert check_charset(pat_pos, 0, ord(c)) + assert not check_charset(pat_neg, 0, ord(c)) for c in negative: assert unicodedb.category(ord(c)) != cat - assert not rsre_char.check_charset(Ctx(pat_pos), 0, ord(c)) - assert rsre_char.check_charset(Ctx(pat_neg), 0, ord(c)) + assert not check_charset(pat_pos, 0, ord(c)) + assert check_charset(pat_neg, 0, ord(c)) # test for how the common 'L&' pattern might be compiled pat = [70, cat2num('Lu'), 70, cat2num('Ll'), 70, cat2num('Lt'), 0] - assert rsre_char.check_charset(Ctx(pat), 0, 65) # Lu - assert rsre_char.check_charset(Ctx(pat), 0, 99) # Ll - assert rsre_char.check_charset(Ctx(pat), 0, 453) # Lt - assert not rsre_char.check_charset(Ctx(pat), 0, 688) # Lm - assert not rsre_char.check_charset(Ctx(pat), 0, 5870) # Nl + assert check_charset(pat, 0, 65) # Lu + assert check_charset(pat, 0, 99) # Lcheck_charset(pat, 0, 453) # Lt + assert not check_charset(pat, 0, 688) # Lm + assert not check_charset(pat, 0, 5870) # Nl diff --git a/rpython/rlib/rsre/test/test_ext_opcode.py b/rpython/rlib/rsre/test/test_ext_opcode.py --- a/rpython/rlib/rsre/test/test_ext_opcode.py +++ b/rpython/rlib/rsre/test/test_ext_opcode.py @@ -17,10 +17,10 @@ # it's a valid optimization because \1 is always one character long r = [MARK, 0, ANY, MARK, 1, REPEAT_ONE, 6, 0, MAXREPEAT, GROUPREF, 0, SUCCESS, SUCCESS] - assert rsre_core.match(r, "aaa").match_end == 3 + assert rsre_core.match(rsre_core.CompiledPattern(r), "aaa").match_end == 3 def test_min_repeat_one_with_backref(): # Python 3.5 compiles "(.)\1*?b" using MIN_REPEAT_ONE r = [MARK, 0, ANY, MARK, 1, MIN_REPEAT_ONE, 6, 0, MAXREPEAT, GROUPREF, 0, SUCCESS, LITERAL, 98, SUCCESS] - assert rsre_core.match(r, "aaab").match_end == 4 + assert rsre_core.match(rsre_core.CompiledPattern(r), "aaab").match_end == 4 diff --git a/rpython/rlib/rsre/test/test_match.py b/rpython/rlib/rsre/test/test_match.py --- a/rpython/rlib/rsre/test/test_match.py +++ b/rpython/rlib/rsre/test/test_match.py @@ -9,7 +9,7 @@ def test_get_code_repetition(): c1 = get_code(r"a+") c2 = get_code(r"a+") - assert c1 == c2 + assert c1.pattern == c2.pattern class TestMatch: @@ -305,6 +305,6 @@ rsre_char.set_unicode_db(unicodedb) # r = get_code(u"[\U00010428-\U0001044f]", re.I) - assert r.count(27) == 1 # OPCODE_RANGE - r[r.index(27)] = 32 # => OPCODE_RANGE_IGNORE + assert r.pattern.count(27) == 1 # OPCODE_RANGE + r.pattern[r.pattern.index(27)] = 32 # => OPCODE_RANGE_IGNORE assert rsre_core.match(r, u"\U00010428") diff --git a/rpython/rlib/rsre/test/test_re.py b/rpython/rlib/rsre/test/test_re.py --- a/rpython/rlib/rsre/test/test_re.py +++ b/rpython/rlib/rsre/test/test_re.py @@ -426,31 +426,6 @@ assert pat.match(p) is not None assert pat.match(p).span() == (0,256) - def test_pickling(self): - import pickle - self.pickle_test(pickle) - import cPickle - self.pickle_test(cPickle) - # old pickles expect the _compile() reconstructor in sre module - import warnings - original_filters = warnings.filters[:] - try: - warnings.filterwarnings("ignore", "The sre module is deprecated", - DeprecationWarning) - from sre import _compile - finally: - warnings.filters = original_filters - - def pickle_test(self, pickle): - oldpat = re.compile('a(?:b|(c|e){1,2}?|d)+?(.)') - s = pickle.dumps(oldpat) - newpat = pickle.loads(s) - # Not using object identity for _sre.py, since some Python builds do - # not seem to preserve that in all cases (observed on an UCS-4 build - # of 2.4.1). - #self.assertEqual(oldpat, newpat) - assert oldpat.__dict__ == newpat.__dict__ - def test_constants(self): assert re.I == re.IGNORECASE assert re.L == re.LOCALE diff --git a/rpython/rlib/rsre/test/test_zinterp.py b/rpython/rlib/rsre/test/test_zinterp.py --- a/rpython/rlib/rsre/test/test_zinterp.py +++ b/rpython/rlib/rsre/test/test_zinterp.py @@ -11,6 +11,7 @@ rsre_core.search(pattern, string) # unicodestr = unichr(n) * n + pattern = rsre_core.CompiledPattern(pattern) ctx = rsre_core.UnicodeMatchContext(pattern, unicodestr, 0, len(unicodestr), 0) rsre_core.search_context(ctx) diff --git a/rpython/rlib/rsre/test/test_zjit.py b/rpython/rlib/rsre/test/test_zjit.py --- a/rpython/rlib/rsre/test/test_zjit.py +++ b/rpython/rlib/rsre/test/test_zjit.py @@ -6,18 +6,20 @@ from rpython.rtyper.annlowlevel import llstr, hlstr def entrypoint1(r, string, repeat): - r = array2list(r) + r = rsre_core.CompiledPattern(array2list(r)) string = hlstr(string) match = None for i in range(repeat): match = rsre_core.match(r, string) + if match is None: + return -1 if match is None: return -1 else: return match.match_end def entrypoint2(r, string, repeat): - r = array2list(r) + r = rsre_core.CompiledPattern(array2list(r)) string = hlstr(string) match = None for i in range(repeat): @@ -48,13 +50,13 @@ def meta_interp_match(self, pattern, string, repeat=1): r = get_code(pattern) - return self.meta_interp(entrypoint1, [list2array(r), llstr(string), + return self.meta_interp(entrypoint1, [list2array(r.pattern), llstr(string), repeat], listcomp=True, backendopt=True) def meta_interp_search(self, pattern, string, repeat=1): r = get_code(pattern) - return self.meta_interp(entrypoint2, [list2array(r), llstr(string), + return self.meta_interp(entrypoint2, [list2array(r.pattern), llstr(string), repeat], listcomp=True, backendopt=True) @@ -166,3 +168,9 @@ res = self.meta_interp_search(r"b+", "a"*30 + "b") assert res == 30 self.check_resops(call=0) + + def test_match_jit_bug(self): + pattern = ".a" * 2500 + text = "a" * 6000 + res = self.meta_interp_match(pattern, text, repeat=10) + assert res != -1 diff --git a/rpython/rlib/test/test_jit.py b/rpython/rlib/test/test_jit.py --- a/rpython/rlib/test/test_jit.py +++ b/rpython/rlib/test/test_jit.py @@ -225,8 +225,10 @@ def test_green_field(self): def get_printable_location(xfoo): return str(ord(xfoo)) # xfoo must be annotated as a character - myjitdriver = JitDriver(greens=['x.foo'], reds=['n', 'x'], + # green fields are disabled! + pytest.raises(ValueError, JitDriver, greens=['x.foo'], reds=['n', 'x'], get_printable_location=get_printable_location) + return class A(object): _immutable_fields_ = ['foo'] def fn(n): _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit