https://github.com/python/cpython/commit/8b31d08e62b9714cf8dd1d8b19afa5ecbad2414a
commit: 8b31d08e62b9714cf8dd1d8b19afa5ecbad2414a
branch: main
author: Kirill Ignatev <[email protected]>
committer: encukou <[email protected]>
date: 2026-05-19T20:36:46+02:00
summary:

gh-149816: Fix SNI callback callable race (GH-150018)

files:
A Misc/NEWS.d/next/Library/2026-05-18-22-45-54.gh-issue-149816.T68vc_.rst
M Lib/test/test_ssl.py
M Modules/_ssl.c

diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index f1f7a07701de16..7f237276617152 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -1606,6 +1606,59 @@ def dummycallback(sock, servername, ctx, cycle=ctx):
         gc.collect()
         self.assertIs(wr(), None)
 
+    @unittest.skipUnless(support.Py_GIL_DISABLED,
+                         "test is only useful if the GIL is disabled")
+    @threading_helper.requires_working_threading()
+    def test_sni_callback_race(self):
+        # Replacing sni_callback while handshakes are in-flight must not
+        # crash (use-after-free on the callback in free-threaded builds).
+        client_ctx, server_ctx, hostname = testing_context()
+
+        server_ctx.sni_callback = lambda *a: None
+        done = threading.Event()
+
+        def do_handshakes():
+            while not done.is_set():
+                c_in = ssl.MemoryBIO()
+                c_out = ssl.MemoryBIO()
+                s_in = ssl.MemoryBIO()
+                s_out = ssl.MemoryBIO()
+                client = client_ctx.wrap_bio(
+                    c_in, c_out, server_hostname=hostname)
+                server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
+                for _ in range(50):
+                    try:
+                        client.do_handshake()
+                    except ssl.SSLWantReadError:
+                        pass
+                    except ssl.SSLError:
+                        break
+                    if c_out.pending:
+                        s_in.write(c_out.read())
+                    try:
+                        server.do_handshake()
+                    except ssl.SSLWantReadError:
+                        pass
+                    except ssl.SSLError:
+                        break
+                    if s_out.pending:
+                        c_in.write(s_out.read())
+
+        def toggle_callback():
+            while not done.is_set():
+                server_ctx.sni_callback = lambda *a: None
+                server_ctx.sni_callback = None
+
+        workers = max(4, (os.cpu_count() or 4) * 2)
+        threads = [threading.Thread(target=do_handshakes)
+                   for _ in range(workers)]
+        threads.append(threading.Thread(target=toggle_callback))
+
+        with threading_helper.catch_threading_exception() as cm:
+            with threading_helper.start_threads(threads):
+                done.set()
+            self.assertIsNone(cm.exc_value)
+
     def test_cert_store_stats(self):
         ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         self.assertEqual(ctx.cert_store_stats(),
diff --git 
a/Misc/NEWS.d/next/Library/2026-05-18-22-45-54.gh-issue-149816.T68vc_.rst 
b/Misc/NEWS.d/next/Library/2026-05-18-22-45-54.gh-issue-149816.T68vc_.rst
new file mode 100644
index 00000000000000..9996cc7ec0e866
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2026-05-18-22-45-54.gh-issue-149816.T68vc_.rst
@@ -0,0 +1 @@
+Fix race condition in :attr:`ssl.SSLContext.sni_callback`
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index 3224ca7d0f93b9..35754e566a1528 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -26,6 +26,7 @@
 #define OPENSSL_NO_DEPRECATED 1
 
 #include "Python.h"
+#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
 #include "pycore_fileutils.h"     // _PyIsSelectable_fd()
 #include "pycore_long.h"          // _PyLong_UnsignedLongLong_Converter()
 #include "pycore_pyerrors.h"      // _PyErr_ChainExceptions1()
@@ -5153,12 +5154,15 @@ _servername_callback(SSL *s, int *al, void *args)
     PyObject *result;
     /* The high-level ssl.SSLSocket object */
     PyObject *ssl_socket;
+    PyObject *sni_cb;
     const char *servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);
     PyGILState_STATE gstate = PyGILState_Ensure();
 
-    if (sslctx->set_sni_cb == NULL) {
-        /* remove race condition in this the call back while if removing the
-         * callback is in progress */
+    Py_BEGIN_CRITICAL_SECTION(sslctx);
+    sni_cb = Py_XNewRef(sslctx->set_sni_cb);
+    Py_END_CRITICAL_SECTION();
+
+    if (sni_cb == NULL) {
         PyGILState_Release(gstate);
         return SSL_TLSEXT_ERR_OK;
     }
@@ -5185,7 +5189,7 @@ _servername_callback(SSL *s, int *al, void *args)
         goto error;
 
     if (servername == NULL) {
-        result = PyObject_CallFunctionObjArgs(sslctx->set_sni_cb, ssl_socket,
+        result = PyObject_CallFunctionObjArgs(sni_cb, ssl_socket,
                                               Py_None, sslctx, NULL);
     }
     else {
@@ -5212,7 +5216,7 @@ _servername_callback(SSL *s, int *al, void *args)
         }
         Py_DECREF(servername_bytes);
         result = PyObject_CallFunctionObjArgs(
-            sslctx->set_sni_cb, ssl_socket, servername_str,
+            sni_cb, ssl_socket, servername_str,
             sslctx, NULL);
         Py_DECREF(servername_str);
     }
@@ -5222,7 +5226,7 @@ _servername_callback(SSL *s, int *al, void *args)
         PyErr_FormatUnraisable("Exception ignored "
                                "in ssl servername callback "
                                "while calling set SNI callback %R",
-                               sslctx->set_sni_cb);
+                               sni_cb);
         *al = SSL_AD_HANDSHAKE_FAILURE;
         ret = SSL_TLSEXT_ERR_ALERT_FATAL;
     }
@@ -5247,11 +5251,13 @@ _servername_callback(SSL *s, int *al, void *args)
         Py_DECREF(result);
     }
 
+    Py_DECREF(sni_cb);
     PyGILState_Release(gstate);
     return ret;
 
 error:
     Py_XDECREF(ssl_socket);
+    Py_XDECREF(sni_cb);
     *al = SSL_AD_INTERNAL_ERROR;
     ret = SSL_TLSEXT_ERR_ALERT_FATAL;
     PyGILState_Release(gstate);
@@ -5301,20 +5307,18 @@ _ssl__SSLContext_sni_callback_set_impl(PySSLContext 
*self, PyObject *value)
                         "sni_callback cannot be set on TLS_CLIENT context");
         return -1;
     }
-    Py_CLEAR(self->set_sni_cb);
-    if (value == Py_None) {
+    if (!PyCallable_Check(value)) {
         SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
-    }
-    else {
-        if (!PyCallable_Check(value)) {
-            SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
-            PyErr_SetString(PyExc_TypeError,
-                            "not a callable object");
+        Py_CLEAR(self->set_sni_cb);
+        if (value != Py_None) {
+            PyErr_SetString(PyExc_TypeError, "not a callable object");
             return -1;
         }
-        self->set_sni_cb = Py_NewRef(value);
-        SSL_CTX_set_tlsext_servername_callback(self->ctx, 
_servername_callback);
+    }
+    else {
+        Py_XSETREF(self->set_sni_cb, Py_NewRef(value));
         SSL_CTX_set_tlsext_servername_arg(self->ctx, self);
+        SSL_CTX_set_tlsext_servername_callback(self->ctx, 
_servername_callback);
     }
     return 0;
 }

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: [email protected]

Reply via email to