This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ed92e5d521 Fix mini scheduler expansion of mapped task (#27506)
ed92e5d521 is described below
commit ed92e5d521f958642615b038ec13068b527db1c4
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Wed Nov 9 15:05:59 2022 +0100
Fix mini scheduler expansion of mapped task (#27506)
We have a case where the mini scheduler tries to expand a mapped task even
when the downstream tasks are not yet done.
The mini scheduler extracts a partial subset of a dag and in the process,
some upstream tasks are dropped.
If the task happens to be a mapped task, the expansion will fail since it
needs the upstream output to make the expansion. When the expansion fails, the
task is marked as `upstream_failed`. This leads to other downstream tasks being
marked as upstream failed.
The solution was to ignore this error and not mark the mapped task as
upstream_failed when the expansion fails and the dag is a partial subset
Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
airflow/jobs/local_task_job.py | 59 +---------------------------
airflow/models/mappedoperator.py | 30 +++++++++-----
airflow/models/taskinstance.py | 61 ++++++++++++++++++++++++++++
tests/jobs/test_local_task_job.py | 1 -
tests/models/test_taskinstance.py | 83 +++++++++++++++++++++++++++++++++++++++
5 files changed, 165 insertions(+), 69 deletions(-)
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 1881511f9b..698c469dbb 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -18,25 +18,20 @@
from __future__ import annotations
import signal
-from typing import TYPE_CHECKING
import psutil
-from sqlalchemy.exc import OperationalError
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.jobs.base_job import BaseJob
from airflow.listeners.events import register_task_instance_state_events
from airflow.listeners.listener import get_listener_manager
-from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
-from airflow.sentry import Sentry
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
from airflow.utils import timezone
from airflow.utils.net import get_hostname
from airflow.utils.session import provide_session
-from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State
@@ -165,7 +160,7 @@ class LocalTaskJob(BaseJob):
if not self.task_instance.test_mode:
if conf.getboolean('scheduler', 'schedule_after_task_execution',
fallback=True):
- self._run_mini_scheduler_on_child_tasks()
+ self.task_instance.schedule_downstream_tasks()
def on_kill(self):
self.task_runner.terminate()
@@ -230,58 +225,6 @@ class LocalTaskJob(BaseJob):
self.terminating = True
self._state_change_checks += 1
- @provide_session
- @Sentry.enrich_errors
- def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
- try:
- # Re-select the row with a lock
- dag_run = with_row_locks(
- session.query(DagRun).filter_by(
- dag_id=self.dag_id,
- run_id=self.task_instance.run_id,
- ),
- session=session,
- ).one()
-
- task = self.task_instance.task
- if TYPE_CHECKING:
- assert task.dag
-
- # Get a partial DAG with just the specific tasks we want to
examine.
- # In order for dep checks to work correctly, we include ourself (so
- # TriggerRuleDep can check the state of the task we just executed).
- partial_dag = task.dag.partial_subset(
- task.downstream_task_ids,
- include_downstream=True,
- include_upstream=False,
- include_direct_upstream=True,
- )
-
- dag_run.dag = partial_dag
- info = dag_run.task_instance_scheduling_decisions(session)
-
- skippable_task_ids = {
- task_id for task_id in partial_dag.task_ids if task_id not in
task.downstream_task_ids
- }
-
- schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id
not in skippable_task_ids]
- for schedulable_ti in schedulable_tis:
- if not hasattr(schedulable_ti, "task"):
- schedulable_ti.task =
task.dag.get_task(schedulable_ti.task_id)
-
- num = dag_run.schedule_tis(schedulable_tis)
- self.log.info("%d downstream tasks scheduled from follow-on
schedule check", num)
-
- session.commit()
- except OperationalError as e:
- # Any kind of DB error here is _non fatal_ as this block is just
an optimisation.
- self.log.info(
- "Skipping mini scheduling run due to exception: %s",
- e.statement,
- exc_info=True,
- )
- session.rollback()
-
@staticmethod
def _enable_task_listeners():
"""
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 2d5d00cf3d..e6bdf815b3 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -620,13 +620,18 @@ class MappedOperator(AbstractOperator):
try:
total_length =
self._get_specified_expand_input().get_total_map_length(run_id, session=session)
except NotFullyPopulated as e:
- self.log.info(
- "Cannot expand %r for run %s; missing upstream values: %s",
- self,
- run_id,
- sorted(e.missing),
- )
total_length = None
+ # partial dags comes from the mini scheduler. It's
+ # possible that the upstream tasks are not yet done,
+ # but we don't have upstream of upstreams in partial dags,
+ # so we ignore this exception.
+ if not self.dag or not self.dag.partial:
+ self.log.error(
+ "Cannot expand %r for run %s; missing upstream values: %s",
+ self,
+ run_id,
+ sorted(e.missing),
+ )
state: TaskInstanceState | None = None
unmapped_ti: TaskInstance | None = (
@@ -647,10 +652,15 @@ class MappedOperator(AbstractOperator):
# The unmapped task instance still exists and is unfinished, i.e.
we
# haven't tried to run it before.
if total_length is None:
- # If the map length cannot be calculated (due to unavailable
- # upstream sources), fail the unmapped task.
- unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
- indexes_to_map: Iterable[int] = ()
+ if self.dag and self.dag.partial:
+ # If the DAG is partial, it's likely that the upstream
tasks
+ # are not done yet, so we do nothing
+ indexes_to_map: Iterable[int] = ()
+ else:
+ # If the map length cannot be calculated (due to
unavailable
+ # upstream sources), fail the unmapped task.
+ unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
+ indexes_to_map = ()
elif total_length < 1:
# If the upstream maps this to a zero-length value, simply mark
# the unmapped task instance as SKIPPED (if needed).
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index e024b4db18..75968f32af 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2459,6 +2459,67 @@ class TaskInstance(Base, LoggingMixin):
return filters[0]
return or_(*filters)
+ @Sentry.enrich_errors
+ @provide_session
+ def schedule_downstream_tasks(self, session=None):
+ """
+ The mini-scheduler for scheduling downstream tasks of this task
instance
+ :meta: private
+ """
+ from sqlalchemy.exc import OperationalError
+
+ from airflow.models import DagRun
+
+ try:
+ # Re-select the row with a lock
+ dag_run = with_row_locks(
+ session.query(DagRun).filter_by(
+ dag_id=self.dag_id,
+ run_id=self.run_id,
+ ),
+ session=session,
+ ).one()
+
+ task = self.task
+ if TYPE_CHECKING:
+ assert task.dag
+
+ # Get a partial DAG with just the specific tasks we want to
examine.
+ # In order for dep checks to work correctly, we include ourself (so
+ # TriggerRuleDep can check the state of the task we just executed).
+ partial_dag = task.dag.partial_subset(
+ task.downstream_task_ids,
+ include_downstream=True,
+ include_upstream=False,
+ include_direct_upstream=True,
+ )
+
+ dag_run.dag = partial_dag
+ info = dag_run.task_instance_scheduling_decisions(session)
+
+ skippable_task_ids = {
+ task_id for task_id in partial_dag.task_ids if task_id not in
task.downstream_task_ids
+ }
+
+ schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id
not in skippable_task_ids]
+ for schedulable_ti in schedulable_tis:
+ if not hasattr(schedulable_ti, "task"):
+ schedulable_ti.task =
task.dag.get_task(schedulable_ti.task_id)
+
+ num = dag_run.schedule_tis(schedulable_tis, session=session)
+ self.log.info("%d downstream tasks scheduled from follow-on
schedule check", num)
+
+ session.flush()
+
+ except OperationalError as e:
+ # Any kind of DB error here is _non fatal_ as this block is just
an optimisation.
+ self.log.info(
+ "Skipping mini scheduling run due to exception: %s",
+ e.statement,
+ exc_info=True,
+ )
+ session.rollback()
+
# State of the task instance.
# Stores string version of the task state.
diff --git a/tests/jobs/test_local_task_job.py
b/tests/jobs/test_local_task_job.py
index 7ac654b857..40eab105d4 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -739,7 +739,6 @@ class TestLocalTaskJob:
ti2_l.refresh_from_db()
assert ti2_k.state == State.SUCCESS
assert ti2_l.state == State.NONE
- assert "0 downstream tasks scheduled from follow-on schedule" in
caplog.text
failed_deps = list(ti2_l.get_failed_dep_statuses())
assert len(failed_deps) == 1
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index bc670af122..29d19ea223 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3613,3 +3613,86 @@ def test_expand_non_templated_field(dag_maker, session):
echo_task = dag.get_task("echo")
assert "get_extra_env" in echo_task.upstream_task_ids
+
+
+def
test_mapped_task_does_not_error_in_mini_scheduler_if_upstreams_are_not_done(dag_maker,
caplog, session):
+ """
+ This tests that when scheduling child tasks of a task and there's a mapped
downstream task,
+ if the mapped downstream task has upstreams that are not yet done, the
mapped downstream task is
+ not marked as `upstream_failed'
+ """
+ with dag_maker() as dag:
+
+ @dag.task
+ def second_task():
+ return [0, 1, 2]
+
+ @dag.task
+ def first_task():
+ print(2)
+
+ @dag.task
+ def middle_task(id):
+ return id
+
+ middle = middle_task.expand(id=second_task())
+
+ @dag.task
+ def last_task():
+ print(3)
+
+ [first_task(), middle] >> last_task()
+
+ dag_run = dag_maker.create_dagrun()
+ first_ti = dag_run.get_task_instance(task_id="first_task")
+ second_ti = dag_run.get_task_instance(task_id="second_task")
+ first_ti.state = State.SUCCESS
+ second_ti.state = State.RUNNING
+ session.merge(first_ti)
+ session.merge(second_ti)
+ session.commit()
+ first_ti.schedule_downstream_tasks(session=session)
+ middle_ti = dag_run.get_task_instance(task_id="middle_task")
+ assert middle_ti.state != State.UPSTREAM_FAILED
+ assert "0 downstream tasks scheduled from follow-on schedule" in
caplog.text
+
+
+def
test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker,
caplog, session):
+ """Test that mini scheduler expands mapped task"""
+ with dag_maker() as dag:
+
+ @dag.task
+ def second_task():
+ return [0, 1, 2]
+
+ @dag.task
+ def first_task():
+ print(2)
+
+ @dag.task
+ def middle_task(id):
+ return id
+
+ middle = middle_task.expand(id=second_task())
+
+ @dag.task
+ def last_task():
+ print(3)
+
+ [first_task(), middle] >> last_task()
+
+ dr = dag_maker.create_dagrun()
+
+ first_ti = dr.get_task_instance(task_id="first_task")
+ first_ti.state = State.SUCCESS
+ session.merge(first_ti)
+ session.commit()
+ second_task = dag.get_task("second_task")
+ second_ti = dr.get_task_instance(task_id="second_task")
+ second_ti.refresh_from_task(second_task)
+ second_ti.run()
+ second_ti.schedule_downstream_tasks(session=session)
+ for i in range(3):
+ middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i)
+ assert middle_ti.state == State.SCHEDULED
+ assert "3 downstream tasks scheduled from follow-on schedule" in
caplog.text