This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 73bc49adb1 Fix depends_on_past work for dynamic tasks (#32397)
73bc49adb1 is described below

commit 73bc49adb17957e5bb8dee357c04534c6b41f9dd
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Jul 24 07:53:52 2023 +0800

    Fix depends_on_past work for dynamic tasks (#32397)
    
    Co-authored-by: Zhyhimont Dmitry <[email protected]>
---
 airflow/ti_deps/deps/prev_dagrun_dep.py    | 129 +++++++++++++++++++++++------
 tests/ti_deps/deps/test_prev_dagrun_dep.py | 123 +++++++++++++++------------
 2 files changed, 174 insertions(+), 78 deletions(-)

diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py 
b/airflow/ti_deps/deps/prev_dagrun_dep.py
index 62acdbca33..a21944d410 100644
--- a/airflow/ti_deps/deps/prev_dagrun_dep.py
+++ b/airflow/ti_deps/deps/prev_dagrun_dep.py
@@ -17,13 +17,22 @@
 # under the License.
 from __future__ import annotations
 
-from sqlalchemy import func
+from typing import TYPE_CHECKING
+
+from sqlalchemy import func, literal, or_, select
+from sqlalchemy.orm import Session
 
 from airflow.models.taskinstance import PAST_DEPENDS_MET, TaskInstance as TI
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.session import provide_session
 from airflow.utils.state import TaskInstanceState
 
+if TYPE_CHECKING:
+    from airflow.models.dagrun import DagRun
+    from airflow.models.operator import Operator
+
+_SUCCESSFUL_STATES = (TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)
+
 
 class PrevDagrunDep(BaseTIDep):
     """
@@ -42,8 +51,87 @@ class PrevDagrunDep(BaseTIDep):
         if dep_context.wait_for_past_depends_before_skipping:
             ti.xcom_push(key=PAST_DEPENDS_MET, value=True)
 
+    @staticmethod
+    def _has_tis(dagrun: DagRun, task_id: str, *, session: Session) -> bool:
+        """Check if a task has presence in the specified DAG run.
+
+        This function exists for easy mocking in tests.
+        """
+        return (
+            session.scalar(
+                select(literal(True))
+                .where(TI.dag_id == dagrun.dag_id, TI.task_id == task_id, 
TI.run_id == dagrun.run_id)
+                .limit(1)
+            )
+            is not None
+        )
+
+    @staticmethod
+    def _has_any_prior_tis(ti: TI, *, session: Session) -> bool:
+        """Check if a task has ever been run before.
+
+        This function exists for easy mocking in tests.
+        """
+        return (
+            session.scalar(
+                select(literal(True))
+                .where(
+                    TI.dag_id == ti.dag_id,
+                    TI.task_id == ti.task_id,
+                    TI.execution_date < ti.execution_date,
+                )
+                .limit(1)
+            )
+            is not None
+        )
+
+    @staticmethod
+    def _count_unsuccessful_tis(dagrun: DagRun, task_id: str, *, session: 
Session) -> int:
+        """Get a count of unsuccessful task instances in a given run.
+
+        Due to historical design considerations, "unsuccessful" here means the
+        task instance is not in either SUCCESS or SKIPPED state. This means 
that
+        unfinished states such as RUNNING are considered unsuccessful.
+
+        This function exists for easy mocking in tests.
+        """
+        return session.scalar(
+            select(func.count()).where(
+                TI.dag_id == dagrun.dag_id,
+                TI.task_id == task_id,
+                TI.run_id == dagrun.run_id,
+                or_(TI.state.is_(None), TI.state.not_in(_SUCCESSFUL_STATES)),
+            )
+        )
+
+    @staticmethod
+    def _has_unsuccessful_dependants(dagrun: DagRun, task: Operator, *, 
session: Session) -> bool:
+        """Check if any of the task's dependants are unsuccessful in a given 
run.
+
+        Due to historical design considerations, "unsuccessful" here means the
+        task instance is not in either SUCCESS or SKIPPED state. This means 
that
+        unfinished states such as RUNNING are considered unsuccessful.
+
+        This function exists for easy mocking in tests.
+        """
+        if not task.downstream_task_ids:
+            return False
+        return (
+            session.scalar(
+                select(literal(True))
+                .where(
+                    TI.dag_id == dagrun.dag_id,
+                    TI.task_id.in_(task.downstream_task_ids),
+                    TI.run_id == dagrun.run_id,
+                    or_(TI.state.is_(None), 
TI.state.not_in(_SUCCESSFUL_STATES)),
+                )
+                .limit(1)
+            )
+            is not None
+        )
+
     @provide_session
-    def _get_dep_statuses(self, ti: TI, session, dep_context):
+    def _get_dep_statuses(self, ti: TI, session: Session, dep_context):
         if dep_context.ignore_depends_on_past:
             self._push_past_deps_met_xcom_if_needed(ti, dep_context)
             reason = "The context specified that the state of past DAGs could 
be ignored."
@@ -80,20 +168,9 @@ class PrevDagrunDep(BaseTIDep):
             yield self._passing_status(reason="This task instance was the 
first task instance for its task.")
             return
 
-        previous_ti = last_dagrun.get_task_instance(ti.task_id, 
map_index=ti.map_index, session=session)
-        if not previous_ti:
+        if not self._has_tis(last_dagrun, ti.task_id, session=session):
             if ti.task.ignore_first_depends_on_past:
-                has_historical_ti = (
-                    session.query(func.count(TI.dag_id))
-                    .filter(
-                        TI.dag_id == ti.dag_id,
-                        TI.task_id == ti.task_id,
-                        TI.execution_date < ti.execution_date,
-                    )
-                    .scalar()
-                    > 0
-                )
-                if not has_historical_ti:
+                if not self._has_any_prior_tis(ti, session=session):
                     self._push_past_deps_met_xcom_if_needed(ti, dep_context)
                     yield self._passing_status(
                         reason="ignore_first_depends_on_past is true for this 
task "
@@ -107,22 +184,24 @@ class PrevDagrunDep(BaseTIDep):
             )
             return
 
-        if previous_ti.state not in {TaskInstanceState.SKIPPED, 
TaskInstanceState.SUCCESS}:
-            yield self._failing_status(
-                reason=(
-                    f"depends_on_past is true for this task, but the previous 
task instance {previous_ti} "
-                    f"is in the state '{previous_ti.state}' which is not a 
successful state."
-                )
+        unsuccessful_tis_count = self._count_unsuccessful_tis(last_dagrun, 
ti.task_id, session=session)
+        if unsuccessful_tis_count > 0:
+            reason = (
+                f"depends_on_past is true for this task, but 
{unsuccessful_tis_count} "
+                f"previous task instance(s) are not in a successful state."
             )
+            yield self._failing_status(reason=reason)
             return
 
-        previous_ti.task = ti.task
-        if ti.task.wait_for_downstream and not 
previous_ti.are_dependents_done(session=session):
+        if ti.task.wait_for_downstream and self._has_unsuccessful_dependants(
+            last_dagrun, ti.task, session=session
+        ):
             yield self._failing_status(
                 reason=(
-                    f"The tasks downstream of the previous task instance 
{previous_ti} haven't completed "
-                    f"(and wait_for_downstream is True)."
+                    "The tasks downstream of the previous task instance(s) "
+                    "haven't completed, and wait_for_downstream is True."
                 )
             )
             return
+
         self._push_past_deps_met_xcom_if_needed(ti, dep_context)
diff --git a/tests/ti_deps/deps/test_prev_dagrun_dep.py 
b/tests/ti_deps/deps/test_prev_dagrun_dep.py
index bf641ad6b5..262b90027d 100644
--- a/tests/ti_deps/deps/test_prev_dagrun_dep.py
+++ b/tests/ti_deps/deps/test_prev_dagrun_dep.py
@@ -17,7 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-from unittest.mock import Mock
+from unittest.mock import ANY, Mock, patch
 
 import pytest
 
@@ -25,7 +25,7 @@ from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperator
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.timezone import convert_to_utc, datetime
 from airflow.utils.types import DagRunType
 from tests.test_utils.db import clear_db_runs
@@ -51,7 +51,7 @@ class TestPrevDagrunDep:
         # Old DAG run will include only TaskInstance of old_task
         dag.create_dagrun(
             run_id="old_run",
-            state=State.SUCCESS,
+            state=TaskInstanceState.SUCCESS,
             execution_date=old_task.start_date,
             run_type=DagRunType.SCHEDULED,
         )
@@ -67,7 +67,7 @@ class TestPrevDagrunDep:
         # New DAG run will include 1st TaskInstance of new_task
         dr = dag.create_dagrun(
             run_id="new_run",
-            state=State.RUNNING,
+            state=DagRunState.RUNNING,
             execution_date=convert_to_utc(datetime(2016, 1, 2)),
             run_type=DagRunType.SCHEDULED,
         )
@@ -75,16 +75,24 @@ class TestPrevDagrunDep:
         ti = dr.get_task_instance(new_task.task_id)
         ti.task = new_task
 
-        # this is important, we need to assert there is no previous_ti of this 
ti
-        assert ti.previous_ti is None
-
         dep_context = DepContext(ignore_depends_on_past=False)
-        assert PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)
+        dep = PrevDagrunDep()
+
+        with patch.object(dep, "_has_any_prior_tis", Mock(return_value=False)) 
as mock_has_any_prior_tis:
+            assert dep.is_met(ti=ti, dep_context=dep_context)
+            mock_has_any_prior_tis.assert_called_once_with(ti, session=ANY)
 
 
 @pytest.mark.parametrize(
-    "depends_on_past, wait_for_past_depends_before_skipping, 
wait_for_downstream, prev_ti,"
-    " context_ignore_depends_on_past, dep_met, past_depends_met_xcom_sent",
+    (
+        "depends_on_past",
+        "wait_for_past_depends_before_skipping",
+        "wait_for_downstream",
+        "prev_tis",
+        "context_ignore_depends_on_past",
+        "expected_dep_met",
+        "past_depends_met_xcom_sent",
+    ),
     [
         # If the task does not set depends_on_past, the previous dagrun should
         # be ignored, even though previous_ti would otherwise fail the dep.
@@ -93,10 +101,7 @@ class TestPrevDagrunDep:
             False,
             False,
             False,  # wait_for_downstream=True overrides depends_on_past=False.
-            Mock(
-                state=State.NONE,
-                **{"are_dependents_done.return_value": False},
-            ),
+            [Mock(state=None, **{"are_dependents_done.return_value": False})],
             False,
             True,
             False,
@@ -109,10 +114,7 @@ class TestPrevDagrunDep:
             False,
             True,
             False,  # wait_for_downstream=True overrides depends_on_past=False.
-            Mock(
-                state=State.NONE,
-                **{"are_dependents_done.return_value": False},
-            ),
+            [Mock(state=None, **{"are_dependents_done.return_value": False})],
             False,
             True,
             True,
@@ -125,10 +127,7 @@ class TestPrevDagrunDep:
             True,
             False,
             False,
-            Mock(
-                state=State.SUCCESS,
-                **{"are_dependents_done.return_value": True},
-            ),
+            [Mock(state=TaskInstanceState.SUCCESS, 
**{"are_dependents_done.return_value": True})],
             True,
             True,
             False,
@@ -141,10 +140,7 @@ class TestPrevDagrunDep:
             True,
             True,
             False,
-            Mock(
-                state=State.SUCCESS,
-                **{"are_dependents_done.return_value": True},
-            ),
+            [Mock(state=TaskInstanceState.SUCCESS, 
**{"are_dependents_done.return_value": True})],
             True,
             True,
             True,
@@ -152,19 +148,16 @@ class TestPrevDagrunDep:
         ),
         # The first task run should pass since it has no previous dagrun.
         # wait_for_past_depends_before_skipping is False, past_depends_met 
xcom should not be sent
-        pytest.param(True, False, False, None, False, True, False, 
id="first_task_run"),
+        pytest.param(True, False, False, [], False, True, False, 
id="first_task_run"),
         # The first task run should pass since it has no previous dagrun.
         # wait_for_past_depends_before_skipping is True, past_depends_met xcom 
should be sent
-        pytest.param(True, True, False, None, False, True, True, 
id="first_task_run"),
+        pytest.param(True, True, False, [], False, True, True, 
id="first_task_run_wait"),
         # Previous TI did not complete execution. This dep should fail.
         pytest.param(
             True,
             False,
             False,
-            Mock(
-                state=State.NONE,
-                **{"are_dependents_done.return_value": True},
-            ),
+            [Mock(state=None, **{"are_dependents_done.return_value": True})],
             False,
             False,
             False,
@@ -177,10 +170,7 @@ class TestPrevDagrunDep:
             True,
             False,
             True,
-            Mock(
-                state=State.SUCCESS,
-                **{"are_dependents_done.return_value": False},
-            ),
+            [Mock(state=TaskInstanceState.SUCCESS, 
**{"are_dependents_done.return_value": False})],
             False,
             False,
             False,
@@ -192,10 +182,7 @@ class TestPrevDagrunDep:
             True,
             False,
             True,
-            Mock(
-                state=State.SUCCESS,
-                **{"are_dependents_done.return_value": True},
-            ),
+            [Mock(state=TaskInstanceState.SUCCESS, 
**{"are_dependents_done.return_value": True})],
             False,
             True,
             False,
@@ -207,10 +194,7 @@ class TestPrevDagrunDep:
             True,
             True,
             True,
-            Mock(
-                state=State.SUCCESS,
-                **{"are_dependents_done.return_value": True},
-            ),
+            [Mock(state=TaskInstanceState.SUCCESS, 
**{"are_dependents_done.return_value": True})],
             False,
             True,
             True,
@@ -222,9 +206,9 @@ def test_dagrun_dep(
     depends_on_past,
     wait_for_past_depends_before_skipping,
     wait_for_downstream,
-    prev_ti,
+    prev_tis,
     context_ignore_depends_on_past,
-    dep_met,
+    expected_dep_met,
     past_depends_met_xcom_sent,
 ):
     task = BaseOperator(
@@ -234,26 +218,59 @@ def test_dagrun_dep(
         start_date=datetime(2016, 1, 1),
         wait_for_downstream=wait_for_downstream,
     )
-    if prev_ti:
-        prev_dagrun = Mock(
-            execution_date=datetime(2016, 1, 2),
-            **{"get_task_instance.return_value": prev_ti},
-        )
+    if prev_tis:
+        prev_dagrun = Mock(execution_date=datetime(2016, 1, 2))
     else:
         prev_dagrun = None
+
     dagrun = Mock(
         **{
             "get_previous_scheduled_dagrun.return_value": prev_dagrun,
             "get_previous_dagrun.return_value": prev_dagrun,
         },
     )
-    ti = Mock(task=task, **{"get_dagrun.return_value": dagrun, 
"xcom_push.return_value": None})
+    ti = Mock(
+        task=task,
+        task_id=task.task_id,
+        **{"get_dagrun.return_value": dagrun, "xcom_push.return_value": None},
+    )
     dep_context = DepContext(
         ignore_depends_on_past=context_ignore_depends_on_past,
         
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
     )
 
-    assert PrevDagrunDep().is_met(ti=ti, dep_context=dep_context) == dep_met
+    unsuccessful_tis_count = sum(
+        int(ti.state not in {TaskInstanceState.SUCCESS, 
TaskInstanceState.SKIPPED}) for ti in prev_tis
+    )
+
+    mock_has_tis = Mock(return_value=bool(prev_tis))
+    mock_has_any_prior_tis = Mock(return_value=bool(prev_tis))
+    mock_count_unsuccessful_tis = Mock(return_value=unsuccessful_tis_count)
+    mock_has_unsuccessful_dependants = Mock(return_value=any(not 
ti.are_dependents_done() for ti in prev_tis))
+
+    dep = PrevDagrunDep()
+    with patch.multiple(
+        dep,
+        _has_tis=mock_has_tis,
+        _has_any_prior_tis=mock_has_any_prior_tis,
+        _count_unsuccessful_tis=mock_count_unsuccessful_tis,
+        _has_unsuccessful_dependants=mock_has_unsuccessful_dependants,
+    ):
+        actual_dep_met = dep.is_met(ti=ti, dep_context=dep_context)
+
+        mock_has_any_prior_tis.assert_not_called()
+        if depends_on_past and not context_ignore_depends_on_past and prev_tis:
+            mock_has_tis.assert_called_once_with(prev_dagrun, "test_task", 
session=ANY)
+            mock_count_unsuccessful_tis.assert_called_once_with(prev_dagrun, 
"test_task", session=ANY)
+        else:
+            mock_has_tis.assert_not_called()
+            mock_count_unsuccessful_tis.assert_not_called()
+        if depends_on_past and not context_ignore_depends_on_past and prev_tis 
and not unsuccessful_tis_count:
+            
mock_has_unsuccessful_dependants.assert_called_once_with(prev_dagrun, task, 
session=ANY)
+        else:
+            mock_has_unsuccessful_dependants.assert_not_called()
+
+    assert actual_dep_met == expected_dep_met
     if past_depends_met_xcom_sent:
         ti.xcom_push.assert_called_with(key="past_depends_met", value=True)
     else:

Reply via email to