This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-6-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 94582bddde32b8b7ccd0f0e42b8ee096c3acaa37
Author: Dmytro Suvorov <[email protected]>
AuthorDate: Wed Apr 26 18:27:48 2023 +0300

    Prevent DagRun's `start_date` from reset (#30124) (#30125)
    
    (cherry picked from commit 070ecbd87c5ac067418b2814f554555da0a4f30c)
---
 airflow/models/taskinstance.py                     | 17 ++--
 airflow/utils/state.py                             |  3 +
 .../endpoints/test_dag_run_endpoint.py             |  4 +-
 tests/models/test_cleartasks.py                    | 90 +++++++++++++++++++++-
 tests/models/test_dagrun.py                        | 26 ++++++-
 5 files changed, 125 insertions(+), 15 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 3b56583b1e..49b3e7ebad 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -190,7 +190,11 @@ def clear_task_instances(
 ) -> None:
     """
     Clears a set of task instances, but makes sure the running ones
-    get killed.
+    get killed. Also sets Dagrun's `state` to QUEUED and `start_date`
+    to the time of execution. But only for finished DRs (SUCCESS and FAILED).
+    Doesn't clear DR's `state` and `start_date`for running
+    DRs (QUEUED and RUNNING) because clearing the state for already
+    running DR is redundant and clearing `start_date` affects DR's duration.
 
     :param tis: a list of task instances
     :param session: current session
@@ -302,11 +306,12 @@ def clear_task_instances(
         )
         dag_run_state = DagRunState(dag_run_state)  # Validate the state value.
         for dr in drs:
-            dr.state = dag_run_state
-            dr.start_date = timezone.utcnow()
-            if dag_run_state == DagRunState.QUEUED:
-                dr.last_scheduling_decision = None
-                dr.start_date = None
+            if dr.state in State.finished_dr_states:
+                dr.state = dag_run_state
+                dr.start_date = timezone.utcnow()
+                if dag_run_state == DagRunState.QUEUED:
+                    dr.last_scheduling_decision = None
+                    dr.start_date = None
     session.flush()
 
 
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index 6f89174bd6..f4a8dc1a0a 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -95,6 +95,9 @@ class State:
     SKIPPED = TaskInstanceState.SKIPPED
     DEFERRED = TaskInstanceState.DEFERRED
 
+    finished_dr_states: frozenset[DagRunState] = 
frozenset([DagRunState.SUCCESS, DagRunState.FAILED])
+    unfinished_dr_states: frozenset[DagRunState] = 
frozenset([DagRunState.QUEUED, DagRunState.RUNNING])
+
     task_states: tuple[TaskInstanceState | None, ...] = (None,) + 
tuple(TaskInstanceState)
 
     dag_states: tuple[DagRunState, ...] = (
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py 
b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index 09ecb9e497..7510b41e0a 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -31,7 +31,7 @@ from airflow.operators.empty import EmptyOperator
 from airflow.security import permissions
 from airflow.utils import timezone
 from airflow.utils.session import create_session, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.types import DagRunType
 from tests.test_utils.api_connexion_utils import assert_401, create_user, 
delete_roles, delete_user
 from tests.test_utils.config import conf_vars
@@ -1440,7 +1440,7 @@ class TestClearDagRun(TestDagRunEndpoint):
         with dag_maker(dag_id) as dag:
             task = EmptyOperator(task_id="task_id", dag=dag)
         self.app.dag_bag.bag_dag(dag, root_dag=dag)
-        dr = dag_maker.create_dagrun(run_id=dag_run_id)
+        dr = dag_maker.create_dagrun(run_id=dag_run_id, 
state=DagRunState.FAILED)
         ti = dr.get_task_instance(task_id="task_id")
         ti.task = task
         ti.state = State.SUCCESS
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index f0ef8002c1..ec504ba186 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -27,7 +27,7 @@ from airflow.models.serialized_dag import SerializedDagModel
 from airflow.operators.empty import EmptyOperator
 from airflow.sensors.python import PythonSensor
 from airflow.utils.session import create_session
-from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 from tests.models import DEFAULT_DATE
 from tests.test_utils import db
@@ -132,7 +132,7 @@ class TestClearTasks:
         assert ti0.next_kwargs is None
 
     @pytest.mark.parametrize(
-        ["state", "last_scheduling"], [(State.QUEUED, None), (State.RUNNING, 
DEFAULT_DATE)]
+        ["state", "last_scheduling"], [(DagRunState.QUEUED, None), 
(DagRunState.RUNNING, DEFAULT_DATE)]
     )
     def test_clear_task_instances_dr_state(self, state, last_scheduling, 
dag_maker):
         """Test that DR state is set to None after clear.
@@ -147,7 +147,7 @@ class TestClearTasks:
             EmptyOperator(task_id="0")
             EmptyOperator(task_id="1", retries=2)
         dr = dag_maker.create_dagrun(
-            state=State.RUNNING,
+            state=DagRunState.SUCCESS,
             run_type=DagRunType.SCHEDULED,
         )
         ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
@@ -168,9 +168,91 @@ class TestClearTasks:
         session.refresh(dr)
 
         assert dr.state == state
-        assert dr.start_date is None if state == State.QUEUED else 
dr.start_date
+        assert dr.start_date is None if state == DagRunState.QUEUED else 
dr.start_date
         assert dr.last_scheduling_decision == last_scheduling
 
+    @pytest.mark.parametrize("state", [DagRunState.QUEUED, 
DagRunState.RUNNING])
+    def test_clear_task_instances_on_running_dr(self, state, dag_maker):
+        """Test that DagRun state, start_date and last_scheduling_decision
+        are not changed after clearing TI in an unfinished DagRun.
+        """
+        with dag_maker(
+            "test_clear_task_instances",
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+        ) as dag:
+            EmptyOperator(task_id="0")
+            EmptyOperator(task_id="1", retries=2)
+        dr = dag_maker.create_dagrun(
+            state=state,
+            run_type=DagRunType.SCHEDULED,
+        )
+        ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
+        dr.last_scheduling_decision = DEFAULT_DATE
+        ti0.state = TaskInstanceState.SUCCESS
+        ti1.state = TaskInstanceState.SUCCESS
+        session = dag_maker.session
+        session.flush()
+
+        # we use order_by(task_id) here because for the test DAG structure of 
ours
+        # this is equivalent to topological sort. It would not work in general 
case
+        # but it works for our case because we specifically constructed test 
DAGS
+        # in the way that those two sort methods are equivalent
+        qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+        clear_task_instances(qry, session, dag=dag)
+        session.flush()
+
+        session.refresh(dr)
+
+        assert dr.state == state
+        assert dr.start_date
+        assert dr.last_scheduling_decision == DEFAULT_DATE
+
+    @pytest.mark.parametrize(
+        ["state", "last_scheduling"],
+        [
+            (DagRunState.SUCCESS, None),
+            (DagRunState.SUCCESS, DEFAULT_DATE),
+            (DagRunState.FAILED, None),
+            (DagRunState.FAILED, DEFAULT_DATE),
+        ],
+    )
+    def test_clear_task_instances_on_finished_dr(self, state, last_scheduling, 
dag_maker):
+        """Test that DagRun state, start_date and last_scheduling_decision
+        are changed after clearing TI in a finished DagRun.
+        """
+        with dag_maker(
+            "test_clear_task_instances",
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+        ) as dag:
+            EmptyOperator(task_id="0")
+            EmptyOperator(task_id="1", retries=2)
+        dr = dag_maker.create_dagrun(
+            state=state,
+            run_type=DagRunType.SCHEDULED,
+        )
+        ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
+        dr.last_scheduling_decision = DEFAULT_DATE
+        ti0.state = TaskInstanceState.SUCCESS
+        ti1.state = TaskInstanceState.SUCCESS
+        session = dag_maker.session
+        session.flush()
+
+        # we use order_by(task_id) here because for the test DAG structure of 
ours
+        # this is equivalent to topological sort. It would not work in general 
case
+        # but it works for our case because we specifically constructed test 
DAGS
+        # in the way that those two sort methods are equivalent
+        qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+        clear_task_instances(qry, session, dag=dag)
+        session.flush()
+
+        session.refresh(dr)
+
+        assert dr.state == DagRunState.QUEUED
+        assert dr.start_date is None
+        assert dr.last_scheduling_decision is None
+
     def test_clear_task_instances_without_task(self, dag_maker):
         with dag_maker(
             "test_clear_task_instances_without_task",
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 5ffbdaae02..8df2e9e0c1 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -81,6 +81,7 @@ class TestDagRun:
         task_states: Mapping[str, TaskInstanceState] | None = None,
         execution_date: datetime.datetime | None = None,
         is_backfill: bool = False,
+        state: DagRunState = DagRunState.RUNNING,
         session: Session,
     ):
         now = timezone.utcnow()
@@ -98,7 +99,7 @@ class TestDagRun:
             execution_date=execution_date,
             data_interval=data_interval,
             start_date=now,
-            state=DagRunState.RUNNING,
+            state=state,
             external_trigger=False,
         )
 
@@ -110,11 +111,30 @@ class TestDagRun:
 
         return dag_run
 
-    def test_clear_task_instances_for_backfill_dagrun(self, session):
+    @pytest.mark.parametrize("state", [DagRunState.QUEUED, 
DagRunState.RUNNING])
+    def test_clear_task_instances_for_backfill_unfinished_dagrun(self, state, 
session):
+        now = timezone.utcnow()
+        dag_id = "test_clear_task_instances_for_backfill_dagrun"
+        dag = DAG(dag_id=dag_id, start_date=now)
+        dag_run = self.create_dag_run(dag, execution_date=now, 
is_backfill=True, state=state, session=session)
+
+        task0 = EmptyOperator(task_id="backfill_task_0", owner="test", dag=dag)
+        ti0 = TI(task=task0, run_id=dag_run.run_id)
+        ti0.run()
+
+        qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
+        clear_task_instances(qry, session)
+        session.commit()
+        ti0.refresh_from_db()
+        dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, 
DagRun.execution_date == now).first()
+        assert dr0.state == state
+
+    @pytest.mark.parametrize("state", [DagRunState.SUCCESS, 
DagRunState.FAILED])
+    def test_clear_task_instances_for_backfill_finished_dagrun(self, state, 
session):
         now = timezone.utcnow()
         dag_id = "test_clear_task_instances_for_backfill_dagrun"
         dag = DAG(dag_id=dag_id, start_date=now)
-        dag_run = self.create_dag_run(dag, execution_date=now, 
is_backfill=True, session=session)
+        dag_run = self.create_dag_run(dag, execution_date=now, 
is_backfill=True, state=state, session=session)
 
         task0 = EmptyOperator(task_id="backfill_task_0", owner="test", dag=dag)
         ti0 = TI(task=task0, run_id=dag_run.run_id)

Reply via email to