This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e7aa4d2289 Fix logic to cancel the external job if the TaskInstance is
not in a running or deferred state for BigQueryInsertJobOperator (#39442)
e7aa4d2289 is described below
commit e7aa4d2289cd4207f11b697729466717889fda38
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Wed May 8 14:40:29 2024 +0545
Fix logic to cancel the external job if the TaskInstance is not in a
running or deferred state for BigQueryInsertJobOperator (#39442)
---
.../providers/google/cloud/triggers/bigquery.py | 60 ++++++++++++++++++++--
.../google/cloud/triggers/test_bigquery.py | 40 ++++++++++++++-
2 files changed, 95 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py
b/airflow/providers/google/cloud/triggers/bigquery.py
index e2e0e82f6b..fc19db9881 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -17,13 +17,20 @@
from __future__ import annotations
import asyncio
-from typing import Any, AsyncIterator, Sequence, SupportsAbs
+from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, SupportsAbs
from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError
+from airflow.exceptions import AirflowException
+from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook,
BigQueryTableAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils.session import provide_session
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm.session import Session
class BigQueryInsertJobTrigger(BaseTrigger):
@@ -89,6 +96,36 @@ class BigQueryInsertJobTrigger(BaseTrigger):
},
)
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ query = session.query(TaskInstance).filter(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ task_instance = query.one_or_none()
+ if task_instance is None:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_instance
+
+ def safe_to_cancel(self) -> bool:
+ """
+ Whether it is safe to cancel the external job which is being executed
by this trigger.
+
+ This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
+ Because in those cases, we should NOT cancel the external job.
+ """
+ # Database query is needed to get the latest state of the task
instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ return task_instance.state != TaskInstanceState.DEFERRED
+
async def run(self) -> AsyncIterator[TriggerEvent]: # type:
ignore[override]
"""Get current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
@@ -117,13 +154,27 @@ class BigQueryInsertJobTrigger(BaseTrigger):
)
await asyncio.sleep(self.poll_interval)
except asyncio.CancelledError:
- self.log.info("Task was killed.")
- if self.job_id and self.cancel_on_kill:
+ if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
+ self.log.info(
+ "The job is safe to cancel the as airflow TaskInstance is
not in deferred state."
+ )
+ self.log.info(
+ "Cancelling job. Project ID: %s, Location: %s, Job ID: %s",
+ self.project_id,
+ self.location,
+ self.job_id,
+ )
await hook.cancel_job( # type: ignore[union-attr]
job_id=self.job_id, project_id=self.project_id,
location=self.location
)
else:
- self.log.info("Skipping to cancel job: %s:%s.%s",
self.project_id, self.location, self.job_id)
+ self.log.info(
+ "Trigger may have shutdown. Skipping to cancel job because
the airflow "
+ "task is not cancelled yet: Project ID: %s, Location:%s,
Job ID:%s",
+ self.project_id,
+ self.location,
+ self.job_id,
+ )
except Exception as e:
self.log.exception("Exception occurred while checking for query
completion")
yield TriggerEvent({"status": "error", "message": str(e)})
@@ -148,6 +199,7 @@ class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
+ "cancel_on_kill": self.cancel_on_kill,
},
)
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py
b/tests/providers/google/cloud/triggers/test_bigquery.py
index 436872903e..bbb1a50356 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -239,13 +239,15 @@ class TestBigQueryInsertJobTrigger:
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
async def test_bigquery_insert_job_trigger_cancellation(
- self, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger
+ self, mock_get_task_instance, mock_get_job_status, mock_cancel_job,
caplog, insert_job_trigger
):
"""
Test that BigQueryInsertJobTrigger handles cancellation correctly,
logs the appropriate message,
and conditionally cancels the job based on the `cancel_on_kill`
attribute.
"""
+ mock_get_task_instance.return_value = True
insert_job_trigger.cancel_on_kill = True
insert_job_trigger.job_id = "1234"
@@ -271,6 +273,41 @@ class TestBigQueryInsertJobTrigger:
), "Expected messages about task status or cancellation not found in
log."
mock_cancel_job.assert_awaited_once()
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
+ async def
test_bigquery_insert_job_trigger_cancellation_unsafe_cancellation(
+ self, mock_safe_to_cancel, mock_get_job_status, mock_cancel_job,
caplog, insert_job_trigger
+ ):
+ """
+ Test that BigQueryInsertJobTrigger logs the appropriate message and
does not cancel the job
+ if safe_to_cancel returns False even when the task is cancelled.
+ """
+ mock_safe_to_cancel.return_value = False
+ insert_job_trigger.cancel_on_kill = True
+ insert_job_trigger.job_id = "1234"
+
+ # Simulate the initial job status as running
+ mock_get_job_status.side_effect = [
+ {"status": "running", "message": "Job is still running"},
+ asyncio.CancelledError(),
+ {"status": "running", "message": "Job is still running after
cancellation"},
+ ]
+
+ caplog.set_level(logging.INFO)
+
+ try:
+ async for _ in insert_job_trigger.run():
+ pass
+ except asyncio.CancelledError:
+ pass
+
+ assert (
+ "Skipping to cancel job" in caplog.text
+ ), "Expected message about skipping cancellation not found in log."
+ assert mock_get_job_status.call_count == 2, "Job status should be
checked multiple times"
+
class TestBigQueryGetDataTrigger:
def test_bigquery_get_data_trigger_serialization(self, get_data_trigger):
@@ -447,6 +484,7 @@ class TestBigQueryCheckTrigger:
"table_id": TEST_TABLE_ID,
"location": None,
"poll_interval": POLLING_PERIOD_SECONDS,
+ "cancel_on_kill": True,
}
@pytest.mark.asyncio