This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new eda6a8fcf0 Mark all tasks as skipped when failing a dag_run manually 
including tasks with None state (#43482)
eda6a8fcf0 is described below

commit eda6a8fcf009eb224ec556f7117a97965dbd4dd5
Author: Abhishek <[email protected]>
AuthorDate: Tue Oct 29 23:16:22 2024 +0530

    Mark all tasks as skipped when failing a dag_run manually including tasks 
with None state (#43482)
---
 airflow/api/common/mark_tasks.py     | 11 ++++--
 tests/www/views/test_views_dagrun.py | 76 ++++++++++++++++++++++++++++++++++++
 2 files changed, 84 insertions(+), 3 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index a828d140c9..957e82e7de 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -21,7 +21,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple
 
-from sqlalchemy import or_, select
+from sqlalchemy import and_, or_, select
 from sqlalchemy.orm import lazyload
 
 from airflow.models.dagrun import DagRun
@@ -402,8 +402,13 @@ def set_dag_run_state_to_failed(
         select(TaskInstance).filter(
             TaskInstance.dag_id == dag.dag_id,
             TaskInstance.run_id == run_id,
-            TaskInstance.state.not_in(State.finished),
-            TaskInstance.state.not_in(running_states),
+            or_(
+                TaskInstance.state.is_(None),
+                and_(
+                    TaskInstance.state.not_in(State.finished),
+                    TaskInstance.state.not_in(running_states),
+                ),
+            ),
         )
     ).all()
 
diff --git a/tests/www/views/test_views_dagrun.py 
b/tests/www/views/test_views_dagrun.py
index df55cc952e..b8e5d84351 100644
--- a/tests/www/views/test_views_dagrun.py
+++ b/tests/www/views/test_views_dagrun.py
@@ -307,3 +307,79 @@ def test_dag_runs_queue_new_tasks_action(session, 
admin_client, completed_dag_ru
     check_content_in_response("runme_2", resp)
     check_content_not_in_response("runme_1", resp)
     assert resp.status_code == 200
+
+
[email protected]
+def dag_run_with_all_done_task(session):
+    """Creates a DAG run for example_bash_decorator with tasks in various 
states and an ALL_DONE task not yet run."""
+    dag = DagBag().get_dag("example_bash_decorator")
+
+    # Re-sync the DAG to the DB
+    dag.sync_to_db()
+
+    execution_date = timezone.datetime(2016, 1, 9)
+    triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
+    dr = dag.create_dagrun(
+        state="running",
+        execution_date=execution_date,
+        data_interval=(execution_date, execution_date),
+        run_id="test_dagrun_failed",
+        session=session,
+        **triggered_by_kwargs,
+    )
+
+    # Create task instances in various states to test the ALL_DONE trigger rule
+    tis = [
+        # runme_loop tasks
+        TaskInstance(dag.get_task("runme_0"), run_id=dr.run_id, 
state="success"),
+        TaskInstance(dag.get_task("runme_1"), run_id=dr.run_id, 
state="failed"),
+        TaskInstance(dag.get_task("runme_2"), run_id=dr.run_id, 
state="running"),
+        # Other tasks before run_this_last
+        TaskInstance(dag.get_task("run_after_loop"), run_id=dr.run_id, 
state="success"),
+        TaskInstance(dag.get_task("also_run_this"), run_id=dr.run_id, 
state="success"),
+        TaskInstance(dag.get_task("also_run_this_again"), run_id=dr.run_id, 
state="skipped"),
+        TaskInstance(dag.get_task("this_will_skip"), run_id=dr.run_id, 
state="running"),
+        # The task with trigger_rule=ALL_DONE
+        TaskInstance(dag.get_task("run_this_last"), run_id=dr.run_id, 
state=None),
+    ]
+    session.bulk_save_objects(tis)
+    session.commit()
+
+    return dag, dr
+
+
+def test_dagrun_failed(session, admin_client, dag_run_with_all_done_task):
+    """Test marking a dag run as failed with a task having 
trigger_rule='all_done'"""
+    dag, dr = dag_run_with_all_done_task
+
+    # Verify task instances were created
+    task_instances = (
+        session.query(TaskInstance)
+        .filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == 
dr.run_id)
+        .all()
+    )
+    assert len(task_instances) > 0
+
+    resp = admin_client.post(
+        "/dagrun_failed",
+        data={"dag_id": dr.dag_id, "dag_run_id": dr.run_id, "confirmed": 
"true"},
+        follow_redirects=True,
+    )
+
+    assert resp.status_code == 200
+
+    with create_session() as session:
+        updated_dr = (
+            session.query(DagRun).filter(DagRun.dag_id == dr.dag_id, 
DagRun.run_id == dr.run_id).first()
+        )
+        assert updated_dr.state == "failed"
+
+        task_instances = (
+            session.query(TaskInstance)
+            .filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == 
dr.run_id)
+            .all()
+        )
+
+        done_states = {"success", "failed", "skipped", "upstream_failed"}
+        for ti in task_instances:
+            assert ti.state in done_states

Reply via email to