This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 73bc49adb1 Fix depends_on_past work for dynamic tasks (#32397)
73bc49adb1 is described below
commit 73bc49adb17957e5bb8dee357c04534c6b41f9dd
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Jul 24 07:53:52 2023 +0800
Fix depends_on_past work for dynamic tasks (#32397)
Co-authored-by: Zhyhimont Dmitry <[email protected]>
---
airflow/ti_deps/deps/prev_dagrun_dep.py | 129 +++++++++++++++++++++++------
tests/ti_deps/deps/test_prev_dagrun_dep.py | 123 +++++++++++++++------------
2 files changed, 174 insertions(+), 78 deletions(-)
diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py
b/airflow/ti_deps/deps/prev_dagrun_dep.py
index 62acdbca33..a21944d410 100644
--- a/airflow/ti_deps/deps/prev_dagrun_dep.py
+++ b/airflow/ti_deps/deps/prev_dagrun_dep.py
@@ -17,13 +17,22 @@
# under the License.
from __future__ import annotations
-from sqlalchemy import func
+from typing import TYPE_CHECKING
+
+from sqlalchemy import func, literal, or_, select
+from sqlalchemy.orm import Session
from airflow.models.taskinstance import PAST_DEPENDS_MET, TaskInstance as TI
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import provide_session
from airflow.utils.state import TaskInstanceState
+if TYPE_CHECKING:
+ from airflow.models.dagrun import DagRun
+ from airflow.models.operator import Operator
+
+_SUCCESSFUL_STATES = (TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)
+
class PrevDagrunDep(BaseTIDep):
"""
@@ -42,8 +51,87 @@ class PrevDagrunDep(BaseTIDep):
if dep_context.wait_for_past_depends_before_skipping:
ti.xcom_push(key=PAST_DEPENDS_MET, value=True)
+ @staticmethod
+ def _has_tis(dagrun: DagRun, task_id: str, *, session: Session) -> bool:
+ """Check if a task has presence in the specified DAG run.
+
+ This function exists for easy mocking in tests.
+ """
+ return (
+ session.scalar(
+ select(literal(True))
+ .where(TI.dag_id == dagrun.dag_id, TI.task_id == task_id,
TI.run_id == dagrun.run_id)
+ .limit(1)
+ )
+ is not None
+ )
+
+ @staticmethod
+ def _has_any_prior_tis(ti: TI, *, session: Session) -> bool:
+ """Check if a task has ever been run before.
+
+ This function exists for easy mocking in tests.
+ """
+ return (
+ session.scalar(
+ select(literal(True))
+ .where(
+ TI.dag_id == ti.dag_id,
+ TI.task_id == ti.task_id,
+ TI.execution_date < ti.execution_date,
+ )
+ .limit(1)
+ )
+ is not None
+ )
+
+ @staticmethod
+ def _count_unsuccessful_tis(dagrun: DagRun, task_id: str, *, session:
Session) -> int:
+ """Get a count of unsuccessful task instances in a given run.
+
+ Due to historical design considerations, "unsuccessful" here means the
+ task instance is not in either SUCCESS or SKIPPED state. This means
that
+ unfinished states such as RUNNING are considered unsuccessful.
+
+ This function exists for easy mocking in tests.
+ """
+ return session.scalar(
+ select(func.count()).where(
+ TI.dag_id == dagrun.dag_id,
+ TI.task_id == task_id,
+ TI.run_id == dagrun.run_id,
+ or_(TI.state.is_(None), TI.state.not_in(_SUCCESSFUL_STATES)),
+ )
+ )
+
+ @staticmethod
+ def _has_unsuccessful_dependants(dagrun: DagRun, task: Operator, *,
session: Session) -> bool:
+ """Check if any of the task's dependants are unsuccessful in a given
run.
+
+ Due to historical design considerations, "unsuccessful" here means the
+ task instance is not in either SUCCESS or SKIPPED state. This means
that
+ unfinished states such as RUNNING are considered unsuccessful.
+
+ This function exists for easy mocking in tests.
+ """
+ if not task.downstream_task_ids:
+ return False
+ return (
+ session.scalar(
+ select(literal(True))
+ .where(
+ TI.dag_id == dagrun.dag_id,
+ TI.task_id.in_(task.downstream_task_ids),
+ TI.run_id == dagrun.run_id,
+ or_(TI.state.is_(None),
TI.state.not_in(_SUCCESSFUL_STATES)),
+ )
+ .limit(1)
+ )
+ is not None
+ )
+
@provide_session
- def _get_dep_statuses(self, ti: TI, session, dep_context):
+ def _get_dep_statuses(self, ti: TI, session: Session, dep_context):
if dep_context.ignore_depends_on_past:
self._push_past_deps_met_xcom_if_needed(ti, dep_context)
reason = "The context specified that the state of past DAGs could
be ignored."
@@ -80,20 +168,9 @@ class PrevDagrunDep(BaseTIDep):
yield self._passing_status(reason="This task instance was the
first task instance for its task.")
return
- previous_ti = last_dagrun.get_task_instance(ti.task_id,
map_index=ti.map_index, session=session)
- if not previous_ti:
+ if not self._has_tis(last_dagrun, ti.task_id, session=session):
if ti.task.ignore_first_depends_on_past:
- has_historical_ti = (
- session.query(func.count(TI.dag_id))
- .filter(
- TI.dag_id == ti.dag_id,
- TI.task_id == ti.task_id,
- TI.execution_date < ti.execution_date,
- )
- .scalar()
- > 0
- )
- if not has_historical_ti:
+ if not self._has_any_prior_tis(ti, session=session):
self._push_past_deps_met_xcom_if_needed(ti, dep_context)
yield self._passing_status(
reason="ignore_first_depends_on_past is true for this
task "
@@ -107,22 +184,24 @@ class PrevDagrunDep(BaseTIDep):
)
return
- if previous_ti.state not in {TaskInstanceState.SKIPPED,
TaskInstanceState.SUCCESS}:
- yield self._failing_status(
- reason=(
- f"depends_on_past is true for this task, but the previous
task instance {previous_ti} "
- f"is in the state '{previous_ti.state}' which is not a
successful state."
- )
+ unsuccessful_tis_count = self._count_unsuccessful_tis(last_dagrun,
ti.task_id, session=session)
+ if unsuccessful_tis_count > 0:
+ reason = (
+ f"depends_on_past is true for this task, but
{unsuccessful_tis_count} "
+ f"previous task instance(s) are not in a successful state."
)
+ yield self._failing_status(reason=reason)
return
- previous_ti.task = ti.task
- if ti.task.wait_for_downstream and not
previous_ti.are_dependents_done(session=session):
+ if ti.task.wait_for_downstream and self._has_unsuccessful_dependants(
+ last_dagrun, ti.task, session=session
+ ):
yield self._failing_status(
reason=(
- f"The tasks downstream of the previous task instance
{previous_ti} haven't completed "
- f"(and wait_for_downstream is True)."
+ "The tasks downstream of the previous task instance(s) "
+ "haven't completed, and wait_for_downstream is True."
)
)
return
+
self._push_past_deps_met_xcom_if_needed(ti, dep_context)
diff --git a/tests/ti_deps/deps/test_prev_dagrun_dep.py
b/tests/ti_deps/deps/test_prev_dagrun_dep.py
index bf641ad6b5..262b90027d 100644
--- a/tests/ti_deps/deps/test_prev_dagrun_dep.py
+++ b/tests/ti_deps/deps/test_prev_dagrun_dep.py
@@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations
-from unittest.mock import Mock
+from unittest.mock import ANY, Mock, patch
import pytest
@@ -25,7 +25,7 @@ from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.timezone import convert_to_utc, datetime
from airflow.utils.types import DagRunType
from tests.test_utils.db import clear_db_runs
@@ -51,7 +51,7 @@ class TestPrevDagrunDep:
# Old DAG run will include only TaskInstance of old_task
dag.create_dagrun(
run_id="old_run",
- state=State.SUCCESS,
+ state=TaskInstanceState.SUCCESS,
execution_date=old_task.start_date,
run_type=DagRunType.SCHEDULED,
)
@@ -67,7 +67,7 @@ class TestPrevDagrunDep:
# New DAG run will include 1st TaskInstance of new_task
dr = dag.create_dagrun(
run_id="new_run",
- state=State.RUNNING,
+ state=DagRunState.RUNNING,
execution_date=convert_to_utc(datetime(2016, 1, 2)),
run_type=DagRunType.SCHEDULED,
)
@@ -75,16 +75,24 @@ class TestPrevDagrunDep:
ti = dr.get_task_instance(new_task.task_id)
ti.task = new_task
- # this is important, we need to assert there is no previous_ti of this
ti
- assert ti.previous_ti is None
-
dep_context = DepContext(ignore_depends_on_past=False)
- assert PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)
+ dep = PrevDagrunDep()
+
+ with patch.object(dep, "_has_any_prior_tis", Mock(return_value=False))
as mock_has_any_prior_tis:
+ assert dep.is_met(ti=ti, dep_context=dep_context)
+ mock_has_any_prior_tis.assert_called_once_with(ti, session=ANY)
@pytest.mark.parametrize(
- "depends_on_past, wait_for_past_depends_before_skipping,
wait_for_downstream, prev_ti,"
- " context_ignore_depends_on_past, dep_met, past_depends_met_xcom_sent",
+ (
+ "depends_on_past",
+ "wait_for_past_depends_before_skipping",
+ "wait_for_downstream",
+ "prev_tis",
+ "context_ignore_depends_on_past",
+ "expected_dep_met",
+ "past_depends_met_xcom_sent",
+ ),
[
# If the task does not set depends_on_past, the previous dagrun should
# be ignored, even though previous_ti would otherwise fail the dep.
@@ -93,10 +101,7 @@ class TestPrevDagrunDep:
False,
False,
False, # wait_for_downstream=True overrides depends_on_past=False.
- Mock(
- state=State.NONE,
- **{"are_dependents_done.return_value": False},
- ),
+ [Mock(state=None, **{"are_dependents_done.return_value": False})],
False,
True,
False,
@@ -109,10 +114,7 @@ class TestPrevDagrunDep:
False,
True,
False, # wait_for_downstream=True overrides depends_on_past=False.
- Mock(
- state=State.NONE,
- **{"are_dependents_done.return_value": False},
- ),
+ [Mock(state=None, **{"are_dependents_done.return_value": False})],
False,
True,
True,
@@ -125,10 +127,7 @@ class TestPrevDagrunDep:
True,
False,
False,
- Mock(
- state=State.SUCCESS,
- **{"are_dependents_done.return_value": True},
- ),
+ [Mock(state=TaskInstanceState.SUCCESS,
**{"are_dependents_done.return_value": True})],
True,
True,
False,
@@ -141,10 +140,7 @@ class TestPrevDagrunDep:
True,
True,
False,
- Mock(
- state=State.SUCCESS,
- **{"are_dependents_done.return_value": True},
- ),
+ [Mock(state=TaskInstanceState.SUCCESS,
**{"are_dependents_done.return_value": True})],
True,
True,
True,
@@ -152,19 +148,16 @@ class TestPrevDagrunDep:
),
# The first task run should pass since it has no previous dagrun.
# wait_for_past_depends_before_skipping is False, past_depends_met
xcom should not be sent
- pytest.param(True, False, False, None, False, True, False,
id="first_task_run"),
+ pytest.param(True, False, False, [], False, True, False,
id="first_task_run"),
# The first task run should pass since it has no previous dagrun.
# wait_for_past_depends_before_skipping is True, past_depends_met xcom
should be sent
- pytest.param(True, True, False, None, False, True, True,
id="first_task_run"),
+ pytest.param(True, True, False, [], False, True, True,
id="first_task_run_wait"),
# Previous TI did not complete execution. This dep should fail.
pytest.param(
True,
False,
False,
- Mock(
- state=State.NONE,
- **{"are_dependents_done.return_value": True},
- ),
+ [Mock(state=None, **{"are_dependents_done.return_value": True})],
False,
False,
False,
@@ -177,10 +170,7 @@ class TestPrevDagrunDep:
True,
False,
True,
- Mock(
- state=State.SUCCESS,
- **{"are_dependents_done.return_value": False},
- ),
+ [Mock(state=TaskInstanceState.SUCCESS,
**{"are_dependents_done.return_value": False})],
False,
False,
False,
@@ -192,10 +182,7 @@ class TestPrevDagrunDep:
True,
False,
True,
- Mock(
- state=State.SUCCESS,
- **{"are_dependents_done.return_value": True},
- ),
+ [Mock(state=TaskInstanceState.SUCCESS,
**{"are_dependents_done.return_value": True})],
False,
True,
False,
@@ -207,10 +194,7 @@ class TestPrevDagrunDep:
True,
True,
True,
- Mock(
- state=State.SUCCESS,
- **{"are_dependents_done.return_value": True},
- ),
+ [Mock(state=TaskInstanceState.SUCCESS,
**{"are_dependents_done.return_value": True})],
False,
True,
True,
@@ -222,9 +206,9 @@ def test_dagrun_dep(
depends_on_past,
wait_for_past_depends_before_skipping,
wait_for_downstream,
- prev_ti,
+ prev_tis,
context_ignore_depends_on_past,
- dep_met,
+ expected_dep_met,
past_depends_met_xcom_sent,
):
task = BaseOperator(
@@ -234,26 +218,59 @@ def test_dagrun_dep(
start_date=datetime(2016, 1, 1),
wait_for_downstream=wait_for_downstream,
)
- if prev_ti:
- prev_dagrun = Mock(
- execution_date=datetime(2016, 1, 2),
- **{"get_task_instance.return_value": prev_ti},
- )
+ if prev_tis:
+ prev_dagrun = Mock(execution_date=datetime(2016, 1, 2))
else:
prev_dagrun = None
+
dagrun = Mock(
**{
"get_previous_scheduled_dagrun.return_value": prev_dagrun,
"get_previous_dagrun.return_value": prev_dagrun,
},
)
- ti = Mock(task=task, **{"get_dagrun.return_value": dagrun,
"xcom_push.return_value": None})
+ ti = Mock(
+ task=task,
+ task_id=task.task_id,
+ **{"get_dagrun.return_value": dagrun, "xcom_push.return_value": None},
+ )
dep_context = DepContext(
ignore_depends_on_past=context_ignore_depends_on_past,
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
)
- assert PrevDagrunDep().is_met(ti=ti, dep_context=dep_context) == dep_met
+ unsuccessful_tis_count = sum(
+ int(ti.state not in {TaskInstanceState.SUCCESS,
TaskInstanceState.SKIPPED}) for ti in prev_tis
+ )
+
+ mock_has_tis = Mock(return_value=bool(prev_tis))
+ mock_has_any_prior_tis = Mock(return_value=bool(prev_tis))
+ mock_count_unsuccessful_tis = Mock(return_value=unsuccessful_tis_count)
+ mock_has_unsuccessful_dependants = Mock(return_value=any(not
ti.are_dependents_done() for ti in prev_tis))
+
+ dep = PrevDagrunDep()
+ with patch.multiple(
+ dep,
+ _has_tis=mock_has_tis,
+ _has_any_prior_tis=mock_has_any_prior_tis,
+ _count_unsuccessful_tis=mock_count_unsuccessful_tis,
+ _has_unsuccessful_dependants=mock_has_unsuccessful_dependants,
+ ):
+ actual_dep_met = dep.is_met(ti=ti, dep_context=dep_context)
+
+ mock_has_any_prior_tis.assert_not_called()
+ if depends_on_past and not context_ignore_depends_on_past and prev_tis:
+ mock_has_tis.assert_called_once_with(prev_dagrun, "test_task",
session=ANY)
+ mock_count_unsuccessful_tis.assert_called_once_with(prev_dagrun,
"test_task", session=ANY)
+ else:
+ mock_has_tis.assert_not_called()
+ mock_count_unsuccessful_tis.assert_not_called()
+ if depends_on_past and not context_ignore_depends_on_past and prev_tis
and not unsuccessful_tis_count:
+
mock_has_unsuccessful_dependants.assert_called_once_with(prev_dagrun, task,
session=ANY)
+ else:
+ mock_has_unsuccessful_dependants.assert_not_called()
+
+ assert actual_dep_met == expected_dep_met
if past_depends_met_xcom_sent:
ti.xcom_push.assert_called_with(key="past_depends_met", value=True)
else: