https://github.com/python/cpython/commit/8dbc11971974a725dc8a11c0dc65d8f6fcb4d902
commit: 8dbc11971974a725dc8a11c0dc65d8f6fcb4d902
branch: main
author: Emma Smith <[email protected]>
committer: colesbury <[email protected]>
date: 2025-05-22T23:30:10-04:00
summary:
gh-133885: Use locks instead of critical sections for _zstd (gh-134289)
Move from using critical sections to locks for the (de)compression methods.
Since the methods allow other threads to run, we should use a lock rather
than a critical section.
files:
M Lib/test/test_zstd.py
M Modules/_zstd/clinic/decompressor.c.h
M Modules/_zstd/clinic/zstddict.c.h
M Modules/_zstd/compressor.c
M Modules/_zstd/decompressor.c
M Modules/_zstd/zstddict.c
M Modules/_zstd/zstddict.h
diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py
index 53ca592ea38828..084f8f24fc009c 100644
--- a/Lib/test/test_zstd.py
+++ b/Lib/test/test_zstd.py
@@ -2430,10 +2430,8 @@ def test_buffer_protocol(self):
self.assertEqual(f.write(arr), LENGTH)
self.assertEqual(f.tell(), LENGTH)
[email protected]("it fails for now, see gh-133885")
class FreeThreadingMethodTests(unittest.TestCase):
- @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail
with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_compress_locking(self):
@@ -2470,7 +2468,6 @@ def run_method(method, input_data, output_data):
actual = b''.join(output) + rest2
self.assertEqual(expected, actual)
- @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail
with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_decompress_locking(self):
@@ -2506,6 +2503,59 @@ def run_method(method, input_data, output_data):
actual = b''.join(output)
self.assertEqual(expected, actual)
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_compress_shared_dict(self):
+ num_threads = 8
+
+ def run_method(b):
+ level = threading.get_ident() % 4
+ # sync threads to increase chance of contention on
+ # capsule storing dictionary levels
+ b.wait()
+ ZstdCompressor(level=level,
+ zstd_dict=TRAINED_DICT.as_digested_dict)
+ b.wait()
+ ZstdCompressor(level=level,
+ zstd_dict=TRAINED_DICT.as_undigested_dict)
+ b.wait()
+ ZstdCompressor(level=level,
+ zstd_dict=TRAINED_DICT.as_prefix)
+ threads = []
+
+ b = threading.Barrier(num_threads)
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(b,))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_decompress_shared_dict(self):
+ num_threads = 8
+
+ def run_method(b):
+ # sync threads to increase chance of contention on
+ # decompression dictionary
+ b.wait()
+ ZstdDecompressor(zstd_dict=TRAINED_DICT.as_digested_dict)
+ b.wait()
+ ZstdDecompressor(zstd_dict=TRAINED_DICT.as_undigested_dict)
+ b.wait()
+ ZstdDecompressor(zstd_dict=TRAINED_DICT.as_prefix)
+ threads = []
+
+ b = threading.Barrier(num_threads)
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(b,))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
if __name__ == "__main__":
diff --git a/Modules/_zstd/clinic/decompressor.c.h
b/Modules/_zstd/clinic/decompressor.c.h
index 4ecb19e9bde6ed..c6fdae74ab0447 100644
--- a/Modules/_zstd/clinic/decompressor.c.h
+++ b/Modules/_zstd/clinic/decompressor.c.h
@@ -7,7 +7,6 @@ preserve
# include "pycore_runtime.h" // _Py_ID()
#endif
#include "pycore_abstract.h" // _PyNumber_Index()
-#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION()
#include "pycore_modsupport.h" // _PyArg_UnpackKeywords()
PyDoc_STRVAR(_zstd_ZstdDecompressor_new__doc__,
@@ -114,13 +113,7 @@
_zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self);
static PyObject *
_zstd_ZstdDecompressor_unused_data_get(PyObject *self, void
*Py_UNUSED(context))
{
- PyObject *return_value = NULL;
-
- Py_BEGIN_CRITICAL_SECTION(self);
- return_value =
_zstd_ZstdDecompressor_unused_data_get_impl((ZstdDecompressor *)self);
- Py_END_CRITICAL_SECTION();
-
- return return_value;
+ return _zstd_ZstdDecompressor_unused_data_get_impl((ZstdDecompressor
*)self);
}
PyDoc_STRVAR(_zstd_ZstdDecompressor_decompress__doc__,
@@ -227,4 +220,4 @@ _zstd_ZstdDecompressor_decompress(PyObject *self, PyObject
*const *args, Py_ssiz
return return_value;
}
-/*[clinic end generated code: output=7a4d278f9244e684 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=30c12ef047027ede input=a9049054013a1b77]*/
diff --git a/Modules/_zstd/clinic/zstddict.c.h
b/Modules/_zstd/clinic/zstddict.c.h
index 34e0e4b3ecfe72..aaa29e491bc1bb 100644
--- a/Modules/_zstd/clinic/zstddict.c.h
+++ b/Modules/_zstd/clinic/zstddict.c.h
@@ -6,7 +6,6 @@ preserve
# include "pycore_gc.h" // PyGC_Head
# include "pycore_runtime.h" // _Py_ID()
#endif
-#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION()
#include "pycore_modsupport.h" // _PyArg_UnpackKeywords()
PyDoc_STRVAR(_zstd_ZstdDict_new__doc__,
@@ -118,13 +117,7 @@ _zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self);
static PyObject *
_zstd_ZstdDict_as_digested_dict_get(PyObject *self, void *Py_UNUSED(context))
{
- PyObject *return_value = NULL;
-
- Py_BEGIN_CRITICAL_SECTION(self);
- return_value = _zstd_ZstdDict_as_digested_dict_get_impl((ZstdDict *)self);
- Py_END_CRITICAL_SECTION();
-
- return return_value;
+ return _zstd_ZstdDict_as_digested_dict_get_impl((ZstdDict *)self);
}
PyDoc_STRVAR(_zstd_ZstdDict_as_undigested_dict__doc__,
@@ -156,13 +149,7 @@ _zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self);
static PyObject *
_zstd_ZstdDict_as_undigested_dict_get(PyObject *self, void *Py_UNUSED(context))
{
- PyObject *return_value = NULL;
-
- Py_BEGIN_CRITICAL_SECTION(self);
- return_value = _zstd_ZstdDict_as_undigested_dict_get_impl((ZstdDict
*)self);
- Py_END_CRITICAL_SECTION();
-
- return return_value;
+ return _zstd_ZstdDict_as_undigested_dict_get_impl((ZstdDict *)self);
}
PyDoc_STRVAR(_zstd_ZstdDict_as_prefix__doc__,
@@ -194,12 +181,6 @@ _zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self);
static PyObject *
_zstd_ZstdDict_as_prefix_get(PyObject *self, void *Py_UNUSED(context))
{
- PyObject *return_value = NULL;
-
- Py_BEGIN_CRITICAL_SECTION(self);
- return_value = _zstd_ZstdDict_as_prefix_get_impl((ZstdDict *)self);
- Py_END_CRITICAL_SECTION();
-
- return return_value;
+ return _zstd_ZstdDict_as_prefix_get_impl((ZstdDict *)self);
}
-/*[clinic end generated code: output=bfb31c1187477afd input=a9049054013a1b77]*/
+/*[clinic end generated code: output=8692eabee4e0d1fe input=a9049054013a1b77]*/
diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c
index 38baee2be1e95b..8f934858ef784f 100644
--- a/Modules/_zstd/compressor.c
+++ b/Modules/_zstd/compressor.c
@@ -17,6 +17,7 @@ class _zstd.ZstdCompressor "ZstdCompressor *"
"&zstd_compressor_type_spec"
#include "_zstdmodule.h"
#include "buffer.h"
#include "zstddict.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
#include <stddef.h> // offsetof()
#include <zstd.h> // ZSTD_*()
@@ -38,6 +39,9 @@ typedef struct {
/* Compression level */
int compression_level;
+
+ /* Lock to protect the compression context */
+ PyMutex lock;
} ZstdCompressor;
#define ZstdCompressor_CAST(op) ((ZstdCompressor *)op)
@@ -149,12 +153,12 @@ capsule_free_cdict(PyObject *capsule)
ZSTD_CDict *
_get_CDict(ZstdDict *self, int compressionLevel)
{
+ assert(PyMutex_IsLocked(&self->lock));
PyObject *level = NULL;
- PyObject *capsule;
+ PyObject *capsule = NULL;
ZSTD_CDict *cdict;
+ int ret;
- // TODO(emmatyping): refactor critical section code into a lock_held
function
- Py_BEGIN_CRITICAL_SECTION(self);
/* int level object */
level = PyLong_FromLong(compressionLevel);
@@ -163,12 +167,11 @@ _get_CDict(ZstdDict *self, int compressionLevel)
}
/* Get PyCapsule object from self->c_dicts */
- capsule = PyDict_GetItemWithError(self->c_dicts, level);
+ ret = PyDict_GetItemRef(self->c_dicts, level, &capsule);
+ if (ret < 0) {
+ goto error;
+ }
if (capsule == NULL) {
- if (PyErr_Occurred()) {
- goto error;
- }
-
/* Create ZSTD_CDict instance */
char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
Py_ssize_t dict_len = Py_SIZE(self->dict_content);
@@ -196,11 +199,10 @@ _get_CDict(ZstdDict *self, int compressionLevel)
}
/* Add PyCapsule object to self->c_dicts */
- if (PyDict_SetItem(self->c_dicts, level, capsule) < 0) {
- Py_DECREF(capsule);
+ ret = PyDict_SetItem(self->c_dicts, level, capsule);
+ if (ret < 0) {
goto error;
}
- Py_DECREF(capsule);
}
else {
/* ZSTD_CDict instance already exists */
@@ -212,15 +214,55 @@ _get_CDict(ZstdDict *self, int compressionLevel)
cdict = NULL;
success:
Py_XDECREF(level);
- Py_END_CRITICAL_SECTION();
+ Py_XDECREF(capsule);
return cdict;
}
static int
-_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
+_zstd_load_impl(ZstdCompressor *self, ZstdDict *zd,
+ _zstd_state *mod_state, int type)
{
-
size_t zstd_ret;
+ if (type == DICT_TYPE_DIGESTED) {
+ /* Get ZSTD_CDict */
+ ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level);
+ if (c_dict == NULL) {
+ return -1;
+ }
+ /* Reference a prepared dictionary.
+ It overrides some compression context's parameters. */
+ zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict);
+ }
+ else if (type == DICT_TYPE_UNDIGESTED) {
+ /* Load a dictionary.
+ It doesn't override compression context's parameters. */
+ zstd_ret = ZSTD_CCtx_loadDictionary(
+ self->cctx,
+ PyBytes_AS_STRING(zd->dict_content),
+ Py_SIZE(zd->dict_content));
+ }
+ else if (type == DICT_TYPE_PREFIX) {
+ /* Load a prefix */
+ zstd_ret = ZSTD_CCtx_refPrefix(
+ self->cctx,
+ PyBytes_AS_STRING(zd->dict_content),
+ Py_SIZE(zd->dict_content));
+ }
+ else {
+ Py_UNREACHABLE();
+ }
+
+ /* Check error */
+ if (ZSTD_isError(zstd_ret)) {
+ set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret);
+ return -1;
+ }
+ return 0;
+}
+
+static int
+_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
+{
_zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
return -1;
@@ -237,7 +279,10 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
/* When compressing, use undigested dictionary by default. */
zd = (ZstdDict*)dict;
type = DICT_TYPE_UNDIGESTED;
- goto load;
+ PyMutex_Lock(&zd->lock);
+ ret = _zstd_load_impl(self, zd, mod_state, type);
+ PyMutex_Unlock(&zd->lock);
+ return ret;
}
/* Check (ZstdDict, type) */
@@ -257,7 +302,10 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
{
assert(type >= 0);
zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
- goto load;
+ PyMutex_Lock(&zd->lock);
+ ret = _zstd_load_impl(self, zd, mod_state, type);
+ PyMutex_Unlock(&zd->lock);
+ return ret;
}
}
}
@@ -266,49 +314,6 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
PyErr_SetString(PyExc_TypeError,
"zstd_dict argument should be ZstdDict object.");
return -1;
-
-load:
- if (type == DICT_TYPE_DIGESTED) {
- /* Get ZSTD_CDict */
- ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level);
- if (c_dict == NULL) {
- return -1;
- }
- /* Reference a prepared dictionary.
- It overrides some compression context's parameters. */
- Py_BEGIN_CRITICAL_SECTION(self);
- zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict);
- Py_END_CRITICAL_SECTION();
- }
- else if (type == DICT_TYPE_UNDIGESTED) {
- /* Load a dictionary.
- It doesn't override compression context's parameters. */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_CCtx_loadDictionary(
- self->cctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
- }
- else if (type == DICT_TYPE_PREFIX) {
- /* Load a prefix */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_CCtx_refPrefix(
- self->cctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
- }
- else {
- Py_UNREACHABLE();
- }
-
- /* Check error */
- if (ZSTD_isError(zstd_ret)) {
- set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret);
- return -1;
- }
- return 0;
}
/*[clinic input]
@@ -339,6 +344,7 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject
*level,
self->use_multithread = 0;
self->dict = NULL;
+ self->lock = (PyMutex){0};
/* Compression context */
self->cctx = ZSTD_createCCtx();
@@ -403,6 +409,8 @@ ZstdCompressor_dealloc(PyObject *ob)
ZSTD_freeCCtx(self->cctx);
}
+ assert(!PyMutex_IsLocked(&self->lock));
+
/* Py_XDECREF the dict after free the compression context */
Py_CLEAR(self->dict);
@@ -412,9 +420,10 @@ ZstdCompressor_dealloc(PyObject *ob)
}
static PyObject *
-compress_impl(ZstdCompressor *self, Py_buffer *data,
- ZSTD_EndDirective end_directive)
+compress_lock_held(ZstdCompressor *self, Py_buffer *data,
+ ZSTD_EndDirective end_directive)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_inBuffer in;
ZSTD_outBuffer out;
_BlocksOutputBuffer buffer = {.list = NULL};
@@ -495,8 +504,9 @@ mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer
*out)
#endif
static PyObject *
-compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data)
+compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_inBuffer in;
ZSTD_outBuffer out;
_BlocksOutputBuffer buffer = {.list = NULL};
@@ -529,7 +539,7 @@ compress_mt_continue_impl(ZstdCompressor *self, Py_buffer
*data)
goto error;
}
- /* Like compress_impl(), output as much as possible. */
+ /* Like compress_lock_held(), output as much as possible. */
if (out.pos == out.size) {
if (_OutputBuffer_Grow(&buffer, &out) < 0) {
goto error;
@@ -588,14 +598,14 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self,
Py_buffer *data,
}
/* Thread-safe code */
- Py_BEGIN_CRITICAL_SECTION(self);
+ PyMutex_Lock(&self->lock);
/* Compress */
if (self->use_multithread && mode == ZSTD_e_continue) {
- ret = compress_mt_continue_impl(self, data);
+ ret = compress_mt_continue_lock_held(self, data);
}
else {
- ret = compress_impl(self, data, mode);
+ ret = compress_lock_held(self, data, mode);
}
if (ret) {
@@ -607,7 +617,7 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self,
Py_buffer *data,
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
- Py_END_CRITICAL_SECTION();
+ PyMutex_Unlock(&self->lock);
return ret;
}
@@ -642,8 +652,9 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int
mode)
}
/* Thread-safe code */
- Py_BEGIN_CRITICAL_SECTION(self);
- ret = compress_impl(self, NULL, mode);
+ PyMutex_Lock(&self->lock);
+
+ ret = compress_lock_held(self, NULL, mode);
if (ret) {
self->last_mode = mode;
@@ -654,7 +665,7 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int
mode)
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
- Py_END_CRITICAL_SECTION();
+ PyMutex_Unlock(&self->lock);
return ret;
}
diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c
index 58f9c9f804e549..e299f73b071353 100644
--- a/Modules/_zstd/decompressor.c
+++ b/Modules/_zstd/decompressor.c
@@ -17,6 +17,7 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *"
"&zstd_decompressor_type_spec"
#include "_zstdmodule.h"
#include "buffer.h"
#include "zstddict.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
#include <stdbool.h> // bool
#include <stddef.h> // offsetof()
@@ -45,6 +46,9 @@ typedef struct {
/* For ZstdDecompressor, 0 or 1.
1 means the end of the first frame has been reached. */
bool eof;
+
+ /* Lock to protect the decompression context */
+ PyMutex lock;
} ZstdDecompressor;
#define ZstdDecompressor_CAST(op) ((ZstdDecompressor *)op)
@@ -54,6 +58,7 @@ typedef struct {
static inline ZSTD_DDict *
_get_DDict(ZstdDict *self)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_DDict *ret;
/* Already created */
@@ -61,15 +66,14 @@ _get_DDict(ZstdDict *self)
return self->d_dict;
}
- Py_BEGIN_CRITICAL_SECTION(self);
if (self->d_dict == NULL) {
/* Create ZSTD_DDict instance from dictionary content */
char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
Py_ssize_t dict_len = Py_SIZE(self->dict_content);
Py_BEGIN_ALLOW_THREADS
- self->d_dict = ZSTD_createDDict(dict_buffer,
- dict_len);
+ ret = ZSTD_createDDict(dict_buffer, dict_len);
Py_END_ALLOW_THREADS
+ self->d_dict = ret;
if (self->d_dict == NULL) {
_zstd_state* const mod_state =
PyType_GetModuleState(Py_TYPE(self));
@@ -81,11 +85,7 @@ _get_DDict(ZstdDict *self)
}
}
- /* Don't lose any exception */
- ret = self->d_dict;
- Py_END_CRITICAL_SECTION();
-
- return ret;
+ return self->d_dict;
}
/* Set decompression parameters to decompression context */
@@ -134,9 +134,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject
*options)
}
/* Set parameter to compression context */
- Py_BEGIN_CRITICAL_SECTION(self);
zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v);
- Py_END_CRITICAL_SECTION();
/* Check error */
if (ZSTD_isError(zstd_ret)) {
@@ -147,11 +145,53 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject
*options)
return 0;
}
+static int
+_zstd_load_impl(ZstdDecompressor *self, ZstdDict *zd,
+ _zstd_state *mod_state, int type)
+{
+ size_t zstd_ret;
+ if (type == DICT_TYPE_DIGESTED) {
+ /* Get ZSTD_DDict */
+ ZSTD_DDict *d_dict = _get_DDict(zd);
+ if (d_dict == NULL) {
+ return -1;
+ }
+ /* Reference a prepared dictionary */
+ zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
+ }
+ else if (type == DICT_TYPE_UNDIGESTED) {
+ /* Load a dictionary */
+ zstd_ret = ZSTD_DCtx_loadDictionary(
+ self->dctx,
+ PyBytes_AS_STRING(zd->dict_content),
+ Py_SIZE(zd->dict_content));
+ }
+ else if (type == DICT_TYPE_PREFIX) {
+ /* Load a prefix */
+ zstd_ret = ZSTD_DCtx_refPrefix(
+ self->dctx,
+ PyBytes_AS_STRING(zd->dict_content),
+ Py_SIZE(zd->dict_content));
+ }
+ else {
+ /* Impossible code path */
+ PyErr_SetString(PyExc_SystemError,
+ "load_d_dict() impossible code path");
+ return -1;
+ }
+
+ /* Check error */
+ if (ZSTD_isError(zstd_ret)) {
+ set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret);
+ return -1;
+ }
+ return 0;
+}
+
/* Load dictionary or prefix to decompression context */
static int
_zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
{
- size_t zstd_ret;
_zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
return -1;
@@ -168,7 +208,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
/* When decompressing, use digested dictionary by default. */
zd = (ZstdDict*)dict;
type = DICT_TYPE_DIGESTED;
- goto load;
+ PyMutex_Lock(&zd->lock);
+ ret = _zstd_load_impl(self, zd, mod_state, type);
+ PyMutex_Unlock(&zd->lock);
+ return ret;
}
/* Check (ZstdDict, type) */
@@ -188,7 +231,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
{
assert(type >= 0);
zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
- goto load;
+ PyMutex_Lock(&zd->lock);
+ ret = _zstd_load_impl(self, zd, mod_state, type);
+ PyMutex_Unlock(&zd->lock);
+ return ret;
}
}
}
@@ -197,50 +243,6 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
PyErr_SetString(PyExc_TypeError,
"zstd_dict argument should be ZstdDict object.");
return -1;
-
-load:
- if (type == DICT_TYPE_DIGESTED) {
- /* Get ZSTD_DDict */
- ZSTD_DDict *d_dict = _get_DDict(zd);
- if (d_dict == NULL) {
- return -1;
- }
- /* Reference a prepared dictionary */
- Py_BEGIN_CRITICAL_SECTION(self);
- zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
- Py_END_CRITICAL_SECTION();
- }
- else if (type == DICT_TYPE_UNDIGESTED) {
- /* Load a dictionary */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_DCtx_loadDictionary(
- self->dctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
- }
- else if (type == DICT_TYPE_PREFIX) {
- /* Load a prefix */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_DCtx_refPrefix(
- self->dctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
- }
- else {
- /* Impossible code path */
- PyErr_SetString(PyExc_SystemError,
- "load_d_dict() impossible code path");
- return -1;
- }
-
- /* Check error */
- if (ZSTD_isError(zstd_ret)) {
- set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret);
- return -1;
- }
- return 0;
}
/*
@@ -268,8 +270,8 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
Note, decompressing "an empty input" in any case will make it > 0.
*/
static PyObject *
-decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
- Py_ssize_t max_length)
+decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in,
+ Py_ssize_t max_length)
{
size_t zstd_ret;
ZSTD_outBuffer out;
@@ -339,10 +341,9 @@ decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
}
static void
-decompressor_reset_session(ZstdDecompressor *self)
+decompressor_reset_session_lock_held(ZstdDecompressor *self)
{
- // TODO(emmatyping): use _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED here
- // and ensure lock is always held
+ assert(PyMutex_IsLocked(&self->lock));
/* Reset variables */
self->in_begin = 0;
@@ -359,8 +360,10 @@ decompressor_reset_session(ZstdDecompressor *self)
}
static PyObject *
-stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t
max_length)
+stream_decompress_lock_held(ZstdDecompressor *self, Py_buffer *data,
+ Py_ssize_t max_length)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_inBuffer in;
PyObject *ret = NULL;
int use_input_buffer;
@@ -456,7 +459,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data,
Py_ssize_t max_length
assert(in.pos == 0);
/* Decompress */
- ret = decompress_impl(self, &in, max_length);
+ ret = decompress_lock_held(self, &in, max_length);
if (ret == NULL) {
goto error;
}
@@ -517,7 +520,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data,
Py_ssize_t max_length
error:
/* Reset decompressor's states/session */
- decompressor_reset_session(self);
+ decompressor_reset_session_lock_held(self);
Py_CLEAR(ret);
return NULL;
@@ -555,6 +558,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type,
PyObject *zstd_dict,
self->unused_data = NULL;
self->eof = 0;
self->dict = NULL;
+ self->lock = (PyMutex){0};
/* needs_input flag */
self->needs_input = 1;
@@ -608,6 +612,8 @@ ZstdDecompressor_dealloc(PyObject *ob)
ZSTD_freeDCtx(self->dctx);
}
+ assert(!PyMutex_IsLocked(&self->lock));
+
/* Py_CLEAR the dict after free decompression context */
Py_CLEAR(self->dict);
@@ -623,7 +629,6 @@ ZstdDecompressor_dealloc(PyObject *ob)
}
/*[clinic input]
-@critical_section
@getter
_zstd.ZstdDecompressor.unused_data
@@ -635,11 +640,14 @@ decompressed, unused input data after the frame.
Otherwise this will be b''.
static PyObject *
_zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
-/*[clinic end generated code: output=f3a20940f11b6b09 input=5233800bef00df04]*/
+/*[clinic end generated code: output=f3a20940f11b6b09 input=54d41ecd681a3444]*/
{
PyObject *ret;
+ PyMutex_Lock(&self->lock);
+
if (!self->eof) {
+ PyMutex_Unlock(&self->lock);
return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES);
}
else {
@@ -656,6 +664,7 @@
_zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
}
}
+ PyMutex_Unlock(&self->lock);
return ret;
}
@@ -693,10 +702,9 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor
*self,
{
PyObject *ret;
/* Thread-safe code */
- Py_BEGIN_CRITICAL_SECTION(self);
-
- ret = stream_decompress(self, data, max_length);
- Py_END_CRITICAL_SECTION();
+ PyMutex_Lock(&self->lock);
+ ret = stream_decompress_lock_held(self, data, max_length);
+ PyMutex_Unlock(&self->lock);
return ret;
}
diff --git a/Modules/_zstd/zstddict.c b/Modules/_zstd/zstddict.c
index 7df187a6fa69d7..39828c9b36b5c2 100644
--- a/Modules/_zstd/zstddict.c
+++ b/Modules/_zstd/zstddict.c
@@ -17,6 +17,7 @@ class _zstd.ZstdDict "ZstdDict *" "&zstd_dict_type_spec"
#include "_zstdmodule.h"
#include "zstddict.h"
#include "clinic/zstddict.c.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
#include <zstd.h> // ZSTD_freeDDict(),
ZSTD_getDictID_fromDict()
@@ -53,6 +54,7 @@ _zstd_ZstdDict_new_impl(PyTypeObject *type, PyObject
*dict_content,
self->dict_content = NULL;
self->d_dict = NULL;
self->dict_id = 0;
+ self->lock = (PyMutex){0};
/* ZSTD_CDict dict */
self->c_dicts = PyDict_New();
@@ -109,6 +111,8 @@ ZstdDict_dealloc(PyObject *ob)
ZSTD_freeDDict(self->d_dict);
}
+ assert(!PyMutex_IsLocked(&self->lock));
+
/* Release dict_content after Free ZSTD_CDict/ZSTD_DDict instances */
Py_CLEAR(self->dict_content);
Py_CLEAR(self->c_dicts);
@@ -143,7 +147,6 @@ static PyMemberDef ZstdDict_members[] = {
};
/*[clinic input]
-@critical_section
@getter
_zstd.ZstdDict.as_digested_dict
@@ -160,13 +163,12 @@ Pass this attribute as zstd_dict argument: compress(dat,
zstd_dict=zd.as_digeste
static PyObject *
_zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self)
-/*[clinic end generated code: output=09b086e7a7320dbb input=585448c79f31f74a]*/
+/*[clinic end generated code: output=09b086e7a7320dbb input=10cd2b6165931b77]*/
{
return Py_BuildValue("Oi", self, DICT_TYPE_DIGESTED);
}
/*[clinic input]
-@critical_section
@getter
_zstd.ZstdDict.as_undigested_dict
@@ -181,13 +183,12 @@ Pass this attribute as zstd_dict argument: compress(dat,
zstd_dict=zd.as_undiges
static PyObject *
_zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self)
-/*[clinic end generated code: output=43c7a989e6d4253a input=022b0829ffb1c220]*/
+/*[clinic end generated code: output=43c7a989e6d4253a input=11e5f5df690a85b4]*/
{
return Py_BuildValue("Oi", self, DICT_TYPE_UNDIGESTED);
}
/*[clinic input]
-@critical_section
@getter
_zstd.ZstdDict.as_prefix
@@ -202,7 +203,7 @@ Pass this attribute as zstd_dict argument: compress(dat,
zstd_dict=zd.as_prefix)
static PyObject *
_zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self)
-/*[clinic end generated code: output=6f7130c356595a16 input=09fb82a6a5407e87]*/
+/*[clinic end generated code: output=6f7130c356595a16 input=b028e0ae6ec4292b]*/
{
return Py_BuildValue("Oi", self, DICT_TYPE_PREFIX);
}
diff --git a/Modules/_zstd/zstddict.h b/Modules/_zstd/zstddict.h
index e8a55a3670b869..dcba0f21852087 100644
--- a/Modules/_zstd/zstddict.h
+++ b/Modules/_zstd/zstddict.h
@@ -19,6 +19,9 @@ typedef struct {
PyObject *dict_content;
/* Dictionary id */
uint32_t dict_id;
+
+ /* Lock to protect the digested dictionaries */
+ PyMutex lock;
} ZstdDict;
#endif // !ZSTD_DICT_H
_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: [email protected]