This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 9a7d48c287df5332a3afe6d4df1123516b77b624
Author: Aleksey Kirilishin <[email protected]>
AuthorDate: Fri Feb 9 15:53:04 2024 +0300

    Fix the bug that affected the DAG end date. (#36144)
    
    (cherry picked from commit 9f4f208b5da38bc2e82db682c636ec4fcf7ad617)
---
 airflow/api/common/mark_tasks.py                   |  5 --
 airflow/models/dagrun.py                           | 66 +++++++++++++++++++++-
 .../endpoints/test_dag_run_endpoint.py             |  4 +-
 tests/api_experimental/client/test_local_client.py |  4 +-
 tests/api_experimental/common/test_mark_tasks.py   | 49 +++++++++++-----
 tests/models/test_cleartasks.py                    |  5 +-
 6 files changed, 107 insertions(+), 26 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 3cc6dfdfd7..a175a61e20 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -366,11 +366,6 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: 
DagRunState, session: SA
         select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
     ).scalar_one()
     dag_run.state = state
-    if state == DagRunState.RUNNING:
-        dag_run.start_date = timezone.utcnow()
-        dag_run.end_date = None
-    else:
-        dag_run.end_date = timezone.utcnow()
     session.merge(dag_run)
 
 
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index c7f48e1692..59cd7e58b7 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -272,11 +272,75 @@ class DagRun(Base, LoggingMixin):
         return self._state
 
     def set_state(self, state: DagRunState) -> None:
+        """Change the state of the DagRan.
+
+        Changes to attributes are implemented in accordance with the following 
table
+        (rows represent old states, columns represent new states):
+
+        .. list-table:: State transition matrix
+           :header-rows: 1
+           :stub-columns: 1
+
+           * -
+             - QUEUED
+             - RUNNING
+             - SUCCESS
+             - FAILED
+           * - None
+             - queued_at = timezone.utcnow()
+             - if empty: start_date = timezone.utcnow()
+               end_date = None
+             - end_date = timezone.utcnow()
+             - end_date = timezone.utcnow()
+           * - QUEUED
+             - queued_at = timezone.utcnow()
+             - if empty: start_date = timezone.utcnow()
+               end_date = None
+             - end_date = timezone.utcnow()
+             - end_date = timezone.utcnow()
+           * - RUNNING
+             - queued_at = timezone.utcnow()
+               start_date = None
+               end_date = None
+             -
+             - end_date = timezone.utcnow()
+             - end_date = timezone.utcnow()
+           * - SUCCESS
+             - queued_at = timezone.utcnow()
+               start_date = None
+               end_date = None
+             - start_date = timezone.utcnow()
+               end_date = None
+             -
+             -
+           * - FAILED
+             - queued_at = timezone.utcnow()
+               start_date = None
+               end_date = None
+             - start_date = timezone.utcnow()
+               end_date = None
+             -
+             -
+
+        """
         if state not in State.dag_states:
             raise ValueError(f"invalid DagRun state: {state}")
         if self._state != state:
+            if state == DagRunState.QUEUED:
+                self.queued_at = timezone.utcnow()
+                self.start_date = None
+                self.end_date = None
+            if state == DagRunState.RUNNING:
+                if self._state in State.finished_dr_states:
+                    self.start_date = timezone.utcnow()
+                else:
+                    self.start_date = self.start_date or timezone.utcnow()
+                self.end_date = None
+            if self._state in State.unfinished_dr_states or self._state is 
None:
+                if state in State.finished_dr_states:
+                    self.end_date = timezone.utcnow()
             self._state = state
-            self.end_date = timezone.utcnow() if self._state in 
State.finished_dr_states else None
+        else:
             if state == DagRunState.QUEUED:
                 self.queued_at = timezone.utcnow()
 
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py 
b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index 2c4c393dd3..face8f7d75 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -1442,11 +1442,11 @@ class TestPatchDagRunState(TestDagRunEndpoint):
             "conf": {},
             "dag_id": dag_id,
             "dag_run_id": dag_run_id,
-            "end_date": dr.end_date.isoformat(),
+            "end_date": dr.end_date.isoformat() if state != State.QUEUED else 
None,
             "execution_date": dr.execution_date.isoformat(),
             "external_trigger": False,
             "logical_date": dr.execution_date.isoformat(),
-            "start_date": dr.start_date.isoformat(),
+            "start_date": dr.start_date.isoformat() if state != State.QUEUED 
else None,
             "state": state,
             "data_interval_start": dr.data_interval_start.isoformat(),
             "data_interval_end": dr.data_interval_end.isoformat(),
diff --git a/tests/api_experimental/client/test_local_client.py 
b/tests/api_experimental/client/test_local_client.py
index b02a5a5c42..91a81a0caf 100644
--- a/tests/api_experimental/client/test_local_client.py
+++ b/tests/api_experimental/client/test_local_client.py
@@ -135,13 +135,11 @@ class TestLocalClient:
 
             # test output
             queued_at = pendulum.now()
-            started_at = pendulum.now()
             mock.return_value = DagRun(
                 dag_id=test_dag_id,
                 run_id=run_id,
                 queued_at=queued_at,
                 execution_date=EXECDATE,
-                start_date=started_at,
                 external_trigger=True,
                 state=DagRunState.QUEUED,
                 conf={},
@@ -159,7 +157,7 @@ class TestLocalClient:
                 "last_scheduling_decision": None,
                 "logical_date": EXECDATE,
                 "run_type": DagRunType.MANUAL,
-                "start_date": started_at,
+                "start_date": None,
                 "state": DagRunState.QUEUED,
             }
             dag_run = self.client.trigger_dag(dag_id=test_dag_id)
diff --git a/tests/api_experimental/common/test_mark_tasks.py 
b/tests/api_experimental/common/test_mark_tasks.py
index 47c10fa185..9b28136bba 100644
--- a/tests/api_experimental/common/test_mark_tasks.py
+++ b/tests/api_experimental/common/test_mark_tasks.py
@@ -555,20 +555,28 @@ class TestMarkDAGRun:
         assert dr.get_state() == state
 
     @provide_session
-    def _verify_dag_run_dates(self, dag, date, state, middle_time, 
session=None):
+    def _verify_dag_run_dates(self, dag, date, state, middle_time=None, 
old_end_date=None, session=None):
         # When target state is RUNNING, we should set start_date,
         # otherwise we should set end_date.
         DR = DagRun
         dr = session.query(DR).filter(DR.dag_id == dag.dag_id, 
DR.execution_date == date).one()
         if state == State.RUNNING:
             # Since the DAG is running, the start_date must be updated after 
creation
-            assert dr.start_date > middle_time
+            if middle_time:
+                assert dr.start_date > middle_time
             # If the dag is still running, we don't have an end date
             assert dr.end_date is None
         else:
-            # If the dag is not running, there must be an end time
-            assert dr.start_date < middle_time
-            assert dr.end_date > middle_time
+            # If the dag is not running, there must be an end time,
+            # and the end time must not be changed if it has already been set.
+            if dr.start_date and middle_time:
+                assert dr.start_date < middle_time
+            if dr.end_date:
+                if old_end_date:
+                    assert dr.end_date == old_end_date
+                else:
+                    if middle_time:
+                        assert dr.end_date > middle_time
 
     def test_set_running_dag_run_to_success(self):
         date = self.execution_dates[0]
@@ -599,30 +607,42 @@ class TestMarkDAGRun:
         assert dr.get_task_instance("run_after_loop").state == State.FAILED
         self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time)
 
-    @pytest.mark.parametrize(
-        "dag_run_alter_function, new_state",
-        [(set_dag_run_state_to_running, State.RUNNING), 
(set_dag_run_state_to_queued, State.QUEUED)],
-    )
-    def test_set_running_dag_run_to_activate_state(self, 
dag_run_alter_function: Callable, new_state: State):
+    def test_set_running_dag_run_to_running_state(self):
+        date = self.execution_dates[0]  # type: ignore
+        dr = self._create_test_dag_run(State.RUNNING, date)
+        self._set_default_task_instance_states(dr)
+
+        altered = set_dag_run_state_to_running(dag=self.dag1, 
run_id=dr.run_id, commit=True)  # type: ignore
+
+        # None of the tasks should be altered, only the dag itself
+        assert len(altered) == 0
+        new_state = State.RUNNING
+        self._verify_dag_run_state(self.dag1, date, new_state)  # type: ignore
+        self._verify_task_instance_states_remain_default(dr)
+        self._verify_dag_run_dates(self.dag1, date, new_state)  # type: ignore
+
+    def test_set_running_dag_run_to_queued_state(self):
         date = self.execution_dates[0]  # type: ignore
         dr = self._create_test_dag_run(State.RUNNING, date)
         middle_time = timezone.utcnow()
         self._set_default_task_instance_states(dr)
 
-        altered = dag_run_alter_function(dag=self.dag1, run_id=dr.run_id, 
commit=True)  # type: ignore
+        altered = set_dag_run_state_to_queued(dag=self.dag1, run_id=dr.run_id, 
commit=True)  # type: ignore
 
         # None of the tasks should be altered, only the dag itself
         assert len(altered) == 0
+        new_state = State.QUEUED
         self._verify_dag_run_state(self.dag1, date, new_state)  # type: ignore
         self._verify_task_instance_states_remain_default(dr)
         self._verify_dag_run_dates(self.dag1, date, new_state, middle_time)  # 
type: ignore
 
     @pytest.mark.parametrize("completed_state", [State.SUCCESS, State.FAILED])
-    def test_set_success_dag_run_to_success(self, completed_state):
+    def test_set_completed_dag_run_to_success(self, completed_state):
         date = self.execution_dates[0]
         dr = self._create_test_dag_run(completed_state, date)
         middle_time = timezone.utcnow()
         self._set_default_task_instance_states(dr)
+        old_end_date = dr.end_date
 
         altered = set_dag_run_state_to_success(dag=self.dag1, 
run_id=dr.run_id, commit=True)
 
@@ -631,13 +651,14 @@ class TestMarkDAGRun:
         assert len(altered) == expected
         self._verify_dag_run_state(self.dag1, date, State.SUCCESS)
         self._verify_task_instance_states(self.dag1, date, State.SUCCESS)
-        self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time)
+        self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, 
middle_time, old_end_date)
 
     @pytest.mark.parametrize("completed_state", [State.SUCCESS, State.FAILED])
     def test_set_completed_dag_run_to_failed(self, completed_state):
         date = self.execution_dates[0]
         dr = self._create_test_dag_run(completed_state, date)
         middle_time = timezone.utcnow()
+        old_end_date = dr.end_date
         self._set_default_task_instance_states(dr)
 
         altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, 
commit=True)
@@ -646,7 +667,7 @@ class TestMarkDAGRun:
         assert len(altered) == expected
         self._verify_dag_run_state(self.dag1, date, State.FAILED)
         assert dr.get_task_instance("run_after_loop").state == State.FAILED
-        self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time)
+        self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time, 
old_end_date)
 
     @pytest.mark.parametrize(
         "dag_run_alter_function,new_state",
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index ed0232926a..bce9dc4668 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -210,7 +210,10 @@ class TestClearTasks:
         session.refresh(dr)
 
         assert dr.state == state
-        assert dr.start_date
+        if state == DagRunState.QUEUED:
+            assert dr.start_date is None
+        if state == DagRunState.RUNNING:
+            assert dr.start_date
         assert dr.last_scheduling_decision == DEFAULT_DATE
 
     @pytest.mark.parametrize(

Reply via email to