Author: Ronan Lamy <[email protected]>
Branch: unicode-utf8
Changeset: r93362:09186de461ba
Date: 2017-12-11 13:02 +0000
http://bitbucket.org/pypy/pypy/changeset/09186de461ba/

Log:    Use the same logic for all encoders

diff --git a/pypy/interpreter/unicodehelper.py 
b/pypy/interpreter/unicodehelper.py
--- a/pypy/interpreter/unicodehelper.py
+++ b/pypy/interpreter/unicodehelper.py
@@ -159,74 +159,67 @@
         return _utf8_encode_latin_1_slowpath(s, errors, errorhandler)
 
 def _utf8_encode_latin_1_slowpath(s, errors, errorhandler):
-    res = StringBuilder(len(s))
-    cur = 0
-    iter = rutf8.Utf8StringIterator(s)
-    while True:
-        try:
-            ch = iter.next()
-        except StopIteration:
-            break
+    size = len(s)
+    result = StringBuilder(size)
+    index = 0
+    pos = 0
+    while pos < size:
+        ch = rutf8.codepoint_at_pos(s, pos)
         if ch <= 0xFF:
-            res.append(chr(ch))
-            cur += 1
+            result.append(chr(ch))
+            index += 1
+            pos = rutf8.next_codepoint_pos(s, pos)
         else:
-            r, pos = errorhandler(errors, 'latin1',
-                                  'ordinal not in range(256)', s, cur,
-                                  cur + 1)
+            startindex = index
+            pos = rutf8.next_codepoint_pos(s, pos)
+            index += 1
+            while pos < size and rutf8.codepoint_at_pos(s, pos) > 0xFF:
+                pos = rutf8.next_codepoint_pos(s, pos)
+                index += 1
+            msg = "ordinal not in range(256)"
+            res_8, newindex = errorhandler(
+                errors, 'latin1', msg, s, startindex, index)
+            for cp in rutf8.Utf8StringIterator(res_8):
+                if cp > 0xFF:
+                    errorhandler("strict", 'latin1', msg, s, startindex, index)
+                result.append(chr(cp))
+            if index != newindex:  # Should be uncommon
+                index = newindex
+                pos = rutf8._pos_at_index(s, newindex)
+    return result.build()
 
-            for c in rutf8.Utf8StringIterator(r):
-                if c > 0xFF:
-                    errorhandler("strict", 'latin1',
-                                 'ordinal not in range(256)', s,
-                                 cur, cur + 1)
-                res.append(chr(c))
-
-            for j in range(pos - cur - 1):
-                iter.next()
-
-            cur = pos
-    r = res.build()
-    return r
-
-def utf8_encode_ascii(utf8, errors, errorhandler):
+def utf8_encode_ascii(s, errors, errorhandler):
     """ Don't be confused - this is a slowpath for errors e.g. "ignore"
     or an obscure errorhandler
     """
-    res = StringBuilder()
-    i = 0
+    size = len(s)
+    result = StringBuilder(size)
+    index = 0
     pos = 0
-    while i < len(utf8):
-        ch = rutf8.codepoint_at_pos(utf8, i)
-        if ch > 0x7F:
-            endpos = pos + 1
-            end_i = rutf8.next_codepoint_pos(utf8, i)
-            while end_i < len(utf8) and rutf8.codepoint_at_pos(utf8, end_i) > 
0x7F:
-                endpos += 1
-                end_i = rutf8.next_codepoint_pos(utf8, end_i)
+    while pos < size:
+        ch = rutf8.codepoint_at_pos(s, pos)
+        if ch <= 0x7F:
+            result.append(chr(ch))
+            index += 1
+            pos = rutf8.next_codepoint_pos(s, pos)
+        else:
+            startindex = index
+            pos = rutf8.next_codepoint_pos(s, pos)
+            index += 1
+            while pos < size and rutf8.codepoint_at_pos(s, pos) > 0x7F:
+                pos = rutf8.next_codepoint_pos(s, pos)
+                index += 1
             msg = "ordinal not in range(128)"
-            r, newpos = errorhandler(errors, 'ascii', msg, utf8,
-                pos, endpos)
-            for j in range(newpos - pos):
-                i = rutf8.next_codepoint_pos(utf8, i)
-
-            j = 0
-            while j < len(r):
-                c = rutf8.codepoint_at_pos(r, j)
-                if c > 0x7F:
-                    errorhandler("strict", 'ascii',
-                                 'ordinal not in range(128)', utf8,
-                                 pos, pos + 1)
-                j = rutf8.next_codepoint_pos(r, j)
-            pos = newpos
-            res.append(r)
-        else:
-            res.append(chr(ch))
-            i = rutf8.next_codepoint_pos(utf8, i)
-            pos += 1
-
-    s = res.build()
-    return s
+            res_8, newindex = errorhandler(
+                errors, 'ascii', msg, s, startindex, index)
+            for cp in rutf8.Utf8StringIterator(res_8):
+                if cp > 0x7F:
+                    errorhandler("strict", 'ascii', msg, s, startindex, index)
+                result.append(chr(cp))
+            if index != newindex:  # Should be uncommon
+                index = newindex
+                pos = rutf8._pos_at_index(s, newindex)
+    return result.build()
 
 def str_decode_utf8(s, errors, final, errorhandler):
     """ Same as checking for the valid utf8, but we know the utf8 is not
diff --git a/pypy/module/_codecs/test/test_codecs.py 
b/pypy/module/_codecs/test/test_codecs.py
--- a/pypy/module/_codecs/test/test_codecs.py
+++ b/pypy/module/_codecs/test/test_codecs.py
@@ -760,3 +760,25 @@
         assert r == '&#4660;\x80&#9029;y\xab'
         r = u'\u1234\u0080\u2345\u0079\u00AB'.encode('ascii', 
'xmlcharrefreplace')
         assert r == '&#4660;&#128;&#9029;y&#171;'
+
+    def test_errorhandler_collection(self):
+        import _codecs
+        errors = []
+        def record_error(exc):
+            if not isinstance(exc, UnicodeEncodeError):
+                raise TypeError("don't know how to handle %r" % exc)
+            errors.append(exc.object[exc.start:exc.end])
+            return (u'', exc.end)
+        _codecs.register_error("test.record", record_error)
+
+        sin = u"\xac\u1234\u1234\u20ac\u8000"
+        assert sin.encode("ascii", "test.record") == ""
+        assert errors == [sin]
+
+        errors = []
+        assert sin.encode("latin-1", "test.record") == "\xac"
+        assert errors == [u'\u1234\u1234\u20ac\u8000']
+
+        errors = []
+        assert sin.encode("iso-8859-15", "test.record") == "\xac\xa4"
+        assert errors == [u'\u1234\u1234', u'\u8000']
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to