This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7bea27ca90d Remove direct access to DB for safe_to_cancel() method for
Dataproc and BigQuery triggers (#49711)
7bea27ca90d is described below
commit 7bea27ca90de908621ee5f0b9095f4ab4621e526
Author: Maksim <[email protected]>
AuthorDate: Thu May 29 05:06:56 2025 -0700
Remove direct access to DB for safe_to_cancel() method for Dataproc and
BigQuery triggers (#49711)
---
.../providers/google/cloud/triggers/bigquery.py | 37 +++++++++--
.../providers/google/cloud/triggers/dataproc.py | 72 +++++++++++++++++++---
2 files changed, 94 insertions(+), 15 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
index 13d9f55bacb..2cac99543a6 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
@@ -22,10 +22,12 @@ from typing import TYPE_CHECKING, Any, SupportsAbs
from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError
+from asgiref.sync import sync_to_async
from airflow.exceptions import AirflowException
from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook,
BigQueryTableAsyncHook
+from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from airflow.utils.state import TaskInstanceState
@@ -116,16 +118,41 @@ class BigQueryInsertJobTrigger(BaseTrigger):
)
return task_instance
- def safe_to_cancel(self) -> bool:
+ async def get_task_state(self):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_state
+
+ async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed
by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
- # Database query is needed to get the latest state of the task
instance.
- task_instance = self.get_task_instance() # type: ignore[call-arg]
- return task_instance.state != TaskInstanceState.DEFERRED
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = await self.get_task_state()
+ else:
+ # Database query is needed to get the latest state of the task
instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ task_state = task_instance.state
+ return task_state != TaskInstanceState.DEFERRED
async def run(self) -> AsyncIterator[TriggerEvent]: # type:
ignore[override]
"""Get current job execution status and yields a TriggerEvent."""
@@ -155,7 +182,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
)
await asyncio.sleep(self.poll_interval)
except asyncio.CancelledError:
- if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
+ if self.job_id and self.cancel_on_kill and await
self.safe_to_cancel():
self.log.info(
"The job is safe to cancel the as airflow TaskInstance is
not in deferred state."
)
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
index 80afd381cc3..2f44d370432 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
@@ -25,6 +25,7 @@ import time
from collections.abc import AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any
+from asgiref.sync import sync_to_async
from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
@@ -33,6 +34,7 @@ from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook,
DataprocHook
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
+from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from airflow.utils.state import TaskInstanceState
@@ -141,16 +143,41 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
)
return task_instance
- def safe_to_cancel(self) -> bool:
+ async def get_task_state(self):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_state
+
+ async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed
by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
- # Database query is needed to get the latest state of the task
instance.
- task_instance = self.get_task_instance() # type: ignore[call-arg]
- return task_instance.state != TaskInstanceState.DEFERRED
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = await self.get_task_state()
+ else:
+ # Database query is needed to get the latest state of the task
instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ task_state = task_instance.state
+ return task_state != TaskInstanceState.DEFERRED
async def run(self):
try:
@@ -167,7 +194,7 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
except asyncio.CancelledError:
self.log.info("Task got cancelled.")
try:
- if self.job_id and self.cancel_on_kill and
self.safe_to_cancel():
+ if self.job_id and self.cancel_on_kill and await
self.safe_to_cancel():
self.log.info(
"Cancelling the job as it is safe to do so. Note that
the airflow TaskInstance is not"
" in deferred state."
@@ -243,16 +270,41 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
)
return task_instance
- def safe_to_cancel(self) -> bool:
+ async def get_task_state(self):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_state
+
+ async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed
by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
- # Database query is needed to get the latest state of the task
instance.
- task_instance = self.get_task_instance() # type: ignore[call-arg]
- return task_instance.state != TaskInstanceState.DEFERRED
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = await self.get_task_state()
+ else:
+ # Database query is needed to get the latest state of the task
instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ task_state = task_instance.state
+ return task_state != TaskInstanceState.DEFERRED
async def run(self) -> AsyncIterator[TriggerEvent]:
try:
@@ -283,7 +335,7 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
await asyncio.sleep(self.polling_interval_seconds)
except asyncio.CancelledError:
try:
- if self.delete_on_error and self.safe_to_cancel():
+ if self.delete_on_error and await self.safe_to_cancel():
self.log.info(
"Deleting the cluster as it is safe to delete as the
airflow TaskInstance is not in "
"deferred state."