This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 5620df94692a8802202789b21a05104999f8494c
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Apr 12 16:56:11 2022 +0100

    fixup! Allow marking/clearing mapped taskinstances from the UI
---
 airflow/www/views.py | 168 ++++++++++++++++++++++++---------------------------
 1 file changed, 80 insertions(+), 88 deletions(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index ae0186e493..7be1289144 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -95,7 +95,6 @@ from airflow.api.common.mark_tasks import (
     set_dag_run_state_to_failed,
     set_dag_run_state_to_queued,
     set_dag_run_state_to_success,
-    set_state,
 )
 from airflow.compat.functools import cached_property
 from airflow.configuration import AIRFLOW_CONFIG, conf
@@ -108,7 +107,6 @@ from airflow.models import DAG, Connection, DagModel, 
DagTag, Log, SlaMiss, Task
 from airflow.models.abstractoperator import AbstractOperator
 from airflow.models.dagcode import DagCode
 from airflow.models.dagrun import DagRun, DagRunType
-from airflow.models.operator import Operator
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import TaskInstance
 from airflow.providers_manager import ProvidersManager
@@ -1960,11 +1958,11 @@ class Airflow(AirflowBaseView):
 
     def _clear_dag_tis(
         self,
-        dag: DAG,
+        dag,
         start_date,
         end_date,
         origin,
-        task_ids=None,
+        map_indexes=None,
         recursive=False,
         confirmed=False,
         only_failed=False,
@@ -1973,7 +1971,7 @@ class Airflow(AirflowBaseView):
             count = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                task_ids=task_ids,
+                map_indexes=map_indexes,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -1986,7 +1984,7 @@ class Airflow(AirflowBaseView):
             tis = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                task_ids=task_ids,
+                map_indexes=map_indexes,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -1995,19 +1993,24 @@ class Airflow(AirflowBaseView):
         except AirflowException as ex:
             return redirect_or_json(origin, msg=str(ex), status="error")
 
-        assert isinstance(tis, collections.abc.Iterable)
-        details = [str(t) for t in tis]
-
-        if not details:
-            return redirect_or_json(origin, "No task instances to clear", 
status="error")
+        if not tis:
+            msg = "No task instances to clear"
+            return redirect_or_json(origin, msg, status="error")
         elif request.headers.get('Accept') == 'application/json':
+            details = [str(t) for t in tis]
+
             return htmlsafe_json_dumps(details, separators=(',', ':'))
-        return self.render_template(
-            'airflow/confirm.html',
-            endpoint=None,
-            message="Task instances you are about to clear:",
-            details="\n".join(details),
-        )
+        else:
+            details = "\n".join(str(t) for t in tis)
+
+            response = self.render_template(
+                'airflow/confirm.html',
+                endpoint=None,
+                message="Task instances you are about to clear:",
+                details=details,
+            )
+
+        return response
 
     @expose('/clear', methods=['POST'])
     @auth.has_access(
@@ -2023,11 +2026,9 @@ class Airflow(AirflowBaseView):
         task_id = request.form.get('task_id')
         origin = get_safe_url(request.form.get('origin'))
         dag = current_app.dag_bag.get_dag(dag_id)
-
-        if 'map_index' not in request.form:
-            map_indexes: Optional[List[int]] = None
-        else:
-            map_indexes = request.form.getlist('map_index', type=int)
+        map_indexes = request.form.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         execution_date = request.form.get('execution_date')
         execution_date = timezone.parse(execution_date)
@@ -2047,17 +2048,12 @@ class Airflow(AirflowBaseView):
         end_date = execution_date if not future else None
         start_date = execution_date if not past else None
 
-        if map_indexes is None:
-            task_ids: Union[List[str], List[Tuple[str, int]]] = [task_id]
-        else:
-            task_ids = [(task_id, map_index) for map_index in map_indexes]
-
         return self._clear_dag_tis(
             dag,
             start_date,
             end_date,
             origin,
-            task_ids=task_ids,
+            map_indexes=map_indexes,
             recursive=recursive,
             confirmed=confirmed,
             only_failed=only_failed,
@@ -2076,6 +2072,9 @@ class Airflow(AirflowBaseView):
         dag_id = request.form.get('dag_id')
         dag_run_id = request.form.get('dag_run_id')
         confirmed = request.form.get('confirmed') == "true"
+        map_indexes = request.form.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         dag = current_app.dag_bag.get_dag(dag_id)
         dr = dag.get_dagrun(run_id=dag_run_id)
@@ -2086,6 +2085,7 @@ class Airflow(AirflowBaseView):
             dag,
             start_date,
             end_date,
+            map_indexes=map_indexes,
             origin=None,
             recursive=True,
             confirmed=confirmed,
@@ -2290,28 +2290,28 @@ class Airflow(AirflowBaseView):
 
     def _mark_task_instance_state(
         self,
-        *,
-        dag_id: str,
-        run_id: str,
-        task_id: str,
-        map_indexes: Optional[List[int]],
-        origin: str,
-        upstream: bool,
-        downstream: bool,
-        future: bool,
-        past: bool,
-        state: TaskInstanceState,
+        dag_id,
+        task_id,
+        map_indexes,
+        origin,
+        dag_run_id,
+        upstream,
+        downstream,
+        future,
+        past,
+        state,
     ):
-        dag: DAG = current_app.dag_bag.get_dag(dag_id)
+        dag = current_app.dag_bag.get_dag(dag_id)
+        latest_execution_date = dag.get_latest_execution_date()
 
-        if not run_id:
-            flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has 
never run", "error")
+        if not latest_execution_date:
+            flash(f"Cannot mark tasks as {state}, seem that dag {dag_id} has 
never run", "error")
             return redirect(origin)
 
         altered = dag.set_task_instance_state(
             task_id=task_id,
-            map_indexes=map_indexes,
-            run_id=run_id,
+            map_index=map_indexes,
+            run_id=dag_run_id,
             state=state,
             upstream=upstream,
             downstream=downstream,
@@ -2338,11 +2338,9 @@ class Airflow(AirflowBaseView):
         dag_run_id = args.get('dag_run_id')
         state = args.get('state')
         origin = args.get('origin')
-
-        if 'map_index' not in args:
-            map_indexes: Optional[List[int]] = None
-        else:
-            map_indexes = args.getlist('map_index', type=int)
+        map_indexes = args.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2375,13 +2373,11 @@ class Airflow(AirflowBaseView):
             msg = f"Cannot mark tasks as {state}, seem that dag {dag_id} has 
never run"
             return redirect_or_json(origin, msg, status='error')
 
-        if map_indexes is None:
-            tasks: Union[List[Operator], List[Tuple[Operator, int]]] = [task]
-        else:
-            tasks = [(task, map_index) for map_index in map_indexes]
+        from airflow.api.common.mark_tasks import set_state
 
         to_be_altered = set_state(
-            tasks=tasks,
+            tasks=[task],
+            map_indexes=map_indexes,
             run_id=dag_run_id,
             upstream=upstream,
             downstream=downstream,
@@ -2419,30 +2415,28 @@ class Airflow(AirflowBaseView):
         args = request.form
         dag_id = args.get('dag_id')
         task_id = args.get('task_id')
-        run_id = args.get('dag_run_id')
-
-        if 'map_index' not in args:
-            map_indexes: Optional[List[int]] = None
-        else:
-            map_indexes = args.getlist('map_index', type=int)
-
         origin = get_safe_url(args.get('origin'))
+        dag_run_id = args.get('dag_run_id')
+        map_indexes = args.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
+
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
         future = to_boolean(args.get('future'))
         past = to_boolean(args.get('past'))
 
         return self._mark_task_instance_state(
-            dag_id=dag_id,
-            run_id=run_id,
-            task_id=task_id,
-            map_indexes=map_indexes,
-            origin=origin,
-            upstream=upstream,
-            downstream=downstream,
-            future=future,
-            past=past,
-            state=TaskInstanceState.FAILED,
+            dag_id,
+            task_id,
+            map_indexes,
+            origin,
+            dag_run_id,
+            upstream,
+            downstream,
+            future,
+            past,
+            State.FAILED,
         )
 
     @expose('/success', methods=['POST'])
@@ -2458,30 +2452,28 @@ class Airflow(AirflowBaseView):
         args = request.form
         dag_id = args.get('dag_id')
         task_id = args.get('task_id')
-        run_id = args.get('dag_run_id')
-
-        if 'map_index' not in args:
-            map_indexes: Optional[List[int]] = None
-        else:
-            map_indexes = args.getlist('map_index', type=int)
-
         origin = get_safe_url(args.get('origin'))
+        dag_run_id = args.get('dag_run_id')
+        map_indexes = args.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
+
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
         future = to_boolean(args.get('future'))
         past = to_boolean(args.get('past'))
 
         return self._mark_task_instance_state(
-            dag_id=dag_id,
-            run_id=run_id,
-            task_id=task_id,
-            map_indexes=map_indexes,
-            origin=origin,
-            upstream=upstream,
-            downstream=downstream,
-            future=future,
-            past=past,
-            state=TaskInstanceState.SUCCESS,
+            dag_id,
+            task_id,
+            map_indexes,
+            origin,
+            dag_run_id,
+            upstream,
+            downstream,
+            future,
+            past,
+            State.SUCCESS,
         )
 
     @expose('/dags/<string:dag_id>')

Reply via email to