https://github.com/python/cpython/commit/ee4e14aa4c1150438b18d828770a967ca2019d43
commit: ee4e14aa4c1150438b18d828770a967ca2019d43
branch: main
author: Sam Gross <[email protected]>
committer: colesbury <[email protected]>
date: 2026-01-22T14:02:48-05:00
summary:

gh-143756: Avoid borrowed reference in SSL code (gh-143816)

GET_SOCKET() returned a borrowed reference, which was potentially
unsafe. Also, refactor out some common code.

files:
M Modules/_ssl.c

diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index e240b889d86a2d..22865bdfc3f727 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -423,26 +423,6 @@ typedef enum {
 #define ERRSTR1(x,y,z) (x ":" y ": " z)
 #define ERRSTR(x) ERRSTR1("_ssl.c", Py_STRINGIFY(__LINE__), x)
 
-// Get the socket from a PySSLSocket, if it has one.
-// Return a borrowed reference.
-static inline PySocketSockObject* GET_SOCKET(PySSLSocket *obj) {
-    if (obj->Socket) {
-        PyObject *sock;
-        if (PyWeakref_GetRef(obj->Socket, &sock)) {
-            // GET_SOCKET() returns a borrowed reference
-            Py_DECREF(sock);
-        }
-        else {
-            // dead weak reference
-            sock = Py_None;
-        }
-        return (PySocketSockObject *)sock;  // borrowed reference
-    }
-    else {
-        return NULL;
-    }
-}
-
 /* If sock is NULL, use a timeout of 0 second */
 #define GET_SOCKET_TIMEOUT(sock) \
     ((sock != NULL) ? (sock)->sock_timeout : 0)
@@ -794,6 +774,35 @@ _ssl_deprecated(const char* msg, int stacklevel) {
 #define PY_SSL_DEPRECATED(name, stacklevel, ret) \
     if (_ssl_deprecated((name), (stacklevel)) == -1) return (ret)
 
+// Get the socket from a PySSLSocket, if it has one.
+// Stores a strong reference in out_sock.
+static int
+get_socket(PySSLSocket *obj, PySocketSockObject **out_sock,
+           const char *filename, int lineno)
+{
+    if (!obj->Socket) {
+        *out_sock = NULL;
+        return 0;
+    }
+    PySocketSockObject *sock;
+    int res = PyWeakref_GetRef(obj->Socket, (PyObject **)&sock);
+    if (res == 0 || sock->sock_fd == INVALID_SOCKET) {
+        _setSSLError(get_state_sock(obj),
+                     "Underlying socket connection gone",
+                     PY_SSL_ERROR_NO_SOCKET, filename, lineno);
+        *out_sock = NULL;
+        return -1;
+    }
+    if (sock != NULL) {
+        /* just in case the blocking state of the socket has been changed */
+        int nonblocking = (sock->sock_timeout >= 0);
+        BIO_set_nbio(SSL_get_rbio(obj->ssl), nonblocking);
+        BIO_set_nbio(SSL_get_wbio(obj->ssl), nonblocking);
+    }
+    *out_sock = sock;
+    return res;
+}
+
 /*
  * SSL objects
  */
@@ -1021,24 +1030,13 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
     int ret;
     _PySSLError err;
     PyObject *exc = NULL;
-    int sockstate, nonblocking;
-    PySocketSockObject *sock = GET_SOCKET(self);
+    int sockstate;
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
-    if (sock) {
-        if (((PyObject*)sock) == Py_None) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
-
-        /* just in case the blocking state of the socket has been changed */
-        nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     timeout = GET_SOCKET_TIMEOUT(sock);
@@ -2610,22 +2608,12 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int 
fd, Py_off_t offset,
     int sockstate;
     _PySSLError err;
     PyObject *exc = NULL;
-    PySocketSockObject *sock = GET_SOCKET(self);
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
-    if (sock != NULL) {
-        if ((PyObject *)sock == Py_None) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
-        /* just in case the blocking state of the socket has been changed */
-        int nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     timeout = GET_SOCKET_TIMEOUT(sock);
@@ -2747,26 +2735,12 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer 
*b)
     int sockstate;
     _PySSLError err;
     PyObject *exc = NULL;
-    int nonblocking;
-    PySocketSockObject *sock = GET_SOCKET(self);
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
-    if (sock != NULL) {
-        if (((PyObject*)sock) == Py_None) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
-    }
-
-    if (sock != NULL) {
-        /* just in case the blocking state of the socket has been changed */
-        nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     timeout = GET_SOCKET_TIMEOUT(sock);
@@ -2896,8 +2870,6 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t 
len,
     int sockstate;
     _PySSLError err;
     PyObject *exc = NULL;
-    int nonblocking;
-    PySocketSockObject *sock = GET_SOCKET(self);
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
@@ -2906,14 +2878,9 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t 
len,
         return NULL;
     }
 
-    if (sock != NULL) {
-        if (((PyObject*)sock) == Py_None) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     if (!group_right_1) {
@@ -2944,13 +2911,6 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t 
len,
         }
     }
 
-    if (sock != NULL) {
-        /* just in case the blocking state of the socket has been changed */
-        nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
-    }
-
     timeout = GET_SOCKET_TIMEOUT(sock);
     has_timeout = (timeout > 0);
     if (has_timeout)
@@ -3041,26 +3001,14 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
 {
     _PySSLError err;
     PyObject *exc = NULL;
-    int sockstate, nonblocking, ret;
+    int sockstate, ret;
     int zeros = 0;
-    PySocketSockObject *sock = GET_SOCKET(self);
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
-    if (sock != NULL) {
-        /* Guard against closed socket */
-        if ((((PyObject*)sock) == Py_None) || (sock->sock_fd == 
INVALID_SOCKET)) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
-
-        /* Just in case the blocking state of the socket has been changed */
-        nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     timeout = GET_SOCKET_TIMEOUT(sock);

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: [email protected]

Reply via email to