This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new af80c491a8c Refactor timeout handling in DatabricksSqlHook to use
explicit signaling (#62623)
af80c491a8c is described below
commit af80c491a8cabbc9d3697b7ebf243cecb48816a7
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Mar 10 18:35:13 2026 +0000
Refactor timeout handling in DatabricksSqlHook to use explicit signaling
(#62623)
Replace implicit timeout detection based on Timer.is_alive() with explicit
timeout signaling via threading.Event. Timeout classification now checks an
explicit signal set by the timeout callback instead of inferring state from
thread lifecycle behavior.
Preserves existing cancellation semantics and exception types. Unit tests
have been adjusted accordingly.
Co-authored-by: Sameer Mesiah <[email protected]>
---
.../providers/databricks/hooks/databricks_sql.py | 48 ++++++++++++++--------
.../unit/databricks/hooks/test_databricks_sql.py | 39 +++++++++++-------
2 files changed, 55 insertions(+), 32 deletions(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index 127f6b71c70..2c2164bd9c7 100644
---
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -52,14 +52,23 @@ if TYPE_CHECKING:
T = TypeVar("T")
-def create_timeout_thread(cur, execution_timeout: timedelta | None) ->
threading.Timer | None:
- if execution_timeout is not None:
- seconds_to_timeout = execution_timeout.total_seconds()
- t = threading.Timer(seconds_to_timeout, cur.connection.cancel)
- else:
- t = None
+def create_timeout_thread(
+ cur, execution_timeout: timedelta | None
+) -> tuple[threading.Timer | None, threading.Event | None]:
+ """Create a timeout timer that cancels the connection and sets a timeout
flag."""
+ if not execution_timeout:
+ return None, None
- return t
+ timeout_event = threading.Event()
+
+ def _cancel():
+ timeout_event.set()
+ cur.connection.cancel()
+
+ timer = threading.Timer(execution_timeout.total_seconds(), _cancel)
+ timer.start()
+
+ return timer, timeout_event
class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
@@ -290,22 +299,25 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
self.set_autocommit(conn, autocommit)
with closing(conn.cursor()) as cur:
- t = create_timeout_thread(cur, execution_timeout)
+ timer, timeout_event = create_timeout_thread(cur,
execution_timeout)
- # TODO: adjust this to make testing easier
try:
self._run_command(cur, sql_statement, parameters)
+
except Exception as e:
- if t is None or t.is_alive():
- raise DatabricksSqlExecutionError(
- f"Error running SQL statement:
{sql_statement}. {str(e)}"
- )
- raise DatabricksSqlExecutionTimeout(
- f"Timeout threshold exceeded for SQL statement:
{sql_statement} was cancelled."
- )
+ if timeout_event and timeout_event.is_set():
+ raise DatabricksSqlExecutionTimeout(
+ f"Timeout threshold exceeded for SQL
statement: "
+ f"{sql_statement} was cancelled."
+ ) from e
+
+ raise DatabricksSqlExecutionError(
+ f"Error running SQL statement: {sql_statement}.
{str(e)}"
+ ) from e
+
finally:
- if t is not None:
- t.cancel()
+ if timer:
+ timer.cancel()
if query_id := cur.query_id:
self.log.info("Databricks query id: %s", query_id)
diff --git
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
index 98ea7e1d347..d661f5b0714 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
@@ -18,7 +18,6 @@
#
from __future__ import annotations
-import threading
from collections import namedtuple
from datetime import timedelta
from unittest import mock
@@ -509,8 +508,12 @@ def test_execution_timeout_exceeded(
description=get_cursor_descriptions(cursor_descriptions),
)
- # Simulate a timeout
- mock_create_timeout_thread.return_value = threading.Timer(cur,
execution_timeout)
+ mock_event = mock.MagicMock()
+ mock_event.is_set.return_value = True # simulate timeout
+
+ mock_timer = mock.MagicMock()
+
+ mock_create_timeout_thread.return_value = (mock_timer, mock_event)
mock_run_command.side_effect = Exception("Mocked exception")
@@ -532,20 +535,22 @@ def test_execution_timeout_exceeded(
"cursor_descriptions",
[(("id", "value"),)],
)
-def test_create_timeout_thread(
- mock_get_conn,
- mock_get_requests,
- mock_timer,
- cursor_descriptions,
-):
+def test_create_timeout_thread(mock_get_conn, mock_get_requests,
cursor_descriptions):
+
cur = mock.MagicMock(
rowcount=1,
description=get_cursor_descriptions(cursor_descriptions),
)
+
timeout = timedelta(seconds=1)
- thread = create_timeout_thread(cur=cur, execution_timeout=timeout)
- mock_timer.assert_called_once_with(timeout.total_seconds(),
cur.connection.cancel)
- assert thread is not None
+
+ timer, event = create_timeout_thread(cur=cur, execution_timeout=timeout)
+
+ assert timer is not None
+ assert event is not None
+ assert not event.is_set()
+
+ timer.cancel()
@pytest.mark.parametrize(
@@ -562,9 +567,15 @@ def test_create_timeout_thread_no_timeout(
rowcount=1,
description=get_cursor_descriptions(cursor_descriptions),
)
- thread = create_timeout_thread(cur=cur, execution_timeout=None)
+
+ timer, timeout_event = create_timeout_thread(
+ cur=cur,
+ execution_timeout=None,
+ )
+
mock_timer.assert_not_called()
- assert thread is None
+ assert timer is None
+ assert timeout_event is None
def test_get_openlineage_default_schema_with_no_schema_set():