Author: Armin Rigo <ar...@tunes.org>
Branch: unicode-utf8-re
Changeset: r93336:170afb57631b
Date: 2017-12-09 20:11 +0100
http://bitbucket.org/pypy/pypy/changeset/170afb57631b/

Log:    Getting there

diff --git a/pypy/module/_sre/interp_sre.py b/pypy/module/_sre/interp_sre.py
--- a/pypy/module/_sre/interp_sre.py
+++ b/pypy/module/_sre/interp_sre.py
@@ -44,9 +44,8 @@
             end = ctx._real_pos(end)
             return space.newbytes(ctx._string[start:end])
         elif isinstance(ctx, rsre_utf8.Utf8MatchContext):
-            XXXXXXX
-            s = ctx._unicodestr[start:end]
-            lgt = rutf8.check_utf8(s, True)
+            s = ctx._utf8[start:end]
+            lgt = rutf8.get_utf8_length(s)
             return space.newutf8(s, lgt)
         else:
             # unreachable
@@ -59,11 +58,11 @@
     # Returns a list of RPython-level integers.
     # Unlike the app-level groups() method, groups are numbered from 0
     # and the returned list does not start with the whole match range.
+    # The integers are byte positions, not character indexes (for utf8).
     if num_groups == 0:
         return None
     result = [-1] * (2 * num_groups)
     mark = ctx.match_marks
-    XXX
     while mark is not None:
         index = jit.promote(mark.gid)
         if result[index] == -1:
@@ -74,7 +73,6 @@
 
 @jit.look_inside_iff(lambda space, ctx, fmarks, num_groups, w_default: 
jit.isconstant(num_groups))
 def allgroups_w(space, ctx, fmarks, num_groups, w_default):
-    XXX
     grps = [slice_w(space, ctx, fmarks[i * 2], fmarks[i * 2 + 1], w_default)
             for i in range(num_groups)]
     return space.newtuple(grps)
@@ -117,12 +115,7 @@
         if endpos < pos:
             endpos = pos
         if space.isinstance_w(w_string, space.w_unicode):
-            # xxx fish for the _index_storage
-            w_string = space.convert_arg_to_w_unicode(w_string)
-            utf8str = w_string._utf8
-            length = w_string._len()
-            index_storage = w_string._get_index_storage()
-            #
+            utf8str, length = space.utf8_len_w(w_string)
             if pos <= 0:
                 bytepos = 0
             elif pos >= length:
@@ -135,8 +128,7 @@
                 endbytepos = rutf8.codepoint_at_index(utf8str, index_storage,
                                                       endpos)
             return rsre_utf8.Utf8MatchContext(
-                self.code, unicodestr, index_storage,
-                bytepos, endbytepos, self.flags)
+                self.code, utf8str, bytepos, endbytepos, self.flags)
         elif space.isinstance_w(w_string, space.w_bytes):
             str = space.bytes_w(w_string)
             if pos > len(str):
@@ -198,9 +190,10 @@
                     w_item = allgroups_w(space, ctx, fmarks, num_groups,
                                          w_emptystr)
             matchlist_w.append(w_item)
-            no_progress = (ctx.match_start == ctx.match_end)
-            ctx.reset(ctx.match_end + no_progress)
-            XXX #                   ^^^
+            reset_at = ctx.match_end
+            if ctx.match_start == ctx.match_end:
+                reset_at = ctx.next(reset_at)
+            ctx.reset(reset_at)
         return space.newlist(matchlist_w)
 
     @unwrap_spec(pos=int, endpos=int)
@@ -216,16 +209,15 @@
         space = self.space
         splitlist = []
         n = 0
-        last = 0
         ctx = self.make_ctx(w_string)
+        last = ctx.ZERO
         while not maxsplit or n < maxsplit:
             if not searchcontext(space, ctx):
                 break
             if ctx.match_start == ctx.match_end:     # zero-width match
                 if ctx.match_start == ctx.end:       # or end of string
                     break
-                ctx.reset(ctx.match_end + 1)
-                XXX   #                 ^^^
+                ctx.reset(ctx.next(ctx.match_end))
                 continue
             splitlist.append(slice_w(space, ctx, last, ctx.match_start,
                                      space.w_None))
@@ -254,20 +246,20 @@
 
     def subx(self, w_ptemplate, w_string, count):
         space = self.space
-        # use a (much faster) string/unicode builder if w_ptemplate and
+        # use a (much faster) string builder (possibly utf8) if w_ptemplate and
         # w_string are both string or both unicode objects, and if w_ptemplate
         # is a literal
-        use_builder = False
-        filter_as_unicode = filter_as_string = None
+        use_builder = '\x00'   # or 'S'tring or 'U'nicode/UTF8
+        filter_as_string = None
         if space.is_true(space.callable(w_ptemplate)):
             w_filter = w_ptemplate
             filter_is_callable = True
         else:
             if space.isinstance_w(w_ptemplate, space.w_unicode):
-                filter_as_unicode = space.utf8_w(w_ptemplate)
-                literal = '\\' not in filter_as_unicode
-                use_builder = (
-                    space.isinstance_w(w_string, space.w_unicode) and literal)
+                filter_as_string = space.utf8_w(w_ptemplate)
+                literal = '\\' not in filter_as_string
+                if space.isinstance_w(w_string, space.w_unicode) and literal:
+                    use_builder = 'U'
             else:
                 try:
                     filter_as_string = space.bytes_w(w_ptemplate)
@@ -277,8 +269,8 @@
                     literal = False
                 else:
                     literal = '\\' not in filter_as_string
-                    use_builder = (
-                        space.isinstance_w(w_string, space.w_bytes) and 
literal)
+                    if space.isinstance_w(w_string, space.w_bytes) and literal:
+                        use_builder = 'S'
             if literal:
                 w_filter = w_ptemplate
                 filter_is_callable = False
@@ -291,16 +283,14 @@
         #
         # XXX this is a bit of a mess, but it improves performance a lot
         ctx = self.make_ctx(w_string)
-        sublist_w = strbuilder = unicodebuilder = None
-        if use_builder:
-            if filter_as_unicode is not None:
-                unicodebuilder = XXX  #Utf8StringBuilder(ctx.end)
-            else:
-                assert filter_as_string is not None
-                strbuilder = StringBuilder(ctx.end)
+        sublist_w = strbuilder = None
+        if use_builder != '\x00':
+            assert filter_as_string is not None
+            strbuilder = StringBuilder(ctx.end)
         else:
             sublist_w = []
-        n = last_pos = 0
+        n = 0
+        last_pos = ctx.ZERO
         while not count or n < count:
             sub_jitdriver.jit_merge_point(
                 self=self,
@@ -310,9 +300,7 @@
                 ctx=ctx,
                 w_filter=w_filter,
                 strbuilder=strbuilder,
-                unicodebuilder=unicodebuilder,
                 filter_as_string=filter_as_string,
-                filter_as_unicode=filter_as_unicode,
                 count=count,
                 w_string=w_string,
                 n=n, last_pos=last_pos, sublist_w=sublist_w
@@ -323,10 +311,10 @@
             if last_pos < ctx.match_start:
                 _sub_append_slice(
                     ctx, space, use_builder, sublist_w,
-                    strbuilder, unicodebuilder, last_pos, ctx.match_start)
+                    strbuilder, last_pos, ctx.match_start)
             start = ctx.match_end
             if start == ctx.match_start:
-                start += 1
+                start = ctx.next(start)
             if not (last_pos == ctx.match_start
                              == ctx.match_end and n > 0):
                 # the above ignores empty matches on latest position
@@ -334,18 +322,13 @@
                     w_match = self.getmatch(ctx, True)
                     w_piece = space.call_function(w_filter, w_match)
                     if not space.is_w(w_piece, space.w_None):
-                        assert strbuilder is None and unicodebuilder is None
-                        assert not use_builder
+                        assert strbuilder is None
+                        assert use_builder == '\x00'
                         sublist_w.append(w_piece)
                 else:
-                    if use_builder:
-                        if strbuilder is not None:
-                            assert filter_as_string is not None
-                            strbuilder.append(filter_as_string)
-                        else:
-                            assert unicodebuilder is not None
-                            assert filter_as_unicode is not None
-                            unicodebuilder.append(filter_as_unicode)
+                    if use_builder != '\x00':
+                        assert filter_as_string is not None
+                        strbuilder.append(filter_as_string)
                     else:
                         sublist_w.append(w_filter)
                 last_pos = ctx.match_end
@@ -356,14 +339,16 @@
 
         if last_pos < ctx.end:
             _sub_append_slice(ctx, space, use_builder, sublist_w,
-                              strbuilder, unicodebuilder, last_pos, ctx.end)
-        if use_builder:
-            if strbuilder is not None:
-                return space.newbytes(strbuilder.build()), n
+                              strbuilder, last_pos, ctx.end)
+        if use_builder != '\x00':
+            result_bytes = strbuilder.build()
+            if use_builder == 'S':
+                return space.newbytes(result_bytes), n
+            elif use_builder == 'U':
+                return space.newutf8(result_bytes,
+                                     rutf8.get_utf8_length(result_bytes)), n
             else:
-                assert unicodebuilder is not None
-                return space.newutf8(unicodebuilder.build(),
-                                     unicodebuilder.get_length()), n
+                raise AssertionError(use_builder)
         else:
             if space.isinstance_w(w_string, space.w_unicode):
                 w_emptystr = space.newutf8('', 0)
@@ -376,27 +361,27 @@
 sub_jitdriver = jit.JitDriver(
     reds="""count n last_pos
             ctx w_filter
-            strbuilder unicodebuilder
+            strbuilder
             filter_as_string
-            filter_as_unicode
             w_string sublist_w
             self""".split(),
     greens=["filter_is_callable", "use_builder", "filter_type", "ctx.pattern"])
 
 
 def _sub_append_slice(ctx, space, use_builder, sublist_w,
-                      strbuilder, unicodebuilder, start, end):
-    if use_builder:
+                      strbuilder, start, end):
+    if use_builder != '\x00':
         if isinstance(ctx, rsre_core.BufMatchContext):
-            assert strbuilder is not None
+            assert use_builder == 'S'
             return strbuilder.append(ctx._buffer.getslice(start, end, 1, 
end-start))
         if isinstance(ctx, rsre_core.StrMatchContext):
-            assert strbuilder is not None
+            assert use_builder == 'S'
+            start = ctx._real_pos(start)
+            end = ctx._real_pos(end)
             return strbuilder.append_slice(ctx._string, start, end)
         elif isinstance(ctx, rsre_utf8.Utf8MatchContext):
-            XXXXXXX
-            assert unicodebuilder is not None
-            return unicodebuilder.append_slice(ctx._unicodestr, start, end)
+            assert use_builder == 'U'
+            return strbuilder.append_slice(ctx._utf8, start, end)
         assert 0, "unreachable"
     else:
         sublist_w.append(slice_w(space, ctx, start, end, space.w_None))
@@ -523,6 +508,9 @@
     @unwrap_spec(w_groupnum=WrappedDefault(0))
     def span_w(self, w_groupnum):
         start, end = self.do_span(w_groupnum)
+        return self.new_charindex_tuple(start, end)
+
+    def new_charindex_tuple(self, start, end):
         start = self.bytepos_to_charindex(start)
         end = self.bytepos_to_charindex(end)
         return self.space.newtuple([self.space.newint(start),
@@ -541,6 +529,8 @@
         return self.flatten_cache
 
     def do_span(self, w_arg):
+        # return a pair of integers, which are byte positions, not
+        # character indexes (for utf8)
         space = self.space
         try:
             groupnum = space.int_w(w_arg)
@@ -588,10 +578,10 @@
         return space.w_None
 
     def fget_pos(self, space):
-        return space.newint(self.ctx.original_pos)
+        return space.newint(self.bytepos_to_charindex(self.ctx.original_pos))
 
     def fget_endpos(self, space):
-        return space.newint(self.ctx.end)
+        return space.newint(self.bytepos_to_charindex(self.ctx.end))
 
     def fget_regs(self, space):
         space = self.space
@@ -599,11 +589,11 @@
         num_groups = self.srepat.num_groups
         result_w = [None] * (num_groups + 1)
         ctx = self.ctx
-        result_w[0] = space.newtuple([space.newint(ctx.match_start),
-                                      space.newint(ctx.match_end)])
+        result_w[0] = self.new_charindex_tuple(ctx.match_start,
+                                               ctx.match_end)
         for i in range(num_groups):
-            result_w[i + 1] = space.newtuple([space.newint(fmarks[i*2]),
-                                              space.newint(fmarks[i*2+1])])
+            result_w[i + 1] = self.new_charindex_tuple(fmarks[i*2],
+                                                       fmarks[i*2+1])
         return space.newtuple(result_w)
 
     def fget_string(self, space):
@@ -680,12 +670,14 @@
         if found:
             ctx = self.ctx
             nextstart = ctx.match_end
-            nextstart += (ctx.match_start == nextstart)
+            if ctx.match_start == nextstart:
+                nextstart = ctx.next(nextstart)
             self.ctx = ctx.fresh_copy(nextstart)
             match = W_SRE_Match(self.srepat, ctx)
             return match
         else:
-            self.ctx.match_start += 1     # obscure corner case
+            # obscure corner case
+            self.ctx.match_start = self.ctx.next(self.ctx.match_start)
             return None
 
 W_SRE_Scanner.typedef = TypeDef(
diff --git a/pypy/module/_sre/test/test_app_sre.py 
b/pypy/module/_sre/test/test_app_sre.py
--- a/pypy/module/_sre/test/test_app_sre.py
+++ b/pypy/module/_sre/test/test_app_sre.py
@@ -33,7 +33,9 @@
     return support.MatchContextForTests(self.code, str, start, end, self.flags)
 
 def _bytepos_to_charindex(self, bytepos):
-    return self.ctx._real_pos(bytepos)
+    if isinstance(self.ctx, support.MatchContextForTests):
+        return self.ctx._real_pos(bytepos)
+    return bytepos
 
 def setup_module(mod):
     mod._org_maker = (
diff --git a/rpython/rlib/rsre/rsre_core.py b/rpython/rlib/rsre/rsre_core.py
--- a/rpython/rlib/rsre/rsre_core.py
+++ b/rpython/rlib/rsre/rsre_core.py
@@ -165,14 +165,13 @@
     def maximum_distance(self, position_low, position_high):
         raise NotImplementedError
     @not_rpython
-    def bytes_difference(self, position1, position2):
-        raise NotImplementedError
-    @not_rpython
     def get_single_byte(self, base_position, index):
         raise NotImplementedError
-    @not_rpython
+
+    def bytes_difference(self, position1, position2):
+        return position1 - position2
     def go_forward_by_bytes(self, base_position, index):
-        raise NotImplementedError
+        return base_position + index
 
     def get_mark(self, gid):
         return find_mark(self.match_marks, gid)
@@ -243,12 +242,6 @@
     def maximum_distance(self, position_low, position_high):
         return position_high - position_low
 
-    def bytes_difference(self, position1, position2):
-        return position1 - position2
-
-    def go_forward_by_bytes(self, base_position, index):
-        return base_position + index
-
 
 class BufMatchContext(FixedMatchContext):
     """Concrete subclass for matching in a buffer."""
diff --git a/rpython/rlib/rsre/test/support.py 
b/rpython/rlib/rsre/test/support.py
--- a/rpython/rlib/rsre/test/support.py
+++ b/rpython/rlib/rsre/test/support.py
@@ -104,6 +104,10 @@
         assert isinstance(index, int)
         return Position(base_position._p + index)
 
+    def fresh_copy(self, start):
+        return MatchContextForTests(self.pattern, self._string, start,
+                                    self.end, self.flags)
+
 
 def match(pattern, string, start=0, end=sys.maxint, flags=0, fullmatch=False):
     start, end = _adjust(start, end, len(string))
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to