https://github.com/python/cpython/commit/ab47892c32e6361f2180e7d86682650f0850c1c4
commit: ab47892c32e6361f2180e7d86682650f0850c1c4
branch: main
author: Serhiy Storchaka <[email protected]>
committer: serhiy-storchaka <[email protected]>
date: 2026-03-19T19:15:30+02:00
summary:

Improve tests for the PyUnicodeWriter C API (GH-146157)

Add tests for corner cases: NULL pointers and out of range values.

files:
M Lib/test/test_capi/test_unicode.py
M Modules/_testcapi/unicode.c

diff --git a/Lib/test/test_capi/test_unicode.py 
b/Lib/test/test_capi/test_unicode.py
index 55120448a8a3ec..5dee25756fe289 100644
--- a/Lib/test/test_capi/test_unicode.py
+++ b/Lib/test/test_capi/test_unicode.py
@@ -1765,13 +1765,15 @@ def test_basic(self):
         writer.write_utf8(b'var', -1)
 
         # test PyUnicodeWriter_WriteChar()
-        writer.write_char('=')
+        writer.write_char(ord('='))
 
         # test PyUnicodeWriter_WriteSubstring()
         writer.write_substring("[long]", 1, 5)
+        # CRASHES writer.write_substring(NULL, 0, 0)
 
         # test PyUnicodeWriter_WriteStr()
         writer.write_str(" value ")
+        # CRASHES writer.write_str(NULL)
 
         # test PyUnicodeWriter_WriteRepr()
         writer.write_repr("repr")
@@ -1786,14 +1788,28 @@ def test_repr_null(self):
         self.assertEqual(writer.finish(),
                          "var=<NULL>")
 
+    def test_write_char(self):
+        writer = self.create_writer(0)
+        writer.write_char(0)
+        writer.write_char(ord('$'))
+        writer.write_char(0x20ac)
+        writer.write_char(0x10_ffff)
+        self.assertRaises(ValueError, writer.write_char, 0x11_0000)
+        self.assertRaises(ValueError, writer.write_char, 0xFFFF_FFFF)
+        self.assertEqual(writer.finish(),
+                         "\0$\u20AC\U0010FFFF")
+
     def test_utf8(self):
         writer = self.create_writer(0)
         writer.write_utf8(b"ascii", -1)
-        writer.write_char('-')
+        writer.write_char(ord('-'))
         writer.write_utf8(b"latin1=\xC3\xA9", -1)
-        writer.write_char('-')
+        writer.write_char(ord('-'))
         writer.write_utf8(b"euro=\xE2\x82\xAC", -1)
-        writer.write_char('.')
+        writer.write_char(ord('.'))
+        writer.write_utf8(NULL, 0)
+        # CRASHES writer.write_utf8(NULL, 1)
+        # CRASHES writer.write_utf8(NULL, -1)
         self.assertEqual(writer.finish(),
                          "ascii-latin1=\xE9-euro=\u20AC.")
 
@@ -1801,6 +1817,9 @@ def test_ascii(self):
         writer = self.create_writer(0)
         writer.write_ascii(b"Hello ", -1)
         writer.write_ascii(b"", 0)
+        writer.write_ascii(NULL, 0)
+        # CRASHES writer.write_ascii(NULL, 1)
+        # CRASHES writer.write_ascii(NULL, -1)
         writer.write_ascii(b"Python! <truncated>", 6)
         self.assertEqual(writer.finish(), "Hello Python")
 
@@ -1817,6 +1836,9 @@ def test_recover_utf8_error(self):
         # write fails with an invalid string
         with self.assertRaises(UnicodeDecodeError):
             writer.write_utf8(b"invalid\xFF", -1)
+        with self.assertRaises(UnicodeDecodeError):
+            s = "truncated\u20AC".encode()
+            writer.write_utf8(s, len(s) - 1)
 
         # retry write with a valid string
         writer.write_utf8(b"valid", -1)
@@ -1828,13 +1850,19 @@ def test_decode_utf8(self):
         # test PyUnicodeWriter_DecodeUTF8Stateful()
         writer = self.create_writer(0)
         writer.decodeutf8stateful(b"ign\xFFore", -1, b"ignore")
-        writer.write_char('-')
+        writer.write_char(ord('-'))
         writer.decodeutf8stateful(b"replace\xFF", -1, b"replace")
-        writer.write_char('-')
+        writer.write_char(ord('-'))
 
         # incomplete trailing UTF-8 sequence
         writer.decodeutf8stateful(b"incomplete\xC3", -1, b"replace")
 
+        writer.decodeutf8stateful(NULL, 0, b"replace")
+        # CRASHES writer.decodeutf8stateful(NULL, 1, b"replace")
+        # CRASHES writer.decodeutf8stateful(NULL, -1, b"replace")
+        with self.assertRaises(UnicodeDecodeError):
+            writer.decodeutf8stateful(b"default\xFF", -1, NULL)
+
         self.assertEqual(writer.finish(),
                          "ignore-replace\uFFFD-incomplete\uFFFD")
 
@@ -1845,12 +1873,12 @@ def test_decode_utf8_consumed(self):
         # valid string
         consumed = writer.decodeutf8stateful(b"text", -1, b"strict", True)
         self.assertEqual(consumed, 4)
-        writer.write_char('-')
+        writer.write_char(ord('-'))
 
         # non-ASCII
         consumed = writer.decodeutf8stateful(b"\xC3\xA9-\xE2\x82\xAC", 6, 
b"strict", True)
         self.assertEqual(consumed, 6)
-        writer.write_char('-')
+        writer.write_char(ord('-'))
 
         # invalid UTF-8 (consumed is 0 on error)
         with self.assertRaises(UnicodeDecodeError):
@@ -1859,54 +1887,92 @@ def test_decode_utf8_consumed(self):
         # ignore error handler
         consumed = writer.decodeutf8stateful(b"more\xFF", -1, b"ignore", True)
         self.assertEqual(consumed, 5)
-        writer.write_char('-')
+        writer.write_char(ord('-'))
 
         # incomplete trailing UTF-8 sequence
         consumed = writer.decodeutf8stateful(b"incomplete\xC3", -1, b"ignore", 
True)
         self.assertEqual(consumed, 10)
+        writer.write_char(ord('-'))
 
-        self.assertEqual(writer.finish(), "text-\xE9-\u20AC-more-incomplete")
+        consumed = writer.decodeutf8stateful(NULL, 0, b"replace", True)
+        self.assertEqual(consumed, 0)
+        # CRASHES writer.decodeutf8stateful(NULL, 1, b"replace", True)
+        # CRASHES writer.decodeutf8stateful(NULL, -1, b"replace", True)
+        consumed = writer.decodeutf8stateful(b"default\xC3", -1, NULL, True)
+        self.assertEqual(consumed, 7)
+
+        self.assertEqual(writer.finish(), 
"text-\xE9-\u20AC-more-incomplete-default")
 
     def test_widechar(self):
+        from _testcapi import SIZEOF_WCHAR_T
+
+        if SIZEOF_WCHAR_T == 2:
+            encoding = 'utf-16le' if sys.byteorder == 'little' else 'utf-16be'
+        elif SIZEOF_WCHAR_T == 4:
+            encoding = 'utf-32le' if sys.byteorder == 'little' else 'utf-32be'
+
         writer = self.create_writer(0)
-        writer.write_widechar("latin1=\xE9")
-        writer.write_widechar("-")
-        writer.write_widechar("euro=\u20AC")
-        writer.write_char("-")
-        writer.write_widechar("max=\U0010ffff")
-        writer.write_char('.')
+        writer.write_widechar("latin1=\xE9".encode(encoding))
+        writer.write_char(ord("-"))
+        writer.write_widechar("euro=\u20AC".encode(encoding))
+        writer.write_char(ord("-"))
+        writer.write_widechar("max=\U0010ffff".encode(encoding))
+        writer.write_char(ord("-"))
+        writer.write_widechar("zeroes=".encode(encoding).ljust(SIZEOF_WCHAR_T 
* 10, b'\0'),
+                              10)
+        writer.write_char(ord('.'))
+
+        if SIZEOF_WCHAR_T == 4:
+            invalid = (b'\x00\x00\x11\x00' if sys.byteorder == 'little' else
+                       b'\x00\x11\x00\x00')
+            with self.assertRaises(ValueError):
+                writer.write_widechar("invalid=".encode(encoding) + invalid)
+        writer.write_widechar(b'', -5)
+        writer.write_widechar(NULL, 0)
+        # CRASHES writer.write_widechar(NULL, 1)
+        # CRASHES writer.write_widechar(NULL, -1)
+
         self.assertEqual(writer.finish(),
-                         "latin1=\xE9-euro=\u20AC-max=\U0010ffff.")
+                         
"latin1=\xE9-euro=\u20AC-max=\U0010ffff-zeroes=\0\0\0.")
 
     def test_ucs4(self):
+        encoding = 'utf-32le' if sys.byteorder == 'little' else 'utf-32be'
+
         writer = self.create_writer(0)
-        writer.write_ucs4("ascii IGNORED", 5)
-        writer.write_char("-")
-        writer.write_ucs4("latin1=\xe9", 8)
-        writer.write_char("-")
-        writer.write_ucs4("euro=\u20ac", 6)
-        writer.write_char("-")
-        writer.write_ucs4("max=\U0010ffff", 5)
-        writer.write_char(".")
+        writer.write_ucs4("ascii IGNORED".encode(encoding), 5)
+        writer.write_char(ord("-"))
+        writer.write_ucs4("latin1=\xe9".encode(encoding))
+        writer.write_char(ord("-"))
+        writer.write_ucs4("euro=\u20ac".encode(encoding))
+        writer.write_char(ord("-"))
+        writer.write_ucs4("max=\U0010ffff".encode(encoding))
+        writer.write_char(ord("."))
         self.assertEqual(writer.finish(),
                          "ascii-latin1=\xE9-euro=\u20AC-max=\U0010ffff.")
 
         # Test some special characters
         writer = self.create_writer(0)
         # Lone surrogate character
-        writer.write_ucs4("lone\uDC80", 5)
-        writer.write_char("-")
+        writer.write_ucs4("lone\uDC80".encode(encoding, 'surrogatepass'))
+        writer.write_char(ord("-"))
         # Surrogate pair
-        writer.write_ucs4("pair\uDBFF\uDFFF", 5)
-        writer.write_char("-")
-        writer.write_ucs4("null[\0]", 7)
+        writer.write_ucs4("pair\uD83D\uDC0D".encode(encoding, 'surrogatepass'))
+        writer.write_char(ord("-"))
+        writer.write_ucs4("null[\0]".encode(encoding), 7)
+        invalid = (b'\x00\x00\x11\x00' if sys.byteorder == 'little' else
+                   b'\x00\x11\x00\x00')
+        # CRASHES writer.write_ucs4("invalid".encode(encoding) + invalid)
+        writer.write_ucs4(NULL, 0)
+        # CRASHES writer.write_ucs4(NULL, 1)
         self.assertEqual(writer.finish(),
-                         "lone\udc80-pair\udbff-null[\0]")
+                         "lone\udc80-pair\ud83d\udc0d-null[\x00]")
 
         # invalid size
         writer = self.create_writer(0)
         with self.assertRaises(ValueError):
-            writer.write_ucs4("text", -1)
+            writer.write_ucs4("text".encode(encoding), -1)
+        self.assertRaises(ValueError, writer.write_ucs4, b'', -1)
+        self.assertRaises(ValueError, writer.write_ucs4, NULL, -1)
 
     def test_substring_empty(self):
         writer = self.create_writer(0)
@@ -1932,7 +1998,7 @@ def test_format(self):
         from ctypes import c_int
         writer = self.create_writer(0)
         self.writer_format(writer, b'%s %i', b'abc', c_int(123))
-        writer.write_char('.')
+        writer.write_char(ord('.'))
         self.assertEqual(writer.finish(), 'abc 123.')
 
     def test_recover_error(self):
diff --git a/Modules/_testcapi/unicode.c b/Modules/_testcapi/unicode.c
index 668adc5085b4fe..915c9230f66b52 100644
--- a/Modules/_testcapi/unicode.c
+++ b/Modules/_testcapi/unicode.c
@@ -301,16 +301,12 @@ writer_write_char(PyObject *self_raw, PyObject *args)
         return NULL;
     }
 
-    PyObject *str;
-    if (!PyArg_ParseTuple(args, "U", &str)) {
+    unsigned int ch;
+    if (!PyArg_ParseTuple(args, "I", &ch)) {
         return NULL;
     }
-    if (PyUnicode_GET_LENGTH(str) != 1) {
-        PyErr_SetString(PyExc_ValueError, "expect a single character");
-    }
-    Py_UCS4 ch = PyUnicode_READ_CHAR(str, 0);
 
-    if (PyUnicodeWriter_WriteChar(self->writer, ch) < 0) {
+    if (PyUnicodeWriter_WriteChar(self->writer, (Py_UCS4)ch) < 0) {
         return NULL;
     }
     Py_RETURN_NONE;
@@ -325,9 +321,9 @@ writer_write_utf8(PyObject *self_raw, PyObject *args)
         return NULL;
     }
 
-    char *str;
-    Py_ssize_t size;
-    if (!PyArg_ParseTuple(args, "yn", &str, &size)) {
+    const char *str;
+    Py_ssize_t bsize, size;
+    if (!PyArg_ParseTuple(args, "z#n", &str, &bsize, &size)) {
         return NULL;
     }
 
@@ -346,9 +342,9 @@ writer_write_ascii(PyObject *self_raw, PyObject *args)
         return NULL;
     }
 
-    char *str;
-    Py_ssize_t size;
-    if (!PyArg_ParseTuple(args, "yn", &str, &size)) {
+    const char *str;
+    Py_ssize_t bsize, size;
+    if (!PyArg_ParseTuple(args, "z#n", &str, &bsize, &size)) {
         return NULL;
     }
 
@@ -367,19 +363,23 @@ writer_write_widechar(PyObject *self_raw, PyObject *args)
         return NULL;
     }
 
-    PyObject *str;
-    if (!PyArg_ParseTuple(args, "U", &str)) {
-        return NULL;
-    }
+    const char *s;
+    Py_ssize_t bsize;
+    Py_ssize_t size = -100;
 
-    Py_ssize_t size;
-    wchar_t *wstr = PyUnicode_AsWideCharString(str, &size);
-    if (wstr == NULL) {
+    if (!PyArg_ParseTuple(args, "z#|n", &s, &bsize, &size)) {
         return NULL;
     }
+    if (size == -100) {
+        if (bsize % SIZEOF_WCHAR_T) {
+            PyErr_SetString(PyExc_AssertionError,
+                            "invalid size in writer.write_widechar()");
+            return NULL;
+        }
+        size = bsize / SIZEOF_WCHAR_T;
+    }
 
-    int res = PyUnicodeWriter_WriteWideChar(self->writer, wstr, size);
-    PyMem_Free(wstr);
+    int res = PyUnicodeWriter_WriteWideChar(self->writer, (const wchar_t *)s, 
size);
     if (res < 0) {
         return NULL;
     }
@@ -395,21 +395,23 @@ writer_write_ucs4(PyObject *self_raw, PyObject *args)
         return NULL;
     }
 
-    PyObject *str;
-    Py_ssize_t size;
-    if (!PyArg_ParseTuple(args, "Un", &str, &size)) {
-        return NULL;
-    }
-    Py_ssize_t len = PyUnicode_GET_LENGTH(str);
-    size = Py_MIN(size, len);
+    const char *s;
+    Py_ssize_t bsize;
+    Py_ssize_t size = -100;
 
-    Py_UCS4 *ucs4 = PyUnicode_AsUCS4Copy(str);
-    if (ucs4 == NULL) {
+    if (!PyArg_ParseTuple(args, "z#|n", &s, &bsize, &size)) {
         return NULL;
     }
+    if (size == -100) {
+        if (bsize % sizeof(Py_UCS4)) {
+            PyErr_SetString(PyExc_AssertionError,
+                            "invalid size in writer.write_ucs4()");
+            return NULL;
+        }
+        size = bsize / sizeof(Py_UCS4);
+    }
 
-    int res = PyUnicodeWriter_WriteUCS4(self->writer, ucs4, size);
-    PyMem_Free(ucs4);
+    int res = PyUnicodeWriter_WriteUCS4(self->writer, (const Py_UCS4 *)s, 
size);
     if (res < 0) {
         return NULL;
     }
@@ -418,18 +420,14 @@ writer_write_ucs4(PyObject *self_raw, PyObject *args)
 
 
 static PyObject*
-writer_write_str(PyObject *self_raw, PyObject *args)
+writer_write_str(PyObject *self_raw, PyObject *obj)
 {
     WriterObject *self = (WriterObject *)self_raw;
     if (writer_check(self) < 0) {
         return NULL;
     }
 
-    PyObject *obj;
-    if (!PyArg_ParseTuple(args, "O", &obj)) {
-        return NULL;
-    }
-
+    NULLABLE(obj);
     if (PyUnicodeWriter_WriteStr(self->writer, obj) < 0) {
         return NULL;
     }
@@ -438,19 +436,14 @@ writer_write_str(PyObject *self_raw, PyObject *args)
 
 
 static PyObject*
-writer_write_repr(PyObject *self_raw, PyObject *args)
+writer_write_repr(PyObject *self_raw, PyObject *obj)
 {
     WriterObject *self = (WriterObject *)self_raw;
     if (writer_check(self) < 0) {
         return NULL;
     }
 
-    PyObject *obj;
-    if (!PyArg_ParseTuple(args, "O", &obj)) {
-        return NULL;
-    }
     NULLABLE(obj);
-
     if (PyUnicodeWriter_WriteRepr(self->writer, obj) < 0) {
         return NULL;
     }
@@ -468,9 +461,10 @@ writer_write_substring(PyObject *self_raw, PyObject *args)
 
     PyObject *str;
     Py_ssize_t start, end;
-    if (!PyArg_ParseTuple(args, "Unn", &str, &start, &end)) {
+    if (!PyArg_ParseTuple(args, "Onn", &str, &start, &end)) {
         return NULL;
     }
+    NULLABLE(str);
 
     if (PyUnicodeWriter_WriteSubstring(self->writer, str, start, end) < 0) {
         return NULL;
@@ -488,10 +482,10 @@ writer_decodeutf8stateful(PyObject *self_raw, PyObject 
*args)
     }
 
     const char *str;
-    Py_ssize_t len;
+    Py_ssize_t bsize, len;
     const char *errors;
     int use_consumed = 0;
-    if (!PyArg_ParseTuple(args, "yny|i", &str, &len, &errors, &use_consumed)) {
+    if (!PyArg_ParseTuple(args, "z#nz#|p", &str, &bsize, &len, &errors, 
&bsize, &use_consumed)) {
         return NULL;
     }
 
@@ -544,8 +538,8 @@ static PyMethodDef writer_methods[] = {
     {"write_ascii", _PyCFunction_CAST(writer_write_ascii), METH_VARARGS},
     {"write_widechar", _PyCFunction_CAST(writer_write_widechar), METH_VARARGS},
     {"write_ucs4", _PyCFunction_CAST(writer_write_ucs4), METH_VARARGS},
-    {"write_str", _PyCFunction_CAST(writer_write_str), METH_VARARGS},
-    {"write_repr", _PyCFunction_CAST(writer_write_repr), METH_VARARGS},
+    {"write_str", _PyCFunction_CAST(writer_write_str), METH_O},
+    {"write_repr", _PyCFunction_CAST(writer_write_repr), METH_O},
     {"write_substring", _PyCFunction_CAST(writer_write_substring), 
METH_VARARGS},
     {"decodeutf8stateful", _PyCFunction_CAST(writer_decodeutf8stateful), 
METH_VARARGS},
     {"get_pointer", _PyCFunction_CAST(writer_get_pointer), METH_VARARGS},

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: [email protected]

Reply via email to