This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7bea27ca90d Remove direct access to DB for safe_to_cancel() method for 
Dataproc and BigQuery triggers (#49711)
7bea27ca90d is described below

commit 7bea27ca90de908621ee5f0b9095f4ab4621e526
Author: Maksim <[email protected]>
AuthorDate: Thu May 29 05:06:56 2025 -0700

    Remove direct access to DB for safe_to_cancel() method for Dataproc and 
BigQuery triggers (#49711)
---
 .../providers/google/cloud/triggers/bigquery.py    | 37 +++++++++--
 .../providers/google/cloud/triggers/dataproc.py    | 72 +++++++++++++++++++---
 2 files changed, 94 insertions(+), 15 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py 
b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
index 13d9f55bacb..2cac99543a6 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
@@ -22,10 +22,12 @@ from typing import TYPE_CHECKING, Any, SupportsAbs
 
 from aiohttp import ClientSession
 from aiohttp.client_exceptions import ClientResponseError
+from asgiref.sync import sync_to_async
 
 from airflow.exceptions import AirflowException
 from airflow.models.taskinstance import TaskInstance
 from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, 
BigQueryTableAsyncHook
+from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 from airflow.utils.session import provide_session
 from airflow.utils.state import TaskInstanceState
@@ -116,16 +118,41 @@ class BigQueryInsertJobTrigger(BaseTrigger):
             )
         return task_instance
 
-    def safe_to_cancel(self) -> bool:
+    async def get_task_state(self):
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+            dag_id=self.task_instance.dag_id,
+            task_ids=[self.task_instance.task_id],
+            run_ids=[self.task_instance.run_id],
+            map_index=self.task_instance.map_index,
+        )
+        try:
+            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+        except Exception:
+            raise AirflowException(
+                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
+                self.task_instance.dag_id,
+                self.task_instance.task_id,
+                self.task_instance.run_id,
+                self.task_instance.map_index,
+            )
+        return task_state
+
+    async def safe_to_cancel(self) -> bool:
         """
         Whether it is safe to cancel the external job which is being executed 
by this trigger.
 
         This is to avoid the case that `asyncio.CancelledError` is called 
because the trigger itself is stopped.
         Because in those cases, we should NOT cancel the external job.
         """
-        # Database query is needed to get the latest state of the task 
instance.
-        task_instance = self.get_task_instance()  # type: ignore[call-arg]
-        return task_instance.state != TaskInstanceState.DEFERRED
+        if AIRFLOW_V_3_0_PLUS:
+            task_state = await self.get_task_state()
+        else:
+            # Database query is needed to get the latest state of the task 
instance.
+            task_instance = self.get_task_instance()  # type: ignore[call-arg]
+            task_state = task_instance.state
+        return task_state != TaskInstanceState.DEFERRED
 
     async def run(self) -> AsyncIterator[TriggerEvent]:  # type: 
ignore[override]
         """Get current job execution status and yields a TriggerEvent."""
@@ -155,7 +182,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
                     )
                     await asyncio.sleep(self.poll_interval)
         except asyncio.CancelledError:
-            if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
+            if self.job_id and self.cancel_on_kill and await 
self.safe_to_cancel():
                 self.log.info(
                     "The job is safe to cancel the as airflow TaskInstance is 
not in deferred state."
                 )
diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py 
b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
index 80afd381cc3..2f44d370432 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
@@ -25,6 +25,7 @@ import time
 from collections.abc import AsyncIterator, Sequence
 from typing import TYPE_CHECKING, Any
 
+from asgiref.sync import sync_to_async
 from google.api_core.exceptions import NotFound
 from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
 
@@ -33,6 +34,7 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, 
DataprocHook
 from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
 from airflow.providers.google.common.hooks.base_google import 
PROVIDE_PROJECT_ID
+from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 from airflow.utils.session import provide_session
 from airflow.utils.state import TaskInstanceState
@@ -141,16 +143,41 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
             )
         return task_instance
 
-    def safe_to_cancel(self) -> bool:
+    async def get_task_state(self):
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+            dag_id=self.task_instance.dag_id,
+            task_ids=[self.task_instance.task_id],
+            run_ids=[self.task_instance.run_id],
+            map_index=self.task_instance.map_index,
+        )
+        try:
+            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+        except Exception:
+            raise AirflowException(
+                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
+                self.task_instance.dag_id,
+                self.task_instance.task_id,
+                self.task_instance.run_id,
+                self.task_instance.map_index,
+            )
+        return task_state
+
+    async def safe_to_cancel(self) -> bool:
         """
         Whether it is safe to cancel the external job which is being executed 
by this trigger.
 
         This is to avoid the case that `asyncio.CancelledError` is called 
because the trigger itself is stopped.
         Because in those cases, we should NOT cancel the external job.
         """
-        # Database query is needed to get the latest state of the task 
instance.
-        task_instance = self.get_task_instance()  # type: ignore[call-arg]
-        return task_instance.state != TaskInstanceState.DEFERRED
+        if AIRFLOW_V_3_0_PLUS:
+            task_state = await self.get_task_state()
+        else:
+            # Database query is needed to get the latest state of the task 
instance.
+            task_instance = self.get_task_instance()  # type: ignore[call-arg]
+            task_state = task_instance.state
+        return task_state != TaskInstanceState.DEFERRED
 
     async def run(self):
         try:
@@ -167,7 +194,7 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
         except asyncio.CancelledError:
             self.log.info("Task got cancelled.")
             try:
-                if self.job_id and self.cancel_on_kill and 
self.safe_to_cancel():
+                if self.job_id and self.cancel_on_kill and await 
self.safe_to_cancel():
                     self.log.info(
                         "Cancelling the job as it is safe to do so. Note that 
the airflow TaskInstance is not"
                         " in deferred state."
@@ -243,16 +270,41 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
             )
         return task_instance
 
-    def safe_to_cancel(self) -> bool:
+    async def get_task_state(self):
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+            dag_id=self.task_instance.dag_id,
+            task_ids=[self.task_instance.task_id],
+            run_ids=[self.task_instance.run_id],
+            map_index=self.task_instance.map_index,
+        )
+        try:
+            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+        except Exception:
+            raise AirflowException(
+                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
+                self.task_instance.dag_id,
+                self.task_instance.task_id,
+                self.task_instance.run_id,
+                self.task_instance.map_index,
+            )
+        return task_state
+
+    async def safe_to_cancel(self) -> bool:
         """
         Whether it is safe to cancel the external job which is being executed 
by this trigger.
 
         This is to avoid the case that `asyncio.CancelledError` is called 
because the trigger itself is stopped.
         Because in those cases, we should NOT cancel the external job.
         """
-        # Database query is needed to get the latest state of the task 
instance.
-        task_instance = self.get_task_instance()  # type: ignore[call-arg]
-        return task_instance.state != TaskInstanceState.DEFERRED
+        if AIRFLOW_V_3_0_PLUS:
+            task_state = await self.get_task_state()
+        else:
+            # Database query is needed to get the latest state of the task 
instance.
+            task_instance = self.get_task_instance()  # type: ignore[call-arg]
+            task_state = task_instance.state
+        return task_state != TaskInstanceState.DEFERRED
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         try:
@@ -283,7 +335,7 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
                 await asyncio.sleep(self.polling_interval_seconds)
         except asyncio.CancelledError:
             try:
-                if self.delete_on_error and self.safe_to_cancel():
+                if self.delete_on_error and await self.safe_to_cancel():
                     self.log.info(
                         "Deleting the cluster as it is safe to delete as the 
airflow TaskInstance is not in "
                         "deferred state."

Reply via email to