This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-6-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 403cf86b561b58d65cb0bff7d5884434670c5794 Author: Sean Rose <[email protected]> AuthorDate: Fri Apr 14 10:17:38 2023 -0700 When clearing task instances try to get associated DAGs from database (#29065) * When clearing task instances try to get associated DAGs from database. This fixes problems when recursively clearing task instances across multiple DAGs: * Task instances in downstream DAGs weren't having their `max_tries` property incremented, which could cause downstream external task sensors in reschedule mode to instantly time out (issue #29049). * Task instances in downstream DAGs could have some of their properties overridden by an unrelated task in the upstream DAG if they had the same task ID. * Use session fixture for new `test_clear_task_instances_without_dag_param` test. * Use session fixture for new `test_clear_task_instances_in_multiple_dags` test. --------- Co-authored-by: eladkal <[email protected]> (cherry picked from commit 0d2e6dce709acebdb46288faef17d322196f29a2) --- airflow/models/taskinstance.py | 11 +++-- tests/models/test_cleartasks.py | 97 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 101 insertions(+), 7 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index b02076c076..d4e37741c3 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -87,6 +87,7 @@ from airflow.exceptions import ( ) from airflow.listeners.listener import get_listener_manager from airflow.models.base import Base, StringID +from airflow.models.dagbag import DagBag from airflow.models.log import Log from airflow.models.mappedoperator import MappedOperator from airflow.models.param import process_params @@ -203,6 +204,7 @@ def clear_task_instances( task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict( lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set))) ) + dag_bag = DagBag(read_dags_from_db=True) for ti in tis: if ti.state == TaskInstanceState.RUNNING: if ti.job_id: @@ -211,15 +213,16 @@ def clear_task_instances( ti.state = TaskInstanceState.RESTARTING job_ids.append(ti.job_id) else: + ti_dag = dag if dag and dag.dag_id == ti.dag_id else dag_bag.get_dag(ti.dag_id, session=session) task_id = ti.task_id - if dag and dag.has_task(task_id): - task = dag.get_task(task_id) + if ti_dag and ti_dag.has_task(task_id): + task = ti_dag.get_task(task_id) ti.refresh_from_task(task) task_retries = task.retries ti.max_tries = ti.try_number + task_retries - 1 else: - # Ignore errors when updating max_tries if dag is None or - # task not found in dag since database records could be + # Ignore errors when updating max_tries if the DAG or + # task are not found since database records could be # outdated. We make max_tries the maximum value of its # original max_tries or the last attempted try number. ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries) diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py index ebec03d94e..f0ef8002c1 100644 --- a/tests/models/test_cleartasks.py +++ b/tests/models/test_cleartasks.py @@ -23,6 +23,7 @@ import pytest from airflow import settings from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances +from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator from airflow.sensors.python import PythonSensor from airflow.utils.session import create_session @@ -202,9 +203,9 @@ class TestClearTasks: # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - clear_task_instances(qry, session) + clear_task_instances(qry, session, dag=dag) - # When dag is None, max_tries will be maximum of original max_tries or try_number. + # When no task is found, max_tries will be maximum of original max_tries or try_number. ti0.refresh_from_db() ti1.refresh_from_db() # Next try to run will be try 2 @@ -214,6 +215,7 @@ class TestClearTasks: assert ti1.max_tries == 2 def test_clear_task_instances_without_dag(self, dag_maker): + # Don't write DAG to the database, so no DAG is found by clear_task_instances(). with dag_maker( "test_clear_task_instances_without_dag", start_date=DEFAULT_DATE, @@ -242,7 +244,7 @@ class TestClearTasks: qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() clear_task_instances(qry, session) - # When dag is None, max_tries will be maximum of original max_tries or try_number. + # When no DAG is found, max_tries will be maximum of original max_tries or try_number. ti0.refresh_from_db() ti1.refresh_from_db() # Next try to run will be try 2 @@ -251,6 +253,95 @@ class TestClearTasks: assert ti1.try_number == 2 assert ti1.max_tries == 2 + def test_clear_task_instances_without_dag_param(self, dag_maker, session): + with dag_maker( + "test_clear_task_instances_without_dag_param", + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + session=session, + ) as dag: + task0 = EmptyOperator(task_id="task0") + task1 = EmptyOperator(task_id="task1", retries=2) + + # Write DAG to the database so it can be found by clear_task_instances(). + SerializedDagModel.write_dag(dag, session=session) + + dr = dag_maker.create_dagrun( + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + + ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id) + ti0.refresh_from_task(task0) + ti1.refresh_from_task(task1) + + ti0.run(session=session) + ti1.run(session=session) + + # we use order_by(task_id) here because for the test DAG structure of ours + # this is equivalent to topological sort. It would not work in general case + # but it works for our case because we specifically constructed test DAGS + # in the way that those two sort methods are equivalent + qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() + clear_task_instances(qry, session) + + ti0.refresh_from_db(session=session) + ti1.refresh_from_db(session=session) + # Next try to run will be try 2 + assert ti0.try_number == 2 + assert ti0.max_tries == 1 + assert ti1.try_number == 2 + assert ti1.max_tries == 3 + + def test_clear_task_instances_in_multiple_dags(self, dag_maker, session): + with dag_maker( + "test_clear_task_instances_in_multiple_dags0", + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + session=session, + ) as dag0: + task0 = EmptyOperator(task_id="task0") + + dr0 = dag_maker.create_dagrun( + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + + with dag_maker( + "test_clear_task_instances_in_multiple_dags1", + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + session=session, + ) as dag1: + task1 = EmptyOperator(task_id="task1", retries=2) + + # Write secondary DAG to the database so it can be found by clear_task_instances(). + SerializedDagModel.write_dag(dag1, session=session) + + dr1 = dag_maker.create_dagrun( + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + + ti0 = dr0.task_instances[0] + ti1 = dr1.task_instances[0] + ti0.refresh_from_task(task0) + ti1.refresh_from_task(task1) + + ti0.run(session=session) + ti1.run(session=session) + + qry = session.query(TI).filter(TI.dag_id.in_((dag0.dag_id, dag1.dag_id))).all() + clear_task_instances(qry, session, dag=dag0) + + ti0.refresh_from_db(session=session) + ti1.refresh_from_db(session=session) + # Next try to run will be try 2 + assert ti0.try_number == 2 + assert ti0.max_tries == 1 + assert ti1.try_number == 2 + assert ti1.max_tries == 3 + def test_clear_task_instances_with_task_reschedule(self, dag_maker): """Test that TaskReschedules are deleted correctly when TaskInstances are cleared"""
