This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch show-mapped-task-in-tree-view
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 0a315d88c8e35a2e6f832e6fb938086cbee7a025
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Tue Feb 15 23:18:50 2022 +0000

    Expand mapped tasks in the Scheduler
    
    Technically this is done inside
    DagRun.task_instance_scheduling_decisions, but the only place that is
    currently called is the Scheduler
    
    The way we are getting `upstream_ti` to pass to expand_mapped_task is
    all sorts of wrong and will need fixing, I think the interface for that
    method is wrong and the mapped task should be responsible for finding
    the right upstream TI itself.
---
 airflow/models/dagrun.py | 48 ++++++++++++++++++++++++------------------------
 1 file changed, 24 insertions(+), 24 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 5170ad3..69b003f 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -43,7 +43,6 @@ from sqlalchemy.orm.session import Session
 from sqlalchemy.sql.expression import false, select, true
 
 from airflow import settings
-from airflow.callbacks.callback_requests import DagCallbackRequest
 from airflow.configuration import conf as airflow_conf
 from airflow.exceptions import AirflowException, TaskNotFound
 from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
@@ -53,7 +52,7 @@ from airflow.models.tasklog import LogTemplate
 from airflow.stats import Stats
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
-from airflow.utils import timezone
+from airflow.utils import callback_requests, timezone
 from airflow.utils.helpers import is_container
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
@@ -488,7 +487,7 @@ class DagRun(Base, LoggingMixin):
     @provide_session
     def update_state(
         self, session: Session = NEW_SESSION, execute_callbacks: bool = True
-    ) -> Tuple[List[TI], Optional[DagCallbackRequest]]:
+    ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]:
         """
         Determines the overall state of the DagRun based on the state
         of its TaskInstances.
@@ -500,7 +499,7 @@ class DagRun(Base, LoggingMixin):
             needs to be executed
         """
         # Callback to execute in case of Task Failures
-        callback: Optional[DagCallbackRequest] = None
+        callback: Optional[callback_requests.DagCallbackRequest] = None
 
         start_dttm = timezone.utcnow()
         self.last_scheduling_decision = start_dttm
@@ -536,7 +535,7 @@ class DagRun(Base, LoggingMixin):
             if execute_callbacks:
                 dag.handle_callback(self, success=False, 
reason='task_failure', session=session)
             elif dag.has_on_failure_callback:
-                callback = DagCallbackRequest(
+                callback = callback_requests.DagCallbackRequest(
                     full_filepath=dag.fileloc,
                     dag_id=self.dag_id,
                     run_id=self.run_id,
@@ -551,7 +550,7 @@ class DagRun(Base, LoggingMixin):
             if execute_callbacks:
                 dag.handle_callback(self, success=True, reason='success', 
session=session)
             elif dag.has_on_success_callback:
-                callback = DagCallbackRequest(
+                callback = callback_requests.DagCallbackRequest(
                     full_filepath=dag.fileloc,
                     dag_id=self.dag_id,
                     run_id=self.run_id,
@@ -572,7 +571,7 @@ class DagRun(Base, LoggingMixin):
             if execute_callbacks:
                 dag.handle_callback(self, success=False, 
reason='all_tasks_deadlocked', session=session)
             elif dag.has_on_failure_callback:
-                callback = DagCallbackRequest(
+                callback = callback_requests.DagCallbackRequest(
                     full_filepath=dag.fileloc,
                     dag_id=self.dag_id,
                     run_id=self.run_id,
@@ -652,7 +651,7 @@ class DagRun(Base, LoggingMixin):
 
     def _get_ready_tis(
         self,
-        schedulable_tis: List[TI],
+        scheduleable_tis: List[TI],
         finished_tis: List[TI],
         session: Session,
     ) -> Tuple[List[TI], bool]:
@@ -660,40 +659,41 @@ class DagRun(Base, LoggingMixin):
         ready_tis: List[TI] = []
         changed_tis = False
 
-        if not schedulable_tis:
+        if not scheduleable_tis:
             return ready_tis, changed_tis
 
         # If we expand TIs, we need a new list so that we iterate over them 
too. (We can't alter
-        # `schedulable_tis` in place and have the `for` loop pick them up
+        # `scheduleable_tis` in place and have the `for` loop pick them up
         expanded_tis: List[TI] = []
 
         # Check dependencies
-        for schedulable in itertools.chain(schedulable_tis, expanded_tis):
+        for st in itertools.chain(scheduleable_tis, expanded_tis):
 
             # Expansion of last resort! This is ideally handled in the 
mini-scheduler in LocalTaskJob, but if
             # for any reason it wasn't, we need to expand it now
-            if schedulable.map_index < 0 and schedulable.task.is_mapped:
+            if st.map_index < 0 and st.task.is_mapped:
                 # HACK. This needs a better way, one that copes with multiple 
upstreams!
                 for ti in finished_tis:
-                    if schedulable.task_id in ti.task.downstream_task_ids:
-
-                        assert isinstance(schedulable.task, MappedOperator)
-                        new_tis = 
schedulable.task.expand_mapped_task(self.run_id, session=session)
-                        if schedulable.state == TaskInstanceState.SKIPPED:
-                            # Task is now skipped (likely cos upstream 
returned 0 tasks
-                            continue
-                        assert new_tis[0] is schedulable
+                    if st.task_id in ti.task.downstream_task_ids:
+                        upstream = ti
+
+                        assert isinstance(st.task, MappedOperator)
+                        new_tis = st.task.expand_mapped_task(upstream, 
session=session)
+                        assert new_tis[0] is st
+                        # Add the new TIs to the list to be checked
+                        for new_ti in new_tis[1:]:
+                            new_ti.task = st.task
                         expanded_tis.extend(new_tis[1:])
                         break
 
-            old_state = schedulable.state
-            if schedulable.are_dependencies_met(
+            old_state = st.state
+            if st.are_dependencies_met(
                 dep_context=DepContext(flag_upstream_failed=True, 
finished_tis=finished_tis),
                 session=session,
             ):
-                ready_tis.append(schedulable)
+                ready_tis.append(st)
             else:
-                old_states[schedulable.key] = old_state
+                old_states[st.key] = old_state
 
         # Check if any ti changed state
         tis_filter = TI.filter_for_tis(old_states.keys())

Reply via email to