This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-6-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 403cf86b561b58d65cb0bff7d5884434670c5794
Author: Sean Rose <[email protected]>
AuthorDate: Fri Apr 14 10:17:38 2023 -0700

    When clearing task instances try to get associated DAGs from database 
(#29065)
    
    * When clearing task instances try to get associated DAGs from database.
    
    This fixes problems when recursively clearing task instances across 
multiple DAGs:
      * Task instances in downstream DAGs weren't having their `max_tries` 
property incremented, which could cause downstream external task sensors in 
reschedule mode to instantly time out (issue #29049).
      * Task instances in downstream DAGs could have some of their properties 
overridden by an unrelated task in the upstream DAG if they had the same task 
ID.
    
    * Use session fixture for new `test_clear_task_instances_without_dag_param` 
test.
    
    * Use session fixture for new `test_clear_task_instances_in_multiple_dags` 
test.
    
    ---------
    
    Co-authored-by: eladkal <[email protected]>
    (cherry picked from commit 0d2e6dce709acebdb46288faef17d322196f29a2)
---
 airflow/models/taskinstance.py  | 11 +++--
 tests/models/test_cleartasks.py | 97 +++++++++++++++++++++++++++++++++++++++--
 2 files changed, 101 insertions(+), 7 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b02076c076..d4e37741c3 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -87,6 +87,7 @@ from airflow.exceptions import (
 )
 from airflow.listeners.listener import get_listener_manager
 from airflow.models.base import Base, StringID
+from airflow.models.dagbag import DagBag
 from airflow.models.log import Log
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.param import process_params
@@ -203,6 +204,7 @@ def clear_task_instances(
     task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = 
defaultdict(
         lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
     )
+    dag_bag = DagBag(read_dags_from_db=True)
     for ti in tis:
         if ti.state == TaskInstanceState.RUNNING:
             if ti.job_id:
@@ -211,15 +213,16 @@ def clear_task_instances(
                 ti.state = TaskInstanceState.RESTARTING
                 job_ids.append(ti.job_id)
         else:
+            ti_dag = dag if dag and dag.dag_id == ti.dag_id else 
dag_bag.get_dag(ti.dag_id, session=session)
             task_id = ti.task_id
-            if dag and dag.has_task(task_id):
-                task = dag.get_task(task_id)
+            if ti_dag and ti_dag.has_task(task_id):
+                task = ti_dag.get_task(task_id)
                 ti.refresh_from_task(task)
                 task_retries = task.retries
                 ti.max_tries = ti.try_number + task_retries - 1
             else:
-                # Ignore errors when updating max_tries if dag is None or
-                # task not found in dag since database records could be
+                # Ignore errors when updating max_tries if the DAG or
+                # task are not found since database records could be
                 # outdated. We make max_tries the maximum value of its
                 # original max_tries or the last attempted try number.
                 ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries)
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index ebec03d94e..f0ef8002c1 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -23,6 +23,7 @@ import pytest
 
 from airflow import settings
 from airflow.models import DAG, TaskInstance as TI, TaskReschedule, 
clear_task_instances
+from airflow.models.serialized_dag import SerializedDagModel
 from airflow.operators.empty import EmptyOperator
 from airflow.sensors.python import PythonSensor
 from airflow.utils.session import create_session
@@ -202,9 +203,9 @@ class TestClearTasks:
             # but it works for our case because we specifically constructed 
test DAGS
             # in the way that those two sort methods are equivalent
             qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
-            clear_task_instances(qry, session)
+            clear_task_instances(qry, session, dag=dag)
 
-        # When dag is None, max_tries will be maximum of original max_tries or 
try_number.
+        # When no task is found, max_tries will be maximum of original 
max_tries or try_number.
         ti0.refresh_from_db()
         ti1.refresh_from_db()
         # Next try to run will be try 2
@@ -214,6 +215,7 @@ class TestClearTasks:
         assert ti1.max_tries == 2
 
     def test_clear_task_instances_without_dag(self, dag_maker):
+        # Don't write DAG to the database, so no DAG is found by 
clear_task_instances().
         with dag_maker(
             "test_clear_task_instances_without_dag",
             start_date=DEFAULT_DATE,
@@ -242,7 +244,7 @@ class TestClearTasks:
             qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
             clear_task_instances(qry, session)
 
-        # When dag is None, max_tries will be maximum of original max_tries or 
try_number.
+        # When no DAG is found, max_tries will be maximum of original 
max_tries or try_number.
         ti0.refresh_from_db()
         ti1.refresh_from_db()
         # Next try to run will be try 2
@@ -251,6 +253,95 @@ class TestClearTasks:
         assert ti1.try_number == 2
         assert ti1.max_tries == 2
 
+    def test_clear_task_instances_without_dag_param(self, dag_maker, session):
+        with dag_maker(
+            "test_clear_task_instances_without_dag_param",
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+            session=session,
+        ) as dag:
+            task0 = EmptyOperator(task_id="task0")
+            task1 = EmptyOperator(task_id="task1", retries=2)
+
+        # Write DAG to the database so it can be found by 
clear_task_instances().
+        SerializedDagModel.write_dag(dag, session=session)
+
+        dr = dag_maker.create_dagrun(
+            state=State.RUNNING,
+            run_type=DagRunType.SCHEDULED,
+        )
+
+        ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
+        ti0.refresh_from_task(task0)
+        ti1.refresh_from_task(task1)
+
+        ti0.run(session=session)
+        ti1.run(session=session)
+
+        # we use order_by(task_id) here because for the test DAG structure of 
ours
+        # this is equivalent to topological sort. It would not work in general 
case
+        # but it works for our case because we specifically constructed test 
DAGS
+        # in the way that those two sort methods are equivalent
+        qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+        clear_task_instances(qry, session)
+
+        ti0.refresh_from_db(session=session)
+        ti1.refresh_from_db(session=session)
+        # Next try to run will be try 2
+        assert ti0.try_number == 2
+        assert ti0.max_tries == 1
+        assert ti1.try_number == 2
+        assert ti1.max_tries == 3
+
+    def test_clear_task_instances_in_multiple_dags(self, dag_maker, session):
+        with dag_maker(
+            "test_clear_task_instances_in_multiple_dags0",
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+            session=session,
+        ) as dag0:
+            task0 = EmptyOperator(task_id="task0")
+
+        dr0 = dag_maker.create_dagrun(
+            state=State.RUNNING,
+            run_type=DagRunType.SCHEDULED,
+        )
+
+        with dag_maker(
+            "test_clear_task_instances_in_multiple_dags1",
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+            session=session,
+        ) as dag1:
+            task1 = EmptyOperator(task_id="task1", retries=2)
+
+        # Write secondary DAG to the database so it can be found by 
clear_task_instances().
+        SerializedDagModel.write_dag(dag1, session=session)
+
+        dr1 = dag_maker.create_dagrun(
+            state=State.RUNNING,
+            run_type=DagRunType.SCHEDULED,
+        )
+
+        ti0 = dr0.task_instances[0]
+        ti1 = dr1.task_instances[0]
+        ti0.refresh_from_task(task0)
+        ti1.refresh_from_task(task1)
+
+        ti0.run(session=session)
+        ti1.run(session=session)
+
+        qry = session.query(TI).filter(TI.dag_id.in_((dag0.dag_id, 
dag1.dag_id))).all()
+        clear_task_instances(qry, session, dag=dag0)
+
+        ti0.refresh_from_db(session=session)
+        ti1.refresh_from_db(session=session)
+        # Next try to run will be try 2
+        assert ti0.try_number == 2
+        assert ti0.max_tries == 1
+        assert ti1.try_number == 2
+        assert ti1.max_tries == 3
+
     def test_clear_task_instances_with_task_reschedule(self, dag_maker):
         """Test that TaskReschedules are deleted correctly when TaskInstances 
are cleared"""
 

Reply via email to