This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e7aa4d2289 Fix logic to cancel the external job if the TaskInstance is 
not in a running or deferred state for BigQueryInsertJobOperator (#39442)
e7aa4d2289 is described below

commit e7aa4d2289cd4207f11b697729466717889fda38
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Wed May 8 14:40:29 2024 +0545

    Fix logic to cancel the external job if the TaskInstance is not in a 
running or deferred state for BigQueryInsertJobOperator (#39442)
---
 .../providers/google/cloud/triggers/bigquery.py    | 60 ++++++++++++++++++++--
 .../google/cloud/triggers/test_bigquery.py         | 40 ++++++++++++++-
 2 files changed, 95 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/google/cloud/triggers/bigquery.py 
b/airflow/providers/google/cloud/triggers/bigquery.py
index e2e0e82f6b..fc19db9881 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -17,13 +17,20 @@
 from __future__ import annotations
 
 import asyncio
-from typing import Any, AsyncIterator, Sequence, SupportsAbs
+from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, SupportsAbs
 
 from aiohttp import ClientSession
 from aiohttp.client_exceptions import ClientResponseError
 
+from airflow.exceptions import AirflowException
+from airflow.models.taskinstance import TaskInstance
 from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, 
BigQueryTableAsyncHook
 from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils.session import provide_session
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm.session import Session
 
 
 class BigQueryInsertJobTrigger(BaseTrigger):
@@ -89,6 +96,36 @@ class BigQueryInsertJobTrigger(BaseTrigger):
             },
         )
 
+    @provide_session
+    def get_task_instance(self, session: Session) -> TaskInstance:
+        query = session.query(TaskInstance).filter(
+            TaskInstance.dag_id == self.task_instance.dag_id,
+            TaskInstance.task_id == self.task_instance.task_id,
+            TaskInstance.run_id == self.task_instance.run_id,
+            TaskInstance.map_index == self.task_instance.map_index,
+        )
+        task_instance = query.one_or_none()
+        if task_instance is None:
+            raise AirflowException(
+                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
+                self.task_instance.dag_id,
+                self.task_instance.task_id,
+                self.task_instance.run_id,
+                self.task_instance.map_index,
+            )
+        return task_instance
+
+    def safe_to_cancel(self) -> bool:
+        """
+        Whether it is safe to cancel the external job which is being executed 
by this trigger.
+
+        This is to avoid the case that `asyncio.CancelledError` is called 
because the trigger itself is stopped.
+        Because in those cases, we should NOT cancel the external job.
+        """
+        # Database query is needed to get the latest state of the task 
instance.
+        task_instance = self.get_task_instance()  # type: ignore[call-arg]
+        return task_instance.state != TaskInstanceState.DEFERRED
+
     async def run(self) -> AsyncIterator[TriggerEvent]:  # type: 
ignore[override]
         """Get current job execution status and yields a TriggerEvent."""
         hook = self._get_async_hook()
@@ -117,13 +154,27 @@ class BigQueryInsertJobTrigger(BaseTrigger):
                     )
                     await asyncio.sleep(self.poll_interval)
         except asyncio.CancelledError:
-            self.log.info("Task was killed.")
-            if self.job_id and self.cancel_on_kill:
+            if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
+                self.log.info(
+                    "The job is safe to cancel the as airflow TaskInstance is 
not in deferred state."
+                )
+                self.log.info(
+                    "Cancelling job. Project ID: %s, Location: %s, Job ID: %s",
+                    self.project_id,
+                    self.location,
+                    self.job_id,
+                )
                 await hook.cancel_job(  # type: ignore[union-attr]
                     job_id=self.job_id, project_id=self.project_id, 
location=self.location
                 )
             else:
-                self.log.info("Skipping to cancel job: %s:%s.%s", 
self.project_id, self.location, self.job_id)
+                self.log.info(
+                    "Trigger may have shutdown. Skipping to cancel job because 
the airflow "
+                    "task is not cancelled yet: Project ID: %s, Location:%s, 
Job ID:%s",
+                    self.project_id,
+                    self.location,
+                    self.job_id,
+                )
         except Exception as e:
             self.log.exception("Exception occurred while checking for query 
completion")
             yield TriggerEvent({"status": "error", "message": str(e)})
@@ -148,6 +199,7 @@ class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
                 "table_id": self.table_id,
                 "poll_interval": self.poll_interval,
                 "impersonation_chain": self.impersonation_chain,
+                "cancel_on_kill": self.cancel_on_kill,
             },
         )
 
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py 
b/tests/providers/google/cloud/triggers/test_bigquery.py
index 436872903e..bbb1a50356 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -239,13 +239,15 @@ class TestBigQueryInsertJobTrigger:
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
     
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+    
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
     async def test_bigquery_insert_job_trigger_cancellation(
-        self, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger
+        self, mock_get_task_instance, mock_get_job_status, mock_cancel_job, 
caplog, insert_job_trigger
     ):
         """
         Test that BigQueryInsertJobTrigger handles cancellation correctly, 
logs the appropriate message,
         and conditionally cancels the job based on the `cancel_on_kill` 
attribute.
         """
+        mock_get_task_instance.return_value = True
         insert_job_trigger.cancel_on_kill = True
         insert_job_trigger.job_id = "1234"
 
@@ -271,6 +273,41 @@ class TestBigQueryInsertJobTrigger:
         ), "Expected messages about task status or cancellation not found in 
log."
         mock_cancel_job.assert_awaited_once()
 
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+    
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
+    async def 
test_bigquery_insert_job_trigger_cancellation_unsafe_cancellation(
+        self, mock_safe_to_cancel, mock_get_job_status, mock_cancel_job, 
caplog, insert_job_trigger
+    ):
+        """
+        Test that BigQueryInsertJobTrigger logs the appropriate message and 
does not cancel the job
+        if safe_to_cancel returns False even when the task is cancelled.
+        """
+        mock_safe_to_cancel.return_value = False
+        insert_job_trigger.cancel_on_kill = True
+        insert_job_trigger.job_id = "1234"
+
+        # Simulate the initial job status as running
+        mock_get_job_status.side_effect = [
+            {"status": "running", "message": "Job is still running"},
+            asyncio.CancelledError(),
+            {"status": "running", "message": "Job is still running after 
cancellation"},
+        ]
+
+        caplog.set_level(logging.INFO)
+
+        try:
+            async for _ in insert_job_trigger.run():
+                pass
+        except asyncio.CancelledError:
+            pass
+
+        assert (
+            "Skipping to cancel job" in caplog.text
+        ), "Expected message about skipping cancellation not found in log."
+        assert mock_get_job_status.call_count == 2, "Job status should be 
checked multiple times"
+
 
 class TestBigQueryGetDataTrigger:
     def test_bigquery_get_data_trigger_serialization(self, get_data_trigger):
@@ -447,6 +484,7 @@ class TestBigQueryCheckTrigger:
             "table_id": TEST_TABLE_ID,
             "location": None,
             "poll_interval": POLLING_PERIOD_SECONDS,
+            "cancel_on_kill": True,
         }
 
     @pytest.mark.asyncio

Reply via email to