https://github.com/python/cpython/commit/649857a1574a02235ccfac9e2ac1c12914cf8fe0
commit: 649857a1574a02235ccfac9e2ac1c12914cf8fe0
branch: main
author: John Sloboda <slob...@gmail.com>
committer: methane <songofaca...@gmail.com>
date: 2024-03-17T04:58:42Z
summary:

gh-85287: Change codecs to raise precise UnicodeEncodeError and 
UnicodeDecodeError (#113674)

Co-authored-by: Inada Naoki <songofaca...@gmail.com>

files:
A Misc/NEWS.d/next/Library/2024-01-02-22-47-12.gh-issue-85287.ZC5DLj.rst
M Lib/encodings/idna.py
M Lib/encodings/punycode.py
M Lib/encodings/undefined.py
M Lib/encodings/utf_16.py
M Lib/encodings/utf_32.py
M Lib/test/test_codecs.py
M Lib/test/test_multibytecodec.py
M Modules/cjkcodecs/multibytecodec.c

diff --git a/Lib/encodings/idna.py b/Lib/encodings/idna.py
index d0f70c00f0ab66..60a8d5eb227f82 100644
--- a/Lib/encodings/idna.py
+++ b/Lib/encodings/idna.py
@@ -11,7 +11,7 @@
 sace_prefix = "xn--"
 
 # This assumes query strings, so AllowUnassigned is true
-def nameprep(label):
+def nameprep(label):  # type: (str) -> str
     # Map
     newlabel = []
     for c in label:
@@ -25,7 +25,7 @@ def nameprep(label):
     label = unicodedata.normalize("NFKC", label)
 
     # Prohibit
-    for c in label:
+    for i, c in enumerate(label):
         if stringprep.in_table_c12(c) or \
            stringprep.in_table_c22(c) or \
            stringprep.in_table_c3(c) or \
@@ -35,7 +35,7 @@ def nameprep(label):
            stringprep.in_table_c7(c) or \
            stringprep.in_table_c8(c) or \
            stringprep.in_table_c9(c):
-            raise UnicodeError("Invalid character %r" % c)
+            raise UnicodeEncodeError("idna", label, i, i+1, f"Invalid 
character {c!r}")
 
     # Check bidi
     RandAL = [stringprep.in_table_d1(x) for x in label]
@@ -46,29 +46,38 @@ def nameprep(label):
         # This is table C.8, which was already checked
         # 2) If a string contains any RandALCat character, the string
         # MUST NOT contain any LCat character.
-        if any(stringprep.in_table_d2(x) for x in label):
-            raise UnicodeError("Violation of BIDI requirement 2")
+        for i, x in enumerate(label):
+            if stringprep.in_table_d2(x):
+                raise UnicodeEncodeError("idna", label, i, i+1,
+                                         "Violation of BIDI requirement 2")
         # 3) If a string contains any RandALCat character, a
         # RandALCat character MUST be the first character of the
         # string, and a RandALCat character MUST be the last
         # character of the string.
-        if not RandAL[0] or not RandAL[-1]:
-            raise UnicodeError("Violation of BIDI requirement 3")
+        if not RandAL[0]:
+            raise UnicodeEncodeError("idna", label, 0, 1,
+                                     "Violation of BIDI requirement 3")
+        if not RandAL[-1]:
+            raise UnicodeEncodeError("idna", label, len(label)-1, len(label),
+                                     "Violation of BIDI requirement 3")
 
     return label
 
-def ToASCII(label):
+def ToASCII(label):  # type: (str) -> bytes
     try:
         # Step 1: try ASCII
-        label = label.encode("ascii")
-    except UnicodeError:
+        label_ascii = label.encode("ascii")
+    except UnicodeEncodeError:
         pass
     else:
         # Skip to step 3: UseSTD3ASCIIRules is false, so
         # Skip to step 8.
-        if 0 < len(label) < 64:
-            return label
-        raise UnicodeError("label empty or too long")
+        if 0 < len(label_ascii) < 64:
+            return label_ascii
+        if len(label) == 0:
+            raise UnicodeEncodeError("idna", label, 0, 1, "label empty")
+        else:
+            raise UnicodeEncodeError("idna", label, 0, len(label), "label too 
long")
 
     # Step 2: nameprep
     label = nameprep(label)
@@ -76,29 +85,34 @@ def ToASCII(label):
     # Step 3: UseSTD3ASCIIRules is false
     # Step 4: try ASCII
     try:
-        label = label.encode("ascii")
-    except UnicodeError:
+        label_ascii = label.encode("ascii")
+    except UnicodeEncodeError:
         pass
     else:
         # Skip to step 8.
         if 0 < len(label) < 64:
-            return label
-        raise UnicodeError("label empty or too long")
+            return label_ascii
+        if len(label) == 0:
+            raise UnicodeEncodeError("idna", label, 0, 1, "label empty")
+        else:
+            raise UnicodeEncodeError("idna", label, 0, len(label), "label too 
long")
 
     # Step 5: Check ACE prefix
-    if label[:4].lower() == sace_prefix:
-        raise UnicodeError("Label starts with ACE prefix")
+    if label.lower().startswith(sace_prefix):
+        raise UnicodeEncodeError(
+            "idna", label, 0, len(sace_prefix), "Label starts with ACE prefix")
 
     # Step 6: Encode with PUNYCODE
-    label = label.encode("punycode")
+    label_ascii = label.encode("punycode")
 
     # Step 7: Prepend ACE prefix
-    label = ace_prefix + label
+    label_ascii = ace_prefix + label_ascii
 
     # Step 8: Check size
-    if 0 < len(label) < 64:
-        return label
-    raise UnicodeError("label empty or too long")
+    # do not check for empty as we prepend ace_prefix.
+    if len(label_ascii) < 64:
+        return label_ascii
+    raise UnicodeEncodeError("idna", label, 0, len(label), "label too long")
 
 def ToUnicode(label):
     if len(label) > 1024:
@@ -110,7 +124,9 @@ def ToUnicode(label):
         # per https://www.rfc-editor.org/rfc/rfc3454#section-3.1 while still
         # preventing us from wasting time decoding a big thing that'll just
         # hit the actual <= 63 length limit in Step 6.
-        raise UnicodeError("label way too long")
+        if isinstance(label, str):
+            label = label.encode("utf-8", errors="backslashreplace")
+        raise UnicodeDecodeError("idna", label, 0, len(label), "label way too 
long")
     # Step 1: Check for ASCII
     if isinstance(label, bytes):
         pure_ascii = True
@@ -118,25 +134,32 @@ def ToUnicode(label):
         try:
             label = label.encode("ascii")
             pure_ascii = True
-        except UnicodeError:
+        except UnicodeEncodeError:
             pure_ascii = False
     if not pure_ascii:
+        assert isinstance(label, str)
         # Step 2: Perform nameprep
         label = nameprep(label)
         # It doesn't say this, but apparently, it should be ASCII now
         try:
             label = label.encode("ascii")
-        except UnicodeError:
-            raise UnicodeError("Invalid character in IDN label")
+        except UnicodeEncodeError as exc:
+            raise UnicodeEncodeError("idna", label, exc.start, exc.end,
+                                     "Invalid character in IDN label")
     # Step 3: Check for ACE prefix
-    if not label[:4].lower() == ace_prefix:
+    assert isinstance(label, bytes)
+    if not label.lower().startswith(ace_prefix):
         return str(label, "ascii")
 
     # Step 4: Remove ACE prefix
     label1 = label[len(ace_prefix):]
 
     # Step 5: Decode using PUNYCODE
-    result = label1.decode("punycode")
+    try:
+        result = label1.decode("punycode")
+    except UnicodeDecodeError as exc:
+        offset = len(ace_prefix)
+        raise UnicodeDecodeError("idna", label, offset+exc.start, 
offset+exc.end, exc.reason)
 
     # Step 6: Apply ToASCII
     label2 = ToASCII(result)
@@ -144,7 +167,8 @@ def ToUnicode(label):
     # Step 7: Compare the result of step 6 with the one of step 3
     # label2 will already be in lower case.
     if str(label, "ascii").lower() != str(label2, "ascii"):
-        raise UnicodeError("IDNA does not round-trip", label, label2)
+        raise UnicodeDecodeError("idna", label, 0, len(label),
+                                 f"IDNA does not round-trip, '{label!r}' != 
'{label2!r}'")
 
     # Step 8: return the result of step 5
     return result
@@ -156,7 +180,7 @@ def encode(self, input, errors='strict'):
 
         if errors != 'strict':
             # IDNA is quite clear that implementations must be strict
-            raise UnicodeError("unsupported error handling "+errors)
+            raise UnicodeError(f"Unsupported error handling: {errors}")
 
         if not input:
             return b'', 0
@@ -168,11 +192,16 @@ def encode(self, input, errors='strict'):
         else:
             # ASCII name: fast path
             labels = result.split(b'.')
-            for label in labels[:-1]:
-                if not (0 < len(label) < 64):
-                    raise UnicodeError("label empty or too long")
-            if len(labels[-1]) >= 64:
-                raise UnicodeError("label too long")
+            for i, label in enumerate(labels[:-1]):
+                if len(label) == 0:
+                    offset = sum(len(l) for l in labels[:i]) + i
+                    raise UnicodeEncodeError("idna", input, offset, offset+1,
+                                             "label empty")
+            for i, label in enumerate(labels):
+                if len(label) >= 64:
+                    offset = sum(len(l) for l in labels[:i]) + i
+                    raise UnicodeEncodeError("idna", input, offset, 
offset+len(label),
+                                             "label too long")
             return result, len(input)
 
         result = bytearray()
@@ -182,17 +211,27 @@ def encode(self, input, errors='strict'):
             del labels[-1]
         else:
             trailing_dot = b''
-        for label in labels:
+        for i, label in enumerate(labels):
             if result:
                 # Join with U+002E
                 result.extend(b'.')
-            result.extend(ToASCII(label))
+            try:
+                result.extend(ToASCII(label))
+            except (UnicodeEncodeError, UnicodeDecodeError) as exc:
+                offset = sum(len(l) for l in labels[:i]) + i
+                raise UnicodeEncodeError(
+                    "idna",
+                    input,
+                    offset + exc.start,
+                    offset + exc.end,
+                    exc.reason,
+                )
         return bytes(result+trailing_dot), len(input)
 
     def decode(self, input, errors='strict'):
 
         if errors != 'strict':
-            raise UnicodeError("Unsupported error handling "+errors)
+            raise UnicodeError(f"Unsupported error handling: {errors}")
 
         if not input:
             return "", 0
@@ -218,8 +257,15 @@ def decode(self, input, errors='strict'):
             trailing_dot = ''
 
         result = []
-        for label in labels:
-            result.append(ToUnicode(label))
+        for i, label in enumerate(labels):
+            try:
+                u_label = ToUnicode(label)
+            except (UnicodeEncodeError, UnicodeDecodeError) as exc:
+                offset = sum(len(x) for x in labels[:i]) + len(labels[:i])
+                raise UnicodeDecodeError(
+                    "idna", input, offset+exc.start, offset+exc.end, 
exc.reason)
+            else:
+                result.append(u_label)
 
         return ".".join(result)+trailing_dot, len(input)
 
@@ -227,7 +273,7 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
     def _buffer_encode(self, input, errors, final):
         if errors != 'strict':
             # IDNA is quite clear that implementations must be strict
-            raise UnicodeError("unsupported error handling "+errors)
+            raise UnicodeError(f"Unsupported error handling: {errors}")
 
         if not input:
             return (b'', 0)
@@ -251,7 +297,16 @@ def _buffer_encode(self, input, errors, final):
                 # Join with U+002E
                 result.extend(b'.')
                 size += 1
-            result.extend(ToASCII(label))
+            try:
+                result.extend(ToASCII(label))
+            except (UnicodeEncodeError, UnicodeDecodeError) as exc:
+                raise UnicodeEncodeError(
+                    "idna",
+                    input,
+                    size + exc.start,
+                    size + exc.end,
+                    exc.reason,
+                )
             size += len(label)
 
         result += trailing_dot
@@ -261,7 +316,7 @@ def _buffer_encode(self, input, errors, final):
 class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
     def _buffer_decode(self, input, errors, final):
         if errors != 'strict':
-            raise UnicodeError("Unsupported error handling "+errors)
+            raise UnicodeError("Unsupported error handling: {errors}")
 
         if not input:
             return ("", 0)
@@ -271,7 +326,11 @@ def _buffer_decode(self, input, errors, final):
             labels = dots.split(input)
         else:
             # Must be ASCII string
-            input = str(input, "ascii")
+            try:
+                input = str(input, "ascii")
+            except (UnicodeEncodeError, UnicodeDecodeError) as exc:
+                raise UnicodeDecodeError("idna", input,
+                                         exc.start, exc.end, exc.reason)
             labels = input.split(".")
 
         trailing_dot = ''
@@ -288,7 +347,18 @@ def _buffer_decode(self, input, errors, final):
         result = []
         size = 0
         for label in labels:
-            result.append(ToUnicode(label))
+            try:
+                u_label = ToUnicode(label)
+            except (UnicodeEncodeError, UnicodeDecodeError) as exc:
+                raise UnicodeDecodeError(
+                    "idna",
+                    input.encode("ascii", errors="backslashreplace"),
+                    size + exc.start,
+                    size + exc.end,
+                    exc.reason,
+                )
+            else:
+                result.append(u_label)
             if size:
                 size += 1
             size += len(label)
diff --git a/Lib/encodings/punycode.py b/Lib/encodings/punycode.py
index 1c5726447077b1..4622fc8c9206f3 100644
--- a/Lib/encodings/punycode.py
+++ b/Lib/encodings/punycode.py
@@ -1,4 +1,4 @@
-""" Codec for the Punicode encoding, as specified in RFC 3492
+""" Codec for the Punycode encoding, as specified in RFC 3492
 
 Written by Martin v. Löwis.
 """
@@ -131,10 +131,11 @@ def decode_generalized_number(extended, extpos, bias, 
errors):
     j = 0
     while 1:
         try:
-            char = ord(extended[extpos])
+            char = extended[extpos]
         except IndexError:
             if errors == "strict":
-                raise UnicodeError("incomplete punicode string")
+                raise UnicodeDecodeError("punycode", extended, extpos, 
extpos+1,
+                                         "incomplete punycode string")
             return extpos + 1, None
         extpos += 1
         if 0x41 <= char <= 0x5A: # A-Z
@@ -142,8 +143,8 @@ def decode_generalized_number(extended, extpos, bias, 
errors):
         elif 0x30 <= char <= 0x39:
             digit = char - 22 # 0x30-26
         elif errors == "strict":
-            raise UnicodeError("Invalid extended code point '%s'"
-                               % extended[extpos-1])
+            raise UnicodeDecodeError("punycode", extended, extpos-1, extpos,
+                                     f"Invalid extended code point 
'{extended[extpos-1]}'")
         else:
             return extpos, None
         t = T(j, bias)
@@ -155,11 +156,14 @@ def decode_generalized_number(extended, extpos, bias, 
errors):
 
 
 def insertion_sort(base, extended, errors):
-    """3.2 Insertion unsort coding"""
+    """3.2 Insertion sort coding"""
+    # This function raises UnicodeDecodeError with position in the extended.
+    # Caller should add the offset.
     char = 0x80
     pos = -1
     bias = 72
     extpos = 0
+
     while extpos < len(extended):
         newpos, delta = decode_generalized_number(extended, extpos,
                                                   bias, errors)
@@ -171,7 +175,9 @@ def insertion_sort(base, extended, errors):
         char += pos // (len(base) + 1)
         if char > 0x10FFFF:
             if errors == "strict":
-                raise UnicodeError("Invalid character U+%x" % char)
+                raise UnicodeDecodeError(
+                    "punycode", extended, pos-1, pos,
+                    f"Invalid character U+{char:x}")
             char = ord('?')
         pos = pos % (len(base) + 1)
         base = base[:pos] + chr(char) + base[pos:]
@@ -187,11 +193,21 @@ def punycode_decode(text, errors):
     pos = text.rfind(b"-")
     if pos == -1:
         base = ""
-        extended = str(text, "ascii").upper()
+        extended = text.upper()
     else:
-        base = str(text[:pos], "ascii", errors)
-        extended = str(text[pos+1:], "ascii").upper()
-    return insertion_sort(base, extended, errors)
+        try:
+            base = str(text[:pos], "ascii", errors)
+        except UnicodeDecodeError as exc:
+            raise UnicodeDecodeError("ascii", text, exc.start, exc.end,
+                                     exc.reason) from None
+        extended = text[pos+1:].upper()
+    try:
+        return insertion_sort(base, extended, errors)
+    except UnicodeDecodeError as exc:
+        offset = pos + 1
+        raise UnicodeDecodeError("punycode", text,
+                                 offset+exc.start, offset+exc.end,
+                                 exc.reason) from None
 
 ### Codec APIs
 
@@ -203,7 +219,7 @@ def encode(self, input, errors='strict'):
 
     def decode(self, input, errors='strict'):
         if errors not in ('strict', 'replace', 'ignore'):
-            raise UnicodeError("Unsupported error handling "+errors)
+            raise UnicodeError(f"Unsupported error handling: {errors}")
         res = punycode_decode(input, errors)
         return res, len(input)
 
@@ -214,7 +230,7 @@ def encode(self, input, final=False):
 class IncrementalDecoder(codecs.IncrementalDecoder):
     def decode(self, input, final=False):
         if self.errors not in ('strict', 'replace', 'ignore'):
-            raise UnicodeError("Unsupported error handling "+self.errors)
+            raise UnicodeError(f"Unsupported error handling: {self.errors}")
         return punycode_decode(input, self.errors)
 
 class StreamWriter(Codec,codecs.StreamWriter):
diff --git a/Lib/encodings/undefined.py b/Lib/encodings/undefined.py
index 4690288355c710..082771e1c86677 100644
--- a/Lib/encodings/undefined.py
+++ b/Lib/encodings/undefined.py
@@ -1,6 +1,6 @@
 """ Python 'undefined' Codec
 
-    This codec will always raise a ValueError exception when being
+    This codec will always raise a UnicodeError exception when being
     used. It is intended for use by the site.py file to switch off
     automatic string to Unicode coercion.
 
diff --git a/Lib/encodings/utf_16.py b/Lib/encodings/utf_16.py
index c61248242be8c7..d3b9980026666f 100644
--- a/Lib/encodings/utf_16.py
+++ b/Lib/encodings/utf_16.py
@@ -64,7 +64,7 @@ def _buffer_decode(self, input, errors, final):
             elif byteorder == 1:
                 self.decoder = codecs.utf_16_be_decode
             elif consumed >= 2:
-                raise UnicodeError("UTF-16 stream does not start with BOM")
+                raise UnicodeDecodeError("utf-16", input, 0, 2, "Stream does 
not start with BOM")
             return (output, consumed)
         return self.decoder(input, self.errors, final)
 
@@ -138,7 +138,7 @@ def decode(self, input, errors='strict'):
         elif byteorder == 1:
             self.decode = codecs.utf_16_be_decode
         elif consumed>=2:
-            raise UnicodeError("UTF-16 stream does not start with BOM")
+            raise UnicodeDecodeError("utf-16", input, 0, 2, "Stream does not 
start with BOM")
         return (object, consumed)
 
 ### encodings module API
diff --git a/Lib/encodings/utf_32.py b/Lib/encodings/utf_32.py
index cdf84d14129a62..1924bedbb74c68 100644
--- a/Lib/encodings/utf_32.py
+++ b/Lib/encodings/utf_32.py
@@ -59,7 +59,7 @@ def _buffer_decode(self, input, errors, final):
             elif byteorder == 1:
                 self.decoder = codecs.utf_32_be_decode
             elif consumed >= 4:
-                raise UnicodeError("UTF-32 stream does not start with BOM")
+                raise UnicodeDecodeError("utf-32", input, 0, 4, "Stream does 
not start with BOM")
             return (output, consumed)
         return self.decoder(input, self.errors, final)
 
@@ -132,8 +132,8 @@ def decode(self, input, errors='strict'):
             self.decode = codecs.utf_32_le_decode
         elif byteorder == 1:
             self.decode = codecs.utf_32_be_decode
-        elif consumed>=4:
-            raise UnicodeError("UTF-32 stream does not start with BOM")
+        elif consumed >= 4:
+            raise UnicodeDecodeError("utf-32", input, 0, 4, "Stream does not 
start with BOM")
         return (object, consumed)
 
 ### encodings module API
diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py
index 9585f947877142..fe3776d6dd9337 100644
--- a/Lib/test/test_codecs.py
+++ b/Lib/test/test_codecs.py
@@ -482,11 +482,11 @@ def test_only_one_bom(self):
     def test_badbom(self):
         s = io.BytesIO(4*b"\xff")
         f = codecs.getreader(self.encoding)(s)
-        self.assertRaises(UnicodeError, f.read)
+        self.assertRaises(UnicodeDecodeError, f.read)
 
         s = io.BytesIO(8*b"\xff")
         f = codecs.getreader(self.encoding)(s)
-        self.assertRaises(UnicodeError, f.read)
+        self.assertRaises(UnicodeDecodeError, f.read)
 
     def test_partial(self):
         self.check_partial(
@@ -666,11 +666,11 @@ def test_only_one_bom(self):
     def test_badbom(self):
         s = io.BytesIO(b"\xff\xff")
         f = codecs.getreader(self.encoding)(s)
-        self.assertRaises(UnicodeError, f.read)
+        self.assertRaises(UnicodeDecodeError, f.read)
 
         s = io.BytesIO(b"\xff\xff\xff\xff")
         f = codecs.getreader(self.encoding)(s)
-        self.assertRaises(UnicodeError, f.read)
+        self.assertRaises(UnicodeDecodeError, f.read)
 
     def test_partial(self):
         self.check_partial(
@@ -1356,13 +1356,29 @@ def test_decode(self):
 
     def test_decode_invalid(self):
         testcases = [
-            (b"xn--w&", "strict", UnicodeError()),
+            (b"xn--w&", "strict", UnicodeDecodeError("punycode", b"", 5, 6, 
"")),
+            (b"&egbpdaj6bu4bxfgehfvwxn", "strict", 
UnicodeDecodeError("punycode", b"", 0, 1, "")),
+            (b"egbpdaj6bu&4bx&fgehfvwxn", "strict", 
UnicodeDecodeError("punycode", b"", 10, 11, "")),
+            (b"egbpdaj6bu4bxfgehfvwxn&", "strict", 
UnicodeDecodeError("punycode", b"", 22, 23, "")),
+            (b"\xFFProprostnemluvesky-uyb24dma41a", "strict", 
UnicodeDecodeError("ascii", b"", 0, 1, "")),
+            (b"Pro\xFFprostnemluvesky-uyb24dma41a", "strict", 
UnicodeDecodeError("ascii", b"", 3, 4, "")),
+            (b"Proprost&nemluvesky-uyb24&dma41a", "strict", 
UnicodeDecodeError("punycode", b"", 25, 26, "")),
+            (b"Proprostnemluvesky&-&uyb24dma41a", "strict", 
UnicodeDecodeError("punycode", b"", 20, 21, "")),
+            (b"Proprostnemluvesky-&uyb24dma41a", "strict", 
UnicodeDecodeError("punycode", b"", 19, 20, "")),
+            (b"Proprostnemluvesky-uyb24d&ma41a", "strict", 
UnicodeDecodeError("punycode", b"", 25, 26, "")),
+            (b"Proprostnemluvesky-uyb24dma41a&", "strict", 
UnicodeDecodeError("punycode", b"", 30, 31, "")),
             (b"xn--w&", "ignore", "xn-"),
         ]
         for puny, errors, expected in testcases:
             with self.subTest(puny=puny, errors=errors):
                 if isinstance(expected, Exception):
-                    self.assertRaises(UnicodeError, puny.decode, "punycode", 
errors)
+                    with self.assertRaises(UnicodeDecodeError) as cm:
+                        puny.decode("punycode", errors)
+                    exc = cm.exception
+                    self.assertEqual(exc.encoding, expected.encoding)
+                    self.assertEqual(exc.object, puny)
+                    self.assertEqual(exc.start, expected.start)
+                    self.assertEqual(exc.end, expected.end)
                 else:
                     self.assertEqual(puny.decode("punycode", errors), expected)
 
@@ -1532,7 +1548,7 @@ def test_nameprep(self):
             orig = str(orig, "utf-8", "surrogatepass")
             if prepped is None:
                 # Input contains prohibited characters
-                self.assertRaises(UnicodeError, nameprep, orig)
+                self.assertRaises(UnicodeEncodeError, nameprep, orig)
             else:
                 prepped = str(prepped, "utf-8", "surrogatepass")
                 try:
@@ -1542,6 +1558,23 @@ def test_nameprep(self):
 
 
 class IDNACodecTest(unittest.TestCase):
+
+    invalid_decode_testcases = [
+        (b"\xFFpython.org", UnicodeDecodeError("idna", b"\xFFpython.org", 0, 
1, "")),
+        (b"pyt\xFFhon.org", UnicodeDecodeError("idna", b"pyt\xFFhon.org", 3, 
4, "")),
+        (b"python\xFF.org", UnicodeDecodeError("idna", b"python\xFF.org", 6, 
7, "")),
+        (b"python.\xFForg", UnicodeDecodeError("idna", b"python.\xFForg", 7, 
8, "")),
+        (b"python.o\xFFrg", UnicodeDecodeError("idna", b"python.o\xFFrg", 8, 
9, "")),
+        (b"python.org\xFF", UnicodeDecodeError("idna", b"python.org\xFF", 10, 
11, "")),
+        (b"xn--pythn-&mua.org", UnicodeDecodeError("idna", 
b"xn--pythn-&mua.org", 10, 11, "")),
+        (b"xn--pythn-m&ua.org", UnicodeDecodeError("idna", 
b"xn--pythn-m&ua.org", 11, 12, "")),
+        (b"xn--pythn-mua&.org", UnicodeDecodeError("idna", 
b"xn--pythn-mua&.org", 13, 14, "")),
+    ]
+    invalid_encode_testcases = [
+        (f"foo.{'\xff'*60}", UnicodeEncodeError("idna", f"foo.{'\xff'*60}", 4, 
64, "")),
+        ("あさ.\u034f", UnicodeEncodeError("idna", "あさ.\u034f", 3, 4, "")),
+    ]
+
     def test_builtin_decode(self):
         self.assertEqual(str(b"python.org", "idna"), "python.org")
         self.assertEqual(str(b"python.org.", "idna"), "python.org.")
@@ -1555,16 +1588,38 @@ def test_builtin_decode(self):
         self.assertEqual(str(b"bugs.XN--pythn-mua.org.", "idna"),
                          "bugs.pyth\xf6n.org.")
 
+    def test_builtin_decode_invalid(self):
+        for case, expected in self.invalid_decode_testcases:
+            with self.subTest(case=case, expected=expected):
+                with self.assertRaises(UnicodeDecodeError) as cm:
+                    case.decode("idna")
+                exc = cm.exception
+                self.assertEqual(exc.encoding, expected.encoding)
+                self.assertEqual(exc.object, expected.object)
+                self.assertEqual(exc.start, expected.start, msg=f'reason: 
{exc.reason}')
+                self.assertEqual(exc.end, expected.end)
+
     def test_builtin_encode(self):
         self.assertEqual("python.org".encode("idna"), b"python.org")
         self.assertEqual("python.org.".encode("idna"), b"python.org.")
         self.assertEqual("pyth\xf6n.org".encode("idna"), b"xn--pythn-mua.org")
         self.assertEqual("pyth\xf6n.org.".encode("idna"), 
b"xn--pythn-mua.org.")
 
+    def test_builtin_encode_invalid(self):
+        for case, expected in self.invalid_encode_testcases:
+            with self.subTest(case=case, expected=expected):
+                with self.assertRaises(UnicodeEncodeError) as cm:
+                    case.encode("idna")
+                exc = cm.exception
+                self.assertEqual(exc.encoding, expected.encoding)
+                self.assertEqual(exc.object, expected.object)
+                self.assertEqual(exc.start, expected.start)
+                self.assertEqual(exc.end, expected.end)
+
     def test_builtin_decode_length_limit(self):
-        with self.assertRaisesRegex(UnicodeError, "way too long"):
+        with self.assertRaisesRegex(UnicodeDecodeError, "way too long"):
             (b"xn--016c"+b"a"*1100).decode("idna")
-        with self.assertRaisesRegex(UnicodeError, "too long"):
+        with self.assertRaisesRegex(UnicodeDecodeError, "too long"):
             (b"xn--016c"+b"a"*70).decode("idna")
 
     def test_stream(self):
@@ -1602,6 +1657,39 @@ def test_incremental_decode(self):
         self.assertEqual(decoder.decode(b"rg."), "org.")
         self.assertEqual(decoder.decode(b"", True), "")
 
+    def test_incremental_decode_invalid(self):
+        iterdecode_testcases = [
+            (b"\xFFpython.org", UnicodeDecodeError("idna", b"\xFF", 0, 1, "")),
+            (b"pyt\xFFhon.org", UnicodeDecodeError("idna", b"pyt\xFF", 3, 4, 
"")),
+            (b"python\xFF.org", UnicodeDecodeError("idna", b"python\xFF", 6, 
7, "")),
+            (b"python.\xFForg", UnicodeDecodeError("idna", b"\xFF", 0, 1, "")),
+            (b"python.o\xFFrg", UnicodeDecodeError("idna", b"o\xFF", 1, 2, 
"")),
+            (b"python.org\xFF", UnicodeDecodeError("idna", b"org\xFF", 3, 4, 
"")),
+            (b"xn--pythn-&mua.org", UnicodeDecodeError("idna", 
b"xn--pythn-&mua.", 10, 11, "")),
+            (b"xn--pythn-m&ua.org", UnicodeDecodeError("idna", 
b"xn--pythn-m&ua.", 11, 12, "")),
+            (b"xn--pythn-mua&.org", UnicodeDecodeError("idna", 
b"xn--pythn-mua&.", 13, 14, "")),
+        ]
+        for case, expected in iterdecode_testcases:
+            with self.subTest(case=case, expected=expected):
+                with self.assertRaises(UnicodeDecodeError) as cm:
+                    list(codecs.iterdecode((bytes([c]) for c in case), "idna"))
+                exc = cm.exception
+                self.assertEqual(exc.encoding, expected.encoding)
+                self.assertEqual(exc.object, expected.object)
+                self.assertEqual(exc.start, expected.start)
+                self.assertEqual(exc.end, expected.end)
+
+        decoder = codecs.getincrementaldecoder("idna")()
+        for case, expected in self.invalid_decode_testcases:
+            with self.subTest(case=case, expected=expected):
+                with self.assertRaises(UnicodeDecodeError) as cm:
+                    decoder.decode(case)
+                exc = cm.exception
+                self.assertEqual(exc.encoding, expected.encoding)
+                self.assertEqual(exc.object, expected.object)
+                self.assertEqual(exc.start, expected.start)
+                self.assertEqual(exc.end, expected.end)
+
     def test_incremental_encode(self):
         self.assertEqual(
             b"".join(codecs.iterencode("python.org", "idna")),
@@ -1630,6 +1718,23 @@ def test_incremental_encode(self):
         self.assertEqual(encoder.encode("ample.org."), b"xn--xample-9ta.org.")
         self.assertEqual(encoder.encode("", True), b"")
 
+    def test_incremental_encode_invalid(self):
+        iterencode_testcases = [
+            (f"foo.{'\xff'*60}", UnicodeEncodeError("idna", f"{'\xff'*60}", 0, 
60, "")),
+            ("あさ.\u034f", UnicodeEncodeError("idna", "\u034f", 0, 1, "")),
+        ]
+        for case, expected in iterencode_testcases:
+            with self.subTest(case=case, expected=expected):
+                with self.assertRaises(UnicodeEncodeError) as cm:
+                    list(codecs.iterencode(case, "idna"))
+                exc = cm.exception
+                self.assertEqual(exc.encoding, expected.encoding)
+                self.assertEqual(exc.object, expected.object)
+                self.assertEqual(exc.start, expected.start)
+                self.assertEqual(exc.end, expected.end)
+
+        # codecs.getincrementalencoder.encode() does not throw an error
+
     def test_errors(self):
         """Only supports "strict" error handler"""
         "python.org".encode("idna", "strict")
diff --git a/Lib/test/test_multibytecodec.py b/Lib/test/test_multibytecodec.py
index 6451df14696933..ccdf3a6cdc0dc7 100644
--- a/Lib/test/test_multibytecodec.py
+++ b/Lib/test/test_multibytecodec.py
@@ -303,7 +303,7 @@ def test_setstate_validates_input(self):
         self.assertRaises(TypeError, decoder.setstate, 123)
         self.assertRaises(TypeError, decoder.setstate, ("invalid", 0))
         self.assertRaises(TypeError, decoder.setstate, (b"1234", "invalid"))
-        self.assertRaises(UnicodeError, decoder.setstate, (b"123456789", 0))
+        self.assertRaises(UnicodeDecodeError, decoder.setstate, (b"123456789", 
0))
 
 class Test_StreamReader(unittest.TestCase):
     def test_bug1728403(self):
diff --git 
a/Misc/NEWS.d/next/Library/2024-01-02-22-47-12.gh-issue-85287.ZC5DLj.rst 
b/Misc/NEWS.d/next/Library/2024-01-02-22-47-12.gh-issue-85287.ZC5DLj.rst
new file mode 100644
index 00000000000000..e6d031fbc93e83
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-01-02-22-47-12.gh-issue-85287.ZC5DLj.rst
@@ -0,0 +1,2 @@
+Changes Unicode codecs to return UnicodeEncodeError or UnicodeDecodeError,
+rather than just UnicodeError.
diff --git a/Modules/cjkcodecs/multibytecodec.c 
b/Modules/cjkcodecs/multibytecodec.c
index 2125da437963d2..e5433d7dd85306 100644
--- a/Modules/cjkcodecs/multibytecodec.c
+++ b/Modules/cjkcodecs/multibytecodec.c
@@ -825,8 +825,15 @@ encoder_encode_stateful(MultibyteStatefulEncoderContext 
*ctx,
     if (inpos < datalen) {
         if (datalen - inpos > MAXENCPENDING) {
             /* normal codecs can't reach here */
-            PyErr_SetString(PyExc_UnicodeError,
-                            "pending buffer overflow");
+            PyObject *excobj = PyObject_CallFunction(PyExc_UnicodeEncodeError,
+                                                     "sOnns",
+                                                     ctx->codec->encoding,
+                                                     inbuf,
+                                                     inpos, datalen,
+                                                     "pending buffer 
overflow");
+            if (excobj == NULL) goto errorexit;
+            PyErr_SetObject(PyExc_UnicodeEncodeError, excobj);
+            Py_DECREF(excobj);
             goto errorexit;
         }
         ctx->pending = PyUnicode_Substring(inbuf, inpos, datalen);
@@ -857,7 +864,16 @@ decoder_append_pending(MultibyteStatefulDecoderContext 
*ctx,
     npendings = (Py_ssize_t)(buf->inbuf_end - buf->inbuf);
     if (npendings + ctx->pendingsize > MAXDECPENDING ||
         npendings > PY_SSIZE_T_MAX - ctx->pendingsize) {
-            PyErr_SetString(PyExc_UnicodeError, "pending buffer overflow");
+            Py_ssize_t bufsize = (Py_ssize_t)(buf->inbuf_end - buf->inbuf_top);
+            PyObject *excobj = 
PyUnicodeDecodeError_Create(ctx->codec->encoding,
+                                                           (const char 
*)buf->inbuf_top,
+                                                           bufsize,
+                                                           0,
+                                                           bufsize,
+                                                           "pending buffer 
overflow");
+            if (excobj == NULL) return -1;
+            PyErr_SetObject(PyExc_UnicodeDecodeError, excobj);
+            Py_DECREF(excobj);
             return -1;
     }
     memcpy(ctx->pending + ctx->pendingsize, buf->inbuf, npendings);
@@ -938,7 +954,17 @@ 
_multibytecodec_MultibyteIncrementalEncoder_getstate_impl(MultibyteIncrementalEn
             return NULL;
         }
         if (pendingsize > MAXENCPENDING*4) {
-            PyErr_SetString(PyExc_UnicodeError, "pending buffer too large");
+            PyObject *excobj = PyObject_CallFunction(PyExc_UnicodeEncodeError,
+                                                     "sOnns",
+                                                     self->codec->encoding,
+                                                     self->pending,
+                                                     0, 
PyUnicode_GET_LENGTH(self->pending),
+                                                     "pending buffer too 
large");
+            if (excobj == NULL) {
+                return NULL;
+            }
+            PyErr_SetObject(PyExc_UnicodeEncodeError, excobj);
+            Py_DECREF(excobj);
             return NULL;
         }
         statebytes[0] = (unsigned char)pendingsize;
@@ -1267,7 +1293,13 @@ 
_multibytecodec_MultibyteIncrementalDecoder_setstate_impl(MultibyteIncrementalDe
     }
 
     if (buffersize > MAXDECPENDING) {
-        PyErr_SetString(PyExc_UnicodeError, "pending buffer too large");
+        PyObject *excobj = PyUnicodeDecodeError_Create(self->codec->encoding,
+                                                       
PyBytes_AS_STRING(buffer), buffersize,
+                                                       0, buffersize,
+                                                       "pending buffer too 
large");
+        if (excobj == NULL) return NULL;
+        PyErr_SetObject(PyExc_UnicodeDecodeError, excobj);
+        Py_DECREF(excobj);
         return NULL;
     }
 

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: arch...@mail-archive.com

Reply via email to