https://github.com/python/cpython/commit/81675941fb51f09b89ba46bce237222a87ea60f6
commit: 81675941fb51f09b89ba46bce237222a87ea60f6
branch: 3.14
author: Miss Islington (bot) <31488909+miss-isling...@users.noreply.github.com>
committer: colesbury <colesb...@gmail.com>
date: 2025-05-23T10:00:38Z
summary:

[3.14] gh-133885: Use locks instead of critical sections for _zstd (gh-134289) 
(gh-134560)

Move from using critical sections to locks for the (de)compression methods.
Since the methods allow other threads to run, we should use a lock rather
than a critical section.
(cherry picked from commit 8dbc11971974a725dc8a11c0dc65d8f6fcb4d902)

Co-authored-by: Emma Smith <e...@emmatyping.dev>

files:
M Lib/test/test_zstd.py
M Modules/_zstd/clinic/decompressor.c.h
M Modules/_zstd/clinic/zstddict.c.h
M Modules/_zstd/compressor.c
M Modules/_zstd/decompressor.c
M Modules/_zstd/zstddict.c
M Modules/_zstd/zstddict.h

diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py
index 53ca592ea38828..084f8f24fc009c 100644
--- a/Lib/test/test_zstd.py
+++ b/Lib/test/test_zstd.py
@@ -2430,10 +2430,8 @@ def test_buffer_protocol(self):
             self.assertEqual(f.write(arr), LENGTH)
             self.assertEqual(f.tell(), LENGTH)
 
-@unittest.skip("it fails for now, see gh-133885")
 class FreeThreadingMethodTests(unittest.TestCase):
 
-    @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail 
with GIL disabled')
     @threading_helper.reap_threads
     @threading_helper.requires_working_threading()
     def test_compress_locking(self):
@@ -2470,7 +2468,6 @@ def run_method(method, input_data, output_data):
         actual = b''.join(output) + rest2
         self.assertEqual(expected, actual)
 
-    @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail 
with GIL disabled')
     @threading_helper.reap_threads
     @threading_helper.requires_working_threading()
     def test_decompress_locking(self):
@@ -2506,6 +2503,59 @@ def run_method(method, input_data, output_data):
         actual = b''.join(output)
         self.assertEqual(expected, actual)
 
+    @threading_helper.reap_threads
+    @threading_helper.requires_working_threading()
+    def test_compress_shared_dict(self):
+        num_threads = 8
+
+        def run_method(b):
+            level = threading.get_ident() % 4
+            # sync threads to increase chance of contention on
+            # capsule storing dictionary levels
+            b.wait()
+            ZstdCompressor(level=level,
+                           zstd_dict=TRAINED_DICT.as_digested_dict)
+            b.wait()
+            ZstdCompressor(level=level,
+                           zstd_dict=TRAINED_DICT.as_undigested_dict)
+            b.wait()
+            ZstdCompressor(level=level,
+                           zstd_dict=TRAINED_DICT.as_prefix)
+        threads = []
+
+        b = threading.Barrier(num_threads)
+        for i in range(num_threads):
+            thread = threading.Thread(target=run_method, args=(b,))
+
+            threads.append(thread)
+
+        with threading_helper.start_threads(threads):
+            pass
+
+    @threading_helper.reap_threads
+    @threading_helper.requires_working_threading()
+    def test_decompress_shared_dict(self):
+        num_threads = 8
+
+        def run_method(b):
+            # sync threads to increase chance of contention on
+            # decompression dictionary
+            b.wait()
+            ZstdDecompressor(zstd_dict=TRAINED_DICT.as_digested_dict)
+            b.wait()
+            ZstdDecompressor(zstd_dict=TRAINED_DICT.as_undigested_dict)
+            b.wait()
+            ZstdDecompressor(zstd_dict=TRAINED_DICT.as_prefix)
+        threads = []
+
+        b = threading.Barrier(num_threads)
+        for i in range(num_threads):
+            thread = threading.Thread(target=run_method, args=(b,))
+
+            threads.append(thread)
+
+        with threading_helper.start_threads(threads):
+            pass
 
 
 if __name__ == "__main__":
diff --git a/Modules/_zstd/clinic/decompressor.c.h 
b/Modules/_zstd/clinic/decompressor.c.h
index 4ecb19e9bde6ed..c6fdae74ab0447 100644
--- a/Modules/_zstd/clinic/decompressor.c.h
+++ b/Modules/_zstd/clinic/decompressor.c.h
@@ -7,7 +7,6 @@ preserve
 #  include "pycore_runtime.h"     // _Py_ID()
 #endif
 #include "pycore_abstract.h"      // _PyNumber_Index()
-#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION()
 #include "pycore_modsupport.h"    // _PyArg_UnpackKeywords()
 
 PyDoc_STRVAR(_zstd_ZstdDecompressor_new__doc__,
@@ -114,13 +113,7 @@ 
_zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self);
 static PyObject *
 _zstd_ZstdDecompressor_unused_data_get(PyObject *self, void 
*Py_UNUSED(context))
 {
-    PyObject *return_value = NULL;
-
-    Py_BEGIN_CRITICAL_SECTION(self);
-    return_value = 
_zstd_ZstdDecompressor_unused_data_get_impl((ZstdDecompressor *)self);
-    Py_END_CRITICAL_SECTION();
-
-    return return_value;
+    return _zstd_ZstdDecompressor_unused_data_get_impl((ZstdDecompressor 
*)self);
 }
 
 PyDoc_STRVAR(_zstd_ZstdDecompressor_decompress__doc__,
@@ -227,4 +220,4 @@ _zstd_ZstdDecompressor_decompress(PyObject *self, PyObject 
*const *args, Py_ssiz
 
     return return_value;
 }
-/*[clinic end generated code: output=7a4d278f9244e684 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=30c12ef047027ede input=a9049054013a1b77]*/
diff --git a/Modules/_zstd/clinic/zstddict.c.h 
b/Modules/_zstd/clinic/zstddict.c.h
index 34e0e4b3ecfe72..aaa29e491bc1bb 100644
--- a/Modules/_zstd/clinic/zstddict.c.h
+++ b/Modules/_zstd/clinic/zstddict.c.h
@@ -6,7 +6,6 @@ preserve
 #  include "pycore_gc.h"          // PyGC_Head
 #  include "pycore_runtime.h"     // _Py_ID()
 #endif
-#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION()
 #include "pycore_modsupport.h"    // _PyArg_UnpackKeywords()
 
 PyDoc_STRVAR(_zstd_ZstdDict_new__doc__,
@@ -118,13 +117,7 @@ _zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self);
 static PyObject *
 _zstd_ZstdDict_as_digested_dict_get(PyObject *self, void *Py_UNUSED(context))
 {
-    PyObject *return_value = NULL;
-
-    Py_BEGIN_CRITICAL_SECTION(self);
-    return_value = _zstd_ZstdDict_as_digested_dict_get_impl((ZstdDict *)self);
-    Py_END_CRITICAL_SECTION();
-
-    return return_value;
+    return _zstd_ZstdDict_as_digested_dict_get_impl((ZstdDict *)self);
 }
 
 PyDoc_STRVAR(_zstd_ZstdDict_as_undigested_dict__doc__,
@@ -156,13 +149,7 @@ _zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self);
 static PyObject *
 _zstd_ZstdDict_as_undigested_dict_get(PyObject *self, void *Py_UNUSED(context))
 {
-    PyObject *return_value = NULL;
-
-    Py_BEGIN_CRITICAL_SECTION(self);
-    return_value = _zstd_ZstdDict_as_undigested_dict_get_impl((ZstdDict 
*)self);
-    Py_END_CRITICAL_SECTION();
-
-    return return_value;
+    return _zstd_ZstdDict_as_undigested_dict_get_impl((ZstdDict *)self);
 }
 
 PyDoc_STRVAR(_zstd_ZstdDict_as_prefix__doc__,
@@ -194,12 +181,6 @@ _zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self);
 static PyObject *
 _zstd_ZstdDict_as_prefix_get(PyObject *self, void *Py_UNUSED(context))
 {
-    PyObject *return_value = NULL;
-
-    Py_BEGIN_CRITICAL_SECTION(self);
-    return_value = _zstd_ZstdDict_as_prefix_get_impl((ZstdDict *)self);
-    Py_END_CRITICAL_SECTION();
-
-    return return_value;
+    return _zstd_ZstdDict_as_prefix_get_impl((ZstdDict *)self);
 }
-/*[clinic end generated code: output=bfb31c1187477afd input=a9049054013a1b77]*/
+/*[clinic end generated code: output=8692eabee4e0d1fe input=a9049054013a1b77]*/
diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c
index 38baee2be1e95b..8f934858ef784f 100644
--- a/Modules/_zstd/compressor.c
+++ b/Modules/_zstd/compressor.c
@@ -17,6 +17,7 @@ class _zstd.ZstdCompressor "ZstdCompressor *" 
"&zstd_compressor_type_spec"
 #include "_zstdmodule.h"
 #include "buffer.h"
 #include "zstddict.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
 
 #include <stddef.h>               // offsetof()
 #include <zstd.h>                 // ZSTD_*()
@@ -38,6 +39,9 @@ typedef struct {
 
     /* Compression level */
     int compression_level;
+
+    /* Lock to protect the compression context */
+    PyMutex lock;
 } ZstdCompressor;
 
 #define ZstdCompressor_CAST(op) ((ZstdCompressor *)op)
@@ -149,12 +153,12 @@ capsule_free_cdict(PyObject *capsule)
 ZSTD_CDict *
 _get_CDict(ZstdDict *self, int compressionLevel)
 {
+    assert(PyMutex_IsLocked(&self->lock));
     PyObject *level = NULL;
-    PyObject *capsule;
+    PyObject *capsule = NULL;
     ZSTD_CDict *cdict;
+    int ret;
 
-    // TODO(emmatyping): refactor critical section code into a lock_held 
function
-    Py_BEGIN_CRITICAL_SECTION(self);
 
     /* int level object */
     level = PyLong_FromLong(compressionLevel);
@@ -163,12 +167,11 @@ _get_CDict(ZstdDict *self, int compressionLevel)
     }
 
     /* Get PyCapsule object from self->c_dicts */
-    capsule = PyDict_GetItemWithError(self->c_dicts, level);
+    ret = PyDict_GetItemRef(self->c_dicts, level, &capsule);
+    if (ret < 0) {
+        goto error;
+    }
     if (capsule == NULL) {
-        if (PyErr_Occurred()) {
-            goto error;
-        }
-
         /* Create ZSTD_CDict instance */
         char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
         Py_ssize_t dict_len = Py_SIZE(self->dict_content);
@@ -196,11 +199,10 @@ _get_CDict(ZstdDict *self, int compressionLevel)
         }
 
         /* Add PyCapsule object to self->c_dicts */
-        if (PyDict_SetItem(self->c_dicts, level, capsule) < 0) {
-            Py_DECREF(capsule);
+        ret = PyDict_SetItem(self->c_dicts, level, capsule);
+        if (ret < 0) {
             goto error;
         }
-        Py_DECREF(capsule);
     }
     else {
         /* ZSTD_CDict instance already exists */
@@ -212,15 +214,55 @@ _get_CDict(ZstdDict *self, int compressionLevel)
     cdict = NULL;
 success:
     Py_XDECREF(level);
-    Py_END_CRITICAL_SECTION();
+    Py_XDECREF(capsule);
     return cdict;
 }
 
 static int
-_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
+_zstd_load_impl(ZstdCompressor *self, ZstdDict *zd,
+                _zstd_state *mod_state, int type)
 {
-
     size_t zstd_ret;
+    if (type == DICT_TYPE_DIGESTED) {
+        /* Get ZSTD_CDict */
+        ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level);
+        if (c_dict == NULL) {
+            return -1;
+        }
+        /* Reference a prepared dictionary.
+           It overrides some compression context's parameters. */
+        zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict);
+    }
+    else if (type == DICT_TYPE_UNDIGESTED) {
+        /* Load a dictionary.
+           It doesn't override compression context's parameters. */
+        zstd_ret = ZSTD_CCtx_loadDictionary(
+                            self->cctx,
+                            PyBytes_AS_STRING(zd->dict_content),
+                            Py_SIZE(zd->dict_content));
+    }
+    else if (type == DICT_TYPE_PREFIX) {
+        /* Load a prefix */
+        zstd_ret = ZSTD_CCtx_refPrefix(
+                            self->cctx,
+                            PyBytes_AS_STRING(zd->dict_content),
+                            Py_SIZE(zd->dict_content));
+    }
+    else {
+        Py_UNREACHABLE();
+    }
+
+    /* Check error */
+    if (ZSTD_isError(zstd_ret)) {
+        set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret);
+        return -1;
+    }
+    return 0;
+}
+
+static int
+_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
+{
     _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
     if (mod_state == NULL) {
         return -1;
@@ -237,7 +279,10 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
         /* When compressing, use undigested dictionary by default. */
         zd = (ZstdDict*)dict;
         type = DICT_TYPE_UNDIGESTED;
-        goto load;
+        PyMutex_Lock(&zd->lock);
+        ret = _zstd_load_impl(self, zd, mod_state, type);
+        PyMutex_Unlock(&zd->lock);
+        return ret;
     }
 
     /* Check (ZstdDict, type) */
@@ -257,7 +302,10 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
             {
                 assert(type >= 0);
                 zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
-                goto load;
+                PyMutex_Lock(&zd->lock);
+                ret = _zstd_load_impl(self, zd, mod_state, type);
+                PyMutex_Unlock(&zd->lock);
+                return ret;
             }
         }
     }
@@ -266,49 +314,6 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
     PyErr_SetString(PyExc_TypeError,
                     "zstd_dict argument should be ZstdDict object.");
     return -1;
-
-load:
-    if (type == DICT_TYPE_DIGESTED) {
-        /* Get ZSTD_CDict */
-        ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level);
-        if (c_dict == NULL) {
-            return -1;
-        }
-        /* Reference a prepared dictionary.
-           It overrides some compression context's parameters. */
-        Py_BEGIN_CRITICAL_SECTION(self);
-        zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict);
-        Py_END_CRITICAL_SECTION();
-    }
-    else if (type == DICT_TYPE_UNDIGESTED) {
-        /* Load a dictionary.
-           It doesn't override compression context's parameters. */
-        Py_BEGIN_CRITICAL_SECTION2(self, zd);
-        zstd_ret = ZSTD_CCtx_loadDictionary(
-                            self->cctx,
-                            PyBytes_AS_STRING(zd->dict_content),
-                            Py_SIZE(zd->dict_content));
-        Py_END_CRITICAL_SECTION2();
-    }
-    else if (type == DICT_TYPE_PREFIX) {
-        /* Load a prefix */
-        Py_BEGIN_CRITICAL_SECTION2(self, zd);
-        zstd_ret = ZSTD_CCtx_refPrefix(
-                            self->cctx,
-                            PyBytes_AS_STRING(zd->dict_content),
-                            Py_SIZE(zd->dict_content));
-        Py_END_CRITICAL_SECTION2();
-    }
-    else {
-        Py_UNREACHABLE();
-    }
-
-    /* Check error */
-    if (ZSTD_isError(zstd_ret)) {
-        set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret);
-        return -1;
-    }
-    return 0;
 }
 
 /*[clinic input]
@@ -339,6 +344,7 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject 
*level,
 
     self->use_multithread = 0;
     self->dict = NULL;
+    self->lock = (PyMutex){0};
 
     /* Compression context */
     self->cctx = ZSTD_createCCtx();
@@ -403,6 +409,8 @@ ZstdCompressor_dealloc(PyObject *ob)
         ZSTD_freeCCtx(self->cctx);
     }
 
+    assert(!PyMutex_IsLocked(&self->lock));
+
     /* Py_XDECREF the dict after free the compression context */
     Py_CLEAR(self->dict);
 
@@ -412,9 +420,10 @@ ZstdCompressor_dealloc(PyObject *ob)
 }
 
 static PyObject *
-compress_impl(ZstdCompressor *self, Py_buffer *data,
-              ZSTD_EndDirective end_directive)
+compress_lock_held(ZstdCompressor *self, Py_buffer *data,
+                   ZSTD_EndDirective end_directive)
 {
+    assert(PyMutex_IsLocked(&self->lock));
     ZSTD_inBuffer in;
     ZSTD_outBuffer out;
     _BlocksOutputBuffer buffer = {.list = NULL};
@@ -495,8 +504,9 @@ mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer 
*out)
 #endif
 
 static PyObject *
-compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data)
+compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data)
 {
+    assert(PyMutex_IsLocked(&self->lock));
     ZSTD_inBuffer in;
     ZSTD_outBuffer out;
     _BlocksOutputBuffer buffer = {.list = NULL};
@@ -529,7 +539,7 @@ compress_mt_continue_impl(ZstdCompressor *self, Py_buffer 
*data)
             goto error;
         }
 
-        /* Like compress_impl(), output as much as possible. */
+        /* Like compress_lock_held(), output as much as possible. */
         if (out.pos == out.size) {
             if (_OutputBuffer_Grow(&buffer, &out) < 0) {
                 goto error;
@@ -588,14 +598,14 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, 
Py_buffer *data,
     }
 
     /* Thread-safe code */
-    Py_BEGIN_CRITICAL_SECTION(self);
+    PyMutex_Lock(&self->lock);
 
     /* Compress */
     if (self->use_multithread && mode == ZSTD_e_continue) {
-        ret = compress_mt_continue_impl(self, data);
+        ret = compress_mt_continue_lock_held(self, data);
     }
     else {
-        ret = compress_impl(self, data, mode);
+        ret = compress_lock_held(self, data, mode);
     }
 
     if (ret) {
@@ -607,7 +617,7 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, 
Py_buffer *data,
         /* Resetting cctx's session never fail */
         ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
     }
-    Py_END_CRITICAL_SECTION();
+    PyMutex_Unlock(&self->lock);
 
     return ret;
 }
@@ -642,8 +652,9 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int 
mode)
     }
 
     /* Thread-safe code */
-    Py_BEGIN_CRITICAL_SECTION(self);
-    ret = compress_impl(self, NULL, mode);
+    PyMutex_Lock(&self->lock);
+
+    ret = compress_lock_held(self, NULL, mode);
 
     if (ret) {
         self->last_mode = mode;
@@ -654,7 +665,7 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int 
mode)
         /* Resetting cctx's session never fail */
         ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
     }
-    Py_END_CRITICAL_SECTION();
+    PyMutex_Unlock(&self->lock);
 
     return ret;
 }
diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c
index 58f9c9f804e549..e299f73b071353 100644
--- a/Modules/_zstd/decompressor.c
+++ b/Modules/_zstd/decompressor.c
@@ -17,6 +17,7 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" 
"&zstd_decompressor_type_spec"
 #include "_zstdmodule.h"
 #include "buffer.h"
 #include "zstddict.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
 
 #include <stdbool.h>              // bool
 #include <stddef.h>               // offsetof()
@@ -45,6 +46,9 @@ typedef struct {
     /* For ZstdDecompressor, 0 or 1.
        1 means the end of the first frame has been reached. */
     bool eof;
+
+    /* Lock to protect the decompression context */
+    PyMutex lock;
 } ZstdDecompressor;
 
 #define ZstdDecompressor_CAST(op) ((ZstdDecompressor *)op)
@@ -54,6 +58,7 @@ typedef struct {
 static inline ZSTD_DDict *
 _get_DDict(ZstdDict *self)
 {
+    assert(PyMutex_IsLocked(&self->lock));
     ZSTD_DDict *ret;
 
     /* Already created */
@@ -61,15 +66,14 @@ _get_DDict(ZstdDict *self)
         return self->d_dict;
     }
 
-    Py_BEGIN_CRITICAL_SECTION(self);
     if (self->d_dict == NULL) {
         /* Create ZSTD_DDict instance from dictionary content */
         char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
         Py_ssize_t dict_len = Py_SIZE(self->dict_content);
         Py_BEGIN_ALLOW_THREADS
-        self->d_dict = ZSTD_createDDict(dict_buffer,
-                                        dict_len);
+        ret = ZSTD_createDDict(dict_buffer, dict_len);
         Py_END_ALLOW_THREADS
+        self->d_dict = ret;
 
         if (self->d_dict == NULL) {
             _zstd_state* const mod_state = 
PyType_GetModuleState(Py_TYPE(self));
@@ -81,11 +85,7 @@ _get_DDict(ZstdDict *self)
         }
     }
 
-    /* Don't lose any exception */
-    ret = self->d_dict;
-    Py_END_CRITICAL_SECTION();
-
-    return ret;
+    return self->d_dict;
 }
 
 /* Set decompression parameters to decompression context */
@@ -134,9 +134,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject 
*options)
         }
 
         /* Set parameter to compression context */
-        Py_BEGIN_CRITICAL_SECTION(self);
         zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v);
-        Py_END_CRITICAL_SECTION();
 
         /* Check error */
         if (ZSTD_isError(zstd_ret)) {
@@ -147,11 +145,53 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject 
*options)
     return 0;
 }
 
+static int
+_zstd_load_impl(ZstdDecompressor *self, ZstdDict *zd,
+                _zstd_state *mod_state, int type)
+{
+    size_t zstd_ret;
+    if (type == DICT_TYPE_DIGESTED) {
+        /* Get ZSTD_DDict */
+        ZSTD_DDict *d_dict = _get_DDict(zd);
+        if (d_dict == NULL) {
+            return -1;
+        }
+        /* Reference a prepared dictionary */
+        zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
+    }
+    else if (type == DICT_TYPE_UNDIGESTED) {
+        /* Load a dictionary */
+        zstd_ret = ZSTD_DCtx_loadDictionary(
+                            self->dctx,
+                            PyBytes_AS_STRING(zd->dict_content),
+                            Py_SIZE(zd->dict_content));
+    }
+    else if (type == DICT_TYPE_PREFIX) {
+        /* Load a prefix */
+        zstd_ret = ZSTD_DCtx_refPrefix(
+                            self->dctx,
+                            PyBytes_AS_STRING(zd->dict_content),
+                            Py_SIZE(zd->dict_content));
+    }
+    else {
+        /* Impossible code path */
+        PyErr_SetString(PyExc_SystemError,
+                        "load_d_dict() impossible code path");
+        return -1;
+    }
+
+    /* Check error */
+    if (ZSTD_isError(zstd_ret)) {
+        set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret);
+        return -1;
+    }
+    return 0;
+}
+
 /* Load dictionary or prefix to decompression context */
 static int
 _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
 {
-    size_t zstd_ret;
     _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
     if (mod_state == NULL) {
         return -1;
@@ -168,7 +208,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
         /* When decompressing, use digested dictionary by default. */
         zd = (ZstdDict*)dict;
         type = DICT_TYPE_DIGESTED;
-        goto load;
+        PyMutex_Lock(&zd->lock);
+        ret = _zstd_load_impl(self, zd, mod_state, type);
+        PyMutex_Unlock(&zd->lock);
+        return ret;
     }
 
     /* Check (ZstdDict, type) */
@@ -188,7 +231,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
             {
                 assert(type >= 0);
                 zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
-                goto load;
+                PyMutex_Lock(&zd->lock);
+                ret = _zstd_load_impl(self, zd, mod_state, type);
+                PyMutex_Unlock(&zd->lock);
+                return ret;
             }
         }
     }
@@ -197,50 +243,6 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
     PyErr_SetString(PyExc_TypeError,
                     "zstd_dict argument should be ZstdDict object.");
     return -1;
-
-load:
-    if (type == DICT_TYPE_DIGESTED) {
-        /* Get ZSTD_DDict */
-        ZSTD_DDict *d_dict = _get_DDict(zd);
-        if (d_dict == NULL) {
-            return -1;
-        }
-        /* Reference a prepared dictionary */
-        Py_BEGIN_CRITICAL_SECTION(self);
-        zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
-        Py_END_CRITICAL_SECTION();
-    }
-    else if (type == DICT_TYPE_UNDIGESTED) {
-        /* Load a dictionary */
-        Py_BEGIN_CRITICAL_SECTION2(self, zd);
-        zstd_ret = ZSTD_DCtx_loadDictionary(
-                            self->dctx,
-                            PyBytes_AS_STRING(zd->dict_content),
-                            Py_SIZE(zd->dict_content));
-        Py_END_CRITICAL_SECTION2();
-    }
-    else if (type == DICT_TYPE_PREFIX) {
-        /* Load a prefix */
-        Py_BEGIN_CRITICAL_SECTION2(self, zd);
-        zstd_ret = ZSTD_DCtx_refPrefix(
-                            self->dctx,
-                            PyBytes_AS_STRING(zd->dict_content),
-                            Py_SIZE(zd->dict_content));
-        Py_END_CRITICAL_SECTION2();
-    }
-    else {
-        /* Impossible code path */
-        PyErr_SetString(PyExc_SystemError,
-                        "load_d_dict() impossible code path");
-        return -1;
-    }
-
-    /* Check error */
-    if (ZSTD_isError(zstd_ret)) {
-        set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret);
-        return -1;
-    }
-    return 0;
 }
 
 /*
@@ -268,8 +270,8 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
       Note, decompressing "an empty input" in any case will make it > 0.
 */
 static PyObject *
-decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
-                Py_ssize_t max_length)
+decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in,
+                     Py_ssize_t max_length)
 {
     size_t zstd_ret;
     ZSTD_outBuffer out;
@@ -339,10 +341,9 @@ decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
 }
 
 static void
-decompressor_reset_session(ZstdDecompressor *self)
+decompressor_reset_session_lock_held(ZstdDecompressor *self)
 {
-    // TODO(emmatyping): use _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED here
-    // and ensure lock is always held
+    assert(PyMutex_IsLocked(&self->lock));
 
     /* Reset variables */
     self->in_begin = 0;
@@ -359,8 +360,10 @@ decompressor_reset_session(ZstdDecompressor *self)
 }
 
 static PyObject *
-stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t 
max_length)
+stream_decompress_lock_held(ZstdDecompressor *self, Py_buffer *data,
+                            Py_ssize_t max_length)
 {
+    assert(PyMutex_IsLocked(&self->lock));
     ZSTD_inBuffer in;
     PyObject *ret = NULL;
     int use_input_buffer;
@@ -456,7 +459,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, 
Py_ssize_t max_length
     assert(in.pos == 0);
 
     /* Decompress */
-    ret = decompress_impl(self, &in, max_length);
+    ret = decompress_lock_held(self, &in, max_length);
     if (ret == NULL) {
         goto error;
     }
@@ -517,7 +520,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, 
Py_ssize_t max_length
 
 error:
     /* Reset decompressor's states/session */
-    decompressor_reset_session(self);
+    decompressor_reset_session_lock_held(self);
 
     Py_CLEAR(ret);
     return NULL;
@@ -555,6 +558,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, 
PyObject *zstd_dict,
     self->unused_data = NULL;
     self->eof = 0;
     self->dict = NULL;
+    self->lock = (PyMutex){0};
 
     /* needs_input flag */
     self->needs_input = 1;
@@ -608,6 +612,8 @@ ZstdDecompressor_dealloc(PyObject *ob)
         ZSTD_freeDCtx(self->dctx);
     }
 
+    assert(!PyMutex_IsLocked(&self->lock));
+
     /* Py_CLEAR the dict after free decompression context */
     Py_CLEAR(self->dict);
 
@@ -623,7 +629,6 @@ ZstdDecompressor_dealloc(PyObject *ob)
 }
 
 /*[clinic input]
-@critical_section
 @getter
 _zstd.ZstdDecompressor.unused_data
 
@@ -635,11 +640,14 @@ decompressed, unused input data after the frame. 
Otherwise this will be b''.
 
 static PyObject *
 _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
-/*[clinic end generated code: output=f3a20940f11b6b09 input=5233800bef00df04]*/
+/*[clinic end generated code: output=f3a20940f11b6b09 input=54d41ecd681a3444]*/
 {
     PyObject *ret;
 
+    PyMutex_Lock(&self->lock);
+
     if (!self->eof) {
+        PyMutex_Unlock(&self->lock);
         return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES);
     }
     else {
@@ -656,6 +664,7 @@ 
_zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
         }
     }
 
+    PyMutex_Unlock(&self->lock);
     return ret;
 }
 
@@ -693,10 +702,9 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor 
*self,
 {
     PyObject *ret;
     /* Thread-safe code */
-    Py_BEGIN_CRITICAL_SECTION(self);
-
-    ret = stream_decompress(self, data, max_length);
-    Py_END_CRITICAL_SECTION();
+    PyMutex_Lock(&self->lock);
+    ret = stream_decompress_lock_held(self, data, max_length);
+    PyMutex_Unlock(&self->lock);
     return ret;
 }
 
diff --git a/Modules/_zstd/zstddict.c b/Modules/_zstd/zstddict.c
index 7df187a6fa69d7..39828c9b36b5c2 100644
--- a/Modules/_zstd/zstddict.c
+++ b/Modules/_zstd/zstddict.c
@@ -17,6 +17,7 @@ class _zstd.ZstdDict "ZstdDict *" "&zstd_dict_type_spec"
 #include "_zstdmodule.h"
 #include "zstddict.h"
 #include "clinic/zstddict.c.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
 
 #include <zstd.h>                 // ZSTD_freeDDict(), 
ZSTD_getDictID_fromDict()
 
@@ -53,6 +54,7 @@ _zstd_ZstdDict_new_impl(PyTypeObject *type, PyObject 
*dict_content,
     self->dict_content = NULL;
     self->d_dict = NULL;
     self->dict_id = 0;
+    self->lock = (PyMutex){0};
 
     /* ZSTD_CDict dict */
     self->c_dicts = PyDict_New();
@@ -109,6 +111,8 @@ ZstdDict_dealloc(PyObject *ob)
         ZSTD_freeDDict(self->d_dict);
     }
 
+    assert(!PyMutex_IsLocked(&self->lock));
+
     /* Release dict_content after Free ZSTD_CDict/ZSTD_DDict instances */
     Py_CLEAR(self->dict_content);
     Py_CLEAR(self->c_dicts);
@@ -143,7 +147,6 @@ static PyMemberDef ZstdDict_members[] = {
 };
 
 /*[clinic input]
-@critical_section
 @getter
 _zstd.ZstdDict.as_digested_dict
 
@@ -160,13 +163,12 @@ Pass this attribute as zstd_dict argument: compress(dat, 
zstd_dict=zd.as_digeste
 
 static PyObject *
 _zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self)
-/*[clinic end generated code: output=09b086e7a7320dbb input=585448c79f31f74a]*/
+/*[clinic end generated code: output=09b086e7a7320dbb input=10cd2b6165931b77]*/
 {
     return Py_BuildValue("Oi", self, DICT_TYPE_DIGESTED);
 }
 
 /*[clinic input]
-@critical_section
 @getter
 _zstd.ZstdDict.as_undigested_dict
 
@@ -181,13 +183,12 @@ Pass this attribute as zstd_dict argument: compress(dat, 
zstd_dict=zd.as_undiges
 
 static PyObject *
 _zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self)
-/*[clinic end generated code: output=43c7a989e6d4253a input=022b0829ffb1c220]*/
+/*[clinic end generated code: output=43c7a989e6d4253a input=11e5f5df690a85b4]*/
 {
     return Py_BuildValue("Oi", self, DICT_TYPE_UNDIGESTED);
 }
 
 /*[clinic input]
-@critical_section
 @getter
 _zstd.ZstdDict.as_prefix
 
@@ -202,7 +203,7 @@ Pass this attribute as zstd_dict argument: compress(dat, 
zstd_dict=zd.as_prefix)
 
 static PyObject *
 _zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self)
-/*[clinic end generated code: output=6f7130c356595a16 input=09fb82a6a5407e87]*/
+/*[clinic end generated code: output=6f7130c356595a16 input=b028e0ae6ec4292b]*/
 {
     return Py_BuildValue("Oi", self, DICT_TYPE_PREFIX);
 }
diff --git a/Modules/_zstd/zstddict.h b/Modules/_zstd/zstddict.h
index e8a55a3670b869..dcba0f21852087 100644
--- a/Modules/_zstd/zstddict.h
+++ b/Modules/_zstd/zstddict.h
@@ -19,6 +19,9 @@ typedef struct {
     PyObject *dict_content;
     /* Dictionary id */
     uint32_t dict_id;
+
+    /* Lock to protect the digested dictionaries */
+    PyMutex lock;
 } ZstdDict;
 
 #endif  // !ZSTD_DICT_H

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: arch...@mail-archive.com

Reply via email to