kaxil commented on code in PR #48819:
URL: https://github.com/apache/airflow/pull/48819#discussion_r2033663680


##########
providers/standard/src/airflow/providers/standard/triggers/external_task.py:
##########
@@ -89,38 +89,69 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
             "allowed_states": self.allowed_states,
             "poke_interval": self.poke_interval,
             "soft_fail": self.soft_fail,
+            "execution_dates": self.execution_dates,
         }
         if AIRFLOW_V_3_0_PLUS:
             data["run_ids"] = self.run_ids
-        else:
-            data["execution_dates"] = self.execution_dates
 
         return 
"airflow.providers.standard.triggers.external_task.WorkflowTrigger", data
 
     async def run(self) -> typing.AsyncIterator[TriggerEvent]:
         """Check periodically tasks, task group or dag status."""
+        get_count_func = self._get_count_af_3 if AIRFLOW_V_3_0_PLUS else 
self._get_count
+        run_id_or_dates = self.run_ids or self.execution_dates or []
+
         while True:
             if self.failed_states:
-                failed_count = await self._get_count(self.failed_states)
+                failed_count = await get_count_func(self.failed_states)
                 if failed_count > 0:
                     yield TriggerEvent({"status": "failed"})
                     return
                 else:
                     yield TriggerEvent({"status": "success"})
                     return
             if self.skipped_states:
-                skipped_count = await self._get_count(self.skipped_states)
+                skipped_count = await get_count_func(self.skipped_states)
                 if skipped_count > 0:
                     yield TriggerEvent({"status": "skipped"})
                     return
-            allowed_count = await self._get_count(self.allowed_states)
-            _dates = self.run_ids if AIRFLOW_V_3_0_PLUS else 
self.execution_dates
-            if allowed_count == len(_dates):  # type: ignore[arg-type]
+            allowed_count = await get_count_func(self.allowed_states)
+
+            if allowed_count == len(run_id_or_dates):  # type: ignore[arg-type]
                 yield TriggerEvent({"status": "success"})
                 return
             self.log.info("Sleeping for %s seconds", self.poke_interval)
             await asyncio.sleep(self.poke_interval)
 
+    async def _get_count_af_3(self, states):
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        run_id_or_dates = self.run_ids or self.execution_dates or []
+
+        if self.external_task_ids or self.external_task_group_id:
+            count = await sync_to_async(RuntimeTaskInstance.get_ti_count)(
+                dag_id=self.external_dag_id,
+                task_ids=self.external_task_ids,
+                task_group_id=self.external_task_group_id,
+                logical_dates=self.execution_dates,
+                run_ids=self.run_ids,
+                states=states,
+            )
+        else:
+            count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
+                dag_id=self.external_dag_id,
+                logical_dates=self.execution_dates,
+                run_ids=self.run_ids,
+                states=states,
+            )
+
+        if self.external_task_ids:
+            return count / len(self.external_task_ids)
+        elif self.external_task_group_id:
+            return count / len(run_id_or_dates)
+        else:
+            return count

Review Comment:
   Does this work?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to