Author: Matti Picus <[email protected]>
Branch: unicode-utf8
Changeset: r94459:ab5ac9802e14
Date: 2018-04-29 23:28 +0300
http://bitbucket.org/pypy/pypy/changeset/ab5ac9802e14/

Log:    fix merge

diff --git a/pypy/module/_sre/interp_sre.py b/pypy/module/_sre/interp_sre.py
--- a/pypy/module/_sre/interp_sre.py
+++ b/pypy/module/_sre/interp_sre.py
@@ -133,7 +133,7 @@
                 endbytepos = rutf8.codepoint_position_at_index(utf8str,
                                 index_storage, endpos)
             ctx = rsre_utf8.Utf8MatchContext(
-                self.code, utf8str, bytepos, endbytepos, self.flags)
+                utf8str, bytepos, endbytepos, self.flags)
             # xxx we store the w_string on the ctx too, for
             # W_SRE_Match.bytepos_to_charindex()
             ctx.w_unicode_obj = w_unicode_obj
@@ -159,14 +159,14 @@
     def fresh_copy(self, ctx):
         if isinstance(ctx, rsre_utf8.Utf8MatchContext):
             result = rsre_utf8.Utf8MatchContext(
-                ctx.pattern, ctx._utf8, ctx.match_start, ctx.end, ctx.flags)
+                ctx._utf8, ctx.match_start, ctx.end, ctx.flags)
             result.w_unicode_obj = ctx.w_unicode_obj
         elif isinstance(ctx, rsre_core.StrMatchContext):
             result = self._make_str_match_context(
                 ctx._string, ctx.match_start, ctx.end)
         elif isinstance(ctx, rsre_core.BufMatchContext):
             result = rsre_core.BufMatchContext(
-                ctx.pattern, ctx._buffer, ctx.match_start, ctx.end, ctx.flags)
+                ctx._buffer, ctx.match_start, ctx.end, ctx.flags)
         else:
             raise AssertionError("bad ctx type")
         result.match_end = ctx.match_end
@@ -174,7 +174,7 @@
 
     def _make_str_match_context(self, str, pos, endpos):
         # for tests to override
-        return rsre_core.StrMatchContext(self.code, str,
+        return rsre_core.StrMatchContext(str,
                                          pos, endpos, self.flags)
 
     def getmatch(self, ctx, found):
@@ -319,7 +319,7 @@
         n = 0
         last_pos = ctx.ZERO
         while not count or n < count:
-            pattern = ctx.pattern
+            pattern = self.code
             sub_jitdriver.jit_merge_point(
                 self=self,
                 use_builder=use_builder,
diff --git a/pypy/module/_sre/test/test_app_sre.py 
b/pypy/module/_sre/test/test_app_sre.py
--- a/pypy/module/_sre/test/test_app_sre.py
+++ b/pypy/module/_sre/test/test_app_sre.py
@@ -32,7 +32,7 @@
         start = support.Position(start)
     if not isinstance(end, support.Position):
         end = support.Position(end)
-    return support.MatchContextForTests(self.code, str, start, end, self.flags)
+    return support.MatchContextForTests(str, start, end, self.flags)
 
 def _bytepos_to_charindex(self, bytepos):
     if isinstance(self.ctx, support.MatchContextForTests):
diff --git a/rpython/rlib/rsre/rsre_core.py b/rpython/rlib/rsre/rsre_core.py
--- a/rpython/rlib/rsre/rsre_core.py
+++ b/rpython/rlib/rsre/rsre_core.py
@@ -55,6 +55,8 @@
     specific subclass, calling 'func' is a direct call; if 'ctx' is only known
     to be of class AbstractMatchContext, calling 'func' is an indirect call.
     """
+    from rpython.rlib.rsre.rsre_utf8 import Utf8MatchContext
+
     assert func.func_code.co_varnames[0] == 'ctx'
     specname = '_spec_' + func.func_name
     while specname in _seen_specname:
@@ -65,7 +67,9 @@
     specialized_methods = []
     for prefix, concreteclass in [('buf', BufMatchContext),
                                   ('str', StrMatchContext),
-                                  ('uni', UnicodeMatchContext)]:
+                                  ('uni', UnicodeMatchContext),
+                                  ('utf8', Utf8MatchContext),
+                                  ]:
         newfunc = func_with_new_name(func, prefix + specname)
         assert not hasattr(concreteclass, specname)
         setattr(concreteclass, specname, newfunc)
@@ -83,6 +87,8 @@
     def __init__(self, msg):
         self.msg = msg
 
+class EndOfString(Exception):
+    pass
 
 class CompiledPattern(object):
     _immutable_fields_ = ['pattern[*]']
@@ -142,6 +148,46 @@
         """Similar to str()."""
         raise NotImplementedError
 
+    # The following methods are provided to be overriden in
+    # Utf8MatchContext.  The non-utf8 implementation is provided
+    # by the FixedMatchContext abstract subclass, in order to use
+    # the same @not_rpython safety trick as above.
+    ZERO = 0
+    @not_rpython
+    def next(self, position):
+        raise NotImplementedError
+    @not_rpython
+    def prev(self, position):
+        raise NotImplementedError
+    @not_rpython
+    def next_n(self, position, n):
+        raise NotImplementedError
+    @not_rpython
+    def prev_n(self, position, n, start_position):
+        raise NotImplementedError
+    @not_rpython
+    def debug_check_pos(self, position):
+        raise NotImplementedError
+    @not_rpython
+    def maximum_distance(self, position_low, position_high):
+        raise NotImplementedError
+    @not_rpython
+    def get_single_byte(self, base_position, index):
+        raise NotImplementedError
+
+    def bytes_difference(self, position1, position2):
+        return position1 - position2
+    def go_forward_by_bytes(self, base_position, index):
+        return base_position + index
+    def next_indirect(self, position):
+        assert position < self.end
+        return position + 1     # like next(), but can be called indirectly
+    def prev_indirect(self, position):
+        position -= 1           # like prev(), but can be called indirectly
+        if position < 0:
+            raise EndOfString
+        return position
+
     def get_mark(self, gid):
         return find_mark(self.match_marks, gid)
 
@@ -185,13 +231,40 @@
     def fresh_copy(self, start):
         raise NotImplementedError
 
-class BufMatchContext(AbstractMatchContext):
+class FixedMatchContext(AbstractMatchContext):
+    """Abstract subclass to introduce the default implementation for
+    these position methods.  The Utf8MatchContext subclass doesn't
+    inherit from here."""
+
+    next = AbstractMatchContext.next_indirect
+    prev = AbstractMatchContext.prev_indirect
+
+    def next_n(self, position, n, end_position):
+        position += n
+        if position > end_position:
+            raise EndOfString
+        return position
+
+    def prev_n(self, position, n, start_position):
+        position -= n
+        if position < start_position:
+            raise EndOfString
+        return position
+
+    def debug_check_pos(self, position):
+        pass
+
+    def maximum_distance(self, position_low, position_high):
+        return position_high - position_low
+
+
+class BufMatchContext(FixedMatchContext):
     """Concrete subclass for matching in a buffer."""
 
     _immutable_fields_ = ["_buffer"]
 
     def __init__(self, buf, match_start, end, flags):
-        AbstractMatchContext.__init__(self, match_start, end, flags)
+        FixedMatchContext.__init__(self, match_start, end, flags)
         self._buffer = buf
 
     def str(self, index):
@@ -206,13 +279,17 @@
         return BufMatchContext(self._buffer, start,
                                self.end, self.flags)
 
-class StrMatchContext(AbstractMatchContext):
+    def get_single_byte(self, base_position, index):
+        return self.str(base_position + index)
+
+
+class StrMatchContext(FixedMatchContext):
     """Concrete subclass for matching in a plain string."""
 
     _immutable_fields_ = ["_string"]
 
     def __init__(self, string, match_start, end, flags):
-        AbstractMatchContext.__init__(self, match_start, end, flags)
+        FixedMatchContext.__init__(self, match_start, end, flags)
         self._string = string
         if not we_are_translated() and isinstance(string, unicode):
             self.flags |= rsre_char.SRE_FLAG_UNICODE   # for rsre_re.py
@@ -229,13 +306,20 @@
         return StrMatchContext(self._string, start,
                                self.end, self.flags)
 
-class UnicodeMatchContext(AbstractMatchContext):
+    def get_single_byte(self, base_position, index):
+        return self.str(base_position + index)
+
+    def _real_pos(self, index):
+        return index     # overridden by tests
+
+
+class UnicodeMatchContext(FixedMatchContext):
     """Concrete subclass for matching in a unicode string."""
 
     _immutable_fields_ = ["_unicodestr"]
 
     def __init__(self, unicodestr, match_start, end, flags):
-        AbstractMatchContext.__init__(self, match_start, end, flags)
+        FixedMatchContext.__init__(self, match_start, end, flags)
         self._unicodestr = unicodestr
 
     def str(self, index):
@@ -250,6 +334,9 @@
         return UnicodeMatchContext(self._unicodestr, start,
                                    self.end, self.flags)
 
+    def get_single_byte(self, base_position, index):
+        return self.str(base_position + index)
+
 # ____________________________________________________________
 
 class Mark(object):
@@ -325,7 +412,10 @@
                 self=self, ptr=ptr, ctx=ctx, nextppos=nextppos,
                 pattern=pattern)
             result = sre_match(ctx, pattern, nextppos, ptr, self.start_marks)
-            ptr -= 1
+            try:
+                ptr = ctx.prev_indirect(ptr)
+            except EndOfString:
+                ptr = -1
             if result is not None:
                 self.subresult = result
                 self.start_ptr = ptr
@@ -336,32 +426,35 @@
 class MinRepeatOneMatchResult(MatchResult):
     install_jitdriver('MinRepeatOne',
                       greens=['nextppos', 'ppos3', 'pattern'],
-                      reds=['ptr', 'self', 'ctx'],
+                      reds=['max_count', 'ptr', 'self', 'ctx'],
                       debugprint=(2, 0))   # indices in 'greens'
 
-    def __init__(self, nextppos, ppos3, maxptr, ptr, marks):
+    def __init__(self, nextppos, ppos3, max_count, ptr, marks):
         self.nextppos = nextppos
         self.ppos3 = ppos3
-        self.maxptr = maxptr
+        self.max_count = max_count
         self.start_ptr = ptr
         self.start_marks = marks
 
     def find_first_result(self, ctx, pattern):
         ptr = self.start_ptr
         nextppos = self.nextppos
+        max_count = self.max_count
         ppos3 = self.ppos3
-        while ptr <= self.maxptr:
+        while max_count >= 0:
             ctx.jitdriver_MinRepeatOne.jit_merge_point(
                 self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, ppos3=ppos3,
-                pattern=pattern)
+                max_count=max_count, pattern=pattern)
             result = sre_match(ctx, pattern, nextppos, ptr, self.start_marks)
             if result is not None:
                 self.subresult = result
                 self.start_ptr = ptr
+                self.max_count = max_count
                 return self
             if not self.next_char_ok(ctx, pattern, ptr, ppos3):
                 break
-            ptr += 1
+            ptr = ctx.next_indirect(ptr)
+            max_count -= 1
 
     def find_next_result(self, ctx, pattern):
         ptr = self.start_ptr
@@ -440,12 +533,12 @@
             min = pattern.pat(ppos+1)
             if enum is not None:
                 # matched one more 'item'.  record it and continue.
-                last_match_length = ctx.match_end - ptr
+                last_match_zero_length = (ctx.match_end == ptr)
                 self.pending = Pending(ptr, marks, enum, self.pending)
                 self.num_pending += 1
                 ptr = ctx.match_end
                 marks = ctx.match_marks
-                if last_match_length == 0 and self.num_pending >= min:
+                if last_match_zero_length and self.num_pending >= min:
                     # zero-width protection: after an empty match, if there
                     # are enough matches, don't try to match more.  Instead,
                     # fall through to trying to match 'tail'.
@@ -561,22 +654,25 @@
             # <ANY>
             if ptr >= ctx.end or rsre_char.is_linebreak(ctx.str(ptr)):
                 return
-            ptr += 1
+            ptr = ctx.next(ptr)
 
         elif op == OPCODE_ANY_ALL:
             # match anything
             # <ANY_ALL>
             if ptr >= ctx.end:
                 return
-            ptr += 1
+            ptr = ctx.next(ptr)
 
         elif op == OPCODE_ASSERT:
             # assert subpattern
             # <ASSERT> <0=skip> <1=back> <pattern>
-            ptr1 = ptr - pattern.pat(ppos+1)
+            try:
+                ptr1 = ctx.prev_n(ptr, pattern.pat(ppos+1), ctx.ZERO)
+            except EndOfString:
+                return
             saved = ctx.fullmatch_only
             ctx.fullmatch_only = False
-            stop = ptr1 < 0 or sre_match(ctx, pattern, ppos + 2, ptr1, marks) 
is None
+            stop = sre_match(ctx, pattern, ppos + 2, ptr1, marks) is None
             ctx.fullmatch_only = saved
             if stop:
                 return
@@ -586,14 +682,18 @@
         elif op == OPCODE_ASSERT_NOT:
             # assert not subpattern
             # <ASSERT_NOT> <0=skip> <1=back> <pattern>
-            ptr1 = ptr - pattern.pat(ppos+1)
-            saved = ctx.fullmatch_only
-            ctx.fullmatch_only = False
-            stop = (ptr1 >= 0 and sre_match(ctx, pattern, ppos + 2, ptr1, 
marks)
-                                      is not None)
-            ctx.fullmatch_only = saved
-            if stop:
-                return
+
+            try:
+                ptr1 = ctx.prev_n(ptr, pattern.pat(ppos+1), ctx.ZERO)
+            except EndOfString:
+                pass
+            else:
+                saved = ctx.fullmatch_only
+                ctx.fullmatch_only = False
+                stop = sre_match(ctx, pattern, ppos + 2, ptr1, marks) is not 
None
+                ctx.fullmatch_only = saved
+                if stop:
+                    return
             ppos += pattern.pat(ppos)
 
         elif op == OPCODE_AT:
@@ -616,36 +716,36 @@
             if (ptr == ctx.end or
                 not rsre_char.category_dispatch(pattern.pat(ppos), 
ctx.str(ptr))):
                 return
-            ptr += 1
+            ptr = ctx.next(ptr)
             ppos += 1
 
         elif op == OPCODE_GROUPREF:
             # match backreference
             # <GROUPREF> <groupnum>
-            startptr, length = get_group_ref(marks, pattern.pat(ppos))
-            if length < 0:
+            startptr, length_bytes = get_group_ref(ctx, marks, 
pattern.pat(ppos))
+            if length_bytes < 0:
                 return     # group was not previously defined
-            if not match_repeated(ctx, ptr, startptr, length):
+            if not match_repeated(ctx, ptr, startptr, length_bytes):
                 return     # no match
-            ptr += length
+            ptr = ctx.go_forward_by_bytes(ptr, length_bytes)
             ppos += 1
 
         elif op == OPCODE_GROUPREF_IGNORE:
             # match backreference
             # <GROUPREF> <groupnum>
-            startptr, length = get_group_ref(marks, pattern.pat(ppos))
-            if length < 0:
+            startptr, length_bytes = get_group_ref(ctx, marks, 
pattern.pat(ppos))
+            if length_bytes < 0:
                 return     # group was not previously defined
-            if not match_repeated_ignore(ctx, ptr, startptr, length):
+            if not match_repeated_ignore(ctx, ptr, startptr, length_bytes):
                 return     # no match
-            ptr += length
+            ptr = ctx.go_forward_by_bytes(ptr, length_bytes)
             ppos += 1
 
         elif op == OPCODE_GROUPREF_EXISTS:
             # conditional match depending on the existence of a group
             # <GROUPREF_EXISTS> <group> <skip> codeyes <JUMP> codeno ...
-            _, length = get_group_ref(marks, pattern.pat(ppos))
-            if length >= 0:
+            _, length_bytes = get_group_ref(ctx, marks, pattern.pat(ppos))
+            if length_bytes >= 0:
                 ppos += 2                  # jump to 'codeyes'
             else:
                 ppos += pattern.pat(ppos+1)    # jump to 'codeno'
@@ -657,7 +757,7 @@
                                                              ctx.str(ptr)):
                 return
             ppos += pattern.pat(ppos)
-            ptr += 1
+            ptr = ctx.next(ptr)
 
         elif op == OPCODE_IN_IGNORE:
             # match set member (or non_member), ignoring case
@@ -666,12 +766,12 @@
                                                              ctx.lowstr(ptr)):
                 return
             ppos += pattern.pat(ppos)
-            ptr += 1
+            ptr = ctx.next(ptr)
 
         elif op == OPCODE_INFO:
             # optimization info block
             # <INFO> <0=skip> <1=flags> <2=min> ...
-            if (ctx.end - ptr) < pattern.pat(ppos+2):
+            if ctx.maximum_distance(ptr, ctx.end) < pattern.pat(ppos+2):
                 return
             ppos += pattern.pat(ppos)
 
@@ -684,7 +784,7 @@
             if ptr >= ctx.end or ctx.str(ptr) != pattern.pat(ppos):
                 return
             ppos += 1
-            ptr += 1
+            ptr = ctx.next(ptr)
 
         elif op == OPCODE_LITERAL_IGNORE:
             # match literal string, ignoring case
@@ -692,7 +792,7 @@
             if ptr >= ctx.end or ctx.lowstr(ptr) != pattern.pat(ppos):
                 return
             ppos += 1
-            ptr += 1
+            ptr = ctx.next(ptr)
 
         elif op == OPCODE_MARK:
             # set mark
@@ -707,7 +807,7 @@
             if ptr >= ctx.end or ctx.str(ptr) == pattern.pat(ppos):
                 return
             ppos += 1
-            ptr += 1
+            ptr = ctx.next(ptr)
 
         elif op == OPCODE_NOT_LITERAL_IGNORE:
             # match if it's not a literal string, ignoring case
@@ -715,7 +815,7 @@
             if ptr >= ctx.end or ctx.lowstr(ptr) == pattern.pat(ppos):
                 return
             ppos += 1
-            ptr += 1
+            ptr = ctx.next(ptr)
 
         elif op == OPCODE_REPEAT:
             # general repeat.  in this version of the re module, all the work
@@ -753,8 +853,10 @@
             # use the MAX_REPEAT operator.
             # <REPEAT_ONE> <skip> <1=min> <2=max> item <SUCCESS> tail
             start = ptr
-            minptr = start + pattern.pat(ppos+1)
-            if minptr > ctx.end:
+
+            try:
+                minptr = ctx.next_n(start, pattern.pat(ppos+1), ctx.end)
+            except EndOfString:
                 return    # cannot match
             ptr = find_repetition_end(ctx, pattern, ppos+3, start,
                                       pattern.pat(ppos+2),
@@ -776,22 +878,22 @@
             start = ptr
             min = pattern.pat(ppos+1)
             if min > 0:
-                minptr = ptr + min
-                if minptr > ctx.end:
-                    return   # cannot match
+                try:
+                    minptr = ctx.next_n(ptr, min, ctx.end)
+                except EndOfString:
+                    return    # cannot match
                 # count using pattern min as the maximum
                 ptr = find_repetition_end(ctx, pattern, ppos+3, ptr, min, 
marks)
                 if ptr < minptr:
                     return   # did not match minimum number of times
 
-            maxptr = ctx.end
+            max_count = sys.maxint
             max = pattern.pat(ppos+2)
             if max != rsre_char.MAXREPEAT:
-                maxptr1 = start + max
-                if maxptr1 <= maxptr:
-                    maxptr = maxptr1
+                max_count = max - min
+                assert max_count >= 0
             nextppos = ppos + pattern.pat(ppos)
-            result = MinRepeatOneMatchResult(nextppos, ppos+3, maxptr,
+            result = MinRepeatOneMatchResult(nextppos, ppos+3, max_count,
                                              ptr, marks)
             return result.find_first_result(ctx, pattern)
 
@@ -799,40 +901,43 @@
             raise Error("bad pattern code %d" % op)
 
 
-def get_group_ref(marks, groupnum):
+def get_group_ref(ctx, marks, groupnum):
     gid = groupnum * 2
     startptr = find_mark(marks, gid)
-    if startptr < 0:
+    if startptr < ctx.ZERO:
         return 0, -1
     endptr = find_mark(marks, gid + 1)
-    length = endptr - startptr     # < 0 if endptr < startptr (or if endptr=-1)
-    return startptr, length
+    length_bytes = ctx.bytes_difference(endptr, startptr)
+    return startptr, length_bytes
 
 @specializectx
-def match_repeated(ctx, ptr, oldptr, length):
-    if ptr + length > ctx.end:
+def match_repeated(ctx, ptr, oldptr, length_bytes):
+    if ctx.bytes_difference(ctx.end, ptr) < length_bytes:
         return False
-    for i in range(length):
-        if ctx.str(ptr + i) != ctx.str(oldptr + i):
+    for i in range(length_bytes):
+        if ctx.get_single_byte(ptr, i) != ctx.get_single_byte(oldptr, i):
             return False
     return True
 
 @specializectx
-def match_repeated_ignore(ctx, ptr, oldptr, length):
-    if ptr + length > ctx.end:
-        return False
-    for i in range(length):
-        if ctx.lowstr(ptr + i) != ctx.lowstr(oldptr + i):
-            return False
-    return True
+def match_repeated_ignore(ctx, ptr, oldptr, length_bytes):
+    oldend = ctx.go_forward_by_bytes(oldptr, length_bytes)
+    while oldptr < oldend:
+        if ptr >= ctx.end:
+            return -1
+        if ctx.lowstr(ptr) != ctx.lowstr(oldptr):
+            return -1
+        ptr = ctx.next(ptr)
+        oldptr = ctx.next(oldptr)
+    return ptr
 
 @specializectx
 def find_repetition_end(ctx, pattern, ppos, ptr, maxcount, marks):
     end = ctx.end
-    ptrp1 = ptr + 1
     # First get rid of the cases where we don't have room for any match.
-    if maxcount <= 0 or ptrp1 > end:
+    if maxcount <= 0 or ptr >= end:
         return ptr
+    ptrp1 = ctx.next(ptr)
     # Check the first character directly.  If it doesn't match, we are done.
     # The idea is to be fast for cases like re.search("b+"), where we expect
     # the common case to be a non-match.  It's much faster with the JIT to
@@ -854,9 +959,10 @@
     # Else we really need to count how many times it matches.
     if maxcount != rsre_char.MAXREPEAT:
         # adjust end
-        end1 = ptr + maxcount
-        if end1 <= end:
-            end = end1
+        try:
+            end = ctx.next_n(ptr, maxcount, end)
+        except EndOfString:
+            pass
     op = pattern.pat(ppos)
     for op1, fre in unroll_fre_checker:
         if op1 == op:
@@ -873,7 +979,7 @@
         if end1 <= end:
             end = end1
     while ptr < end and sre_match(ctx, patern, ppos, ptr, marks) is not None:
-        ptr += 1
+        ptr = ctx.next(ptr)
     return ptr
 
 @specializectx
@@ -916,7 +1022,7 @@
                                                       end=end, ppos=ppos,
                                                       pattern=pattern)
                 if ptr < end and checkerfn(ctx, pattern, ptr, ppos):
-                    ptr += 1
+                    ptr = ctx.next(ptr)
                 else:
                     return ptr
     elif checkerfn == match_IN_IGNORE:
@@ -931,7 +1037,7 @@
                                                             end=end, ppos=ppos,
                                                             pattern=pattern)
                 if ptr < end and checkerfn(ctx, pattern, ptr, ppos):
-                    ptr += 1
+                    ptr = ctx.next(ptr)
                 else:
                     return ptr
     else:
@@ -940,7 +1046,7 @@
         @specializectx
         def fre(ctx, pattern, ptr, end, ppos):
             while ptr < end and checkerfn(ctx, pattern, ptr, ppos):
-                ptr += 1
+                ptr = ctx.next(ptr)
             return ptr
     fre = func_with_new_name(fre, 'fre_' + checkerfn.__name__)
     return fre
@@ -980,11 +1086,14 @@
 def sre_at(ctx, atcode, ptr):
     if (atcode == AT_BEGINNING or
         atcode == AT_BEGINNING_STRING):
-        return ptr == 0
+        return ptr == ctx.ZERO
 
     elif atcode == AT_BEGINNING_LINE:
-        prevptr = ptr - 1
-        return prevptr < 0 or rsre_char.is_linebreak(ctx.str(prevptr))
+        try:
+            prevptr = ctx.prev(ptr)
+        except EndOfString:
+            return True
+        return rsre_char.is_linebreak(ctx.str(prevptr))
 
     elif atcode == AT_BOUNDARY:
         return at_boundary(ctx, ptr)
@@ -993,9 +1102,8 @@
         return at_non_boundary(ctx, ptr)
 
     elif atcode == AT_END:
-        remaining_chars = ctx.end - ptr
-        return remaining_chars <= 0 or (
-            remaining_chars == 1 and rsre_char.is_linebreak(ctx.str(ptr)))
+        return (ptr == ctx.end or
+            (ctx.next(ptr) == ctx.end and 
rsre_char.is_linebreak(ctx.str(ptr))))
 
     elif atcode == AT_END_LINE:
         return ptr == ctx.end or rsre_char.is_linebreak(ctx.str(ptr))
@@ -1020,18 +1128,26 @@
 def _make_boundary(word_checker):
     @specializectx
     def at_boundary(ctx, ptr):
-        if ctx.end == 0:
+        if ctx.end == ctx.ZERO:
             return False
-        prevptr = ptr - 1
-        that = prevptr >= 0 and word_checker(ctx.str(prevptr))
+        try:
+            prevptr = ctx.prev(ptr)
+        except EndOfString:
+            that = False
+        else:
+            that = word_checker(ctx.str(prevptr))
         this = ptr < ctx.end and word_checker(ctx.str(ptr))
         return this != that
     @specializectx
     def at_non_boundary(ctx, ptr):
-        if ctx.end == 0:
+        if ctx.end == ctx.ZERO:
             return False
-        prevptr = ptr - 1
-        that = prevptr >= 0 and word_checker(ctx.str(prevptr))
+        try:
+            prevptr = ctx.prev(ptr)
+        except EndOfString:
+            that = False
+        else:
+            that = word_checker(ctx.str(prevptr))
         this = ptr < ctx.end and word_checker(ctx.str(ptr))
         return this == that
     return at_boundary, at_non_boundary
@@ -1109,13 +1225,15 @@
 
 def regular_search(ctx, pattern, base):
     start = ctx.match_start
-    while start <= ctx.end:
-        ctx.jitdriver_RegularSearch.jit_merge_point(ctx=ctx, start=start,
-                                                    base=base, pattern=pattern)
+    while True:
+        ctx.jitdriver_RegularSearch.jit_merge_point(ctx=ctx, pattern=pattern,
+                                                    start=start, base=base)
         if sre_match(ctx, pattern, base, start, None) is not None:
             ctx.match_start = start
             return True
-        start += 1
+        if start >= ctx.end:
+            break
+        start = ctx.next_indirect(start)
     return False
 
 install_jitdriver_spec("LiteralSearch",
@@ -1132,11 +1250,12 @@
     while start < ctx.end:
         ctx.jitdriver_LiteralSearch.jit_merge_point(ctx=ctx, start=start,
                                           base=base, character=character, 
pattern=pattern)
+        start1 = ctx.next(start)
         if ctx.str(start) == character:
-            if sre_match(ctx, pattern, base, start + 1, None) is not None:
+            if sre_match(ctx, pattern, base, start1, None) is not None:
                 ctx.match_start = start
                 return True
-        start += 1
+        start = start1
     return False
 
 install_jitdriver_spec("CharsetSearch",
@@ -1154,7 +1273,7 @@
             if sre_match(ctx, pattern, base, start, None) is not None:
                 ctx.match_start = start
                 return True
-        start += 1
+        start = ctx.next(start)
     return False
 
 install_jitdriver_spec('FastSearch',
@@ -1186,11 +1305,14 @@
         else:
             i += 1
             if i == prefix_len:
-                # found a potential match
-                start = string_position + 1 - prefix_len
-                assert start >= 0
+                # start = string_position + 1 - prefix_len: computed later
+                ptr = string_position
                 prefix_skip = pattern.pat(6)
-                ptr = start + prefix_skip
+                if prefix_skip == prefix_len:
+                    ptr = ctx.next(ptr)
+                else:
+                    assert prefix_skip < prefix_len
+                    ptr = ctx.prev_n(ptr, prefix_len-1 - prefix_skip, ctx.ZERO)
                 #flags = pattern.pat(2)
                 #if flags & rsre_char.SRE_INFO_LITERAL:
                 #    # matched all of pure literal pattern
@@ -1201,10 +1323,11 @@
                 pattern_offset = pattern.pat(1) + 1
                 ppos_start = pattern_offset + 2 * prefix_skip
                 if sre_match(ctx, pattern, ppos_start, ptr, None) is not None:
+                    start = ctx.prev_n(ptr, prefix_skip, ctx.ZERO)
                     ctx.match_start = start
                     return True
                 overlap_offset = prefix_len + (7 - 1)
                 i = pattern.pat(overlap_offset + i)
-        string_position += 1
+        string_position = ctx.next(string_position)
         if string_position >= ctx.end:
             return False
diff --git a/rpython/rlib/rsre/rsre_utf8.py b/rpython/rlib/rsre/rsre_utf8.py
--- a/rpython/rlib/rsre/rsre_utf8.py
+++ b/rpython/rlib/rsre/rsre_utf8.py
@@ -12,8 +12,8 @@
     by this class are expressed in *bytes*, not in characters.
     """
 
-    def __init__(self, pattern, utf8string, match_start, end, flags):
-        AbstractMatchContext.__init__(self, pattern, match_start, end, flags)
+    def __init__(self, utf8string, match_start, end, flags):
+        AbstractMatchContext.__init__(self, match_start, end, flags)
         self._utf8 = utf8string
 
     def str(self, index):
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to