This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b0ce3db4d Add ability to clear downstream tis in "List Task 
Instances" view  (#34529)
5b0ce3db4d is described below

commit 5b0ce3db4d36e2a7f20a78903daf538bbde5e38a
Author: Jean-Eudes Peloye <[email protected]>
AuthorDate: Fri Sep 22 13:54:58 2023 -0400

    Add ability to clear downstream tis in "List Task Instances" view  (#34529)
    
    * Add "clear including downstream" action in task instance view
    
    * Extract logic into helper + support dynamic tasks
    
    * Add unit test
    
    * Restore quick path for ti clear without downstream
    
    * Fix wording
    
    * Call clear_task_instances once per dag + split cleared ti count
    
    * Handle plural
    
    * Update airflow/www/views.py
    
    ---------
    
    Co-authored-by: Jean-Eudes Peloye <[email protected]>
    Co-authored-by: Hussein Awala <[email protected]>
---
 airflow/www/views.py                | 105 ++++++++++++++++++++++++++++++++----
 tests/www/views/test_views_tasks.py |  61 ++++++++++++++++++++-
 2 files changed, 156 insertions(+), 10 deletions(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index 656e52ad49..377b6fbadd 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -5657,6 +5657,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
         "list": "read",
         "delete": "delete",
         "action_clear": "edit",
+        "action_clear_downstream": "edit",
         "action_muldelete": "delete",
         "action_set_running": "edit",
         "action_set_failed": "edit",
@@ -5793,6 +5794,68 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
         "duration": duration_f,
     }
 
+    def _clear_task_instances(
+        self, task_instances: list[TaskInstance], session: Session, 
clear_downstream: bool = False
+    ) -> tuple[int, int]:
+        """
+        Clears task instances, optionally including their downstream 
dependencies.
+
+        :param task_instances: list of TIs to clear
+        :param clear_downstream: should downstream task instances be cleared 
as well?
+
+        :return: a tuple with:
+            - count of cleared task instances actually selected by the user
+            - count of downstream task instances that were additionally cleared
+        """
+        cleared_tis_count = 0
+        cleared_downstream_tis_count = 0
+
+        # Group TIs by dag id in order to call `get_dag` only once per dag
+        tis_grouped_by_dag_id = itertools.groupby(task_instances, lambda ti: 
ti.dag_id)
+
+        for dag_id, dag_tis in tis_grouped_by_dag_id:
+            dag = get_airflow_app().dag_bag.get_dag(dag_id)
+
+            tis_to_clear = list(dag_tis)
+            downstream_tis_to_clear = []
+
+            if clear_downstream:
+                tis_to_clear_grouped_by_dag_run = 
itertools.groupby(tis_to_clear, lambda ti: ti.dag_run)
+
+                for dag_run, dag_run_tis in tis_to_clear_grouped_by_dag_run:
+                    # Determine tasks that are downstream of the cleared TIs 
and fetch associated TIs
+                    # This has to be run for each dag run because the user may 
clear different TIs across runs
+                    task_ids_to_clear = [ti.task_id for ti in dag_run_tis]
+
+                    partial_dag = dag.partial_subset(
+                        task_ids_or_regex=task_ids_to_clear, 
include_downstream=True, include_upstream=False
+                    )
+
+                    downstream_task_ids_to_clear = [
+                        task_id for task_id in partial_dag.task_dict if 
task_id not in task_ids_to_clear
+                    ]
+
+                    # dag.clear returns TIs when in dry run mode
+                    downstream_tis_to_clear.extend(
+                        dag.clear(
+                            start_date=dag_run.execution_date,
+                            end_date=dag_run.execution_date,
+                            task_ids=downstream_task_ids_to_clear,
+                            include_subdags=False,
+                            include_parentdag=False,
+                            session=session,
+                            dry_run=True,
+                        )
+                    )
+
+            # Once all TIs are fetched, perform the actual clearing
+            models.clear_task_instances(tis=tis_to_clear + 
downstream_tis_to_clear, session=session, dag=dag)
+
+            cleared_tis_count += len(tis_to_clear)
+            cleared_downstream_tis_count += len(downstream_tis_to_clear)
+
+        return cleared_tis_count, cleared_downstream_tis_count
+
     @action(
         "clear",
         lazy_gettext("Clear"),
@@ -5806,21 +5869,45 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     @provide_session
     @action_logging
     def action_clear(self, task_instances, session: Session = NEW_SESSION):
-        """Clears the action."""
+        """Clears an arbitrary number of task instances."""
         try:
-            dag_to_tis = collections.defaultdict(list)
-
-            for ti in task_instances:
-                dag = get_airflow_app().dag_bag.get_dag(ti.dag_id)
-                dag_to_tis[dag].append(ti)
+            count, _ = self._clear_task_instances(
+                task_instances=task_instances, session=session, 
clear_downstream=False
+            )
+            session.commit()
+            flash(f"{count} task instance{'s have' if count > 1 else ' has'} 
been cleared")
+        except Exception as e:
+            flash(f'Failed to clear task instances: "{e}"', "error")
 
-            for dag, task_instances_list in dag_to_tis.items():
-                models.clear_task_instances(task_instances_list, session, 
dag=dag)
+        self.update_redirect()
+        return redirect(self.get_redirect())
 
+    @action(
+        "clear_downstream",
+        lazy_gettext("Clear (including downstream tasks)"),
+        lazy_gettext(
+            "Are you sure you want to clear the state of the selected task"
+            " instance(s) and all their downstream dependencies, and set their 
dagruns to the QUEUED state?"
+        ),
+        single=False,
+    )
+    @action_has_dag_edit_access
+    @provide_session
+    @action_logging
+    def action_clear_downstream(self, task_instances, session: Session = 
NEW_SESSION):
+        """Clears an arbitrary number of task instances, including downstream 
dependencies."""
+        try:
+            selected_ti_count, downstream_ti_count = 
self._clear_task_instances(
+                task_instances=task_instances, session=session, 
clear_downstream=True
+            )
             session.commit()
-            flash(f"{len(task_instances)} task instances have been cleared")
+            flash(
+                f"Cleared {selected_ti_count} selected task instance{'s' if 
selected_ti_count > 1 else ''} "
+                f"and {downstream_ti_count} downstream dependencies"
+            )
         except Exception as e:
             flash(f'Failed to clear task instances: "{e}"', "error")
+
         self.update_redirect()
         return redirect(self.get_redirect())
 
diff --git a/tests/www/views/test_views_tasks.py 
b/tests/www/views/test_views_tasks.py
index 01fdc820ae..dfcb8e5d8c 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -24,6 +24,7 @@ import unittest.mock
 import urllib.parse
 from getpass import getuser
 
+import pendulum
 import pytest
 import time_machine
 
@@ -32,12 +33,13 @@ from airflow.exceptions import AirflowException
 from airflow.models import DAG, DagBag, DagModel, TaskFail, TaskInstance, 
TaskReschedule, XCom
 from airflow.models.dagcode import DagCode
 from airflow.operators.bash import BashOperator
+from airflow.operators.empty import EmptyOperator
 from airflow.providers.celery.executors.celery_executor import CeleryExecutor
 from airflow.security import permissions
 from airflow.utils import timezone
 from airflow.utils.log.logging_mixin import ExternalLoggingMixin
 from airflow.utils.session import create_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.types import DagRunType
 from airflow.www.views import TaskInstanceModelView
 from tests.test_utils.api_connexion_utils import create_user, delete_roles, 
delete_user
@@ -857,6 +859,63 @@ def test_task_instance_clear(session, request, 
client_fixture, should_succeed):
     assert state == (State.NONE if should_succeed else initial_state)
 
 
+def test_task_instance_clear_downstream(session, admin_client, dag_maker):
+    """Ensures clearing a task instance clears its downstream dependencies 
exclusively"""
+    with dag_maker(
+        dag_id="test_dag_id",
+        serialized=True,
+        session=session,
+        start_date=pendulum.DateTime(2023, 1, 1, 0, 0, 0, tzinfo=pendulum.UTC),
+    ):
+        EmptyOperator(task_id="task_1") >> EmptyOperator(task_id="task_2")
+        EmptyOperator(task_id="task_3")
+
+    run1 = dag_maker.create_dagrun(
+        run_id="run_1",
+        state=DagRunState.SUCCESS,
+        run_type=DagRunType.SCHEDULED,
+        execution_date=dag_maker.dag.start_date,
+        start_date=dag_maker.dag.start_date,
+        session=session,
+    )
+
+    run2 = dag_maker.create_dagrun(
+        run_id="run_2",
+        state=DagRunState.SUCCESS,
+        run_type=DagRunType.SCHEDULED,
+        execution_date=dag_maker.dag.start_date.add(days=1),
+        start_date=dag_maker.dag.start_date.add(days=1),
+        session=session,
+    )
+
+    for run in (run1, run2):
+        for ti in run.task_instances:
+            ti.state = State.SUCCESS
+
+    # Clear task_1 from dag run 1
+    run1_ti1 = run1.get_task_instance(task_id="task_1")
+    rowid = _get_appbuilder_pk_string(TaskInstanceModelView, run1_ti1)
+    resp = admin_client.post(
+        "/taskinstance/action_post",
+        data={"action": "clear_downstream", "rowid": rowid},
+        follow_redirects=True,
+    )
+    assert resp.status_code == 200
+
+    # Assert that task_1 and task_2 of dag run 1 are cleared, but task_3 is 
left untouched
+    run1_ti1.refresh_from_db(session=session)
+    run1_ti2 = run1.get_task_instance(task_id="task_2")
+    run1_ti3 = run1.get_task_instance(task_id="task_3")
+
+    assert run1_ti1.state == State.NONE
+    assert run1_ti2.state == State.NONE
+    assert run1_ti3.state == State.SUCCESS
+
+    # Assert that task_1 of dag run 2 is left untouched
+    run2_ti1 = run2.get_task_instance(task_id="task_1")
+    assert run2_ti1.state == State.SUCCESS
+
+
 def test_task_instance_clear_failure(admin_client):
     rowid = '["12345"]'  # F.A.B. crashes if the rowid is *too* invalid.
     resp = admin_client.post(

Reply via email to