This is an automated email from the ASF dual-hosted git repository.
shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 944e1311416 Migrate BigQueryInsertJobTrigger to on_kill() for
user-initiated kills (#66704)
944e1311416 is described below
commit 944e1311416f523a912abaf602c0ef47f7dd5845
Author: Yunhui Chae <[email protected]>
AuthorDate: Tue May 12 23:54:13 2026 +0900
Migrate BigQueryInsertJobTrigger to on_kill() for user-initiated kills
(#66704)
---
.../providers/google/cloud/triggers/bigquery.py | 121 ++++++++++++---------
.../unit/google/cloud/triggers/test_bigquery.py | 117 +++++++++++---------
2 files changed, 132 insertions(+), 106 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
index 4a6d1a7780c..c9f8acde0a5 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
@@ -26,14 +26,14 @@ from asgiref.sync import sync_to_async
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook,
BigQueryTableAsyncHook
-from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_3_PLUS
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
-if not AIRFLOW_V_3_0_PLUS:
+if not AIRFLOW_V_3_3_PLUS:
from sqlalchemy import select
from airflow.models.taskinstance import TaskInstance
@@ -103,7 +103,20 @@ class BigQueryInsertJobTrigger(BaseTrigger):
},
)
- if not AIRFLOW_V_3_0_PLUS:
+ async def on_kill(self) -> None:
+ """Cancel the BigQuery job when the task is killed by a user action."""
+ if self.job_id and self.cancel_on_kill:
+ self.log.info(
+ "Cancelling BigQuery job. Project ID: %s, Location: %s, Job
ID: %s",
+ self.project_id,
+ self.location,
+ self.job_id,
+ )
+ hook = self._get_async_hook()
+ await hook.cancel_job(job_id=self.job_id,
project_id=self.project_id, location=self.location)
+ self.log.info("BigQuery job %s cancelled successfully.",
self.job_id)
+
+ if not AIRFLOW_V_3_3_PLUS:
@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
@@ -125,41 +138,41 @@ class BigQueryInsertJobTrigger(BaseTrigger):
)
return task_instance
- async def get_task_state(self):
- from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+ async def get_task_state(self):
+ from airflow.sdk.execution_time.task_runner import
RuntimeTaskInstance
- task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
- dag_id=self.task_instance.dag_id,
- task_ids=[self.task_instance.task_id],
- run_ids=[self.task_instance.run_id],
- map_index=self.task_instance.map_index,
- )
- try:
- task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
- except Exception:
- raise AirflowException(
- "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
)
- return task_state
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_state
- async def safe_to_cancel(self) -> bool:
- """
- Whether it is safe to cancel the external job which is being executed
by this trigger.
+ async def safe_to_cancel(self) -> bool:
+ """
+ Whether it is safe to cancel the external job which is being
executed by this trigger.
- This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
- Because in those cases, we should NOT cancel the external job.
- """
- if AIRFLOW_V_3_0_PLUS:
- task_state = await self.get_task_state()
- else:
- # Database query is needed to get the latest state of the task
instance.
- task_instance = self.get_task_instance() # type: ignore[call-arg]
- task_state = task_instance.state
- return task_state != TaskInstanceState.DEFERRED
+ This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
+ Because in those cases, we should NOT cancel the external job.
+ """
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = await self.get_task_state()
+ else:
+ # Database query is needed to get the latest state of the task
instance.
+ task_instance = self.get_task_instance() # type:
ignore[call-arg]
+ task_state = task_instance.state
+ return task_state != TaskInstanceState.DEFERRED
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job execution status and yields a TriggerEvent."""
@@ -196,25 +209,27 @@ class BigQueryInsertJobTrigger(BaseTrigger):
)
await asyncio.sleep(self.poll_interval)
except asyncio.CancelledError:
- if self.job_id and self.cancel_on_kill and await
self.safe_to_cancel():
- self.log.info(
- "The job is safe to cancel the as airflow TaskInstance is
not in deferred state."
- )
- self.log.info(
- "Cancelling job. Project ID: %s, Location: %s, Job ID: %s",
- self.project_id,
- self.location,
- self.job_id,
- )
- await hook.cancel_job(job_id=self.job_id,
project_id=self.project_id, location=self.location)
- else:
- self.log.info(
- "Trigger may have shutdown. Skipping to cancel job because
the airflow "
- "task is not cancelled yet: Project ID: %s, Location:%s,
Job ID:%s",
- self.project_id,
- self.location,
- self.job_id,
- )
+ # Legacy path for Airflow < 3.3.0
+ # On Airflow 3.3.0+, on_kill() handles user-initiated kills
+ if not AIRFLOW_V_3_3_PLUS:
+ if self.job_id and self.cancel_on_kill and await
self.safe_to_cancel():
+ self.log.info(
+ "Cancelling job (legacy path). Project ID: %s,
Location: %s, Job ID: %s",
+ self.project_id,
+ self.location,
+ self.job_id,
+ )
+ await hook.cancel_job(
+ job_id=self.job_id, project_id=self.project_id,
location=self.location
+ )
+ else:
+ self.log.info(
+ "Trigger may have shutdown. Skipping to cancel job
because the airflow "
+ "task is not cancelled yet: Project ID: %s,
Location:%s, Job ID:%s",
+ self.project_id,
+ self.location,
+ self.job_id,
+ )
raise
except Exception as e:
self.log.exception("Exception occurred while checking for query
completion")
diff --git a/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
b/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
index 720a9a0d806..78448c064a0 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
@@ -40,6 +40,8 @@ from airflow.providers.google.cloud.triggers.bigquery import (
)
from airflow.triggers.base import TriggerEvent
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+
TEST_CONN_ID = "bq_default"
TEST_JOB_ID = "1234"
TEST_GCP_PROJECT_ID = "test-project"
@@ -234,72 +236,81 @@ class TestBigQueryInsertJobTrigger:
assert TriggerEvent({"status": "error", "message": "Test exception"})
== actual
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+ @pytest.mark.skipif(AIRFLOW_V_3_3_PLUS, reason="on_kill() handles
cancellation for Airflow 3.3.0+")
+ @pytest.mark.parametrize("is_safe_to_cancel", [True, False])
+
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger._get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
- async def test_bigquery_insert_job_trigger_cancellation(
- self, mock_get_task_instance, mock_get_job_status, mock_cancel_job,
caplog, insert_job_trigger
+ async def test_insert_job_trigger_run_cancelled(
+ self, mock_safe_to_cancel, mock_get_async_hook, insert_job_trigger,
is_safe_to_cancel
):
- """
- Test that BigQueryInsertJobTrigger handles cancellation correctly,
logs the appropriate message,
- and conditionally cancels the job based on the `cancel_on_kill`
attribute.
- """
- mock_get_task_instance.return_value = True
+ """Test CancelledError handling for Airflow < 3.3.0."""
+ mock_safe_to_cancel.return_value = is_safe_to_cancel
+ mock_hook = mock_get_async_hook.return_value
+ mock_hook.get_job_status = AsyncMock()
+ mock_hook.get_job_status.side_effect = asyncio.CancelledError
+ mock_hook.cancel_job = AsyncMock()
+
+ async_gen = insert_job_trigger.run()
+ try:
+ await async_gen.asend(None)
+ except (asyncio.CancelledError, StopAsyncIteration):
+ pass
+ except Exception as e:
+ pytest.fail(f"Unexpected exception raised: {e}")
+
+ if insert_job_trigger.cancel_on_kill and is_safe_to_cancel:
+ mock_hook.cancel_job.assert_awaited_once_with(
+ job_id=insert_job_trigger.job_id,
+ project_id=insert_job_trigger.project_id,
+ location=insert_job_trigger.location,
+ )
+ else:
+ mock_hook.cancel_job.assert_not_awaited()
+
+ await async_gen.aclose()
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger._get_async_hook")
+ async def test_on_kill_cancels_job(self, mock_get_async_hook,
insert_job_trigger):
+ """Test that on_kill cancels the BigQuery job."""
+ mock_hook = mock_get_async_hook.return_value
+ mock_hook.cancel_job = AsyncMock()
+ insert_job_trigger.job_id = TEST_JOB_ID
insert_job_trigger.cancel_on_kill = True
- insert_job_trigger.job_id = "1234"
- mock_get_job_status.side_effect = [
- {"status": "running", "message": "Job is still running"},
- asyncio.CancelledError(),
- ]
+ await insert_job_trigger.on_kill()
- mock_cancel_job.return_value = asyncio.Future()
- mock_cancel_job.return_value.set_result(None)
+ mock_hook.cancel_job.assert_awaited_once_with(
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ location=TEST_LOCATION,
+ )
- caplog.set_level(logging.INFO)
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger._get_async_hook")
+ async def test_on_kill_respects_cancel_on_kill_false(self,
mock_get_async_hook, insert_job_trigger):
+ """Test that on_kill does not cancel the job when cancel_on_kill is
False."""
+ mock_hook = mock_get_async_hook.return_value
+ mock_hook.cancel_job = AsyncMock()
+ insert_job_trigger.job_id = TEST_JOB_ID
+ insert_job_trigger.cancel_on_kill = False
- with pytest.raises(asyncio.CancelledError):
- async for _ in insert_job_trigger.run():
- pass
+ await insert_job_trigger.on_kill()
- assert (
- "Task was killed" in caplog.text
- or "Bigquery job status is running. Sleeping for 4.0 seconds." in
caplog.text
- ), "Expected messages about task status or cancellation not found in
log."
- mock_cancel_job.assert_awaited_once()
+ mock_hook.cancel_job.assert_not_called()
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
-
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
- async def
test_bigquery_insert_job_trigger_cancellation_unsafe_cancellation(
- self, mock_safe_to_cancel, mock_get_job_status, mock_cancel_job,
caplog, insert_job_trigger
- ):
- """
- Test that BigQueryInsertJobTrigger logs the appropriate message and
does not cancel the job
- if safe_to_cancel returns False even when the task is cancelled.
- """
- mock_safe_to_cancel.return_value = False
+
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger._get_async_hook")
+ async def test_on_kill_no_job_id_does_not_cancel(self,
mock_get_async_hook, insert_job_trigger):
+ """Test that on_kill does not attempt to cancel when job_id is not
set."""
+ mock_hook = mock_get_async_hook.return_value
+ mock_hook.cancel_job = AsyncMock()
+ insert_job_trigger.job_id = None
insert_job_trigger.cancel_on_kill = True
- insert_job_trigger.job_id = "1234"
-
- # Simulate the initial job status as running
- mock_get_job_status.side_effect = [
- {"status": "running", "message": "Job is still running"},
- asyncio.CancelledError(),
- {"status": "running", "message": "Job is still running after
cancellation"},
- ]
- caplog.set_level(logging.INFO)
+ await insert_job_trigger.on_kill()
- with pytest.raises(asyncio.CancelledError):
- async for _ in insert_job_trigger.run():
- pass
-
- assert "Skipping to cancel job" in caplog.text, (
- "Expected message about skipping cancellation not found in log."
- )
- assert mock_get_job_status.call_count == 2, "Job status should be
checked multiple times"
+ mock_hook.cancel_job.assert_not_called()
class TestBigQueryGetDataTrigger: