This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 64ec6c7639323cec6fd8b3525fbf7e6a5537e5bd
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Wed Nov 9 15:05:59 2022 +0100

    Fix mini scheduler expansion of mapped task  (#27506)
    
    We have a case where the mini scheduler tries to expand a mapped task even 
when the downstream tasks are not yet done.
    
    The mini scheduler extracts a partial subset of a dag and in the process, 
some upstream tasks are dropped.
    If the task happens to be a mapped task, the expansion will fail since it 
needs the upstream output to make the expansion. When the expansion fails, the 
task is marked as `upstream_failed`. This leads to other downstream tasks being 
marked as upstream failed.
    
    The solution was to ignore this error and not mark the mapped task as 
upstream_failed when the expansion fails and the dag is a partial subset
    
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
    (cherry picked from commit ed92e5d521f958642615b038ec13068b527db1c4)
---
 airflow/jobs/local_task_job.py    | 59 +---------------------------
 airflow/models/mappedoperator.py  | 30 +++++++++-----
 airflow/models/taskinstance.py    | 61 ++++++++++++++++++++++++++++
 tests/jobs/test_local_task_job.py |  1 -
 tests/models/test_taskinstance.py | 83 +++++++++++++++++++++++++++++++++++++++
 5 files changed, 165 insertions(+), 69 deletions(-)

diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 1881511f9b..698c469dbb 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -18,25 +18,20 @@
 from __future__ import annotations
 
 import signal
-from typing import TYPE_CHECKING
 
 import psutil
-from sqlalchemy.exc import OperationalError
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.jobs.base_job import BaseJob
 from airflow.listeners.events import register_task_instance_state_events
 from airflow.listeners.listener import get_listener_manager
-from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
-from airflow.sentry import Sentry
 from airflow.stats import Stats
 from airflow.task.task_runner import get_task_runner
 from airflow.utils import timezone
 from airflow.utils.net import get_hostname
 from airflow.utils.session import provide_session
-from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.state import State
 
 
@@ -165,7 +160,7 @@ class LocalTaskJob(BaseJob):
 
         if not self.task_instance.test_mode:
             if conf.getboolean('scheduler', 'schedule_after_task_execution', 
fallback=True):
-                self._run_mini_scheduler_on_child_tasks()
+                self.task_instance.schedule_downstream_tasks()
 
     def on_kill(self):
         self.task_runner.terminate()
@@ -230,58 +225,6 @@ class LocalTaskJob(BaseJob):
                 self.terminating = True
             self._state_change_checks += 1
 
-    @provide_session
-    @Sentry.enrich_errors
-    def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
-        try:
-            # Re-select the row with a lock
-            dag_run = with_row_locks(
-                session.query(DagRun).filter_by(
-                    dag_id=self.dag_id,
-                    run_id=self.task_instance.run_id,
-                ),
-                session=session,
-            ).one()
-
-            task = self.task_instance.task
-            if TYPE_CHECKING:
-                assert task.dag
-
-            # Get a partial DAG with just the specific tasks we want to 
examine.
-            # In order for dep checks to work correctly, we include ourself (so
-            # TriggerRuleDep can check the state of the task we just executed).
-            partial_dag = task.dag.partial_subset(
-                task.downstream_task_ids,
-                include_downstream=True,
-                include_upstream=False,
-                include_direct_upstream=True,
-            )
-
-            dag_run.dag = partial_dag
-            info = dag_run.task_instance_scheduling_decisions(session)
-
-            skippable_task_ids = {
-                task_id for task_id in partial_dag.task_ids if task_id not in 
task.downstream_task_ids
-            }
-
-            schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id 
not in skippable_task_ids]
-            for schedulable_ti in schedulable_tis:
-                if not hasattr(schedulable_ti, "task"):
-                    schedulable_ti.task = 
task.dag.get_task(schedulable_ti.task_id)
-
-            num = dag_run.schedule_tis(schedulable_tis)
-            self.log.info("%d downstream tasks scheduled from follow-on 
schedule check", num)
-
-            session.commit()
-        except OperationalError as e:
-            # Any kind of DB error here is _non fatal_ as this block is just 
an optimisation.
-            self.log.info(
-                "Skipping mini scheduling run due to exception: %s",
-                e.statement,
-                exc_info=True,
-            )
-            session.rollback()
-
     @staticmethod
     def _enable_task_listeners():
         """
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 62cc22f379..9c591bf364 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -619,13 +619,18 @@ class MappedOperator(AbstractOperator):
         try:
             total_length = 
self._get_specified_expand_input().get_total_map_length(run_id, session=session)
         except NotFullyPopulated as e:
-            self.log.info(
-                "Cannot expand %r for run %s; missing upstream values: %s",
-                self,
-                run_id,
-                sorted(e.missing),
-            )
             total_length = None
+            # partial dags comes from the mini scheduler. It's
+            # possible that the upstream tasks are not yet done,
+            # but we don't have upstream of upstreams in partial dags,
+            # so we ignore this exception.
+            if not self.dag or not self.dag.partial:
+                self.log.error(
+                    "Cannot expand %r for run %s; missing upstream values: %s",
+                    self,
+                    run_id,
+                    sorted(e.missing),
+                )
 
         state: TaskInstanceState | None = None
         unmapped_ti: TaskInstance | None = (
@@ -646,10 +651,15 @@ class MappedOperator(AbstractOperator):
             # The unmapped task instance still exists and is unfinished, i.e. 
we
             # haven't tried to run it before.
             if total_length is None:
-                # If the map length cannot be calculated (due to unavailable
-                # upstream sources), fail the unmapped task.
-                unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
-                indexes_to_map: Iterable[int] = ()
+                if self.dag and self.dag.partial:
+                    # If the DAG is partial, it's likely that the upstream 
tasks
+                    # are not done yet, so we do nothing
+                    indexes_to_map: Iterable[int] = ()
+                else:
+                    # If the map length cannot be calculated (due to 
unavailable
+                    # upstream sources), fail the unmapped task.
+                    unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
+                    indexes_to_map = ()
             elif total_length < 1:
                 # If the upstream maps this to a zero-length value, simply mark
                 # the unmapped task instance as SKIPPED (if needed).
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index d4453ca842..4388d592e9 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2572,6 +2572,67 @@ class TaskInstance(Base, LoggingMixin):
             return filters[0]
         return or_(*filters)
 
+    @Sentry.enrich_errors
+    @provide_session
+    def schedule_downstream_tasks(self, session=None):
+        """
+        The mini-scheduler for scheduling downstream tasks of this task 
instance
+        :meta: private
+        """
+        from sqlalchemy.exc import OperationalError
+
+        from airflow.models import DagRun
+
+        try:
+            # Re-select the row with a lock
+            dag_run = with_row_locks(
+                session.query(DagRun).filter_by(
+                    dag_id=self.dag_id,
+                    run_id=self.run_id,
+                ),
+                session=session,
+            ).one()
+
+            task = self.task
+            if TYPE_CHECKING:
+                assert task.dag
+
+            # Get a partial DAG with just the specific tasks we want to 
examine.
+            # In order for dep checks to work correctly, we include ourself (so
+            # TriggerRuleDep can check the state of the task we just executed).
+            partial_dag = task.dag.partial_subset(
+                task.downstream_task_ids,
+                include_downstream=True,
+                include_upstream=False,
+                include_direct_upstream=True,
+            )
+
+            dag_run.dag = partial_dag
+            info = dag_run.task_instance_scheduling_decisions(session)
+
+            skippable_task_ids = {
+                task_id for task_id in partial_dag.task_ids if task_id not in 
task.downstream_task_ids
+            }
+
+            schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id 
not in skippable_task_ids]
+            for schedulable_ti in schedulable_tis:
+                if not hasattr(schedulable_ti, "task"):
+                    schedulable_ti.task = 
task.dag.get_task(schedulable_ti.task_id)
+
+            num = dag_run.schedule_tis(schedulable_tis, session=session)
+            self.log.info("%d downstream tasks scheduled from follow-on 
schedule check", num)
+
+            session.flush()
+
+        except OperationalError as e:
+            # Any kind of DB error here is _non fatal_ as this block is just 
an optimisation.
+            self.log.info(
+                "Skipping mini scheduling run due to exception: %s",
+                e.statement,
+                exc_info=True,
+            )
+            session.rollback()
+
 
 # State of the task instance.
 # Stores string version of the task state.
diff --git a/tests/jobs/test_local_task_job.py 
b/tests/jobs/test_local_task_job.py
index 61bce800c8..26f5e629a7 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -739,7 +739,6 @@ class TestLocalTaskJob:
         ti2_l.refresh_from_db()
         assert ti2_k.state == State.SUCCESS
         assert ti2_l.state == State.NONE
-        assert "0 downstream tasks scheduled from follow-on schedule" in 
caplog.text
 
         failed_deps = list(ti2_l.get_failed_dep_statuses())
         assert len(failed_deps) == 1
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index f4edebcb08..7f7cca0da2 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3614,3 +3614,86 @@ def test_expand_non_templated_field(dag_maker, session):
 
     echo_task = dag.get_task("echo")
     assert "get_extra_env" in echo_task.upstream_task_ids
+
+
+def 
test_mapped_task_does_not_error_in_mini_scheduler_if_upstreams_are_not_done(dag_maker,
 caplog, session):
+    """
+    This tests that when scheduling child tasks of a task and there's a mapped 
downstream task,
+    if the mapped downstream task has upstreams that are not yet done, the 
mapped downstream task is
+    not marked as `upstream_failed'
+    """
+    with dag_maker() as dag:
+
+        @dag.task
+        def second_task():
+            return [0, 1, 2]
+
+        @dag.task
+        def first_task():
+            print(2)
+
+        @dag.task
+        def middle_task(id):
+            return id
+
+        middle = middle_task.expand(id=second_task())
+
+        @dag.task
+        def last_task():
+            print(3)
+
+        [first_task(), middle] >> last_task()
+
+    dag_run = dag_maker.create_dagrun()
+    first_ti = dag_run.get_task_instance(task_id="first_task")
+    second_ti = dag_run.get_task_instance(task_id="second_task")
+    first_ti.state = State.SUCCESS
+    second_ti.state = State.RUNNING
+    session.merge(first_ti)
+    session.merge(second_ti)
+    session.commit()
+    first_ti.schedule_downstream_tasks(session=session)
+    middle_ti = dag_run.get_task_instance(task_id="middle_task")
+    assert middle_ti.state != State.UPSTREAM_FAILED
+    assert "0 downstream tasks scheduled from follow-on schedule" in 
caplog.text
+
+
+def 
test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker, 
caplog, session):
+    """Test that mini scheduler expands mapped task"""
+    with dag_maker() as dag:
+
+        @dag.task
+        def second_task():
+            return [0, 1, 2]
+
+        @dag.task
+        def first_task():
+            print(2)
+
+        @dag.task
+        def middle_task(id):
+            return id
+
+        middle = middle_task.expand(id=second_task())
+
+        @dag.task
+        def last_task():
+            print(3)
+
+        [first_task(), middle] >> last_task()
+
+    dr = dag_maker.create_dagrun()
+
+    first_ti = dr.get_task_instance(task_id="first_task")
+    first_ti.state = State.SUCCESS
+    session.merge(first_ti)
+    session.commit()
+    second_task = dag.get_task("second_task")
+    second_ti = dr.get_task_instance(task_id="second_task")
+    second_ti.refresh_from_task(second_task)
+    second_ti.run()
+    second_ti.schedule_downstream_tasks(session=session)
+    for i in range(3):
+        middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i)
+        assert middle_ti.state == State.SCHEDULED
+    assert "3 downstream tasks scheduled from follow-on schedule" in 
caplog.text

Reply via email to