pankajkoti commented on code in PR #49711:
URL: https://github.com/apache/airflow/pull/49711#discussion_r2104404613
##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -243,16 +267,39 @@ def get_task_instance(self, session: Session) ->
TaskInstance:
)
return task_instance
+ def get_task_state(self):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = RuntimeTaskInstance.get_task_states(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s is not
found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ )
+ return task_state
+
def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed
by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
- # Database query is needed to get the latest state of the task
instance.
- task_instance = self.get_task_instance() # type: ignore[call-arg]
- return task_instance.state != TaskInstanceState.DEFERRED
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = self.get_task_state()
+ else:
+ # Database query is needed to get the latest state of the task
instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ task_state = task_instance.state
Review Comment:
Yes, the suggestion appears to be outside the scope of this PR. The focus
here is to ensure compatibility with Airflow 3, aligning it with the existing
behaviour in Airflow 2.
##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -243,16 +270,41 @@ def get_task_instance(self, session: Session) ->
TaskInstance:
)
return task_instance
- def safe_to_cancel(self) -> bool:
+ async def get_task_state(self):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_state
+
+ async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed
by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
- # Database query is needed to get the latest state of the task
instance.
- task_instance = self.get_task_instance() # type: ignore[call-arg]
- return task_instance.state != TaskInstanceState.DEFERRED
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = await self.get_task_state()
+ else:
+ # Database query is needed to get the latest state of the task
instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
Review Comment:
we should check if we can make `get_task_instance()` async since we're
intending for good to make `safe_to_cancel` async too but can be scoped for a
later PR
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]