This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 72554500bd5 feat(snowflake): add support for cancelling running
queries via SQL API (#56164)
72554500bd5 is described below
commit 72554500bd57a8e0f4faa776ac09ca21d8788e3c
Author: Ranuga <[email protected]>
AuthorDate: Wed Nov 26 06:06:12 2025 +0530
feat(snowflake): add support for cancelling running queries via SQL API
(#56164)
* feat(snowflake): support cancelling running queries
Extend `get_request_url_header_params` to accept optional `url_suffix` for
appending endpoint paths (e.g., /cancel).
Implement `_cancel_sql_api_query_execution` and `cancel_queries` in
`SnowflakeSqlApiHook` to cancel running queries.
Add `on_kill` method to `SnowflakeSqlApiOperator` to trigger query
cancellation on task kill.
Introduce unit test to verify cancel endpoint calls for all query IDs.
* Remove Tests
* chore: add tests
* chore: fix ci/cd
---
.../providers/snowflake/hooks/snowflake_sql_api.py | 17 +++++++++-
.../providers/snowflake/operators/snowflake.py | 7 +++++
.../unit/snowflake/hooks/test_snowflake_sql_api.py | 36 ++++++++++++++++++++++
.../unit/snowflake/operators/test_snowflake.py | 30 ++++++++++++++++++
4 files changed, 89 insertions(+), 1 deletion(-)
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 9f3e7c68be8..d0b3516c0e3 100644
---
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -258,16 +258,21 @@ class SnowflakeSqlApiHook(SnowflakeHook):
conn_config=conn_config, token_endpoint=token_endpoint,
grant_type=grant_type
)
- def get_request_url_header_params(self, query_id: str) -> tuple[dict[str,
Any], dict[str, Any], str]:
+ def get_request_url_header_params(
+ self, query_id: str, url_suffix: str | None = None
+ ) -> tuple[dict[str, Any], dict[str, Any], str]:
"""
Build the request header Url with account name identifier and query id
from the connection params.
:param query_id: statement handles query ids for the individual
statements.
+ :param url_suffix: Optional path suffix to append to the URL. Must
start with '/', e.g. '/cancel' or '/result'.
"""
req_id = uuid.uuid4()
header = self.get_headers()
params = {"requestId": str(req_id)}
url =
f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements/{query_id}"
+ if url_suffix:
+ url += url_suffix
return header, params, url
def check_query_output(self, query_ids: list[str]) -> None:
@@ -413,6 +418,16 @@ class SnowflakeSqlApiHook(SnowflakeHook):
status_code, resp = await
self._make_api_call_with_retries_async("GET", url, header, params)
return self._process_response(status_code, resp)
+ def _cancel_sql_api_query_execution(self, query_id: str) -> dict[str, str
| list[str]]:
+ self.log.info("Cancelling query id %s", query_id)
+ header, params, url = self.get_request_url_header_params(query_id,
"/cancel")
+ status_code, resp = self._make_api_call_with_retries("POST", url,
header, params)
+ return self._process_response(status_code, resp)
+
+ def cancel_queries(self, query_ids: list[str]) -> None:
+ for query_id in query_ids:
+ self._cancel_sql_api_query_execution(query_id)
+
@staticmethod
def _should_retry_on_error(exception) -> bool:
"""
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
index 0a9834679a7..b62ec9c7e28 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
@@ -513,3 +513,10 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
self._hook.query_ids = self.query_ids
else:
self.log.info("%s completed successfully.", self.task_id)
+
+ def on_kill(self) -> None:
+ """Cancel the running query."""
+ if self.query_ids:
+ self.log.info("Cancelling the query ids %s", self.query_ids)
+ self._hook.cancel_queries(self.query_ids)
+ self.log.info("Query ids %s cancelled successfully",
self.query_ids)
diff --git
a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
index efb2c587460..d7bcdd5277c 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
@@ -1433,3 +1433,39 @@ class TestSnowflakeSqlApiHook:
failed_response.raise_for_status.assert_called_once()
failed_response.json.assert_not_called()
+
+ @mock.patch(f"{HOOK_PATH}.get_request_url_header_params")
+ def test_cancel_sql_api_query_execution(self, mock_get_url_header_params,
mock_requests):
+ """Test _cancel_sql_api_query_execution makes POST request with
/cancel suffix."""
+ query_id = "test-query-id"
+ mock_get_url_header_params.return_value = (
+ HEADERS,
+ {"requestId": "uuid"},
+ f"{API_URL}/{query_id}/cancel",
+ )
+ mock_requests.request.return_value = create_successful_response_mock(
+ {"status": "success", "message": "Statement cancelled."}
+ )
+
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+ hook._cancel_sql_api_query_execution(query_id)
+
+ mock_get_url_header_params.assert_called_once_with(query_id, "/cancel")
+ mock_requests.request.assert_called_once_with(
+ method="post",
+ url=f"{API_URL}/{query_id}/cancel",
+ headers=HEADERS,
+ params={"requestId": "uuid"},
+ json=None,
+ )
+
+ @mock.patch(f"{HOOK_PATH}._cancel_sql_api_query_execution")
+ def test_cancel_queries(self, mock_cancel_execution):
+ """Test cancel_queries calls _cancel_sql_api_query_execution for each
query id."""
+ query_ids = ["query-1", "query-2", "query-3"]
+
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+ hook.cancel_queries(query_ids)
+
+ assert mock_cancel_execution.call_count == 3
+ mock_cancel_execution.assert_has_calls([call("query-1"),
call("query-2"), call("query-3")])
diff --git
a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
index 15e763234e8..0d1bae0cd40 100644
--- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
@@ -574,3 +574,33 @@ class TestSnowflakeSqlApiOperator:
with pytest.raises(AirflowException):
operator.execute(context=None)
mock_check_query_output.assert_not_called()
+
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.cancel_queries")
+ def test_snowflake_sql_api_on_kill_cancels_queries(self,
mock_cancel_queries):
+ """Test that on_kill cancels running queries."""
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=4,
+ )
+ operator.query_ids = ["uuid1", "uuid2"]
+
+ operator.on_kill()
+
+ mock_cancel_queries.assert_called_once_with(["uuid1", "uuid2"])
+
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.cancel_queries")
+ def test_snowflake_sql_api_on_kill_no_queries(self, mock_cancel_queries):
+ """Test that on_kill does nothing when no query ids exist."""
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=4,
+ )
+ operator.query_ids = []
+
+ operator.on_kill()
+
+ mock_cancel_queries.assert_not_called()