This is an automated email from the ASF dual-hosted git repository. utkarsharma pushed a commit to branch sync_2-10-test-rc2 in repository https://gitbox.apache.org/repos/asf/airflow.git
commit b8ab1cc043a7a04c9422f83effe19bfc7b7cd8f5 Author: Utkarsh Sharma <[email protected]> AuthorDate: Fri Nov 1 13:12:19 2024 +0530 Mark all tasks as skipped when failing a dag_run manually including t… (#43572) * Mark all tasks as skipped when failing a dag_run manually including tasks with None state (#43482) (cherry picked from commit eda6a8fcf009eb224ec556f7117a97965dbd4dd5) * Fix tests for 2.10.x --------- Co-authored-by: Abhishek <[email protected]> (cherry picked from commit 72eef0f85ef22e82c673c16914229c8838216f7f) --- airflow/api/common/mark_tasks.py | 11 ++++-- tests/www/views/test_views_dagrun.py | 74 ++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index fa6ce835a9..58ca737a57 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -21,7 +21,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple -from sqlalchemy import or_, select +from sqlalchemy import and_, or_, select from sqlalchemy.orm import lazyload from airflow.models.dagrun import DagRun @@ -500,8 +500,13 @@ def set_dag_run_state_to_failed( select(TaskInstance).filter( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id == run_id, - TaskInstance.state.not_in(State.finished), - TaskInstance.state.not_in(running_states), + or_( + TaskInstance.state.is_(None), + and_( + TaskInstance.state.not_in(State.finished), + TaskInstance.state.not_in(running_states), + ), + ), ) ).all() diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py index b7e048e0ea..7d9f9d73ab 100644 --- a/tests/www/views/test_views_dagrun.py +++ b/tests/www/views/test_views_dagrun.py @@ -290,3 +290,77 @@ def test_dag_runs_queue_new_tasks_action(session, admin_client, completed_dag_ru check_content_in_response("runme_2", resp) check_content_not_in_response("runme_1", resp) assert resp.status_code == 200 + + [email protected] +def dag_run_with_all_done_task(session): + """Creates a DAG run for example_bash_decorator with tasks in various states and an ALL_DONE task not yet run.""" + dag = DagBag().get_dag("example_bash_decorator") + + # Re-sync the DAG to the DB + dag.sync_to_db() + + execution_date = timezone.datetime(2016, 1, 9) + dr = dag.create_dagrun( + state="running", + execution_date=execution_date, + data_interval=(execution_date, execution_date), + run_id="test_dagrun_failed", + session=session, + ) + + # Create task instances in various states to test the ALL_DONE trigger rule + tis = [ + # runme_loop tasks + TaskInstance(dag.get_task("runme_0"), run_id=dr.run_id, state="success"), + TaskInstance(dag.get_task("runme_1"), run_id=dr.run_id, state="failed"), + TaskInstance(dag.get_task("runme_2"), run_id=dr.run_id, state="running"), + # Other tasks before run_this_last + TaskInstance(dag.get_task("run_after_loop"), run_id=dr.run_id, state="success"), + TaskInstance(dag.get_task("also_run_this"), run_id=dr.run_id, state="success"), + TaskInstance(dag.get_task("also_run_this_again"), run_id=dr.run_id, state="skipped"), + TaskInstance(dag.get_task("this_will_skip"), run_id=dr.run_id, state="running"), + # The task with trigger_rule=ALL_DONE + TaskInstance(dag.get_task("run_this_last"), run_id=dr.run_id, state=None), + ] + session.bulk_save_objects(tis) + session.commit() + + return dag, dr + + +def test_dagrun_failed(session, admin_client, dag_run_with_all_done_task): + """Test marking a dag run as failed with a task having trigger_rule='all_done'""" + dag, dr = dag_run_with_all_done_task + + # Verify task instances were created + task_instances = ( + session.query(TaskInstance) + .filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == dr.run_id) + .all() + ) + assert len(task_instances) > 0 + + resp = admin_client.post( + "/dagrun_failed", + data={"dag_id": dr.dag_id, "dag_run_id": dr.run_id, "confirmed": "true"}, + follow_redirects=True, + ) + + assert resp.status_code == 200 + + with create_session() as session: + updated_dr = ( + session.query(DagRun).filter(DagRun.dag_id == dr.dag_id, DagRun.run_id == dr.run_id).first() + ) + assert updated_dr.state == "failed" + + task_instances = ( + session.query(TaskInstance) + .filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == dr.run_id) + .all() + ) + + done_states = {"success", "failed", "skipped", "upstream_failed"} + for ti in task_instances: + assert ti.state in done_states
