gopidesupavan commented on code in PR #48819:
URL: https://github.com/apache/airflow/pull/48819#discussion_r2035734958
##########
providers/standard/src/airflow/providers/standard/triggers/external_task.py:
##########
@@ -89,38 +89,69 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"allowed_states": self.allowed_states,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
+ "execution_dates": self.execution_dates,
}
if AIRFLOW_V_3_0_PLUS:
data["run_ids"] = self.run_ids
- else:
- data["execution_dates"] = self.execution_dates
return
"airflow.providers.standard.triggers.external_task.WorkflowTrigger", data
async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically tasks, task group or dag status."""
+ get_count_func = self._get_count_af_3 if AIRFLOW_V_3_0_PLUS else
self._get_count
+ run_id_or_dates = self.run_ids or self.execution_dates or []
+
while True:
if self.failed_states:
- failed_count = await self._get_count(self.failed_states)
+ failed_count = await get_count_func(self.failed_states)
if failed_count > 0:
yield TriggerEvent({"status": "failed"})
return
else:
yield TriggerEvent({"status": "success"})
return
if self.skipped_states:
- skipped_count = await self._get_count(self.skipped_states)
+ skipped_count = await get_count_func(self.skipped_states)
if skipped_count > 0:
yield TriggerEvent({"status": "skipped"})
return
- allowed_count = await self._get_count(self.allowed_states)
- _dates = self.run_ids if AIRFLOW_V_3_0_PLUS else
self.execution_dates
- if allowed_count == len(_dates): # type: ignore[arg-type]
+ allowed_count = await get_count_func(self.allowed_states)
+
+ if allowed_count == len(run_id_or_dates): # type: ignore[arg-type]
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)
+ async def _get_count_af_3(self, states):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ run_id_or_dates = self.run_ids or self.execution_dates or []
+
+ if self.external_task_ids or self.external_task_group_id:
+ count = await sync_to_async(RuntimeTaskInstance.get_ti_count)(
+ dag_id=self.external_dag_id,
+ task_ids=self.external_task_ids,
+ task_group_id=self.external_task_group_id,
+ logical_dates=self.execution_dates,
+ run_ids=self.run_ids,
+ states=states,
+ )
+ else:
+ count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
+ dag_id=self.external_dag_id,
+ logical_dates=self.execution_dates,
+ run_ids=self.run_ids,
+ states=states,
+ )
+
+ if self.external_task_ids:
+ return count / len(self.external_task_ids)
+ elif self.external_task_group_id:
+ return count / len(run_id_or_dates)
+ else:
+ return count
Review Comment:
Yes agree with you, Thinking about this one more time, task_group condition
will never met in the current logic. i think we would have to return count of
succeeded task_groups.
For example, if a task group contains 3 tasks and the user provides 2
execution dates, the get_ti_count function returns 6, assuming all tasks
succeed. However, we can't directly validate this against the number of dates
provided (2). Even if we divide the total count (6) by the number of dates (2)
and get 3, this value still can't be reliably used to validate against the
dates alone.
So for the task_group_id case what would be the correct validation not sure
now i am getting doubt 🤔
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]