This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-0-test by this push:
     new 6008c06ae73 [v3-0-test] Mark DagRun as success when no teardown tasks 
are running (#49752) (#50019)
6008c06ae73 is described below

commit 6008c06ae731006a4d8ebf34a5a112dabbe31748
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Apr 30 15:39:44 2025 +0200

    [v3-0-test] Mark DagRun as success when no teardown tasks are running 
(#49752) (#50019)
    
    * Mark DagRun as success when no teardown tasks are running
    
    * Change to unfinished teardown tasks, modified unit tests
    
    * Fix in unit tests
    
    * Fix pytest.xdist issues
    
    * Only select task_ids and convert to set, update unit tests
    (cherry picked from commit b552bca55dd9af9125f5046e515ab76b4fd9b940)
    
    Co-authored-by: Renze Post <[email protected]>
---
 airflow-core/src/airflow/api/common/mark_tasks.py  | 24 ++++++++---
 .../tests/unit/api/common/test_mark_tasks.py       | 46 +++++++++++++++++++++-
 2 files changed, 63 insertions(+), 7 deletions(-)

diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py 
b/airflow-core/src/airflow/api/common/mark_tasks.py
index c957a5cd53a..ad8c7fe4928 100644
--- a/airflow-core/src/airflow/api/common/mark_tasks.py
+++ b/airflow-core/src/airflow/api/common/mark_tasks.py
@@ -215,17 +215,31 @@ def set_dag_run_state_to_success(
     if not run_id:
         raise ValueError(f"Invalid dag_run_id: {run_id}")
 
-    # Mark all task instances of the dag run to success - except for teardown 
as they need to complete work.
+    # Mark all task instances of the dag run to success - except for 
unfinished teardown as they need to complete work.
     normal_tasks = [task for task in dag.tasks if not task.is_teardown]
+    teardown_tasks = [task for task in dag.tasks if task.is_teardown]
+    unfinished_teardown_task_ids = set(
+        session.scalars(
+            select(TaskInstance.task_id).where(
+                TaskInstance.dag_id == dag.dag_id,
+                TaskInstance.run_id == run_id,
+                TaskInstance.task_id.in_([task.task_id for task in 
teardown_tasks]),
+                or_(TaskInstance.state.is_(None), 
TaskInstance.state.in_(State.unfinished)),
+            )
+        ).all()
+    )
 
-    # Mark the dag run to success.
-    if commit and len(normal_tasks) == len(dag.tasks):
+    # Mark the dag run to success if there are no unfinished teardown tasks.
+    if commit and len(unfinished_teardown_task_ids) == 0:
         _set_dag_run_state(dag.dag_id, run_id, DagRunState.SUCCESS, session)
 
-    for task in normal_tasks:
+    tasks_to_mark_success = normal_tasks + [
+        task for task in teardown_tasks if task.task_id not in 
unfinished_teardown_task_ids
+    ]
+    for task in tasks_to_mark_success:
         task.dag = dag
     return set_state(
-        tasks=normal_tasks,
+        tasks=tasks_to_mark_success,
         run_id=run_id,
         state=TaskInstanceState.SUCCESS,
         commit=commit,
diff --git a/airflow-core/tests/unit/api/common/test_mark_tasks.py 
b/airflow-core/tests/unit/api/common/test_mark_tasks.py
index 59fca1f4caa..35dfe7b5879 100644
--- a/airflow-core/tests/unit/api/common/test_mark_tasks.py
+++ b/airflow-core/tests/unit/api/common/test_mark_tasks.py
@@ -19,10 +19,12 @@ from __future__ import annotations
 from typing import TYPE_CHECKING
 
 import pytest
+from sqlalchemy import select
 
 from airflow.api.common.mark_tasks import set_dag_run_state_to_failed, 
set_dag_run_state_to_success
+from airflow.models.dagrun import DagRun
 from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.utils.state import TaskInstanceState
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.models.taskinstance import TaskInstance
@@ -54,23 +56,63 @@ def test_set_dag_run_state_to_failed(dag_maker: DagMaker):
     assert "teardown" not in task_dict
 
 
-def test_set_dag_run_state_to_success(dag_maker: DagMaker):
[email protected](
+    "unfinished_state", sorted([state for state in State.unfinished if state 
is not None])
+)
+def test_set_dag_run_state_to_success_unfinished_teardown(dag_maker: DagMaker, 
unfinished_state):
     with dag_maker("TEST_DAG_1"):
         with EmptyOperator(task_id="teardown").as_teardown():
             EmptyOperator(task_id="running")
             EmptyOperator(task_id="pending")
+
     dr = dag_maker.create_dagrun()
     for ti in dr.get_task_instances():
         if ti.task_id == "running":
             ti.set_state(TaskInstanceState.RUNNING)
+        if ti.task_id == "teardown":
+            ti.set_state(unfinished_state)
+
     dag_maker.session.flush()
     assert dr.dag
+    assert dr.state == DagRunState.RUNNING
 
     updated_tis: list[TaskInstance] = set_dag_run_state_to_success(
         dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session
     )
+    run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id, 
run_id=dr.run_id))
+    assert run.state != DagRunState.SUCCESS
     assert len(updated_tis) == 2
     task_dict = {ti.task_id: ti for ti in updated_tis}
     assert task_dict["running"].state == TaskInstanceState.SUCCESS
     assert task_dict["pending"].state == TaskInstanceState.SUCCESS
     assert "teardown" not in task_dict
+
+
[email protected]("finished_state", sorted(list(State.finished)))
+def test_set_dag_run_state_to_success_finished_teardown(dag_maker: DagMaker, 
finished_state):
+    with dag_maker("TEST_DAG_1"):
+        with EmptyOperator(task_id="teardown").as_teardown():
+            EmptyOperator(task_id="failed")
+    dr = dag_maker.create_dagrun()
+    for ti in dr.get_task_instances():
+        if ti.task_id == "failed":
+            ti.set_state(TaskInstanceState.FAILED)
+        if ti.task_id == "teardown":
+            ti.set_state(finished_state)
+    dag_maker.session.flush()
+    dr.set_state(DagRunState.FAILED)
+    assert dr.dag
+
+    updated_tis: list[TaskInstance] = set_dag_run_state_to_success(
+        dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session
+    )
+    run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id, 
run_id=dr.run_id))
+    assert run.state == DagRunState.SUCCESS
+    if finished_state == TaskInstanceState.SUCCESS:
+        assert len(updated_tis) == 1
+    else:
+        assert len(updated_tis) == 2
+    task_dict = {ti.task_id: ti for ti in updated_tis}
+    assert task_dict["failed"].state == TaskInstanceState.SUCCESS
+    if finished_state != TaskInstanceState.SUCCESS:
+        assert task_dict["teardown"].state == TaskInstanceState.SUCCESS

Reply via email to