This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-3-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 626814aec802586a366403e7d8e95c07eaaf94bc Author: Tzu-ping Chung <[email protected]> AuthorDate: Sat Jul 2 08:06:27 2022 +0800 Refactor DR.task_instance_scheduling_decisions (#24774) (cherry picked from commit 5d5d62e41e93fe9845c96ab894047422761023d8) --- airflow/models/dagrun.py | 63 +++++++++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 968f360de7..0e7d4e1374 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -638,14 +638,23 @@ class DagRun(Base, LoggingMixin): @provide_session def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision: + tis = self.get_task_instances(session=session, state=State.task_states) + self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) - schedulable_tis: List[TI] = [] - changed_tis = False + def _filter_tis_and_exclude_removed(dag: "DAG", tis: List[TI]) -> Iterable[TI]: + """Populate ``ti.task`` while excluding those missing one, marking them as REMOVED.""" + for ti in tis: + try: + ti.task = dag.get_task(ti.task_id) + except TaskNotFound: + self.log.error("Failed to get task for ti %s. Marking it as removed.", ti) + ti.state = State.REMOVED + session.flush() + else: + yield ti - tis = list(self.get_task_instances(session=session, state=State.task_states)) - self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) - dag = self.get_dag() - missing_indexes = self._find_missing_task_indexes(dag, tis, session=session) + tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis)) + missing_indexes = self._find_missing_task_indexes(tis, session=session) if missing_indexes: self.verify_integrity(missing_indexes=missing_indexes, session=session) @@ -666,6 +675,9 @@ class DagRun(Base, LoggingMixin): new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished] finished_tis.extend(t for t in unfinished_tis if t.state in State.finished) unfinished_tis = new_unfinished_tis + else: + schedulable_tis = [] + changed_tis = False return TISchedulingDecision( tis=tis, @@ -1068,38 +1080,33 @@ class DagRun(Base, LoggingMixin): # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() - def _find_missing_task_indexes(self, dag, tis, *, session) -> Dict["MappedOperator", Sequence[int]]: - """ - Here we check if the length of the mapped task instances changed - at runtime. If so, we find the missing indexes. - - This function also marks task instances with missing tasks as REMOVED. + def _find_missing_task_indexes( + self, + tis: Iterable[TI], + *, + session: Session, + ) -> Dict["MappedOperator", Sequence[int]]: + """Check if the length of the mapped task instances changed at runtime and find the missing indexes. - :param dag: DAG object corresponding to the dagrun - :param tis: task instances to check - :param session: the session to use + :param tis: Task instances to check + :param session: The session to use """ - existing_indexes: Dict["MappedOperator", list] = defaultdict(list) - new_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list) - for ti in tis: - try: - task = ti.task = dag.get_task(ti.task_id) - except TaskNotFound: - self.log.error("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id) + from airflow.models.mappedoperator import MappedOperator - ti.state = State.REMOVED - session.flush() - continue - if not task.is_mapped: + existing_indexes: Dict[MappedOperator, List[int]] = defaultdict(list) + new_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list) + for ti in tis: + task = ti.task + if not isinstance(task, MappedOperator): continue # skip unexpanded tasks and also tasks that expands with literal arguments if ti.map_index < 0 or task.parse_time_mapped_ti_count: continue existing_indexes[task].append(ti.map_index) - task.run_time_mapped_ti_count.cache_clear() + task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined] new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0 new_indexes[task] = range(new_length) - missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list) + missing_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list) for k, v in existing_indexes.items(): missing_indexes.update({k: list(set(new_indexes[k]).difference(v))}) return missing_indexes
