This is an automated email from the ASF dual-hosted git repository.
utkarsharma pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-10-test by this push:
new 72eef0f85e Mark all tasks as skipped when failing a dag_run manually
including t… (#43572)
72eef0f85e is described below
commit 72eef0f85ef22e82c673c16914229c8838216f7f
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Fri Nov 1 13:12:19 2024 +0530
Mark all tasks as skipped when failing a dag_run manually including t…
(#43572)
* Mark all tasks as skipped when failing a dag_run manually including tasks
with None state (#43482)
(cherry picked from commit eda6a8fcf009eb224ec556f7117a97965dbd4dd5)
* Fix tests for 2.10.x
---------
Co-authored-by: Abhishek <[email protected]>
---
airflow/api/common/mark_tasks.py | 11 ++++--
tests/www/views/test_views_dagrun.py | 74 ++++++++++++++++++++++++++++++++++++
2 files changed, 82 insertions(+), 3 deletions(-)
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index fa6ce835a9..58ca737a57 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -21,7 +21,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple
-from sqlalchemy import or_, select
+from sqlalchemy import and_, or_, select
from sqlalchemy.orm import lazyload
from airflow.models.dagrun import DagRun
@@ -500,8 +500,13 @@ def set_dag_run_state_to_failed(
select(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
- TaskInstance.state.not_in(State.finished),
- TaskInstance.state.not_in(running_states),
+ or_(
+ TaskInstance.state.is_(None),
+ and_(
+ TaskInstance.state.not_in(State.finished),
+ TaskInstance.state.not_in(running_states),
+ ),
+ ),
)
).all()
diff --git a/tests/www/views/test_views_dagrun.py
b/tests/www/views/test_views_dagrun.py
index b7e048e0ea..7d9f9d73ab 100644
--- a/tests/www/views/test_views_dagrun.py
+++ b/tests/www/views/test_views_dagrun.py
@@ -290,3 +290,77 @@ def test_dag_runs_queue_new_tasks_action(session,
admin_client, completed_dag_ru
check_content_in_response("runme_2", resp)
check_content_not_in_response("runme_1", resp)
assert resp.status_code == 200
+
+
[email protected]
+def dag_run_with_all_done_task(session):
+ """Creates a DAG run for example_bash_decorator with tasks in various
states and an ALL_DONE task not yet run."""
+ dag = DagBag().get_dag("example_bash_decorator")
+
+ # Re-sync the DAG to the DB
+ dag.sync_to_db()
+
+ execution_date = timezone.datetime(2016, 1, 9)
+ dr = dag.create_dagrun(
+ state="running",
+ execution_date=execution_date,
+ data_interval=(execution_date, execution_date),
+ run_id="test_dagrun_failed",
+ session=session,
+ )
+
+ # Create task instances in various states to test the ALL_DONE trigger rule
+ tis = [
+ # runme_loop tasks
+ TaskInstance(dag.get_task("runme_0"), run_id=dr.run_id,
state="success"),
+ TaskInstance(dag.get_task("runme_1"), run_id=dr.run_id,
state="failed"),
+ TaskInstance(dag.get_task("runme_2"), run_id=dr.run_id,
state="running"),
+ # Other tasks before run_this_last
+ TaskInstance(dag.get_task("run_after_loop"), run_id=dr.run_id,
state="success"),
+ TaskInstance(dag.get_task("also_run_this"), run_id=dr.run_id,
state="success"),
+ TaskInstance(dag.get_task("also_run_this_again"), run_id=dr.run_id,
state="skipped"),
+ TaskInstance(dag.get_task("this_will_skip"), run_id=dr.run_id,
state="running"),
+ # The task with trigger_rule=ALL_DONE
+ TaskInstance(dag.get_task("run_this_last"), run_id=dr.run_id,
state=None),
+ ]
+ session.bulk_save_objects(tis)
+ session.commit()
+
+ return dag, dr
+
+
+def test_dagrun_failed(session, admin_client, dag_run_with_all_done_task):
+ """Test marking a dag run as failed with a task having
trigger_rule='all_done'"""
+ dag, dr = dag_run_with_all_done_task
+
+ # Verify task instances were created
+ task_instances = (
+ session.query(TaskInstance)
+ .filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id ==
dr.run_id)
+ .all()
+ )
+ assert len(task_instances) > 0
+
+ resp = admin_client.post(
+ "/dagrun_failed",
+ data={"dag_id": dr.dag_id, "dag_run_id": dr.run_id, "confirmed":
"true"},
+ follow_redirects=True,
+ )
+
+ assert resp.status_code == 200
+
+ with create_session() as session:
+ updated_dr = (
+ session.query(DagRun).filter(DagRun.dag_id == dr.dag_id,
DagRun.run_id == dr.run_id).first()
+ )
+ assert updated_dr.state == "failed"
+
+ task_instances = (
+ session.query(TaskInstance)
+ .filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id ==
dr.run_id)
+ .all()
+ )
+
+ done_states = {"success", "failed", "skipped", "upstream_failed"}
+ for ti in task_instances:
+ assert ti.state in done_states