This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 72554500bd5 feat(snowflake): add support for cancelling running 
queries via SQL API (#56164)
72554500bd5 is described below

commit 72554500bd57a8e0f4faa776ac09ca21d8788e3c
Author: Ranuga <[email protected]>
AuthorDate: Wed Nov 26 06:06:12 2025 +0530

    feat(snowflake): add support for cancelling running queries via SQL API 
(#56164)
    
    * feat(snowflake): support cancelling running queries
    
    Extend `get_request_url_header_params` to accept optional `url_suffix` for 
appending endpoint paths (e.g., /cancel).
    Implement `_cancel_sql_api_query_execution` and `cancel_queries` in 
`SnowflakeSqlApiHook` to cancel running queries.
    Add `on_kill` method to `SnowflakeSqlApiOperator` to trigger query 
cancellation on task kill.
    Introduce unit test to verify cancel endpoint calls for all query IDs.
    
    * Remove Tests
    
    * chore: add tests
    
    * chore: fix ci/cd
---
 .../providers/snowflake/hooks/snowflake_sql_api.py | 17 +++++++++-
 .../providers/snowflake/operators/snowflake.py     |  7 +++++
 .../unit/snowflake/hooks/test_snowflake_sql_api.py | 36 ++++++++++++++++++++++
 .../unit/snowflake/operators/test_snowflake.py     | 30 ++++++++++++++++++
 4 files changed, 89 insertions(+), 1 deletion(-)

diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 9f3e7c68be8..d0b3516c0e3 100644
--- 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++ 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -258,16 +258,21 @@ class SnowflakeSqlApiHook(SnowflakeHook):
             conn_config=conn_config, token_endpoint=token_endpoint, 
grant_type=grant_type
         )
 
-    def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, 
Any], dict[str, Any], str]:
+    def get_request_url_header_params(
+        self, query_id: str, url_suffix: str | None = None
+    ) -> tuple[dict[str, Any], dict[str, Any], str]:
         """
         Build the request header Url with account name identifier and query id 
from the connection params.
 
         :param query_id: statement handles query ids for the individual 
statements.
+        :param url_suffix: Optional path suffix to append to the URL. Must 
start with '/', e.g. '/cancel' or '/result'.
         """
         req_id = uuid.uuid4()
         header = self.get_headers()
         params = {"requestId": str(req_id)}
         url = 
f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements/{query_id}"
+        if url_suffix:
+            url += url_suffix
         return header, params, url
 
     def check_query_output(self, query_ids: list[str]) -> None:
@@ -413,6 +418,16 @@ class SnowflakeSqlApiHook(SnowflakeHook):
         status_code, resp = await 
self._make_api_call_with_retries_async("GET", url, header, params)
         return self._process_response(status_code, resp)
 
+    def _cancel_sql_api_query_execution(self, query_id: str) -> dict[str, str 
| list[str]]:
+        self.log.info("Cancelling query id %s", query_id)
+        header, params, url = self.get_request_url_header_params(query_id, 
"/cancel")
+        status_code, resp = self._make_api_call_with_retries("POST", url, 
header, params)
+        return self._process_response(status_code, resp)
+
+    def cancel_queries(self, query_ids: list[str]) -> None:
+        for query_id in query_ids:
+            self._cancel_sql_api_query_execution(query_id)
+
     @staticmethod
     def _should_retry_on_error(exception) -> bool:
         """
diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py 
b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
index 0a9834679a7..b62ec9c7e28 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
@@ -513,3 +513,10 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
                     self._hook.query_ids = self.query_ids
         else:
             self.log.info("%s completed successfully.", self.task_id)
+
+    def on_kill(self) -> None:
+        """Cancel the running query."""
+        if self.query_ids:
+            self.log.info("Cancelling the query ids %s", self.query_ids)
+            self._hook.cancel_queries(self.query_ids)
+            self.log.info("Query ids %s cancelled successfully", 
self.query_ids)
diff --git 
a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py 
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
index efb2c587460..d7bcdd5277c 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
@@ -1433,3 +1433,39 @@ class TestSnowflakeSqlApiHook:
 
         failed_response.raise_for_status.assert_called_once()
         failed_response.json.assert_not_called()
+
+    @mock.patch(f"{HOOK_PATH}.get_request_url_header_params")
+    def test_cancel_sql_api_query_execution(self, mock_get_url_header_params, 
mock_requests):
+        """Test _cancel_sql_api_query_execution makes POST request with 
/cancel suffix."""
+        query_id = "test-query-id"
+        mock_get_url_header_params.return_value = (
+            HEADERS,
+            {"requestId": "uuid"},
+            f"{API_URL}/{query_id}/cancel",
+        )
+        mock_requests.request.return_value = create_successful_response_mock(
+            {"status": "success", "message": "Statement cancelled."}
+        )
+
+        hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+        hook._cancel_sql_api_query_execution(query_id)
+
+        mock_get_url_header_params.assert_called_once_with(query_id, "/cancel")
+        mock_requests.request.assert_called_once_with(
+            method="post",
+            url=f"{API_URL}/{query_id}/cancel",
+            headers=HEADERS,
+            params={"requestId": "uuid"},
+            json=None,
+        )
+
+    @mock.patch(f"{HOOK_PATH}._cancel_sql_api_query_execution")
+    def test_cancel_queries(self, mock_cancel_execution):
+        """Test cancel_queries calls _cancel_sql_api_query_execution for each 
query id."""
+        query_ids = ["query-1", "query-2", "query-3"]
+
+        hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+        hook.cancel_queries(query_ids)
+
+        assert mock_cancel_execution.call_count == 3
+        mock_cancel_execution.assert_has_calls([call("query-1"), 
call("query-2"), call("query-3")])
diff --git 
a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py 
b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
index 15e763234e8..0d1bae0cd40 100644
--- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
@@ -574,3 +574,33 @@ class TestSnowflakeSqlApiOperator:
         with pytest.raises(AirflowException):
             operator.execute(context=None)
         mock_check_query_output.assert_not_called()
+
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.cancel_queries")
+    def test_snowflake_sql_api_on_kill_cancels_queries(self, 
mock_cancel_queries):
+        """Test that on_kill cancels running queries."""
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id=CONN_ID,
+            sql=SQL_MULTIPLE_STMTS,
+            statement_count=4,
+        )
+        operator.query_ids = ["uuid1", "uuid2"]
+
+        operator.on_kill()
+
+        mock_cancel_queries.assert_called_once_with(["uuid1", "uuid2"])
+
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.cancel_queries")
+    def test_snowflake_sql_api_on_kill_no_queries(self, mock_cancel_queries):
+        """Test that on_kill does nothing when no query ids exist."""
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id=CONN_ID,
+            sql=SQL_MULTIPLE_STMTS,
+            statement_count=4,
+        )
+        operator.query_ids = []
+
+        operator.on_kill()
+
+        mock_cancel_queries.assert_not_called()

Reply via email to