gopidesupavan commented on code in PR #48819:
URL: https://github.com/apache/airflow/pull/48819#discussion_r2033788583
##########
providers/standard/src/airflow/providers/standard/triggers/external_task.py:
##########
@@ -89,38 +89,69 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"allowed_states": self.allowed_states,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
+ "execution_dates": self.execution_dates,
}
if AIRFLOW_V_3_0_PLUS:
data["run_ids"] = self.run_ids
- else:
- data["execution_dates"] = self.execution_dates
return
"airflow.providers.standard.triggers.external_task.WorkflowTrigger", data
async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically tasks, task group or dag status."""
+ get_count_func = self._get_count_af_3 if AIRFLOW_V_3_0_PLUS else
self._get_count
+ run_id_or_dates = self.run_ids or self.execution_dates or []
+
while True:
if self.failed_states:
- failed_count = await self._get_count(self.failed_states)
+ failed_count = await get_count_func(self.failed_states)
if failed_count > 0:
yield TriggerEvent({"status": "failed"})
return
else:
yield TriggerEvent({"status": "success"})
return
if self.skipped_states:
- skipped_count = await self._get_count(self.skipped_states)
+ skipped_count = await get_count_func(self.skipped_states)
if skipped_count > 0:
yield TriggerEvent({"status": "skipped"})
return
- allowed_count = await self._get_count(self.allowed_states)
- _dates = self.run_ids if AIRFLOW_V_3_0_PLUS else
self.execution_dates
- if allowed_count == len(_dates): # type: ignore[arg-type]
+ allowed_count = await get_count_func(self.allowed_states)
+
+ if allowed_count == len(run_id_or_dates): # type: ignore[arg-type]
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)
+ async def _get_count_af_3(self, states):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ run_id_or_dates = self.run_ids or self.execution_dates or []
+
+ if self.external_task_ids or self.external_task_group_id:
+ count = await sync_to_async(RuntimeTaskInstance.get_ti_count)(
+ dag_id=self.external_dag_id,
+ task_ids=self.external_task_ids,
+ task_group_id=self.external_task_group_id,
+ logical_dates=self.execution_dates,
+ run_ids=self.run_ids,
+ states=states,
+ )
+ else:
+ count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
+ dag_id=self.external_dag_id,
+ logical_dates=self.execution_dates,
+ run_ids=self.run_ids,
+ states=states,
+ )
+
+ if self.external_task_ids:
+ return count / len(self.external_task_ids)
+ elif self.external_task_group_id:
+ return count / len(run_id_or_dates)
+ else:
+ return count
Review Comment:
Let me double check..
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]