This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 626814aec802586a366403e7d8e95c07eaaf94bc
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Sat Jul 2 08:06:27 2022 +0800

    Refactor DR.task_instance_scheduling_decisions (#24774)
    
    (cherry picked from commit 5d5d62e41e93fe9845c96ab894047422761023d8)
---
 airflow/models/dagrun.py | 63 +++++++++++++++++++++++++++---------------------
 1 file changed, 35 insertions(+), 28 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 968f360de7..0e7d4e1374 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -638,14 +638,23 @@ class DagRun(Base, LoggingMixin):
 
     @provide_session
     def task_instance_scheduling_decisions(self, session: Session = 
NEW_SESSION) -> TISchedulingDecision:
+        tis = self.get_task_instances(session=session, state=State.task_states)
+        self.log.debug("number of tis tasks for %s: %s task(s)", self, 
len(tis))
 
-        schedulable_tis: List[TI] = []
-        changed_tis = False
+        def _filter_tis_and_exclude_removed(dag: "DAG", tis: List[TI]) -> 
Iterable[TI]:
+            """Populate ``ti.task`` while excluding those missing one, marking 
them as REMOVED."""
+            for ti in tis:
+                try:
+                    ti.task = dag.get_task(ti.task_id)
+                except TaskNotFound:
+                    self.log.error("Failed to get task for ti %s. Marking it 
as removed.", ti)
+                    ti.state = State.REMOVED
+                    session.flush()
+                else:
+                    yield ti
 
-        tis = list(self.get_task_instances(session=session, 
state=State.task_states))
-        self.log.debug("number of tis tasks for %s: %s task(s)", self, 
len(tis))
-        dag = self.get_dag()
-        missing_indexes = self._find_missing_task_indexes(dag, tis, 
session=session)
+        tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
+        missing_indexes = self._find_missing_task_indexes(tis, session=session)
         if missing_indexes:
             self.verify_integrity(missing_indexes=missing_indexes, 
session=session)
 
@@ -666,6 +675,9 @@ class DagRun(Base, LoggingMixin):
                 new_unfinished_tis = [t for t in unfinished_tis if t.state in 
State.unfinished]
                 finished_tis.extend(t for t in unfinished_tis if t.state in 
State.finished)
                 unfinished_tis = new_unfinished_tis
+        else:
+            schedulable_tis = []
+            changed_tis = False
 
         return TISchedulingDecision(
             tis=tis,
@@ -1068,38 +1080,33 @@ class DagRun(Base, LoggingMixin):
             # TODO[HA]: We probably need to savepoint this so we can keep the 
transaction alive.
             session.rollback()
 
-    def _find_missing_task_indexes(self, dag, tis, *, session) -> 
Dict["MappedOperator", Sequence[int]]:
-        """
-        Here we check if the length of the mapped task instances changed
-        at runtime. If so, we find the missing indexes.
-
-        This function also marks task instances with missing tasks as REMOVED.
+    def _find_missing_task_indexes(
+        self,
+        tis: Iterable[TI],
+        *,
+        session: Session,
+    ) -> Dict["MappedOperator", Sequence[int]]:
+        """Check if the length of the mapped task instances changed at runtime 
and find the missing indexes.
 
-        :param dag: DAG object corresponding to the dagrun
-        :param tis: task instances to check
-        :param session: the session to use
+        :param tis: Task instances to check
+        :param session: The session to use
         """
-        existing_indexes: Dict["MappedOperator", list] = defaultdict(list)
-        new_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
-        for ti in tis:
-            try:
-                task = ti.task = dag.get_task(ti.task_id)
-            except TaskNotFound:
-                self.log.error("Failed to get task '%s' for dag '%s'. Marking 
it as removed.", ti, ti.dag_id)
+        from airflow.models.mappedoperator import MappedOperator
 
-                ti.state = State.REMOVED
-                session.flush()
-                continue
-            if not task.is_mapped:
+        existing_indexes: Dict[MappedOperator, List[int]] = defaultdict(list)
+        new_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
+        for ti in tis:
+            task = ti.task
+            if not isinstance(task, MappedOperator):
                 continue
             # skip unexpanded tasks and also tasks that expands with literal 
arguments
             if ti.map_index < 0 or task.parse_time_mapped_ti_count:
                 continue
             existing_indexes[task].append(ti.map_index)
-            task.run_time_mapped_ti_count.cache_clear()
+            task.run_time_mapped_ti_count.cache_clear()  # type: 
ignore[attr-defined]
             new_length = task.run_time_mapped_ti_count(self.run_id, 
session=session) or 0
             new_indexes[task] = range(new_length)
-        missing_indexes: Dict["MappedOperator", Sequence[int]] = 
defaultdict(list)
+        missing_indexes: Dict[MappedOperator, Sequence[int]] = 
defaultdict(list)
         for k, v in existing_indexes.items():
             missing_indexes.update({k: 
list(set(new_indexes[k]).difference(v))})
         return missing_indexes

Reply via email to