This is an automated email from the ASF dual-hosted git repository.

utkarsharma pushed a commit to branch sync_2-10-test-rc2
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit b8ab1cc043a7a04c9422f83effe19bfc7b7cd8f5
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Fri Nov 1 13:12:19 2024 +0530

    Mark all tasks as skipped when failing a dag_run manually including t… 
(#43572)
    
    * Mark all tasks as skipped when failing a dag_run manually including tasks 
with None state (#43482)
    
    (cherry picked from commit eda6a8fcf009eb224ec556f7117a97965dbd4dd5)
    
    * Fix tests for 2.10.x
    
    ---------
    
    Co-authored-by: Abhishek <[email protected]>
    (cherry picked from commit 72eef0f85ef22e82c673c16914229c8838216f7f)
---
 airflow/api/common/mark_tasks.py     | 11 ++++--
 tests/www/views/test_views_dagrun.py | 74 ++++++++++++++++++++++++++++++++++++
 2 files changed, 82 insertions(+), 3 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index fa6ce835a9..58ca737a57 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -21,7 +21,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple
 
-from sqlalchemy import or_, select
+from sqlalchemy import and_, or_, select
 from sqlalchemy.orm import lazyload
 
 from airflow.models.dagrun import DagRun
@@ -500,8 +500,13 @@ def set_dag_run_state_to_failed(
         select(TaskInstance).filter(
             TaskInstance.dag_id == dag.dag_id,
             TaskInstance.run_id == run_id,
-            TaskInstance.state.not_in(State.finished),
-            TaskInstance.state.not_in(running_states),
+            or_(
+                TaskInstance.state.is_(None),
+                and_(
+                    TaskInstance.state.not_in(State.finished),
+                    TaskInstance.state.not_in(running_states),
+                ),
+            ),
         )
     ).all()
 
diff --git a/tests/www/views/test_views_dagrun.py 
b/tests/www/views/test_views_dagrun.py
index b7e048e0ea..7d9f9d73ab 100644
--- a/tests/www/views/test_views_dagrun.py
+++ b/tests/www/views/test_views_dagrun.py
@@ -290,3 +290,77 @@ def test_dag_runs_queue_new_tasks_action(session, 
admin_client, completed_dag_ru
     check_content_in_response("runme_2", resp)
     check_content_not_in_response("runme_1", resp)
     assert resp.status_code == 200
+
+
[email protected]
+def dag_run_with_all_done_task(session):
+    """Creates a DAG run for example_bash_decorator with tasks in various 
states and an ALL_DONE task not yet run."""
+    dag = DagBag().get_dag("example_bash_decorator")
+
+    # Re-sync the DAG to the DB
+    dag.sync_to_db()
+
+    execution_date = timezone.datetime(2016, 1, 9)
+    dr = dag.create_dagrun(
+        state="running",
+        execution_date=execution_date,
+        data_interval=(execution_date, execution_date),
+        run_id="test_dagrun_failed",
+        session=session,
+    )
+
+    # Create task instances in various states to test the ALL_DONE trigger rule
+    tis = [
+        # runme_loop tasks
+        TaskInstance(dag.get_task("runme_0"), run_id=dr.run_id, 
state="success"),
+        TaskInstance(dag.get_task("runme_1"), run_id=dr.run_id, 
state="failed"),
+        TaskInstance(dag.get_task("runme_2"), run_id=dr.run_id, 
state="running"),
+        # Other tasks before run_this_last
+        TaskInstance(dag.get_task("run_after_loop"), run_id=dr.run_id, 
state="success"),
+        TaskInstance(dag.get_task("also_run_this"), run_id=dr.run_id, 
state="success"),
+        TaskInstance(dag.get_task("also_run_this_again"), run_id=dr.run_id, 
state="skipped"),
+        TaskInstance(dag.get_task("this_will_skip"), run_id=dr.run_id, 
state="running"),
+        # The task with trigger_rule=ALL_DONE
+        TaskInstance(dag.get_task("run_this_last"), run_id=dr.run_id, 
state=None),
+    ]
+    session.bulk_save_objects(tis)
+    session.commit()
+
+    return dag, dr
+
+
+def test_dagrun_failed(session, admin_client, dag_run_with_all_done_task):
+    """Test marking a dag run as failed with a task having 
trigger_rule='all_done'"""
+    dag, dr = dag_run_with_all_done_task
+
+    # Verify task instances were created
+    task_instances = (
+        session.query(TaskInstance)
+        .filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == 
dr.run_id)
+        .all()
+    )
+    assert len(task_instances) > 0
+
+    resp = admin_client.post(
+        "/dagrun_failed",
+        data={"dag_id": dr.dag_id, "dag_run_id": dr.run_id, "confirmed": 
"true"},
+        follow_redirects=True,
+    )
+
+    assert resp.status_code == 200
+
+    with create_session() as session:
+        updated_dr = (
+            session.query(DagRun).filter(DagRun.dag_id == dr.dag_id, 
DagRun.run_id == dr.run_id).first()
+        )
+        assert updated_dr.state == "failed"
+
+        task_instances = (
+            session.query(TaskInstance)
+            .filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == 
dr.run_id)
+            .all()
+        )
+
+        done_states = {"success", "failed", "skipped", "upstream_failed"}
+        for ti in task_instances:
+            assert ti.state in done_states

Reply via email to