https://github.com/python/cpython/commit/4b44b3409ac026e7f13054a3daa18ab7ee14d85c
commit: 4b44b3409ac026e7f13054a3daa18ab7ee14d85c
branch: main
author: Emma Smith <e...@emmatyping.dev>
committer: serhiy-storchaka <storch...@gmail.com>
date: 2025-06-05T14:31:49+03:00
summary:

gh-134938: Add set_pledged_input_size() to ZstdCompressor (GH-135010)

files:
M Doc/library/compression.zstd.rst
M Lib/test/test_zstd.py
M Modules/_zstd/_zstdmodule.c
M Modules/_zstd/_zstdmodule.h
M Modules/_zstd/clinic/compressor.c.h
M Modules/_zstd/compressor.c

diff --git a/Doc/library/compression.zstd.rst b/Doc/library/compression.zstd.rst
index 35bcbc2bfd8eac..57ad8e3377fc67 100644
--- a/Doc/library/compression.zstd.rst
+++ b/Doc/library/compression.zstd.rst
@@ -247,6 +247,27 @@ Compressing and decompressing data in memory
       The *mode* argument is a :class:`ZstdCompressor` attribute, either
       :attr:`~.FLUSH_BLOCK`, or :attr:`~.FLUSH_FRAME`.
 
+   .. method:: set_pledged_input_size(size)
+
+      Specify the amount of uncompressed data *size* that will be provided for
+      the next frame. *size* will be written into the frame header of the next
+      frame unless :attr:`CompressionParameter.content_size_flag` is ``False``
+      or ``0``. A size of ``0`` means that the frame is empty. If *size* is
+      ``None``, the frame header will omit the frame size. Frames that include
+      the uncompressed data size require less memory to decompress, especially
+      at higher compression levels.
+
+      If :attr:`last_mode` is not :attr:`FLUSH_FRAME`, a
+      :exc:`ValueError` is raised as the compressor is not at the start of
+      a frame. If the pledged size does not match the actual size of data
+      provided to :meth:`.compress`, future calls to :meth:`!compress` or
+      :meth:`flush` may raise :exc:`ZstdError` and the last chunk of data may
+      be lost.
+
+      After :meth:`flush` or :meth:`.compress` are called with mode
+      :attr:`FLUSH_FRAME`, the next frame will not include the frame size into
+      the header unless :meth:`!set_pledged_input_size` is called again.
+
    .. attribute:: CONTINUE
 
       Collect more data for compression, which may or may not generate output
@@ -266,6 +287,13 @@ Compressing and decompressing data in memory
       :meth:`~.compress` will be written into a new frame and
       *cannot* reference past data.
 
+   .. attribute:: last_mode
+
+      The last mode passed to either :meth:`~.compress` or :meth:`~.flush`.
+      The value can be one of :attr:`~.CONTINUE`, :attr:`~.FLUSH_BLOCK`, or
+      :attr:`~.FLUSH_FRAME`. The initial value is :attr:`~.FLUSH_FRAME`,
+      signifying that the compressor is at the start of a new frame.
+
 
 .. class:: ZstdDecompressor(zstd_dict=None, options=None)
 
@@ -620,12 +648,17 @@ Advanced parameter control
       Write the size of the data to be compressed into the Zstandard frame
       header when known prior to compressing.
 
-      This flag only takes effect under the following two scenarios:
+      This flag only takes effect under the following scenarios:
 
       * Calling :func:`compress` for one-shot compression
       * Providing all of the data to be compressed in the frame in a single
         :meth:`ZstdCompressor.compress` call, with the
         :attr:`ZstdCompressor.FLUSH_FRAME` mode.
+      * Calling :meth:`ZstdCompressor.set_pledged_input_size` with the exact
+        amount of data that will be provided to the compressor prior to any
+        calls to :meth:`ZstdCompressor.compress` for the current frame.
+        :meth:`!ZstdCompressor.set_pledged_input_size` must be called for each
+        new frame.
 
       All other compression calls may not write the size information into the
       frame header.
diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py
index e475d9346b9594..14a09a886046f7 100644
--- a/Lib/test/test_zstd.py
+++ b/Lib/test/test_zstd.py
@@ -395,6 +395,115 @@ def test_compress_empty(self):
         c = ZstdCompressor()
         self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'')
 
+    def test_set_pledged_input_size(self):
+        DAT = DECOMPRESSED_100_PLUS_32KB
+        CHUNK_SIZE = len(DAT) // 3
+
+        # wrong value
+        c = ZstdCompressor()
+        with self.assertRaisesRegex(ValueError,
+                                    r'should be a positive int less than \d+'):
+            c.set_pledged_input_size(-300)
+        # overflow
+        with self.assertRaisesRegex(ValueError,
+                                    r'should be a positive int less than \d+'):
+            c.set_pledged_input_size(2**64)
+        # ZSTD_CONTENTSIZE_ERROR is invalid
+        with self.assertRaisesRegex(ValueError,
+                                    r'should be a positive int less than \d+'):
+            c.set_pledged_input_size(2**64-2)
+        # ZSTD_CONTENTSIZE_UNKNOWN should use None
+        with self.assertRaisesRegex(ValueError,
+                                    r'should be a positive int less than \d+'):
+            c.set_pledged_input_size(2**64-1)
+
+        # check valid values are settable
+        c.set_pledged_input_size(2**63)
+        c.set_pledged_input_size(2**64-3)
+
+        # check that zero means empty frame
+        c = ZstdCompressor(level=1)
+        c.set_pledged_input_size(0)
+        c.compress(b'')
+        dat = c.flush()
+        ret = get_frame_info(dat)
+        self.assertEqual(ret.decompressed_size, 0)
+
+
+        # wrong mode
+        c = ZstdCompressor(level=1)
+        c.compress(b'123456')
+        self.assertEqual(c.last_mode, c.CONTINUE)
+        with self.assertRaisesRegex(ValueError,
+                                    r'last_mode == FLUSH_FRAME'):
+            c.set_pledged_input_size(300)
+
+        # None value
+        c = ZstdCompressor(level=1)
+        c.set_pledged_input_size(None)
+        dat = c.compress(DAT) + c.flush()
+
+        ret = get_frame_info(dat)
+        self.assertEqual(ret.decompressed_size, None)
+
+        # correct value
+        c = ZstdCompressor(level=1)
+        c.set_pledged_input_size(len(DAT))
+
+        chunks = []
+        posi = 0
+        while posi < len(DAT):
+            dat = c.compress(DAT[posi:posi+CHUNK_SIZE])
+            posi += CHUNK_SIZE
+            chunks.append(dat)
+
+        dat = c.flush()
+        chunks.append(dat)
+        chunks = b''.join(chunks)
+
+        ret = get_frame_info(chunks)
+        self.assertEqual(ret.decompressed_size, len(DAT))
+        self.assertEqual(decompress(chunks), DAT)
+
+        c.set_pledged_input_size(len(DAT)) # the second frame
+        dat = c.compress(DAT) + c.flush()
+
+        ret = get_frame_info(dat)
+        self.assertEqual(ret.decompressed_size, len(DAT))
+        self.assertEqual(decompress(dat), DAT)
+
+        # not enough data
+        c = ZstdCompressor(level=1)
+        c.set_pledged_input_size(len(DAT)+1)
+
+        for start in range(0, len(DAT), CHUNK_SIZE):
+            end = min(start+CHUNK_SIZE, len(DAT))
+            _dat = c.compress(DAT[start:end])
+
+        with self.assertRaises(ZstdError):
+            c.flush()
+
+        # too much data
+        c = ZstdCompressor(level=1)
+        c.set_pledged_input_size(len(DAT))
+
+        for start in range(0, len(DAT), CHUNK_SIZE):
+            end = min(start+CHUNK_SIZE, len(DAT))
+            _dat = c.compress(DAT[start:end])
+
+        with self.assertRaises(ZstdError):
+            c.compress(b'extra', ZstdCompressor.FLUSH_FRAME)
+
+        # content size not set if content_size_flag == 0
+        c = ZstdCompressor(options={CompressionParameter.content_size_flag: 0})
+        c.set_pledged_input_size(10)
+        dat1 = c.compress(b"hello")
+        dat2 = c.compress(b"world")
+        dat3 = c.flush()
+        frame_data = get_frame_info(dat1 + dat2 + dat3)
+        self.assertIsNone(frame_data.decompressed_size)
+
+
 class DecompressorTestCase(unittest.TestCase):
 
     def test_simple_decompress_bad_args(self):
diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c
index b0e50f873f4ca6..d75c0779474a82 100644
--- a/Modules/_zstd/_zstdmodule.c
+++ b/Modules/_zstd/_zstdmodule.c
@@ -72,6 +72,9 @@ set_zstd_error(const _zstd_state *state, error_type type, 
size_t zstd_ret)
         case ERR_COMPRESS:
             msg = "Unable to compress Zstandard data: %s";
             break;
+        case ERR_SET_PLEDGED_INPUT_SIZE:
+            msg = "Unable to set pledged uncompressed content size: %s";
+            break;
 
         case ERR_LOAD_D_DICT:
             msg = "Unable to load Zstandard dictionary or prefix for "
diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h
index c73f15b3c5299b..4e8f708f2232c7 100644
--- a/Modules/_zstd/_zstdmodule.h
+++ b/Modules/_zstd/_zstdmodule.h
@@ -27,6 +27,7 @@ typedef struct {
 typedef enum {
     ERR_DECOMPRESS,
     ERR_COMPRESS,
+    ERR_SET_PLEDGED_INPUT_SIZE,
 
     ERR_LOAD_D_DICT,
     ERR_LOAD_C_DICT,
diff --git a/Modules/_zstd/clinic/compressor.c.h 
b/Modules/_zstd/clinic/compressor.c.h
index f69161b590e5b7..4f8d93fd9e867c 100644
--- a/Modules/_zstd/clinic/compressor.c.h
+++ b/Modules/_zstd/clinic/compressor.c.h
@@ -252,4 +252,43 @@ _zstd_ZstdCompressor_flush(PyObject *self, PyObject *const 
*args, Py_ssize_t nar
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=ee2d1dc298de790c input=a9049054013a1b77]*/
+
+PyDoc_STRVAR(_zstd_ZstdCompressor_set_pledged_input_size__doc__,
+"set_pledged_input_size($self, size, /)\n"
+"--\n"
+"\n"
+"Set the uncompressed content size to be written into the frame header.\n"
+"\n"
+"  size\n"
+"    The size of the uncompressed data to be provided to the compressor.\n"
+"\n"
+"This method can be used to ensure the header of the frame about to be 
written\n"
+"includes the size of the data, unless the 
CompressionParameter.content_size_flag\n"
+"is set to False. If last_mode != FLUSH_FRAME, then a RuntimeError is 
raised.\n"
+"\n"
+"It is important to ensure that the pledged data size matches the actual 
data\n"
+"size. If they do not match the compressed output data may be corrupted and 
the\n"
+"final chunk written may be lost.");
+
+#define _ZSTD_ZSTDCOMPRESSOR_SET_PLEDGED_INPUT_SIZE_METHODDEF    \
+    {"set_pledged_input_size", 
(PyCFunction)_zstd_ZstdCompressor_set_pledged_input_size, METH_O, 
_zstd_ZstdCompressor_set_pledged_input_size__doc__},
+
+static PyObject *
+_zstd_ZstdCompressor_set_pledged_input_size_impl(ZstdCompressor *self,
+                                                 unsigned long long size);
+
+static PyObject *
+_zstd_ZstdCompressor_set_pledged_input_size(PyObject *self, PyObject *arg)
+{
+    PyObject *return_value = NULL;
+    unsigned long long size;
+
+    if (!zstd_contentsize_converter(arg, &size)) {
+        goto exit;
+    }
+    return_value = 
_zstd_ZstdCompressor_set_pledged_input_size_impl((ZstdCompressor *)self, size);
+
+exit:
+    return return_value;
+}
+/*[clinic end generated code: output=c1d5c2cf06a8becd input=a9049054013a1b77]*/
diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c
index e1217635f60cb0..bc9e6eff89af68 100644
--- a/Modules/_zstd/compressor.c
+++ b/Modules/_zstd/compressor.c
@@ -45,6 +45,52 @@ typedef struct {
 
 #define ZstdCompressor_CAST(op) ((ZstdCompressor *)op)
 
+/*[python input]
+
+class zstd_contentsize_converter(CConverter):
+    type = 'unsigned long long'
+    converter = 'zstd_contentsize_converter'
+
+[python start generated code]*/
+/*[python end generated code: output=da39a3ee5e6b4b0d input=0932c350d633c7de]*/
+
+
+static int
+zstd_contentsize_converter(PyObject *size, unsigned long long *p)
+{
+    // None means the user indicates the size is unknown.
+    if (size == Py_None) {
+        *p = ZSTD_CONTENTSIZE_UNKNOWN;
+    }
+    else {
+        /* ZSTD_CONTENTSIZE_UNKNOWN is 0ULL - 1
+           ZSTD_CONTENTSIZE_ERROR   is 0ULL - 2
+           Users should only pass values < ZSTD_CONTENTSIZE_ERROR */
+        unsigned long long pledged_size = PyLong_AsUnsignedLongLong(size);
+        /* Here we check for (unsigned long long)-1 as a sign of an error in
+           PyLong_AsUnsignedLongLong */
+        if (pledged_size == (unsigned long long)-1 && PyErr_Occurred()) {
+            *p = ZSTD_CONTENTSIZE_ERROR;
+            if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+                PyErr_Format(PyExc_ValueError,
+                             "size argument should be a positive int less "
+                             "than %ull", ZSTD_CONTENTSIZE_ERROR);
+                return 0;
+            }
+            return 0;
+        }
+        if (pledged_size >= ZSTD_CONTENTSIZE_ERROR) {
+            *p = ZSTD_CONTENTSIZE_ERROR;
+            PyErr_Format(PyExc_ValueError,
+                         "size argument should be a positive int less "
+                         "than %ull", ZSTD_CONTENTSIZE_ERROR);
+            return 0;
+        }
+        *p = pledged_size;
+    }
+    return 1;
+}
+
 #include "clinic/compressor.c.h"
 
 static int
@@ -643,9 +689,61 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int 
mode)
     return ret;
 }
 
+
+/*[clinic input]
+_zstd.ZstdCompressor.set_pledged_input_size
+
+    size: zstd_contentsize
+        The size of the uncompressed data to be provided to the compressor.
+    /
+
+Set the uncompressed content size to be written into the frame header.
+
+This method can be used to ensure the header of the frame about to be written
+includes the size of the data, unless the 
CompressionParameter.content_size_flag
+is set to False. If last_mode != FLUSH_FRAME, then a RuntimeError is raised.
+
+It is important to ensure that the pledged data size matches the actual data
+size. If they do not match the compressed output data may be corrupted and the
+final chunk written may be lost.
+[clinic start generated code]*/
+
+static PyObject *
+_zstd_ZstdCompressor_set_pledged_input_size_impl(ZstdCompressor *self,
+                                                 unsigned long long size)
+/*[clinic end generated code: output=3a09e55cc0e3b4f9 input=afd8a7d78cff2eb5]*/
+{
+    // Error occured while converting argument, should be unreachable
+    assert(size != ZSTD_CONTENTSIZE_ERROR);
+
+    /* Thread-safe code */
+    PyMutex_Lock(&self->lock);
+
+    /* Check the current mode */
+    if (self->last_mode != ZSTD_e_end) {
+        PyErr_SetString(PyExc_ValueError,
+                        "set_pledged_input_size() method must be called "
+                        "when last_mode == FLUSH_FRAME");
+        PyMutex_Unlock(&self->lock);
+        return NULL;
+    }
+
+    /* Set pledged content size */
+    size_t zstd_ret = ZSTD_CCtx_setPledgedSrcSize(self->cctx, size);
+    PyMutex_Unlock(&self->lock);
+    if (ZSTD_isError(zstd_ret)) {
+        _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
+        set_zstd_error(mod_state, ERR_SET_PLEDGED_INPUT_SIZE, zstd_ret);
+        return NULL;
+    }
+
+    Py_RETURN_NONE;
+}
+
 static PyMethodDef ZstdCompressor_methods[] = {
     _ZSTD_ZSTDCOMPRESSOR_COMPRESS_METHODDEF
     _ZSTD_ZSTDCOMPRESSOR_FLUSH_METHODDEF
+    _ZSTD_ZSTDCOMPRESSOR_SET_PLEDGED_INPUT_SIZE_METHODDEF
     {NULL, NULL}
 };
 

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: arch...@mail-archive.com

Reply via email to