https://github.com/python/cpython/commit/7f6c16a956d598663d8c67071c492f197045d967
commit: 7f6c16a956d598663d8c67071c492f197045d967
branch: main
author: Bénédikt Tran <[email protected]>
committer: picnixz <[email protected]>
date: 2026-01-01T11:55:05+01:00
summary:

gh-142830: prevent some crashes when mutating `sqlite3` callbacks (#143245)

files:
A Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst
M Lib/test/test_sqlite3/test_hooks.py
M Modules/_sqlite/connection.c
M Modules/_sqlite/connection.h

diff --git a/Lib/test/test_sqlite3/test_hooks.py 
b/Lib/test/test_sqlite3/test_hooks.py
index 2b907e35131d06..495ef97fa3c61c 100644
--- a/Lib/test/test_sqlite3/test_hooks.py
+++ b/Lib/test/test_sqlite3/test_hooks.py
@@ -24,11 +24,15 @@
 import sqlite3 as sqlite
 import unittest
 
+from test.support import import_helper
 from test.support.os_helper import TESTFN, unlink
 
 from .util import memory_database, cx_limit, with_tracebacks
 from .util import MemoryDatabaseMixin
 
+# TODO(picnixz): increase test coverage for other callbacks
+# such as 'func', 'step', 'finalize', and 'collation'.
+
 
 class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
 
@@ -129,8 +133,55 @@ def test_deregister_collation(self):
         self.assertEqual(str(cm.exception), 'no such collation sequence: 
mycoll')
 
 
+class AuthorizerTests(MemoryDatabaseMixin, unittest.TestCase):
+
+    def assert_not_authorized(self, func, /, *args, **kwargs):
+        with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
+            func(*args, **kwargs)
+
+    # When a handler has an invalid signature, the exception raised is
+    # the same that would be raised if the handler "negatively" replied.
+
+    def test_authorizer_invalid_signature(self):
+        self.cx.execute("create table if not exists test(a number)")
+        self.cx.set_authorizer(lambda: None)
+        self.assert_not_authorized(self.cx.execute, "select * from test")
+
+    # Tests for checking that callback context mutations do not crash.
+    # Regression tests for https://github.com/python/cpython/issues/142830.
+
+    @with_tracebacks(ZeroDivisionError, regex="hello world")
+    def test_authorizer_concurrent_mutation_in_call(self):
+        self.cx.execute("create table if not exists test(a number)")
+
+        def handler(*a, **kw):
+            self.cx.set_authorizer(None)
+            raise ZeroDivisionError("hello world")
+
+        self.cx.set_authorizer(handler)
+        self.assert_not_authorized(self.cx.execute, "select * from test")
+
+    @with_tracebacks(OverflowError)
+    def test_authorizer_concurrent_mutation_with_overflown_value(self):
+        _testcapi = import_helper.import_module("_testcapi")
+        self.cx.execute("create table if not exists test(a number)")
+
+        def handler(*a, **kw):
+            self.cx.set_authorizer(None)
+            # We expect 'int' at the C level, so this one will raise
+            # when converting via PyLong_Int().
+            return _testcapi.INT_MAX + 1
+
+        self.cx.set_authorizer(handler)
+        self.assert_not_authorized(self.cx.execute, "select * from test")
+
+
 class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
 
+    def assert_interrupted(self, func, /, *args, **kwargs):
+        with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
+            func(*args, **kwargs)
+
     def test_progress_handler_used(self):
         """
         Test that the progress handler is invoked once it is set.
@@ -219,11 +270,48 @@ def bad_progress():
                 create table foo(a, b)
                 """)
 
-    def test_progress_handler_keyword_args(self):
+    def test_set_progress_handler_keyword_args(self):
         with self.assertRaisesRegex(TypeError,
                 'takes at least 1 positional argument'):
             self.con.set_progress_handler(progress_handler=lambda: None, n=1)
 
+    # When a handler has an invalid signature, the exception raised is
+    # the same that would be raised if the handler "negatively" replied.
+
+    def test_progress_handler_invalid_signature(self):
+        self.cx.execute("create table if not exists test(a number)")
+        self.cx.set_progress_handler(lambda x: None, 1)
+        self.assert_interrupted(self.cx.execute, "select * from test")
+
+    # Tests for checking that callback context mutations do not crash.
+    # Regression tests for https://github.com/python/cpython/issues/142830.
+
+    @with_tracebacks(ZeroDivisionError, regex="hello world")
+    def test_progress_handler_concurrent_mutation_in_call(self):
+        self.cx.execute("create table if not exists test(a number)")
+
+        def handler(*a, **kw):
+            self.cx.set_progress_handler(None, 1)
+            raise ZeroDivisionError("hello world")
+
+        self.cx.set_progress_handler(handler, 1)
+        self.assert_interrupted(self.cx.execute, "select * from test")
+
+    def test_progress_handler_concurrent_mutation_in_conversion(self):
+        self.cx.execute("create table if not exists test(a number)")
+
+        class Handler:
+            def __bool__(_):
+                # clear the progress handler
+                self.cx.set_progress_handler(None, 1)
+                raise ValueError  # force PyObject_True() to fail
+
+        self.cx.set_progress_handler(Handler.__init__, 1)
+        self.assert_interrupted(self.cx.execute, "select * from test")
+
+        # Running with tracebacks makes the second execution of this
+        # function raise another exception because of a database change.
+
 
 class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
 
@@ -345,11 +433,40 @@ def test_trace_bad_handler(self):
             cx.set_trace_callback(lambda stmt: 5/0)
             cx.execute("select 1")
 
-    def test_trace_keyword_args(self):
+    def test_set_trace_callback_keyword_args(self):
         with self.assertRaisesRegex(TypeError,
                 'takes exactly 1 positional argument'):
             self.con.set_trace_callback(trace_callback=lambda: None)
 
+    # When a handler has an invalid signature, the exception raised is
+    # the same that would be raised if the handler "negatively" replied,
+    # but for the trace handler, exceptions are never re-raised (only
+    # printed when needed).
+
+    @with_tracebacks(
+        TypeError,
+        regex=r".*<lambda>\(\) missing 6 required positional arguments",
+    )
+    def test_trace_handler_invalid_signature(self):
+        self.cx.execute("create table if not exists test(a number)")
+        self.cx.set_trace_callback(lambda x, y, z, t, a, b, c: None)
+        self.cx.execute("select * from test")
+
+    # Tests for checking that callback context mutations do not crash.
+    # Regression tests for https://github.com/python/cpython/issues/142830.
+
+    @with_tracebacks(ZeroDivisionError, regex="hello world")
+    def test_trace_callback_concurrent_mutation_in_call(self):
+        self.cx.execute("create table if not exists test(a number)")
+
+        def handler(statement):
+            # clear the progress handler
+            self.cx.set_trace_callback(None)
+            raise ZeroDivisionError("hello world")
+
+        self.cx.set_trace_callback(handler)
+        self.cx.execute("select * from test")
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git 
a/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst 
b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst
new file mode 100644
index 00000000000000..246979e91d76b5
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst
@@ -0,0 +1,2 @@
+:mod:`sqlite3`: fix use-after-free crashes when the connection's callbacks
+are mutated during a callback execution. Patch by Bénédikt Tran.
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 83ff8e60557c07..cde06c965ad4e3 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -145,7 +145,8 @@ class _sqlite3.Connection "pysqlite_Connection *" 
"clinic_state()->ConnectionTyp
 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=67369db2faf80891]*/
 
 static int _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
-static void free_callback_context(callback_context *ctx);
+static void incref_callback_context(callback_context *ctx);
+static void decref_callback_context(callback_context *ctx);
 static void set_callback_context(callback_context **ctx_pp,
                                  callback_context *ctx);
 static int connection_close(pysqlite_Connection *self);
@@ -936,8 +937,9 @@ func_callback(sqlite3_context *context, int argc, 
sqlite3_value **argv)
     args = _pysqlite_build_py_params(context, argc, argv);
     if (args) {
         callback_context *ctx = (callback_context *)sqlite3_user_data(context);
-        assert(ctx != NULL);
+        incref_callback_context(ctx);
         py_retval = PyObject_CallObject(ctx->callable, args);
+        decref_callback_context(ctx);
         Py_DECREF(args);
     }
 
@@ -964,7 +966,7 @@ step_callback(sqlite3_context *context, int argc, 
sqlite3_value **params)
     PyObject* stepmethod = NULL;
 
     callback_context *ctx = (callback_context *)sqlite3_user_data(context);
-    assert(ctx != NULL);
+    incref_callback_context(ctx);
 
     aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, 
sizeof(PyObject*));
     if (aggregate_instance == NULL) {
@@ -1002,6 +1004,7 @@ step_callback(sqlite3_context *context, int argc, 
sqlite3_value **params)
     }
 
 error:
+    decref_callback_context(ctx);
     Py_XDECREF(stepmethod);
     Py_XDECREF(function_result);
 
@@ -1033,9 +1036,10 @@ final_callback(sqlite3_context *context)
     PyObject *exc = PyErr_GetRaisedException();
 
     callback_context *ctx = (callback_context *)sqlite3_user_data(context);
-    assert(ctx != NULL);
+    incref_callback_context(ctx);
     function_result = PyObject_CallMethodNoArgs(*aggregate_instance,
                                                 ctx->state->str_finalize);
+    decref_callback_context(ctx);
     Py_DECREF(*aggregate_instance);
 
     ok = 0;
@@ -1107,6 +1111,7 @@ create_callback_context(PyTypeObject *cls, PyObject 
*callable)
     callback_context *ctx = PyMem_Malloc(sizeof(callback_context));
     if (ctx != NULL) {
         PyObject *module = PyType_GetModule(cls);
+        ctx->refcount = 1;
         ctx->callable = Py_NewRef(callable);
         ctx->module = Py_NewRef(module);
         ctx->state = pysqlite_get_state(module);
@@ -1118,11 +1123,33 @@ static void
 free_callback_context(callback_context *ctx)
 {
     assert(ctx != NULL);
+    assert(ctx->refcount == 0);
     Py_XDECREF(ctx->callable);
     Py_XDECREF(ctx->module);
     PyMem_Free(ctx);
 }
 
+static inline void
+incref_callback_context(callback_context *ctx)
+{
+    assert(PyGILState_Check());
+    assert(ctx != NULL);
+    assert(ctx->refcount > 0);
+    ctx->refcount++;
+}
+
+static inline void
+decref_callback_context(callback_context *ctx)
+{
+    assert(PyGILState_Check());
+    assert(ctx != NULL);
+    assert(ctx->refcount > 0);
+    ctx->refcount--;
+    if (ctx->refcount == 0) {
+        free_callback_context(ctx);
+    }
+}
+
 static void
 set_callback_context(callback_context **ctx_pp, callback_context *ctx)
 {
@@ -1130,7 +1157,7 @@ set_callback_context(callback_context **ctx_pp, 
callback_context *ctx)
     callback_context *tmp = *ctx_pp;
     *ctx_pp = ctx;
     if (tmp != NULL) {
-        free_callback_context(tmp);
+        decref_callback_context(tmp);
     }
 }
 
@@ -1141,7 +1168,7 @@ destructor_callback(void *ctx)
         // This function may be called without the GIL held, so we need to
         // ensure that we destroy 'ctx' with the GIL held.
         PyGILState_STATE gstate = PyGILState_Ensure();
-        free_callback_context((callback_context *)ctx);
+        decref_callback_context((callback_context *)ctx);
         PyGILState_Release(gstate);
     }
 }
@@ -1202,7 +1229,7 @@ 
pysqlite_connection_create_function_impl(pysqlite_Connection *self,
                                     func_callback,
                                     NULL,
                                     NULL,
-                                    &destructor_callback);  // will decref func
+                                    &destructor_callback);  // will free 'ctx'
 
     if (rc != SQLITE_OK) {
         /* Workaround for SQLite bug: no error code or string is available 
here */
@@ -1226,7 +1253,7 @@ inverse_callback(sqlite3_context *context, int argc, 
sqlite3_value **params)
     PyGILState_STATE gilstate = PyGILState_Ensure();
 
     callback_context *ctx = (callback_context *)sqlite3_user_data(context);
-    assert(ctx != NULL);
+    incref_callback_context(ctx);
 
     int size = sizeof(PyObject *);
     PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
@@ -1258,6 +1285,7 @@ inverse_callback(sqlite3_context *context, int argc, 
sqlite3_value **params)
     Py_DECREF(res);
 
 exit:
+    decref_callback_context(ctx);
     Py_XDECREF(method);
     PyGILState_Release(gilstate);
 }
@@ -1274,7 +1302,7 @@ value_callback(sqlite3_context *context)
     PyGILState_STATE gilstate = PyGILState_Ensure();
 
     callback_context *ctx = (callback_context *)sqlite3_user_data(context);
-    assert(ctx != NULL);
+    incref_callback_context(ctx);
 
     int size = sizeof(PyObject *);
     PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
@@ -1282,6 +1310,8 @@ value_callback(sqlite3_context *context)
     assert(*cls != NULL);
 
     PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value);
+    decref_callback_context(ctx);
+
     if (res == NULL) {
         int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
         set_sqlite_error(context, attr_err
@@ -1403,7 +1433,7 @@ 
pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
                                     0,
                                     &step_callback,
                                     &final_callback,
-                                    &destructor_callback); // will decref func
+                                    &destructor_callback); // will free 'ctx'
     if (rc != SQLITE_OK) {
         /* Workaround for SQLite bug: no error code or string is available 
here */
         PyErr_SetString(self->OperationalError, "Error creating aggregate");
@@ -1413,7 +1443,7 @@ 
pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
 }
 
 static int
-authorizer_callback(void *ctx, int action, const char *arg1,
+authorizer_callback(void *ctx_vp, int action, const char *arg1,
                     const char *arg2 , const char *dbname,
                     const char *access_attempt_source)
 {
@@ -1422,8 +1452,9 @@ authorizer_callback(void *ctx, int action, const char 
*arg1,
     PyObject *ret;
     int rc = SQLITE_DENY;
 
-    assert(ctx != NULL);
-    PyObject *callable = ((callback_context *)ctx)->callable;
+    callback_context *ctx = (callback_context *)ctx_vp;
+    incref_callback_context(ctx);
+    PyObject *callable = ctx->callable;
     ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname,
                                 access_attempt_source);
 
@@ -1445,21 +1476,23 @@ authorizer_callback(void *ctx, int action, const char 
*arg1,
         Py_DECREF(ret);
     }
 
+    decref_callback_context(ctx);
     PyGILState_Release(gilstate);
     return rc;
 }
 
 static int
-progress_callback(void *ctx)
+progress_callback(void *ctx_vp)
 {
     PyGILState_STATE gilstate = PyGILState_Ensure();
 
     int rc;
     PyObject *ret;
 
-    assert(ctx != NULL);
-    PyObject *callable = ((callback_context *)ctx)->callable;
-    ret = PyObject_CallNoArgs(callable);
+    callback_context *ctx = (callback_context *)ctx_vp;
+    incref_callback_context(ctx);
+
+    ret = PyObject_CallNoArgs(ctx->callable);
     if (!ret) {
         /* abort query if error occurred */
         rc = -1;
@@ -1472,6 +1505,7 @@ progress_callback(void *ctx)
         print_or_clear_traceback(ctx);
     }
 
+    decref_callback_context(ctx);
     PyGILState_Release(gilstate);
     return rc;
 }
@@ -1483,7 +1517,7 @@ progress_callback(void *ctx)
  * to ensure future compatibility.
  */
 static int
-trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
+trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
 {
     if (type != SQLITE_TRACE_STMT) {
         return 0;
@@ -1491,8 +1525,9 @@ trace_callback(unsigned int type, void *ctx, void *stmt, 
void *sql)
 
     PyGILState_STATE gilstate = PyGILState_Ensure();
 
-    assert(ctx != NULL);
-    pysqlite_state *state = ((callback_context *)ctx)->state;
+    callback_context *ctx = (callback_context *)ctx_vp;
+    incref_callback_context(ctx);
+    pysqlite_state *state = ctx->state;
     assert(state != NULL);
 
     PyObject *py_statement = NULL;
@@ -1506,7 +1541,7 @@ trace_callback(unsigned int type, void *ctx, void *stmt, 
void *sql)
 
         PyErr_SetString(state->DataError,
                 "Expanded SQL string exceeds the maximum string length");
-        print_or_clear_traceback((callback_context *)ctx);
+        print_or_clear_traceback(ctx);
 
         // Fall back to unexpanded sql
         py_statement = PyUnicode_FromString((const char *)sql);
@@ -1516,16 +1551,16 @@ trace_callback(unsigned int type, void *ctx, void 
*stmt, void *sql)
         sqlite3_free((void *)expanded_sql);
     }
     if (py_statement) {
-        PyObject *callable = ((callback_context *)ctx)->callable;
-        PyObject *ret = PyObject_CallOneArg(callable, py_statement);
+        PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement);
         Py_DECREF(py_statement);
         Py_XDECREF(ret);
     }
     if (PyErr_Occurred()) {
-        print_or_clear_traceback((callback_context *)ctx);
+        print_or_clear_traceback(ctx);
     }
 
 exit:
+    decref_callback_context(ctx);
     PyGILState_Release(gilstate);
     return 0;
 }
@@ -1950,6 +1985,8 @@ collation_callback(void *context, int text1_length, const 
void *text1_data,
     PyObject* retval = NULL;
     long longval;
     int result = 0;
+    callback_context *ctx = (callback_context *)context;
+    incref_callback_context(ctx);
 
     /* This callback may be executed multiple times per sqlite3_step(). Bail if
      * the previous call failed */
@@ -1966,8 +2003,6 @@ collation_callback(void *context, int text1_length, const 
void *text1_data,
         goto finally;
     }
 
-    callback_context *ctx = (callback_context *)context;
-    assert(ctx != NULL);
     PyObject *args[] = { NULL, string1, string2 };  // Borrowed refs.
     size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET;
     retval = PyObject_Vectorcall(ctx->callable, args + 1, nargsf, NULL);
@@ -1989,6 +2024,7 @@ collation_callback(void *context, int text1_length, const 
void *text1_data,
     }
 
 finally:
+    decref_callback_context(ctx);
     Py_XDECREF(string1);
     Py_XDECREF(string2);
     Py_XDECREF(retval);
diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h
index 7a748ee3ea0c58..703396a0c8db53 100644
--- a/Modules/_sqlite/connection.h
+++ b/Modules/_sqlite/connection.h
@@ -36,6 +36,7 @@ typedef struct _callback_context
     PyObject *callable;
     PyObject *module;
     pysqlite_state *state;
+    Py_ssize_t refcount;
 } callback_context;
 
 enum autocommit_mode {

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: [email protected]

Reply via email to