https://github.com/python/cpython/commit/49ff8b6cc017bb6e9bdb3bf4918d65f32f7aaed8
commit: 49ff8b6cc017bb6e9bdb3bf4918d65f32f7aaed8
branch: main
author: Kumar Aditya <[email protected]>
committer: kumaraditya303 <[email protected]>
date: 2025-11-21T19:49:53+05:30
summary:

gh-140795: fetch thread state once on fast path for critical sections (#141406)

files:
M Include/internal/pycore_critical_section.h
M Python/critical_section.c

diff --git a/Include/internal/pycore_critical_section.h 
b/Include/internal/pycore_critical_section.h
index 2601de40737e85..60b6fc4a72e88f 100644
--- a/Include/internal/pycore_critical_section.h
+++ b/Include/internal/pycore_critical_section.h
@@ -32,7 +32,7 @@ extern "C" {
         const bool _should_lock_cs = PyList_CheckExact(_orig_seq);      \
         PyCriticalSection _cs;                                          \
         if (_should_lock_cs) {                                          \
-            _PyCriticalSection_Begin(&_cs, _orig_seq);                  \
+            PyCriticalSection_Begin(&_cs, _orig_seq);                  \
         }
 
 # define Py_END_CRITICAL_SECTION_SEQUENCE_FAST()                        \
@@ -77,10 +77,10 @@ _PyCriticalSection_Resume(PyThreadState *tstate);
 
 // (private) slow path for locking the mutex
 PyAPI_FUNC(void)
-_PyCriticalSection_BeginSlow(PyCriticalSection *c, PyMutex *m);
+_PyCriticalSection_BeginSlow(PyThreadState *tstate, PyCriticalSection *c, 
PyMutex *m);
 
 PyAPI_FUNC(void)
-_PyCriticalSection2_BeginSlow(PyCriticalSection2 *c, PyMutex *m1, PyMutex *m2,
+_PyCriticalSection2_BeginSlow(PyThreadState *tstate, PyCriticalSection2 *c, 
PyMutex *m1, PyMutex *m2,
                              int is_m1_locked);
 
 PyAPI_FUNC(void)
@@ -95,34 +95,30 @@ _PyCriticalSection_IsActive(uintptr_t tag)
 }
 
 static inline void
-_PyCriticalSection_BeginMutex(PyCriticalSection *c, PyMutex *m)
+_PyCriticalSection_BeginMutex(PyThreadState *tstate, PyCriticalSection *c, 
PyMutex *m)
 {
     if (PyMutex_LockFast(m)) {
-        PyThreadState *tstate = _PyThreadState_GET();
         c->_cs_mutex = m;
         c->_cs_prev = tstate->critical_section;
         tstate->critical_section = (uintptr_t)c;
     }
     else {
-        _PyCriticalSection_BeginSlow(c, m);
+        _PyCriticalSection_BeginSlow(tstate, c, m);
     }
 }
-#define PyCriticalSection_BeginMutex _PyCriticalSection_BeginMutex
 
 static inline void
-_PyCriticalSection_Begin(PyCriticalSection *c, PyObject *op)
+_PyCriticalSection_Begin(PyThreadState *tstate, PyCriticalSection *c, PyObject 
*op)
 {
-    _PyCriticalSection_BeginMutex(c, &op->ob_mutex);
+    _PyCriticalSection_BeginMutex(tstate, c, &op->ob_mutex);
 }
-#define PyCriticalSection_Begin _PyCriticalSection_Begin
 
 // Removes the top-most critical section from the thread's stack of critical
 // sections. If the new top-most critical section is inactive, then it is
 // resumed.
 static inline void
-_PyCriticalSection_Pop(PyCriticalSection *c)
+_PyCriticalSection_Pop(PyThreadState *tstate, PyCriticalSection *c)
 {
-    PyThreadState *tstate = _PyThreadState_GET();
     uintptr_t prev = c->_cs_prev;
     tstate->critical_section = prev;
 
@@ -132,7 +128,7 @@ _PyCriticalSection_Pop(PyCriticalSection *c)
 }
 
 static inline void
-_PyCriticalSection_End(PyCriticalSection *c)
+_PyCriticalSection_End(PyThreadState *tstate, PyCriticalSection *c)
 {
     // If the mutex is NULL, we used the fast path in
     // _PyCriticalSection_BeginSlow for locks already held in the top-most
@@ -141,18 +137,17 @@ _PyCriticalSection_End(PyCriticalSection *c)
         return;
     }
     PyMutex_Unlock(c->_cs_mutex);
-    _PyCriticalSection_Pop(c);
+    _PyCriticalSection_Pop(tstate, c);
 }
-#define PyCriticalSection_End _PyCriticalSection_End
 
 static inline void
-_PyCriticalSection2_BeginMutex(PyCriticalSection2 *c, PyMutex *m1, PyMutex *m2)
+_PyCriticalSection2_BeginMutex(PyThreadState *tstate, PyCriticalSection2 *c, 
PyMutex *m1, PyMutex *m2)
 {
     if (m1 == m2) {
         // If the two mutex arguments are the same, treat this as a critical
         // section with a single mutex.
         c->_cs_mutex2 = NULL;
-        _PyCriticalSection_BeginMutex(&c->_cs_base, m1);
+        _PyCriticalSection_BeginMutex(tstate, &c->_cs_base, m1);
         return;
     }
 
@@ -167,7 +162,6 @@ _PyCriticalSection2_BeginMutex(PyCriticalSection2 *c, 
PyMutex *m1, PyMutex *m2)
 
     if (PyMutex_LockFast(m1)) {
         if (PyMutex_LockFast(m2)) {
-            PyThreadState *tstate = _PyThreadState_GET();
             c->_cs_base._cs_mutex = m1;
             c->_cs_mutex2 = m2;
             c->_cs_base._cs_prev = tstate->critical_section;
@@ -176,24 +170,22 @@ _PyCriticalSection2_BeginMutex(PyCriticalSection2 *c, 
PyMutex *m1, PyMutex *m2)
             tstate->critical_section = p;
         }
         else {
-            _PyCriticalSection2_BeginSlow(c, m1, m2, 1);
+            _PyCriticalSection2_BeginSlow(tstate, c, m1, m2, 1);
         }
     }
     else {
-        _PyCriticalSection2_BeginSlow(c, m1, m2, 0);
+        _PyCriticalSection2_BeginSlow(tstate, c, m1, m2, 0);
     }
 }
-#define PyCriticalSection2_BeginMutex _PyCriticalSection2_BeginMutex
 
 static inline void
-_PyCriticalSection2_Begin(PyCriticalSection2 *c, PyObject *a, PyObject *b)
+_PyCriticalSection2_Begin(PyThreadState *tstate, PyCriticalSection2 *c, 
PyObject *a, PyObject *b)
 {
-    _PyCriticalSection2_BeginMutex(c, &a->ob_mutex, &b->ob_mutex);
+    _PyCriticalSection2_BeginMutex(tstate, c, &a->ob_mutex, &b->ob_mutex);
 }
-#define PyCriticalSection2_Begin _PyCriticalSection2_Begin
 
 static inline void
-_PyCriticalSection2_End(PyCriticalSection2 *c)
+_PyCriticalSection2_End(PyThreadState *tstate, PyCriticalSection2 *c)
 {
     // if mutex1 is NULL, we used the fast path in
     // _PyCriticalSection_BeginSlow for mutexes that are already held,
@@ -207,9 +199,8 @@ _PyCriticalSection2_End(PyCriticalSection2 *c)
         PyMutex_Unlock(c->_cs_mutex2);
     }
     PyMutex_Unlock(c->_cs_base._cs_mutex);
-    _PyCriticalSection_Pop(&c->_cs_base);
+    _PyCriticalSection_Pop(tstate, &c->_cs_base);
 }
-#define PyCriticalSection2_End _PyCriticalSection2_End
 
 static inline void
 _PyCriticalSection_AssertHeld(PyMutex *mutex)
@@ -251,6 +242,45 @@ _PyCriticalSection_AssertHeldObj(PyObject *op)
 
 #endif
 }
+
+#undef Py_BEGIN_CRITICAL_SECTION
+# define Py_BEGIN_CRITICAL_SECTION(op)                                  \
+    {                                                                   \
+        PyCriticalSection _py_cs;                                       \
+        PyThreadState *_cs_tstate = _PyThreadState_GET();               \
+        _PyCriticalSection_Begin(_cs_tstate, &_py_cs, _PyObject_CAST(op))
+
+#undef Py_BEGIN_CRITICAL_SECTION_MUTEX
+# define Py_BEGIN_CRITICAL_SECTION_MUTEX(mutex)                         \
+    {                                                                   \
+        PyCriticalSection _py_cs;                                       \
+        PyThreadState *_cs_tstate = _PyThreadState_GET();               \
+        _PyCriticalSection_BeginMutex(_cs_tstate, &_py_cs, mutex)
+
+#undef Py_END_CRITICAL_SECTION
+# define Py_END_CRITICAL_SECTION()                                      \
+        _PyCriticalSection_End(_cs_tstate, &_py_cs);                    \
+    }
+
+#undef Py_BEGIN_CRITICAL_SECTION2
+# define Py_BEGIN_CRITICAL_SECTION2(a, b)                               \
+    {                                                                   \
+        PyCriticalSection2 _py_cs2;                                     \
+        PyThreadState *_cs_tstate = _PyThreadState_GET();               \
+        _PyCriticalSection2_Begin(_cs_tstate, &_py_cs2, _PyObject_CAST(a), 
_PyObject_CAST(b))
+
+#undef Py_BEGIN_CRITICAL_SECTION2_MUTEX
+# define Py_BEGIN_CRITICAL_SECTION2_MUTEX(m1, m2)                       \
+    {                                                                   \
+        PyCriticalSection2 _py_cs2;                                     \
+        PyThreadState *_cs_tstate = _PyThreadState_GET();               \
+        _PyCriticalSection2_BeginMutex(_cs_tstate, &_py_cs2, m1, m2)
+
+#undef Py_END_CRITICAL_SECTION2
+# define Py_END_CRITICAL_SECTION2()                                     \
+        _PyCriticalSection2_End(_cs_tstate, &_py_cs2);                  \
+    }
+
 #endif /* Py_GIL_DISABLED */
 
 #ifdef __cplusplus
diff --git a/Python/critical_section.c b/Python/critical_section.c
index 218b580e95176d..2c2152f5de4716 100644
--- a/Python/critical_section.c
+++ b/Python/critical_section.c
@@ -17,10 +17,9 @@ untag_critical_section(uintptr_t tag)
 #endif
 
 void
-_PyCriticalSection_BeginSlow(PyCriticalSection *c, PyMutex *m)
+_PyCriticalSection_BeginSlow(PyThreadState *tstate, PyCriticalSection *c, 
PyMutex *m)
 {
 #ifdef Py_GIL_DISABLED
-    PyThreadState *tstate = _PyThreadState_GET();
     // As an optimisation for locking the same object recursively, skip
     // locking if the mutex is currently locked by the top-most critical
     // section.
@@ -53,11 +52,10 @@ _PyCriticalSection_BeginSlow(PyCriticalSection *c, PyMutex 
*m)
 }
 
 void
-_PyCriticalSection2_BeginSlow(PyCriticalSection2 *c, PyMutex *m1, PyMutex *m2,
+_PyCriticalSection2_BeginSlow(PyThreadState *tstate, PyCriticalSection2 *c, 
PyMutex *m1, PyMutex *m2,
                               int is_m1_locked)
 {
 #ifdef Py_GIL_DISABLED
-    PyThreadState *tstate = _PyThreadState_GET();
     c->_cs_base._cs_mutex = NULL;
     c->_cs_mutex2 = NULL;
     c->_cs_base._cs_prev = tstate->critical_section;
@@ -139,7 +137,7 @@ void
 PyCriticalSection_Begin(PyCriticalSection *c, PyObject *op)
 {
 #ifdef Py_GIL_DISABLED
-    _PyCriticalSection_Begin(c, op);
+    _PyCriticalSection_Begin(_PyThreadState_GET(), c, op);
 #endif
 }
 
@@ -148,7 +146,7 @@ void
 PyCriticalSection_BeginMutex(PyCriticalSection *c, PyMutex *m)
 {
 #ifdef Py_GIL_DISABLED
-    _PyCriticalSection_BeginMutex(c, m);
+    _PyCriticalSection_BeginMutex(_PyThreadState_GET(), c, m);
 #endif
 }
 
@@ -157,7 +155,7 @@ void
 PyCriticalSection_End(PyCriticalSection *c)
 {
 #ifdef Py_GIL_DISABLED
-    _PyCriticalSection_End(c);
+    _PyCriticalSection_End(_PyThreadState_GET(), c);
 #endif
 }
 
@@ -166,7 +164,7 @@ void
 PyCriticalSection2_Begin(PyCriticalSection2 *c, PyObject *a, PyObject *b)
 {
 #ifdef Py_GIL_DISABLED
-    _PyCriticalSection2_Begin(c, a, b);
+    _PyCriticalSection2_Begin(_PyThreadState_GET(), c, a, b);
 #endif
 }
 
@@ -175,7 +173,7 @@ void
 PyCriticalSection2_BeginMutex(PyCriticalSection2 *c, PyMutex *m1, PyMutex *m2)
 {
 #ifdef Py_GIL_DISABLED
-    _PyCriticalSection2_BeginMutex(c, m1, m2);
+    _PyCriticalSection2_BeginMutex(_PyThreadState_GET(), c, m1, m2);
 #endif
 }
 
@@ -184,6 +182,6 @@ void
 PyCriticalSection2_End(PyCriticalSection2 *c)
 {
 #ifdef Py_GIL_DISABLED
-    _PyCriticalSection2_End(c);
+    _PyCriticalSection2_End(_PyThreadState_GET(), c);
 #endif
 }

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: [email protected]

Reply via email to