https://github.com/python/cpython/commit/ad3eac1963a5f195ef9b2c1dbb5e44fa3cce4c72
commit: ad3eac1963a5f195ef9b2c1dbb5e44fa3cce4c72
branch: main
author: Serhiy Storchaka <storch...@gmail.com>
committer: serhiy-storchaka <storch...@gmail.com>
date: 2024-10-17T15:46:59Z
summary:

gh-52551: Fix encoding issues in strftime() (GH-125193)

Fix time.strftime(), the strftime() method and formatting of the
datetime classes datetime, date and time.

* Characters not encodable in the current locale are now acceptable in
  the format string.
* Surrogate pairs and sequence of surrogatescape-encoded bytes are no
  longer recombinated.
* Embedded null character no longer terminates the format string.

This fixes also gh-78662 and gh-124531.

files:
A Misc/NEWS.d/next/Library/2024-10-09-17-07-33.gh-issue-52551.PBakSY.rst
M Lib/test/datetimetester.py
M Lib/test/test_time.py
M Modules/_datetimemodule.c
M Modules/timemodule.c

diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py
index c81408b344968d..dbe25ef57dea83 100644
--- a/Lib/test/datetimetester.py
+++ b/Lib/test/datetimetester.py
@@ -2949,11 +2949,32 @@ def test_more_strftime(self):
             self.assertEqual(t.strftime("%z"), "-0200" + z)
             self.assertEqual(t.strftime("%:z"), "-02:00:" + z)
 
-        # bpo-34482: Check that surrogates don't cause a crash.
-        try:
-            t.strftime('%y\ud800%m %H\ud800%M')
-        except UnicodeEncodeError:
-            pass
+    def test_strftime_special(self):
+        t = self.theclass(2004, 12, 31, 6, 22, 33, 47)
+        s1 = t.strftime('%c')
+        s2 = t.strftime('%B')
+        # gh-52551, gh-78662: Unicode strings should pass through strftime,
+        # independently from locale.
+        self.assertEqual(t.strftime('\U0001f40d'), '\U0001f40d')
+        self.assertEqual(t.strftime('\U0001f4bb%c\U0001f40d%B'), 
f'\U0001f4bb{s1}\U0001f40d{s2}')
+        self.assertEqual(t.strftime('%c\U0001f4bb%B\U0001f40d'), 
f'{s1}\U0001f4bb{s2}\U0001f40d')
+        # Lone surrogates should pass through.
+        self.assertEqual(t.strftime('\ud83d'), '\ud83d')
+        self.assertEqual(t.strftime('\udc0d'), '\udc0d')
+        self.assertEqual(t.strftime('\ud83d%c\udc0d%B'), 
f'\ud83d{s1}\udc0d{s2}')
+        self.assertEqual(t.strftime('%c\ud83d%B\udc0d'), 
f'{s1}\ud83d{s2}\udc0d')
+        self.assertEqual(t.strftime('%c\udc0d%B\ud83d'), 
f'{s1}\udc0d{s2}\ud83d')
+        # Surrogate pairs should not recombine.
+        self.assertEqual(t.strftime('\ud83d\udc0d'), '\ud83d\udc0d')
+        self.assertEqual(t.strftime('%c\ud83d\udc0d%B'), 
f'{s1}\ud83d\udc0d{s2}')
+        # Surrogate-escaped bytes should not recombine.
+        self.assertEqual(t.strftime('\udcf0\udc9f\udc90\udc8d'), 
'\udcf0\udc9f\udc90\udc8d')
+        self.assertEqual(t.strftime('%c\udcf0\udc9f\udc90\udc8d%B'), 
f'{s1}\udcf0\udc9f\udc90\udc8d{s2}')
+        # gh-124531: The null character should not terminate the format string.
+        self.assertEqual(t.strftime('\0'), '\0')
+        self.assertEqual(t.strftime('\0'*1000), '\0'*1000)
+        self.assertEqual(t.strftime('\0%c\0%B'), f'\0{s1}\0{s2}')
+        self.assertEqual(t.strftime('%c\0%B\0'), f'{s1}\0{s2}\0')
 
     def test_extract(self):
         dt = self.theclass(2002, 3, 4, 18, 45, 3, 1234)
@@ -3736,6 +3757,33 @@ def test_strftime(self):
         # gh-85432: The parameter was named "fmt" in the pure-Python impl.
         t.strftime(format="%f")
 
+    def test_strftime_special(self):
+        t = self.theclass(1, 2, 3, 4)
+        s1 = t.strftime('%I%p%Z')
+        s2 = t.strftime('%X')
+        # gh-52551, gh-78662: Unicode strings should pass through strftime,
+        # independently from locale.
+        self.assertEqual(t.strftime('\U0001f40d'), '\U0001f40d')
+        self.assertEqual(t.strftime('\U0001f4bb%I%p%Z\U0001f40d%X'), 
f'\U0001f4bb{s1}\U0001f40d{s2}')
+        self.assertEqual(t.strftime('%I%p%Z\U0001f4bb%X\U0001f40d'), 
f'{s1}\U0001f4bb{s2}\U0001f40d')
+        # Lone surrogates should pass through.
+        self.assertEqual(t.strftime('\ud83d'), '\ud83d')
+        self.assertEqual(t.strftime('\udc0d'), '\udc0d')
+        self.assertEqual(t.strftime('\ud83d%I%p%Z\udc0d%X'), 
f'\ud83d{s1}\udc0d{s2}')
+        self.assertEqual(t.strftime('%I%p%Z\ud83d%X\udc0d'), 
f'{s1}\ud83d{s2}\udc0d')
+        self.assertEqual(t.strftime('%I%p%Z\udc0d%X\ud83d'), 
f'{s1}\udc0d{s2}\ud83d')
+        # Surrogate pairs should not recombine.
+        self.assertEqual(t.strftime('\ud83d\udc0d'), '\ud83d\udc0d')
+        self.assertEqual(t.strftime('%I%p%Z\ud83d\udc0d%X'), 
f'{s1}\ud83d\udc0d{s2}')
+        # Surrogate-escaped bytes should not recombine.
+        self.assertEqual(t.strftime('\udcf0\udc9f\udc90\udc8d'), 
'\udcf0\udc9f\udc90\udc8d')
+        self.assertEqual(t.strftime('%I%p%Z\udcf0\udc9f\udc90\udc8d%X'), 
f'{s1}\udcf0\udc9f\udc90\udc8d{s2}')
+        # gh-124531: The null character should not terminate the format string.
+        self.assertEqual(t.strftime('\0'), '\0')
+        self.assertEqual(t.strftime('\0'*1000), '\0'*1000)
+        self.assertEqual(t.strftime('\0%I%p%Z\0%X'), f'\0{s1}\0{s2}')
+        self.assertEqual(t.strftime('%I%p%Z\0%X\0'), f'{s1}\0{s2}\0')
+
     def test_format(self):
         t = self.theclass(1, 2, 3, 4)
         self.assertEqual(t.__format__(''), str(t))
@@ -4259,9 +4307,8 @@ def tzname(self, dt): return self.tz
         self.assertRaises(TypeError, t.strftime, "%Z")
 
         # Issue #6697:
-        if '_Fast' in self.__class__.__name__:
-            Badtzname.tz = '\ud800'
-            self.assertRaises(ValueError, t.strftime, "%Z")
+        Badtzname.tz = '\ud800'
+        self.assertEqual(t.strftime("%Z"), '\ud800')
 
     def test_hash_edge_cases(self):
         # Offsets that overflow a basic time.
diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py
index 27c0f51acc58ab..f8b99a9b6a63f5 100644
--- a/Lib/test/test_time.py
+++ b/Lib/test/test_time.py
@@ -181,8 +181,33 @@ def test_strftime(self):
                 self.fail('conversion specifier: %r failed.' % format)
 
         self.assertRaises(TypeError, time.strftime, b'%S', tt)
-        # embedded null character
-        self.assertRaises(ValueError, time.strftime, '%S\0', tt)
+
+    def test_strftime_special(self):
+        tt = time.gmtime(self.t)
+        s1 = time.strftime('%c', tt)
+        s2 = time.strftime('%B', tt)
+        # gh-52551, gh-78662: Unicode strings should pass through strftime,
+        # independently from locale.
+        self.assertEqual(time.strftime('\U0001f40d', tt), '\U0001f40d')
+        self.assertEqual(time.strftime('\U0001f4bb%c\U0001f40d%B', tt), 
f'\U0001f4bb{s1}\U0001f40d{s2}')
+        self.assertEqual(time.strftime('%c\U0001f4bb%B\U0001f40d', tt), 
f'{s1}\U0001f4bb{s2}\U0001f40d')
+        # Lone surrogates should pass through.
+        self.assertEqual(time.strftime('\ud83d', tt), '\ud83d')
+        self.assertEqual(time.strftime('\udc0d', tt), '\udc0d')
+        self.assertEqual(time.strftime('\ud83d%c\udc0d%B', tt), 
f'\ud83d{s1}\udc0d{s2}')
+        self.assertEqual(time.strftime('%c\ud83d%B\udc0d', tt), 
f'{s1}\ud83d{s2}\udc0d')
+        self.assertEqual(time.strftime('%c\udc0d%B\ud83d', tt), 
f'{s1}\udc0d{s2}\ud83d')
+        # Surrogate pairs should not recombine.
+        self.assertEqual(time.strftime('\ud83d\udc0d', tt), '\ud83d\udc0d')
+        self.assertEqual(time.strftime('%c\ud83d\udc0d%B', tt), 
f'{s1}\ud83d\udc0d{s2}')
+        # Surrogate-escaped bytes should not recombine.
+        self.assertEqual(time.strftime('\udcf0\udc9f\udc90\udc8d', tt), 
'\udcf0\udc9f\udc90\udc8d')
+        self.assertEqual(time.strftime('%c\udcf0\udc9f\udc90\udc8d%B', tt), 
f'{s1}\udcf0\udc9f\udc90\udc8d{s2}')
+        # gh-124531: The null character should not terminate the format string.
+        self.assertEqual(time.strftime('\0', tt), '\0')
+        self.assertEqual(time.strftime('\0'*1000, tt), '\0'*1000)
+        self.assertEqual(time.strftime('\0%c\0%B', tt), f'\0{s1}\0{s2}')
+        self.assertEqual(time.strftime('%c\0%B\0', tt), f'{s1}\0{s2}\0')
 
     def _bounds_checking(self, func):
         # Make sure that strftime() checks the bounds of the various parts
diff --git 
a/Misc/NEWS.d/next/Library/2024-10-09-17-07-33.gh-issue-52551.PBakSY.rst 
b/Misc/NEWS.d/next/Library/2024-10-09-17-07-33.gh-issue-52551.PBakSY.rst
new file mode 100644
index 00000000000000..edc9ac5bb23117
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-10-09-17-07-33.gh-issue-52551.PBakSY.rst
@@ -0,0 +1,8 @@
+Fix encoding issues in :func:`time.strftime`, the
+:meth:`~datetime.datetime.strftime` method of the :mod:`datetime` classes
+:class:`~datetime.datetime`, :class:`~datetime.date` and
+:class:`~datetime.time` and formatting of these classes. Characters not
+encodable in the current locale are now acceptable in the format string.
+Surrogate pairs and sequence of surrogatescape-encoded bytes are no longer
+recombinated. Embedded null character no longer terminates the format
+string.
diff --git a/Modules/_datetimemodule.c b/Modules/_datetimemodule.c
index 2ba46cddb4f558..e1bb98fcf05862 100644
--- a/Modules/_datetimemodule.c
+++ b/Modules/_datetimemodule.c
@@ -1747,7 +1747,7 @@ make_somezreplacement(PyObject *object, char *sep, 
PyObject *tzinfoarg)
     PyObject *tzinfo = get_tzinfo_member(object);
 
     if (tzinfo == Py_None || tzinfo == NULL) {
-        return PyBytes_FromStringAndSize(NULL, 0);
+        return PyUnicode_FromStringAndSize(NULL, 0);
     }
 
     assert(tzinfoarg != NULL);
@@ -1758,7 +1758,7 @@ make_somezreplacement(PyObject *object, char *sep, 
PyObject *tzinfoarg)
                          tzinfoarg) < 0)
         return NULL;
 
-    return PyBytes_FromStringAndSize(buf, strlen(buf));
+    return PyUnicode_FromString(buf);
 }
 
 static PyObject *
@@ -1815,7 +1815,7 @@ make_freplacement(PyObject *object)
     else
         sprintf(freplacement, "%06d", 0);
 
-    return PyBytes_FromStringAndSize(freplacement, strlen(freplacement));
+    return PyUnicode_FromString(freplacement);
 }
 
 /* I sure don't want to reproduce the strftime code from the time module,
@@ -1836,94 +1836,60 @@ wrap_strftime(PyObject *object, PyObject *format, 
PyObject *timetuple,
     PyObject *Zreplacement = NULL;      /* py string, replacement for %Z */
     PyObject *freplacement = NULL;      /* py string, replacement for %f */
 
-    const char *pin;            /* pointer to next char in input format */
-    Py_ssize_t flen;            /* length of input format */
-    char ch;                    /* next char in input format */
-
-    PyObject *newfmt = NULL;            /* py string, the output format */
-    char *pnew;         /* pointer to available byte in output format */
-    size_t totalnew;            /* number bytes total in output format buffer,
-                               exclusive of trailing \0 */
-    size_t usednew;     /* number bytes used so far in output format buffer */
-
-    const char *ptoappend;      /* ptr to string to append to output buffer */
-    Py_ssize_t ntoappend;       /* # of bytes to append to output buffer */
-
-#ifdef Py_NORMALIZE_CENTURY
-    /* Buffer of maximum size of formatted year permitted by long. */
-    char buf[SIZEOF_LONG * 5 / 2 + 2
-#ifdef Py_STRFTIME_C99_SUPPORT
-    /* Need 6 more to accommodate dashes, 2-digit month and day for %F. */
-             + 6
-#endif
-    ];
-#endif
-
     assert(object && format && timetuple);
     assert(PyUnicode_Check(format));
-    /* Convert the input format to a C string and size */
-    pin = PyUnicode_AsUTF8AndSize(format, &flen);
-    if (!pin)
-        return NULL;
 
     PyObject *strftime = _PyImport_GetModuleAttrString("time", "strftime");
     if (strftime == NULL) {
-        goto Done;
+        return NULL;
     }
 
     /* Scan the input format, looking for %z/%Z/%f escapes, building
      * a new format.  Since computing the replacements for those codes
      * is expensive, don't unless they're actually used.
      */
-    if (flen > INT_MAX - 1) {
-        PyErr_NoMemory();
-        goto Done;
-    }
-
-    totalnew = flen + 1;        /* realistic if no %z/%Z */
-    newfmt = PyBytes_FromStringAndSize(NULL, totalnew);
-    if (newfmt == NULL) goto Done;
-    pnew = PyBytes_AsString(newfmt);
-    usednew = 0;
 
-    while ((ch = *pin++) != '\0') {
-        if (ch != '%') {
-            ptoappend = pin - 1;
-            ntoappend = 1;
+    _PyUnicodeWriter writer;
+    _PyUnicodeWriter_Init(&writer);
+    writer.overallocate = 1;
+
+    Py_ssize_t flen = PyUnicode_GET_LENGTH(format);
+    Py_ssize_t i = 0;
+    Py_ssize_t start = 0;
+    Py_ssize_t end = 0;
+    while (i != flen) {
+        i = PyUnicode_FindChar(format, '%', i, flen, 1);
+        if (i < 0) {
+            assert(!PyErr_Occurred());
+            break;
         }
-        else if ((ch = *pin++) == '\0') {
-        /* Null byte follows %, copy only '%'.
-         *
-         * Back the pin up one char so that we catch the null check
-         * the next time through the loop.*/
-            pin--;
-            ptoappend = pin - 1;
-            ntoappend = 1;
+        end = i;
+        i++;
+        if (i == flen) {
+            break;
         }
+        Py_UCS4 ch = PyUnicode_READ_CHAR(format, i);
+        i++;
         /* A % has been seen and ch is the character after it. */
-        else if (ch == 'z') {
+        PyObject *replacement = NULL;
+        if (ch == 'z') {
             /* %z -> +HHMM */
             if (zreplacement == NULL) {
                 zreplacement = make_somezreplacement(object, "", tzinfoarg);
                 if (zreplacement == NULL)
-                    goto Done;
+                    goto Error;
             }
-            assert(zreplacement != NULL);
-            assert(PyBytes_Check(zreplacement));
-            ptoappend = PyBytes_AS_STRING(zreplacement);
-            ntoappend = PyBytes_GET_SIZE(zreplacement);
+            replacement = zreplacement;
         }
-        else if (ch == ':' && *pin == 'z' && pin++) {
+        else if (ch == ':' && i < flen && PyUnicode_READ_CHAR(format, i) == 
'z') {
             /* %:z -> +HH:MM */
+            i++;
             if (colonzreplacement == NULL) {
                 colonzreplacement = make_somezreplacement(object, ":", 
tzinfoarg);
                 if (colonzreplacement == NULL)
-                    goto Done;
+                    goto Error;
             }
-            assert(colonzreplacement != NULL);
-            assert(PyBytes_Check(colonzreplacement));
-            ptoappend = PyBytes_AS_STRING(colonzreplacement);
-            ntoappend = PyBytes_GET_SIZE(colonzreplacement);
+            replacement = colonzreplacement;
         }
         else if (ch == 'Z') {
             /* format tzname */
@@ -1931,26 +1897,18 @@ wrap_strftime(PyObject *object, PyObject *format, 
PyObject *timetuple,
                 Zreplacement = make_Zreplacement(object,
                                                  tzinfoarg);
                 if (Zreplacement == NULL)
-                    goto Done;
+                    goto Error;
             }
-            assert(Zreplacement != NULL);
-            assert(PyUnicode_Check(Zreplacement));
-            ptoappend = PyUnicode_AsUTF8AndSize(Zreplacement,
-                                                  &ntoappend);
-            if (ptoappend == NULL)
-                goto Done;
+            replacement = Zreplacement;
         }
         else if (ch == 'f') {
             /* format microseconds */
             if (freplacement == NULL) {
                 freplacement = make_freplacement(object);
                 if (freplacement == NULL)
-                    goto Done;
+                    goto Error;
             }
-            assert(freplacement != NULL);
-            assert(PyBytes_Check(freplacement));
-            ptoappend = PyBytes_AS_STRING(freplacement);
-            ntoappend = PyBytes_GET_SIZE(freplacement);
+            replacement = freplacement;
         }
 #ifdef Py_NORMALIZE_CENTURY
         else if (ch == 'Y' || ch == 'G'
@@ -1961,100 +1919,102 @@ wrap_strftime(PyObject *object, PyObject *format, 
PyObject *timetuple,
             /* 0-pad year with century as necessary */
             PyObject *item = PySequence_GetItem(timetuple, 0);
             if (item == NULL) {
-                goto Done;
+                goto Error;
             }
             long year_long = PyLong_AsLong(item);
             Py_DECREF(item);
             if (year_long == -1 && PyErr_Occurred()) {
-                goto Done;
+                goto Error;
             }
             /* Note that datetime(1000, 1, 1).strftime('%G') == '1000' so year
                1000 for %G can go on the fast path. */
             if (year_long >= 1000) {
-                goto PassThrough;
+                continue;
             }
             if (ch == 'G') {
                 PyObject *year_str = PyObject_CallFunction(strftime, "sO",
                                                            "%G", timetuple);
                 if (year_str == NULL) {
-                    goto Done;
+                    goto Error;
                 }
                 PyObject *year = PyNumber_Long(year_str);
                 Py_DECREF(year_str);
                 if (year == NULL) {
-                    goto Done;
+                    goto Error;
                 }
                 year_long = PyLong_AsLong(year);
                 Py_DECREF(year);
                 if (year_long == -1 && PyErr_Occurred()) {
-                    goto Done;
+                    goto Error;
                 }
             }
-            ntoappend = PyOS_snprintf(buf, sizeof(buf),
+            /* Buffer of maximum size of formatted year permitted by long.
+             * +6 to accommodate dashes, 2-digit month and day for %F. */
+            char buf[SIZEOF_LONG * 5 / 2 + 2 + 6];
+            Py_ssize_t n = PyOS_snprintf(buf, sizeof(buf),
 #ifdef Py_STRFTIME_C99_SUPPORT
                                       ch == 'F' ? "%04ld-%%m-%%d" :
 #endif
                                       "%04ld", year_long);
 #ifdef Py_STRFTIME_C99_SUPPORT
             if (ch == 'C') {
-                ntoappend -= 2;
+                n -= 2;
             }
 #endif
-            ptoappend = buf;
+            if (_PyUnicodeWriter_WriteSubstring(&writer, format, start, end) < 
0) {
+                goto Error;
+            }
+            start = i;
+            if (_PyUnicodeWriter_WriteASCIIString(&writer, buf, n) < 0) {
+                goto Error;
+            }
+            continue;
         }
 #endif
         else {
             /* percent followed by something else */
-#ifdef Py_NORMALIZE_CENTURY
- PassThrough:
-#endif
-            ptoappend = pin - 2;
-            ntoappend = 2;
-        }
-
-        /* Append the ntoappend chars starting at ptoappend to
-         * the new format.
-         */
-        if (ntoappend == 0)
             continue;
-        assert(ptoappend != NULL);
-        assert(ntoappend > 0);
-        while (usednew + ntoappend > totalnew) {
-            if (totalnew > (PY_SSIZE_T_MAX >> 1)) { /* overflow */
-                PyErr_NoMemory();
-                goto Done;
-            }
-            totalnew <<= 1;
-            if (_PyBytes_Resize(&newfmt, totalnew) < 0)
-                goto Done;
-            pnew = PyBytes_AsString(newfmt) + usednew;
         }
-        memcpy(pnew, ptoappend, ntoappend);
-        pnew += ntoappend;
-        usednew += ntoappend;
-        assert(usednew <= totalnew);
+        assert(replacement != NULL);
+        assert(PyUnicode_Check(replacement));
+        if (_PyUnicodeWriter_WriteSubstring(&writer, format, start, end) < 0) {
+            goto Error;
+        }
+        start = i;
+        if (_PyUnicodeWriter_WriteStr(&writer, replacement) < 0) {
+            goto Error;
+        }
     }  /* end while() */
 
-    if (_PyBytes_Resize(&newfmt, usednew) < 0)
-        goto Done;
-    {
-        PyObject *format;
-
-        format = PyUnicode_FromString(PyBytes_AS_STRING(newfmt));
-        if (format != NULL) {
-            result = PyObject_CallFunctionObjArgs(strftime,
-                                                   format, timetuple, NULL);
-            Py_DECREF(format);
+    PyObject *newformat;
+    if (start == 0) {
+        _PyUnicodeWriter_Dealloc(&writer);
+        newformat = Py_NewRef(format);
+    }
+    else {
+        if (_PyUnicodeWriter_WriteSubstring(&writer, format, start, flen) < 0) 
{
+            goto Error;
+        }
+        newformat = _PyUnicodeWriter_Finish(&writer);
+        if (newformat == NULL) {
+            goto Done;
         }
     }
+    result = PyObject_CallFunctionObjArgs(strftime,
+                                          newformat, timetuple, NULL);
+    Py_DECREF(newformat);
+
  Done:
     Py_XDECREF(freplacement);
     Py_XDECREF(zreplacement);
     Py_XDECREF(colonzreplacement);
     Py_XDECREF(Zreplacement);
-    Py_XDECREF(newfmt);
     Py_XDECREF(strftime);
     return result;
+
+ Error:
+    _PyUnicodeWriter_Dealloc(&writer);
+    goto Done;
 }
 
 /* ---------------------------------------------------------------------------
diff --git a/Modules/timemodule.c b/Modules/timemodule.c
index 9720c201a184a8..b9d114ada0dfcd 100644
--- a/Modules/timemodule.c
+++ b/Modules/timemodule.c
@@ -776,27 +776,100 @@ the C library strftime function.\n"
 #endif
 
 static PyObject *
-time_strftime(PyObject *module, PyObject *args)
+time_strftime1(time_char **outbuf, size_t *bufsize,
+               time_char *format, size_t fmtlen,
+               struct tm *tm)
 {
-    PyObject *tup = NULL;
-    struct tm buf;
-    const time_char *fmt;
+    size_t buflen;
+#if defined(MS_WINDOWS) && !defined(HAVE_WCSFTIME)
+    /* check that the format string contains only valid directives */
+    for (const time_char *f = strchr(format, '%');
+        f != NULL;
+        f = strchr(f + 2, '%'))
+    {
+        if (f[1] == '#')
+            ++f; /* not documented by python, */
+        if (f[1] == '\0')
+            break;
+        if ((f[1] == 'y') && tm->tm_year < 0) {
+            PyErr_SetString(PyExc_ValueError,
+                            "format %y requires year >= 1900 on Windows");
+            return NULL;
+        }
+    }
+#elif (defined(_AIX) || (defined(__sun) && defined(__SVR4))) && 
defined(HAVE_WCSFTIME)
+    for (const time_char *f = wcschr(format, '%');
+        f != NULL;
+        f = wcschr(f + 2, '%'))
+    {
+        if (f[1] == L'\0')
+            break;
+        /* Issue #19634: On AIX, wcsftime("y", (1899, 1, 1, 0, 0, 0, 0, 0, 0))
+           returns "0/" instead of "99" */
+        if (f[1] == L'y' && tm->tm_year < 0) {
+            PyErr_SetString(PyExc_ValueError,
+                            "format %y requires year >= 1900 on AIX");
+            return NULL;
+        }
+    }
+#endif
+
+    /* I hate these functions that presume you know how big the output
+     * will be ahead of time...
+     */
+    while (1) {
+        if (*bufsize > PY_SSIZE_T_MAX/sizeof(time_char)) {
+            PyErr_NoMemory();
+            return NULL;
+        }
+        *outbuf = (time_char *)PyMem_Realloc(*outbuf,
+                                             *bufsize*sizeof(time_char));
+        if (*outbuf == NULL) {
+            PyErr_NoMemory();
+            return NULL;
+        }
+#if defined _MSC_VER && _MSC_VER >= 1400 && defined(__STDC_SECURE_LIB__)
+        errno = 0;
+#endif
+        _Py_BEGIN_SUPPRESS_IPH
+        buflen = format_time(*outbuf, *bufsize, format, tm);
+        _Py_END_SUPPRESS_IPH
+#if defined _MSC_VER && _MSC_VER >= 1400 && defined(__STDC_SECURE_LIB__)
+        /* VisualStudio .NET 2005 does this properly */
+        if (buflen == 0 && errno == EINVAL) {
+            PyErr_SetString(PyExc_ValueError, "Invalid format string");
+            return NULL;
+        }
+#endif
+        if (buflen == 0 && *bufsize < 256 * fmtlen) {
+            *bufsize += *bufsize;
+            continue;
+        }
+        /* If the buffer is 256 times as long as the format,
+           it's probably not failing for lack of room!
+           More likely, the format yields an empty result,
+           e.g. an empty format, or %Z when the timezone
+           is unknown. */
 #ifdef HAVE_WCSFTIME
-    wchar_t *format;
+        return PyUnicode_FromWideChar(*outbuf, buflen);
 #else
-    PyObject *format;
+        return PyUnicode_DecodeLocaleAndSize(*outbuf, buflen, 
"surrogateescape");
 #endif
+    }
+}
+
+static PyObject *
+time_strftime(PyObject *module, PyObject *args)
+{
+    PyObject *tup = NULL;
+    struct tm buf;
     PyObject *format_arg;
-    size_t fmtlen, buflen;
-    time_char *outbuf = NULL;
-    size_t i;
-    PyObject *ret = NULL;
+    Py_ssize_t format_size;
+    time_char *format, *outbuf = NULL;
+    size_t fmtlen, bufsize = 1024;
 
     memset((void *) &buf, '\0', sizeof(buf));
 
-    /* Will always expect a unicode string to be passed as format.
-       Given that there's no str type anymore in py3k this seems safe.
-    */
     if (!PyArg_ParseTuple(args, "U|O:strftime", &format_arg, &tup))
         return NULL;
 
@@ -834,101 +907,63 @@ time_strftime(PyObject *module, PyObject *args)
     else if (buf.tm_isdst > 1)
         buf.tm_isdst = 1;
 
-#ifdef HAVE_WCSFTIME
-    format = PyUnicode_AsWideCharString(format_arg, NULL);
-    if (format == NULL)
+    format_size = PyUnicode_GET_LENGTH(format_arg);
+    if ((size_t)format_size > PY_SSIZE_T_MAX/sizeof(time_char) - 1) {
+        PyErr_NoMemory();
         return NULL;
-    fmt = format;
-#else
-    /* Convert the unicode string to an ascii one */
-    format = PyUnicode_EncodeLocale(format_arg, "surrogateescape");
-    if (format == NULL)
+    }
+    format = PyMem_Malloc((format_size + 1)*sizeof(time_char));
+    if (format == NULL) {
+        PyErr_NoMemory();
         return NULL;
-    fmt = PyBytes_AS_STRING(format);
-#endif
-
-#if defined(MS_WINDOWS) && !defined(HAVE_WCSFTIME)
-    /* check that the format string contains only valid directives */
-    for (outbuf = strchr(fmt, '%');
-        outbuf != NULL;
-        outbuf = strchr(outbuf+2, '%'))
-    {
-        if (outbuf[1] == '#')
-            ++outbuf; /* not documented by python, */
-        if (outbuf[1] == '\0')
-            break;
-        if ((outbuf[1] == 'y') && buf.tm_year < 0) {
-            PyErr_SetString(PyExc_ValueError,
-                        "format %y requires year >= 1900 on Windows");
-            Py_DECREF(format);
-            return NULL;
-        }
     }
-#elif (defined(_AIX) || (defined(__sun) && defined(__SVR4))) && 
defined(HAVE_WCSFTIME)
-    for (outbuf = wcschr(fmt, '%');
-        outbuf != NULL;
-        outbuf = wcschr(outbuf+2, '%'))
-    {
-        if (outbuf[1] == L'\0')
-            break;
-        /* Issue #19634: On AIX, wcsftime("y", (1899, 1, 1, 0, 0, 0, 0, 0, 0))
-           returns "0/" instead of "99" */
-        if (outbuf[1] == L'y' && buf.tm_year < 0) {
-            PyErr_SetString(PyExc_ValueError,
-                            "format %y requires year >= 1900 on AIX");
-            PyMem_Free(format);
-            return NULL;
+    _PyUnicodeWriter writer;
+    _PyUnicodeWriter_Init(&writer);
+    writer.overallocate = 1;
+    Py_ssize_t i = 0;
+    while (i < format_size) {
+        fmtlen = 0;
+        for (; i < format_size; i++) {
+            Py_UCS4 c = PyUnicode_READ_CHAR(format_arg, i);
+            if (!c || c > 127) {
+                break;
+            }
+            format[fmtlen++] = (char)c;
         }
-    }
-#endif
-
-    fmtlen = time_strlen(fmt);
-
-    /* I hate these functions that presume you know how big the output
-     * will be ahead of time...
-     */
-    for (i = 1024; ; i += i) {
-        outbuf = (time_char *)PyMem_Malloc(i*sizeof(time_char));
-        if (outbuf == NULL) {
-            PyErr_NoMemory();
-            break;
+        if (fmtlen) {
+            format[fmtlen] = 0;
+            PyObject *unicode = time_strftime1(&outbuf, &bufsize,
+                                               format, fmtlen, &buf);
+            if (unicode == NULL) {
+                goto error;
+            }
+            if (_PyUnicodeWriter_WriteStr(&writer, unicode) < 0) {
+                Py_DECREF(unicode);
+                goto error;
+            }
+            Py_DECREF(unicode);
         }
-#if defined _MSC_VER && _MSC_VER >= 1400 && defined(__STDC_SECURE_LIB__)
-        errno = 0;
-#endif
-        _Py_BEGIN_SUPPRESS_IPH
-        buflen = format_time(outbuf, i, fmt, &buf);
-        _Py_END_SUPPRESS_IPH
-#if defined _MSC_VER && _MSC_VER >= 1400 && defined(__STDC_SECURE_LIB__)
-        /* VisualStudio .NET 2005 does this properly */
-        if (buflen == 0 && errno == EINVAL) {
-            PyErr_SetString(PyExc_ValueError, "Invalid format string");
-            PyMem_Free(outbuf);
-            break;
+
+        Py_ssize_t start = i;
+        for (; i < format_size; i++) {
+            Py_UCS4 c = PyUnicode_READ_CHAR(format_arg, i);
+            if (c == '%') {
+                break;
+            }
         }
-#endif
-        if (buflen > 0 || i >= 256 * fmtlen) {
-            /* If the buffer is 256 times as long as the format,
-               it's probably not failing for lack of room!
-               More likely, the format yields an empty result,
-               e.g. an empty format, or %Z when the timezone
-               is unknown. */
-#ifdef HAVE_WCSFTIME
-            ret = PyUnicode_FromWideChar(outbuf, buflen);
-#else
-            ret = PyUnicode_DecodeLocaleAndSize(outbuf, buflen, 
"surrogateescape");
-#endif
-            PyMem_Free(outbuf);
-            break;
+        if (_PyUnicodeWriter_WriteSubstring(&writer, format_arg, start, i) < 
0) {
+            goto error;
         }
-        PyMem_Free(outbuf);
     }
-#ifdef HAVE_WCSFTIME
+
+    PyMem_Free(outbuf);
     PyMem_Free(format);
-#else
-    Py_DECREF(format);
-#endif
-    return ret;
+    return _PyUnicodeWriter_Finish(&writer);
+error:
+    PyMem_Free(outbuf);
+    PyMem_Free(format);
+    _PyUnicodeWriter_Dealloc(&writer);
+    return NULL;
 }
 
 #undef time_char

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: arch...@mail-archive.com

Reply via email to