sunank200 commented on code in PR #49711:
URL: https://github.com/apache/airflow/pull/49711#discussion_r2058734598
##########
providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py:
##########
@@ -116,16 +117,39 @@ def get_task_instance(self, session: Session) ->
TaskInstance:
)
return task_instance
+ def get_task_state(self):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = RuntimeTaskInstance.get_task_states(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s is not
found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ )
+ return task_state
+
def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed
by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
- # Database query is needed to get the latest state of the task
instance.
- task_instance = self.get_task_instance() # type: ignore[call-arg]
- return task_instance.state != TaskInstanceState.DEFERRED
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = self.get_task_state()
+ else:
+ # Database query is needed to get the latest state of the task
instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ task_state = task_instance.state
Review Comment:
For more context, please go through
https://github.com/apache/airflow/issues/36090#issuecomment-2164294237
##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -243,16 +267,39 @@ def get_task_instance(self, session: Session) ->
TaskInstance:
)
return task_instance
+ def get_task_state(self):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = RuntimeTaskInstance.get_task_states(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s is not
found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ )
+ return task_state
+
def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed
by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
- # Database query is needed to get the latest state of the task
instance.
- task_instance = self.get_task_instance() # type: ignore[call-arg]
- return task_instance.state != TaskInstanceState.DEFERRED
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = self.get_task_state()
+ else:
+ # Database query is needed to get the latest state of the task
instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ task_state = task_instance.state
Review Comment:
For more context, please go through
https://github.com/apache/airflow/issues/36090#issuecomment-2164294237
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]