gopidesupavan commented on code in PR #48819:
URL: https://github.com/apache/airflow/pull/48819#discussion_r2035734958


##########
providers/standard/src/airflow/providers/standard/triggers/external_task.py:
##########
@@ -89,38 +89,69 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
             "allowed_states": self.allowed_states,
             "poke_interval": self.poke_interval,
             "soft_fail": self.soft_fail,
+            "execution_dates": self.execution_dates,
         }
         if AIRFLOW_V_3_0_PLUS:
             data["run_ids"] = self.run_ids
-        else:
-            data["execution_dates"] = self.execution_dates
 
         return 
"airflow.providers.standard.triggers.external_task.WorkflowTrigger", data
 
     async def run(self) -> typing.AsyncIterator[TriggerEvent]:
         """Check periodically tasks, task group or dag status."""
+        get_count_func = self._get_count_af_3 if AIRFLOW_V_3_0_PLUS else 
self._get_count
+        run_id_or_dates = self.run_ids or self.execution_dates or []
+
         while True:
             if self.failed_states:
-                failed_count = await self._get_count(self.failed_states)
+                failed_count = await get_count_func(self.failed_states)
                 if failed_count > 0:
                     yield TriggerEvent({"status": "failed"})
                     return
                 else:
                     yield TriggerEvent({"status": "success"})
                     return
             if self.skipped_states:
-                skipped_count = await self._get_count(self.skipped_states)
+                skipped_count = await get_count_func(self.skipped_states)
                 if skipped_count > 0:
                     yield TriggerEvent({"status": "skipped"})
                     return
-            allowed_count = await self._get_count(self.allowed_states)
-            _dates = self.run_ids if AIRFLOW_V_3_0_PLUS else 
self.execution_dates
-            if allowed_count == len(_dates):  # type: ignore[arg-type]
+            allowed_count = await get_count_func(self.allowed_states)
+
+            if allowed_count == len(run_id_or_dates):  # type: ignore[arg-type]
                 yield TriggerEvent({"status": "success"})
                 return
             self.log.info("Sleeping for %s seconds", self.poke_interval)
             await asyncio.sleep(self.poke_interval)
 
+    async def _get_count_af_3(self, states):
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        run_id_or_dates = self.run_ids or self.execution_dates or []
+
+        if self.external_task_ids or self.external_task_group_id:
+            count = await sync_to_async(RuntimeTaskInstance.get_ti_count)(
+                dag_id=self.external_dag_id,
+                task_ids=self.external_task_ids,
+                task_group_id=self.external_task_group_id,
+                logical_dates=self.execution_dates,
+                run_ids=self.run_ids,
+                states=states,
+            )
+        else:
+            count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
+                dag_id=self.external_dag_id,
+                logical_dates=self.execution_dates,
+                run_ids=self.run_ids,
+                states=states,
+            )
+
+        if self.external_task_ids:
+            return count / len(self.external_task_ids)
+        elif self.external_task_group_id:
+            return count / len(run_id_or_dates)
+        else:
+            return count

Review Comment:
   Yes agree with you, Thinking about this one more time, task_group condition 
will never met in the current logic. i think we would have to return count of 
succeeded task_groups.
    
   For example, if a task group contains 3 tasks and the user provides 2 
execution dates, the get_ti_count function returns 6, assuming all tasks 
succeed. However, we can't directly validate this against the number of dates 
provided (2). Even if we divide the total count (6) by the number of dates (2) 
and get 3, this value still can't be reliably used to validate against the 
dates alone.
   
   So for the task_group_id case what would be the correct validation not sure 
now i am getting doubt 🤔 
   
   
   
   
   
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to