Author: Matti Picus <matti.pi...@gmail.com>
Branch: unicode-utf8-py3
Changeset: r95339:9dc3de0b34d5
Date: 2018-11-18 18:36 -0800
http://bitbucket.org/pypy/pypy/changeset/9dc3de0b34d5/

Log:    distinguish between errorhandler returning unicode and bytes

diff --git a/pypy/interpreter/unicodehelper.py 
b/pypy/interpreter/unicodehelper.py
--- a/pypy/interpreter/unicodehelper.py
+++ b/pypy/interpreter/unicodehelper.py
@@ -27,7 +27,7 @@
 def decode_never_raise(errors, encoding, msg, s, startingpos, endingpos):
     assert startingpos >= 0
     ux = ['\ux' + hex(ord(x))[2:].upper() for x in s[startingpos:endingpos]]
-    return ''.join(ux), endingpos
+    return ''.join(ux), endingpos, 'b'
 
 @specialize.memo()
 def encode_error_handler(space):
@@ -199,7 +199,7 @@
     while i < len(s):
         ch = s[i]
         if ord(ch) > 0x7F:
-            r, i = errorhandler(errors, 'ascii', 'ordinal not in range(128)',
+            r, i, rettype = errorhandler(errors, 'ascii', 'ordinal not in 
range(128)',
                 s, i, i + 1)
             res.append(r)
         else:
@@ -242,7 +242,7 @@
         assert pos >= 0
         start = s[:pos]
         upos = rutf8.codepoints_in_utf8(s, end=pos)
-        ru, lgt = errorhandler(errors, 'utf8',
+        ru, lgt, rettype = errorhandler(errors, 'utf8',
                     'surrogates not allowed', s, upos, upos + 1)
         end = utf8_encode_utf_8(s[pos+3:], errors, errorhandler,
                                 allow_surrogates=allow_surrogates)
@@ -275,13 +275,20 @@
                 pos = rutf8.next_codepoint_pos(s, pos)
                 index += 1
             msg = "ordinal not in range(256)"
-            res_8, newindex = errorhandler(
+            res, newindex, rettype = errorhandler(
                 errors, 'latin1', msg, s, startindex, index)
-            for ch in res_8:
-                if ord(ch) > 0xFF:
-                    errorhandler("strict", 'latin1', msg, s, startindex, index)
-                    raise RuntimeError('error handler should not have 
returned')
-                result.append(ch)
+            if rettype == 'u':
+                for cp in rutf8.Utf8StringIterator(res):
+                    if cp > 0xFF:
+                        errorhandler("strict", 'latin1', msg, s, startindex, 
index)
+                        raise RuntimeError('error handler should not have 
returned')
+                    result.append(chr(cp))
+            else:
+                for ch in res:
+                    if ord(ch) > 0xFF:
+                        errorhandler("strict", 'latin1', msg, s, startindex, 
index)
+                        raise RuntimeError('error handler should not have 
returned')
+                    result.append(ch)
             if index != newindex:  # Should be uncommon
                 index = newindex
                 pos = rutf8._pos_at_index(s, newindex)
@@ -309,9 +316,20 @@
                 pos = rutf8.next_codepoint_pos(s, pos)
                 index += 1
             msg = "ordinal not in range(128)"
-            res_8, newindex = errorhandler(
+            res, newindex, rettype = errorhandler(
                 errors, 'ascii', msg, s, startindex, index)
-            result.append(res_8)
+            if rettype == 'u':
+                for cp in rutf8.Utf8StringIterator(res):
+                    if cp > 0x80:
+                        errorhandler("strict", 'ascii', msg, s, startindex, 
index)
+                        raise RuntimeError('error handler should not have 
returned')
+                    result.append(chr(cp))
+            else:
+                for ch in res:
+                    if ord(ch) > 0x80:
+                        errorhandler("strict", 'ascii', msg, s, startindex, 
index)
+                        raise RuntimeError('error handler should not have 
returned')
+                    result.append(ch)
             pos = rutf8._pos_at_index(s, newindex)
     return result.build()
 
@@ -346,7 +364,7 @@
             continue
 
         if ordch1 <= 0xC1:
-            r, pos = errorhandler(errors, "utf8", "invalid start byte",
+            r, pos, rettype = errorhandler(errors, "utf8", "invalid start 
byte",
                     s, pos, pos + 1)
             res.append(r)
             continue
@@ -358,14 +376,14 @@
                 if not final:
                     pos -= 1
                     break
-                r, pos = errorhandler(errors, "utf8", "unexpected end of data",
+                r, pos, rettype = errorhandler(errors, "utf8", "unexpected end 
of data",
                     s, pos - 1, pos)
                 res.append(r)
                 continue
             ordch2 = ord(s[pos])
 
             if rutf8._invalid_byte_2_of_2(ordch2):
-                r, pos = errorhandler(errors, "utf8", "invalid continuation 
byte",
+                r, pos, rettype = errorhandler(errors, "utf8", "invalid 
continuation byte",
                     s, pos - 1, pos)
                 res.append(r)
                 continue
@@ -380,7 +398,7 @@
                 if not final:
                     pos -= 1
                     break
-                r, pos = errorhandler(errors, "utf8", "unexpected end of data",
+                r, pos, rettype = errorhandler(errors, "utf8", "unexpected end 
of data",
                     s, pos - 1, pos + 1)
                 res.append(r)
                 continue
@@ -388,12 +406,12 @@
             ordch3 = ord(s[pos + 1])
 
             if rutf8._invalid_byte_2_of_3(ordch1, ordch2, allow_surrogates):
-                r, pos = errorhandler(errors, "utf8", "invalid continuation 
byte",
+                r, pos, rettype = errorhandler(errors, "utf8", "invalid 
continuation byte",
                     s, pos - 1, pos)
                 res.append(r)
                 continue
             elif rutf8._invalid_byte_3_of_3(ordch3):
-                r, pos = errorhandler(errors, "utf8", "invalid continuation 
byte",
+                r, pos, rettype = errorhandler(errors, "utf8", "invalid 
continuation byte",
                     s, pos - 1, pos + 1)
                 res.append(r)
                 continue
@@ -410,26 +428,25 @@
                 if not final:
                     pos -= 1
                     break
-                r, pos = errorhandler(errors, "utf8", "unexpected end of data",
+                r, pos, rettype = errorhandler(errors, "utf8", "unexpected end 
of data",
                     s, pos - 1, pos)
-                res.append(r)
                 continue
             ordch2 = ord(s[pos])
             ordch3 = ord(s[pos + 1])
             ordch4 = ord(s[pos + 2])
 
             if rutf8._invalid_byte_2_of_4(ordch1, ordch2):
-                r, pos = errorhandler(errors, "utf8", "invalid continuation 
byte",
+                r, pos, rettype = errorhandler(errors, "utf8", "invalid 
continuation byte",
                     s, pos - 1, pos)
                 res.append(r)
                 continue
             elif rutf8._invalid_byte_3_of_4(ordch3):
-                r, pos = errorhandler(errors, "utf8", "invalid continuation 
byte",
+                r, pos, rettype = errorhandler(errors, "utf8", "invalid 
continuation byte",
                     s, pos - 1, pos + 1)
                 res.append(r)
                 continue
             elif rutf8._invalid_byte_4_of_4(ordch4):
-                r, pos = errorhandler(errors, "utf8", "invalid continuation 
byte",
+                r, pos, rettype = errorhandler(errors, "utf8", "invalid 
continuation byte",
                     s, pos - 1, pos + 2)
                 res.append(r)
                 continue
@@ -442,7 +459,7 @@
             res.append(chr(ordch4))
             continue
 
-        r, pos = errorhandler(errors, "utf8", "invalid start byte",
+        r, pos, rettype = errorhandler(errors, "utf8", "invalid start byte",
                 s, pos - 1, pos)
         res.append(r)
 
@@ -458,9 +475,9 @@
         endinpos = pos
         while endinpos < len(s) and s[endinpos] in hexdigits:
             endinpos += 1
-        res, pos = errorhandler(
+        r, pos, rettype = errorhandler(
             errors, encoding, message, s, pos - 2, endinpos)
-        builder.append(res)
+        builder.append(r)
     else:
         try:
             chr = int(s[pos:pos + digits], 16)
@@ -468,9 +485,9 @@
             endinpos = pos
             while s[endinpos] in hexdigits:
                 endinpos += 1
-            res, pos = errorhandler(
+            r, pos, rettype = errorhandler(
                 errors, encoding, message, s, pos - 2, endinpos)
-            builder.append(res)
+            builder.append(r)
         else:
             # when we get here, chr is a 32-bit unicode character
             try:
@@ -478,9 +495,9 @@
                 pos += digits
             except ValueError:
                 message = "illegal Unicode character"
-                res, pos = errorhandler(
+                r, pos, rettype = errorhandler(
                     errors, encoding, message, s, pos - 2, pos + digits)
-                builder.append(res)
+                builder.append(r)
     return pos
 
 def str_decode_unicode_escape(s, errors, final, errorhandler, ud_handler):
@@ -506,9 +523,9 @@
         pos += 1
         if pos >= size:
             message = "\\ at end of string"
-            res, pos = errorhandler(errors, "unicodeescape",
+            r, pos, rettype = errorhandler(errors, "unicodeescape",
                                     message, s, pos - 1, size)
-            builder.append(res)
+            builder.append(r)
             continue
 
         ch = s[pos]
@@ -586,21 +603,21 @@
                     name = s[pos + 1:look]
                     code = ud_handler.call(name)
                     if code < 0:
-                        res, pos = errorhandler(
+                        r, pos, rettype = errorhandler(
                             errors, "unicodeescape", message,
                             s, pos - 1, look + 1)
-                        builder.append(res)
+                        builder.append(r)
                         continue
                     pos = look + 1
                     builder.append_code(code)
                 else:
-                    res, pos = errorhandler(errors, "unicodeescape",
+                    r, pos, rettype = errorhandler(errors, "unicodeescape",
                                             message, s, pos - 1, look + 1)
-                    builder.append(res)
+                    builder.append(r)
             else:
-                res, pos = errorhandler(errors, "unicodeescape",
+                r, pos, rettype = errorhandler(errors, "unicodeescape",
                                         message, s, pos - 1, look + 1)
-                builder.append(res)
+                builder.append(r)
         else:
             builder.append_char('\\')
             builder.append_code(ord(ch))
@@ -867,22 +884,22 @@
                         # We've seen at least one base-64 character
                         pos += 1
                         msg = "partial character in shift sequence"
-                        res, pos = errorhandler(errors, 'utf7',
+                        r, pos, rettype = errorhandler(errors, 'utf7',
                                                 msg, s, pos-1, pos)
-                        reslen = rutf8.check_utf8(res, True)
+                        reslen = rutf8.check_utf8(r, True)
                         outsize += reslen
-                        result.append(res)
+                        result.append(r)
                         continue
                     else:
                         # Some bits remain; they should be zero
                         if base64buffer != 0:
                             pos += 1
                             msg = "non-zero padding bits in shift sequence"
-                            res, pos = errorhandler(errors, 'utf7',
+                            r, pos, rettype = errorhandler(errors, 'utf7',
                                                     msg, s, pos-1, pos)
-                            reslen = rutf8.check_utf8(res, True)
+                            reslen = rutf8.check_utf8(r, True)
                             outsize += reslen
-                            result.append(res)
+                            result.append(r)
                             continue
 
                 if surrogate and _utf7_DECODE_DIRECT(ord(ch)):
@@ -917,10 +934,10 @@
             startinpos = pos
             pos += 1
             msg = "unexpected special character"
-            res, pos = errorhandler(errors, 'utf7', msg, s, pos-1, pos)
-            reslen = rutf8.check_utf8(res, True)
+            r, pos, rettype = errorhandler(errors, 'utf7', msg, s, pos-1, pos)
+            reslen = rutf8.check_utf8(r, True)
             outsize += reslen
-            result.append(res)
+            result.append(r)
 
     # end of string
     final_length = result.getlength()
@@ -931,10 +948,10 @@
             base64bits >= 6 or
             (base64bits > 0 and base64buffer != 0)):
             msg = "unterminated shift sequence"
-            res, pos = errorhandler(errors, 'utf7', msg, s, shiftOutStartPos, 
pos)
-            reslen = rutf8.check_utf8(res, True)
+            r, pos, rettype = errorhandler(errors, 'utf7', msg, s, 
shiftOutStartPos, pos)
+            reslen = rutf8.check_utf8(r, True)
             outsize += reslen
-            result.append(res)
+            result.append(r)
             final_length = result.getlength()
     elif inShift:
         pos = startinpos
@@ -1101,7 +1118,7 @@
         if len(s) - pos < 2:
             if not final:
                 break
-            r, pos = errorhandler(errors, public_encoding_name,
+            r, pos, rettype = errorhandler(errors, public_encoding_name,
                                   "truncated data",
                                   s, pos, len(s))
             result.append(r)
@@ -1118,7 +1135,7 @@
             if not final:
                 break
             errmsg = "unexpected end of data"
-            r, pos = errorhandler(errors, public_encoding_name,
+            r, pos, rettype = errorhandler(errors, public_encoding_name,
                                   errmsg, s, pos, len(s))
             result.append(r)
             if len(s) - pos < 2:
@@ -1131,12 +1148,12 @@
                 rutf8.unichr_as_utf8_append(result, ch)
                 continue
             else:
-                r, pos = errorhandler(errors, public_encoding_name,
+                r, pos, rettype = errorhandler(errors, public_encoding_name,
                                       "illegal UTF-16 surrogate",
                                       s, pos - 4, pos - 2)
                 result.append(r)
         else:
-            r, pos = errorhandler(errors, public_encoding_name,
+            r, pos, rettype = errorhandler(errors, public_encoding_name,
                                   "illegal encoding",
                                   s, pos - 2, pos)
             result.append(r)
@@ -1176,44 +1193,62 @@
     index = 0
     while pos < size:
         try:
-            ch = rutf8.codepoint_at_pos(s, pos)
+            cp = rutf8.codepoint_at_pos(s, pos)
         except IndexError:
             # malformed codepoint, blindly use ch
-            ch = ord(s[pos])
             pos += 1
             if errorhandler:
-                res_8, newindex = errorhandler(
+                r, newindex, rettype = errorhandler(
                     errors, public_encoding_name, 'malformed unicode',
                     s, pos - 1, pos)
-                for cp in rutf8.Utf8StringIterator(res_8):
-                    if cp < 0xD800:
+                if rettype == 'u':
+                    for cp in rutf8.Utf8StringIterator(r):
+                        if cp < 0xD800:
+                            _STORECHAR(result, cp, byteorder)
+                        else:
+                            errorhandler('strict', public_encoding_name,
+                                         'malformed unicode',
+                                     s, pos-1, pos)
+                else:
+                    for ch in r:
+                        cp = ord(ch)
+                        if cp < 0xD800:
+                            _STORECHAR(result, cp, byteorder)
+                        else:
+                            errorhandler('strict', public_encoding_name,
+                                         'malformed unicode',
+                                     s, pos-1, pos)
+            else:
+                cp = ord(s[pos])
+                _STORECHAR(result, cp, byteorder)
+            continue
+        if cp < 0xD800:
+            _STORECHAR(result, cp, byteorder)
+        elif cp >= 0x10000:
+            _STORECHAR(result, 0xD800 | ((cp-0x10000) >> 10), byteorder)
+            _STORECHAR(result, 0xDC00 | ((cp-0x10000) & 0x3FF), byteorder)
+        elif cp >= 0xE000 or allow_surrogates:
+            _STORECHAR(result, cp, byteorder)
+        else:
+            r, newindex, rettype = errorhandler(
+                errors, public_encoding_name, 'surrogates not allowed',
+                s, pos, pos+1)
+            if rettype == 'u':
+                for cp in rutf8.Utf8StringIterator(r):
+                    if cp < 0xD800 or allow_surrogates:
                         _STORECHAR(result, cp, byteorder)
                     else:
                         errorhandler('strict', public_encoding_name,
-                                     'malformed unicode',
-                                 s, pos-1, pos)
+                                     'surrogates not allowed',
+                                     s, pos, pos+1)
             else:
-                _STORECHAR(result, ch, byteorder)
-                continue
-        if ch < 0xD800:
-            _STORECHAR(result, ch, byteorder)
-        elif ch >= 0x10000:
-            _STORECHAR(result, 0xD800 | ((ch-0x10000) >> 10), byteorder)
-            _STORECHAR(result, 0xDC00 | ((ch-0x10000) & 0x3FF), byteorder)
-        elif ch >= 0xE000 or allow_surrogates:
-            _STORECHAR(result, ch, byteorder)
-        else:
-            res_8, newindex = errorhandler(
-                errors, public_encoding_name, 'surrogates not allowed',
-                s, pos, pos+1)
-            #for cp in rutf8.Utf8StringIterator(res_8):
-            for ch in res_8:
-                cp = ord(ch)
-                if cp < 0xD800 or allow_surrogates:
-                    _STORECHAR(result, cp, byteorder)
-                else:
-                    errorhandler('strict', public_encoding_name,
-                                 'surrogates not allowed',
+                for ch in r:
+                    cp = ord(ch)
+                    if cp < 0xD800 or allow_surrogates:
+                        _STORECHAR(result, cp, byteorder)
+                    else:
+                        errorhandler('strict', public_encoding_name,
+                                     'surrogates not allowed',
                                  s, pos, pos+1)
             if index != newindex:  # Should be uncommon
                 index = newindex
@@ -1329,7 +1364,7 @@
         if len(s) - pos < 4:
             if not final:
                 break
-            r, pos = errorhandler(errors, public_encoding_name,
+            r, pos, rettype = errorhandler(errors, public_encoding_name,
                                   "truncated data",
                                   s, pos, len(s))
             result.append(r)
@@ -1339,14 +1374,14 @@
         ch = ((ord(s[pos + iorder[3]]) << 24) | (ord(s[pos + iorder[2]]) << 
16) |
               (ord(s[pos + iorder[1]]) << 8)  | ord(s[pos + iorder[0]]))
         if not allow_surrogates and 0xD800 <= ch <= 0xDFFF:
-            r, pos = errorhandler(errors, public_encoding_name,
+            r, pos, rettype = errorhandler(errors, public_encoding_name,
                                   "code point in surrogate code point "
                                   "range(0xd800, 0xe000)",
                                   s, pos, pos + 4)
             result.append(r)
             continue
         elif ch >= 0x110000:
-            r, pos = errorhandler(errors, public_encoding_name,
+            r, pos, rettype = errorhandler(errors, public_encoding_name,
                                   "codepoint not in range(0x110000)",
                                   s, pos, len(s))
             result.append(r)
@@ -1404,11 +1439,20 @@
             ch = ord(s[pos])
             pos += 1
             if errorhandler:
-                res_8, newindex = errorhandler(
+                r, newindex, rettype = errorhandler(
                     errors, public_encoding_name, 'malformed unicode',
                     s, index, index+1)
-                if res_8:
-                    for cp in rutf8.Utf8StringIterator(res_8):
+                if rettype == 'u' and r:
+                    for cp in rutf8.Utf8StringIterator(r):
+                        if cp < 0xD800:
+                            _STORECHAR32(result, cp, byteorder)
+                        else:
+                            errorhandler('strict', public_encoding_name,
+                                     'malformed unicode',
+                                 s, index, index+1)
+                elif r:
+                    for ch in r:
+                        cp = ord(ch)
                         if cp < 0xD800:
                             _STORECHAR32(result, cp, byteorder)
                         else:
@@ -1422,16 +1466,26 @@
             index += 1
             continue
         if not allow_surrogates and 0xD800 <= ch < 0xE000:
-            res_8, newindex = errorhandler(
+            r, newindex, rettype = errorhandler(
                 errors, public_encoding_name, 'surrogates not allowed',
                 s, index, index+1)
-            for ch in rutf8.Utf8StringIterator(res_8):
-                if ch < 0xD800:
-                    _STORECHAR32(result, ch, byteorder)
-                else:
-                    errorhandler(
-                        'strict', public_encoding_name, 'surrogates not 
allowed',
-                        s, index, index+1)
+            if rettype == 'u':
+                for ch in rutf8.Utf8StringIterator(r):
+                    if ch < 0xD800:
+                        _STORECHAR32(result, ch, byteorder)
+                    else:
+                        errorhandler(
+                            'strict', public_encoding_name, 'surrogates not 
allowed',
+                            s, index, index+1)
+            else:
+                for ch in r:
+                    cp = ord(ch)
+                    if cp < 0xD800:
+                        _STORECHAR32(result, cp, byteorder)
+                    else:
+                        errorhandler(
+                            'strict', public_encoding_name, 'surrogates not 
allowed',
+                            s, index, index+1)
             if index != newindex:  # Should be uncommon
                 index = newindex
                 pos = rutf8._pos_at_index(s, newindex)
@@ -1471,11 +1525,20 @@
             ch = ord(s[pos])
             pos += 1
             if errorhandler:
-                res_8, newindex = errorhandler(
+                r, newindex, rettype = errorhandler(
                     errors, public_encoding_name, 'malformed unicode',
                     s, pos - 1, pos)
-                if res_8:
-                    for cp in rutf8.Utf8StringIterator(res_8):
+                if rettype == 'u' and r:
+                    for cp in rutf8.Utf8StringIterator(r):
+                        if cp < 0xD800:
+                            _STORECHAR32(result, cp, byteorder)
+                        else:
+                            errorhandler('strict', public_encoding_name,
+                                     'malformed unicode',
+                                 s, pos-1, pos)
+                elif r:
+                    for ch in r:
+                        cp = ord(ch)
                         if cp < 0xD800:
                             _STORECHAR32(result, cp, byteorder)
                         else:
@@ -1489,16 +1552,26 @@
             index += 1
             continue
         if not allow_surrogates and 0xD800 <= ch < 0xE000:
-            res_8, newindex = errorhandler(
+            r, newindex, rettype = errorhandler(
                 errors, public_encoding_name, 'surrogates not allowed',
                 s, pos - 1, pos)
-            for ch in rutf8.Utf8StringIterator(res_8):
-                if ch < 0xD800:
-                    _STORECHAR32(result, ch, byteorder)
-                else:
-                    errorhandler(
-                        'strict', public_encoding_name, 'surrogates not 
allowed',
-                        s, pos - 1, pos)
+            if rettype == 'u':
+                for ch in rutf8.Utf8StringIterator(res_8):
+                    if ch < 0xD800:
+                        _STORECHAR32(result, ch, byteorder)
+                    else:
+                        errorhandler(
+                            'strict', public_encoding_name, 'surrogates not 
allowed',
+                            s, pos - 1, pos)
+            else:
+                for ch in res_8:
+                    cp = ord(ch)
+                    if cp < 0xD800:
+                        _STORECHAR32(result, cp, byteorder)
+                    else:
+                        errorhandler(
+                            'strict', public_encoding_name, 'surrogates not 
allowed',
+                            s, pos - 1, pos)
             if index != newindex:  # Should be uncommon
                 index = newindex
                 pos = rutf8._pos_at_index(s, newindex)
@@ -1551,10 +1624,10 @@
     pos = 0
     while pos < size:
         if pos > size - unicode_bytes:
-            res, pos = errorhandler(errors, "unicode_internal",
+            r, pos, rettype = errorhandler(errors, "unicode_internal",
                                     "truncated input",
                                     s, pos, size)
-            result.append(res)
+            result.append(r)
             continue
         t = r_uint(0)
         h = 0
@@ -1562,10 +1635,10 @@
             t += r_uint(ord(s[pos + j])) << (h*8)
             h += 1
         if t > runicode.MAXUNICODE:
-            res, pos = errorhandler(errors, "unicode_internal",
+            r, pos, rettype = errorhandler(errors, "unicode_internal",
                                     "unichr(%d) not in range" % (t,),
                                     s, pos, pos + unicode_bytes)
-            result.append(res)
+            result.append(r)
             continue
         rutf8.unichr_as_utf8_append(result, intmask(t), allow_surrogates=True)
         pos += unicode_bytes
@@ -1627,7 +1700,7 @@
 
         c = mapping.get(ord(ch), ERROR_CHAR)
         if c == ERROR_CHAR:
-            r, pos = errorhandler(errors, "charmap",
+            r, pos, rettype = errorhandler(errors, "charmap",
                                   "character maps to <undefined>",
                                   s,  pos, pos + 1)
             result.append(r)
@@ -1659,10 +1732,10 @@
                    mapping.get(rutf8.codepoint_at_pos(s, pos), '') == ''):
                 pos = rutf8.next_codepoint_pos(s, pos)
                 index += 1
-            res_8, newindex = errorhandler(errors, "charmap",
+            r, newindex, rettype = errorhandler(errors, "charmap",
                                    "character maps to <undefined>",
                                    s, startindex, index)
-            for cp2 in rutf8.Utf8StringIterator(res_8):
+            for cp2 in rutf8.Utf8StringIterator(r):
                 ch2 = mapping.get(cp2, '')
                 if not ch2:
                     errorhandler(
@@ -1727,7 +1800,7 @@
             i += 1
         end_index = i
         msg = "invalid decimal Unicode string"
-        r, pos = errorhandler(
+        r, pos, retype = errorhandler(
             errors, 'decimal', msg, s, start_index, end_index)
         for ch in rutf8.Utf8StringIterator(r):
             if unicodedb.isspace(ch):
diff --git a/pypy/module/_codecs/interp_codecs.py 
b/pypy/module/_codecs/interp_codecs.py
--- a/pypy/module/_codecs/interp_codecs.py
+++ b/pypy/module/_codecs/interp_codecs.py
@@ -72,8 +72,11 @@
                 raise OperationError(space.w_TypeError, space.newtext(msg))
 
             w_replace, w_newpos = space.fixedview(w_res, 2)
-            if not (space.isinstance_w(w_replace, space.w_unicode) or
-                (not decode and space.isinstance_w(w_replace, space.w_bytes))):
+            if space.isinstance_w(w_replace, space.w_unicode):
+                rettype = 'u'
+            elif encode and space.isinstance_w(w_replace, space.w_bytes):
+                rettype = 'b'
+            else:
                 if decode:
                     msg = ("decoding error handler must return "
                            "(str, int) tuple")
@@ -94,7 +97,7 @@
                 raise oefmt(space.w_IndexError,
                             "position %d from error handler out of bounds",
                             newpos)
-            return space.utf8_w(w_replace), newpos
+            return space.utf8_w(w_replace), newpos, rettype
         return call_errorhandler
 
     def make_decode_errorhandler(self, space):
diff --git a/pypy/module/_codecs/test/test_codecs.py 
b/pypy/module/_codecs/test/test_codecs.py
--- a/pypy/module/_codecs/test/test_codecs.py
+++ b/pypy/module/_codecs/test/test_codecs.py
@@ -826,11 +826,14 @@
         repl = "\u00E9"
         s = "\u5678".encode("latin-1", "test.bad_handler")
         assert s == b'\xe9'
+        raises(UnicodeEncodeError, "\u5678".encode, "ascii",
+               "test.bad_handler")
 
     def test_lone_surrogates(self):
         encodings = ('utf-8', 'utf-16', 'utf-16-le', 'utf-16-be',
             'utf-32', 'utf-32-le', 'utf-32-be')
         for encoding in encodings:
+            print('encoding', encoding)
             raises(UnicodeEncodeError, u'\ud800'.encode, encoding)
             assert (u'[\udc80]'.encode(encoding, "backslashreplace") ==
                 '[\\udc80]'.encode(encoding))
diff --git a/pypy/objspace/std/unicodeobject.py 
b/pypy/objspace/std/unicodeobject.py
--- a/pypy/objspace/std/unicodeobject.py
+++ b/pypy/objspace/std/unicodeobject.py
@@ -1200,14 +1200,16 @@
     if errors is None or errors == 'strict':
         utf8 = space.utf8_w(w_object)
         if encoding is None or encoding == 'utf-8':
-            #if rutf8.has_surrogates(utf8):
-            #    utf8 = rutf8.reencode_utf8_with_surrogates(utf8)
+            if rutf8.has_surrogates(utf8):
+                # slow path
+                return encode_text(space, w_object, encoding, errors)
             return space.newbytes(utf8)
         elif encoding == 'ascii':
             try:
                 rutf8.check_ascii(utf8)
             except rutf8.CheckError as a:
-                eh = unicodehelper.encode_error_handler(space)
+                state = space.fromcache(CodecState)
+                eh = state.encode_error_handler
                 eh(None, "ascii", "ordinal not in range(128)", utf8,
                     a.pos, a.pos + 1)
                 assert False, "always raises"
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to