This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 4e81a780308a52e16f7fdd2ca8de8e2053636b7c
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Apr 15 12:26:53 2022 +0100

    fixup! Apply suggestions from code review
---
 airflow/api/common/mark_tasks.py | 27 +++++++++-----
 airflow/models/dag.py            | 79 +++++++++++++++++++---------------------
 airflow/www/views.py             | 52 ++++++++------------------
 tests/models/test_dag.py         |  2 +-
 4 files changed, 73 insertions(+), 87 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index a9e4f4812e..594423305c 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -119,10 +119,7 @@ def set_state(
     if execution_date and not timezone.is_localized(execution_date):
         raise ValueError(f"Received non-localized date {execution_date}")
 
-    task_dags = {
-        task[0].dag if isinstance(task, tuple) else task.dag
-        for task in tasks
-    }
+    task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task 
in tasks}
     if len(task_dags) > 1:
         raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
     dag = next(iter(task_dags))
@@ -137,6 +134,12 @@ def set_state(
     dag_run_ids = get_run_ids(dag, run_id, future, past)
     task_id_map_index_list = list(find_task_relatives(tasks, downstream, 
upstream))
     task_ids = [task_id for task_id, _ in task_id_map_index_list]
+    # check if task_id_map_index_list contains map_index of None
+    # if it contains None, there was no map_index supplied for the task
+    for _, index in task_id_map_index_list:
+        if index is None:
+            task_id_map_index_list = [task_id for task_id, _ in 
task_id_map_index_list]
+            break
 
     confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
     confirmed_dates = [info.logical_date for info in confirmed_infos]
@@ -183,20 +186,26 @@ def get_all_dag_task_query(
     dag: DAG,
     session: SASession,
     state: TaskInstanceState,
-    task_id_map_index_list: List[Tuple[str, int]],
+    task_ids: Union[List[str], List[Tuple[str, int]]],
     confirmed_dates: Iterable[datetime],
 ):
     """Get all tasks of the main dag that will be affected by a state change"""
+    is_string_list = isinstance(task_ids[0], str)
     qry_dag = (
         session.query(TaskInstance)
         .join(TaskInstance.dag_run)
         .filter(
             TaskInstance.dag_id == dag.dag_id,
             DagRun.execution_date.in_(confirmed_dates),
-            tuple_(TaskInstance.task_id, 
TaskInstance.map_index).in_(task_id_map_index_list),
         )
-        .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
-        .options(contains_eager(TaskInstance.dag_run))
+    )
+
+    if is_string_list:
+        qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
+    else:
+        qry_dag = qry_dag.filter(tuple_(TaskInstance.task_id, 
TaskInstance.map_index).in_(task_ids))
+    qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), 
TaskInstance.state != state)).options(
+        contains_eager(TaskInstance.dag_run)
     )
     return qry_dag
 
@@ -278,7 +287,7 @@ def find_task_relatives(tasks, downstream, upstream):
         if isinstance(item, tuple):
             task, map_index = item
         else:
-            task, map_index = item, -1
+            task, map_index = item, None
         yield task.task_id, map_index
         if downstream:
             for relative in task.get_flat_relatives(upstream=False):
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8856a841d3..0694f37550 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1349,7 +1349,6 @@ class DAG(LoggingMixin):
                 Query,
                 self._get_task_instances(
                     task_ids=None,
-                    task_ids_and_map_indexes=None,
                     start_date=start_date,
                     end_date=end_date,
                     run_id=None,
@@ -1358,7 +1357,6 @@ class DAG(LoggingMixin):
                     include_parentdag=False,
                     include_dependent_dags=False,
                     exclude_task_ids=cast(List[str], []),
-                    exclude_task_ids_and_map_indexes=None,
                     session=session,
                 ),
             )
@@ -1370,7 +1368,7 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids,
+        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
         task_ids_and_map_indexes,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1379,8 +1377,7 @@ class DAG(LoggingMixin):
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Collection[str],
-        exclude_task_ids_and_map_indexes,
+        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], 
None],
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
     ) -> Iterable[TaskInstance]:
@@ -1390,8 +1387,7 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids,
-        task_ids_and_map_indexes,
+        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
         as_pk_tuple: Literal[True],
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1400,8 +1396,7 @@ class DAG(LoggingMixin):
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Collection[str],
-        exclude_task_ids_and_map_indexes,
+        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], 
None],
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
         recursion_depth: int = ...,
@@ -1413,8 +1408,7 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids,
-        task_ids_and_map_indexes,
+        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
         as_pk_tuple: Literal[True, None] = None,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1423,8 +1417,7 @@ class DAG(LoggingMixin):
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Collection[str],
-        exclude_task_ids_and_map_indexes,
+        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], 
None],
         session: Session,
         dag_bag: Optional["DagBag"] = None,
         recursion_depth: int = 0,
@@ -1448,10 +1441,17 @@ class DAG(LoggingMixin):
             tis = session.query(TaskInstance)
         tis = tis.join(TaskInstance.dag_run)
 
-        if task_ids is not None:  # task not mapped
-            task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
-        if exclude_task_ids and len(exclude_task_ids) > 0:  # task not mapped
-            exclude_task_ids_and_map_indexes = [(task_id, -1) for task_id in 
task_ids]
+        task_ids_and_map_indexes = None
+        if task_ids is not None:
+            task_ids_and_map_indexes = [item for item in task_ids if 
isinstance(item, tuple)]
+        if task_ids_and_map_indexes:
+            task_ids = None  # nullify since we have indexes
+
+        exclude_task_ids_and_map_indexes = None
+        if exclude_task_ids is not None:
+            exclude_task_ids_and_map_indexes = [item for item in 
exclude_task_ids if isinstance(item, tuple)]
+        if exclude_task_ids_and_map_indexes:
+            exclude_task_ids = None
 
         if include_subdags:
             # Crafting the right filter for dag_id and task_ids combo
@@ -1467,6 +1467,8 @@ class DAG(LoggingMixin):
             tis = tis.filter(TaskInstance.run_id == run_id)
         if start_date:
             tis = tis.filter(DagRun.execution_date >= start_date)
+        if task_ids:
+            tis = tis.filter(TaskInstance.task_id.in_(task_ids))
         if task_ids_and_map_indexes:
             tis = tis.filter(
                 tuple_(TaskInstance.task_id, 
TaskInstance.map_index).in_(task_ids_and_map_indexes)
@@ -1509,7 +1511,6 @@ class DAG(LoggingMixin):
             result.update(
                 p_dag._get_task_instances(
                     task_ids=task_ids,
-                    task_ids_and_map_indexes=task_ids_and_map_indexes,
                     start_date=start_date,
                     end_date=end_date,
                     run_id=None,
@@ -1519,7 +1520,6 @@ class DAG(LoggingMixin):
                     include_dependent_dags=include_dependent_dags,
                     as_pk_tuple=True,
                     exclude_task_ids=exclude_task_ids,
-                    
exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
                     session=session,
                     dag_bag=dag_bag,
                     recursion_depth=recursion_depth,
@@ -1588,7 +1588,6 @@ class DAG(LoggingMixin):
                     result.update(
                         downstream._get_task_instances(
                             task_ids=None,
-                            task_ids_and_map_indexes=None,
                             run_id=tii.run_id,
                             start_date=None,
                             end_date=None,
@@ -1598,7 +1597,6 @@ class DAG(LoggingMixin):
                             include_parentdag=False,
                             as_pk_tuple=True,
                             exclude_task_ids=exclude_task_ids,
-                            
exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
                             dag_bag=dag_bag,
                             session=session,
                             recursion_depth=recursion_depth + 1,
@@ -1614,7 +1612,15 @@ class DAG(LoggingMixin):
             else:
                 result.update(ti.key for ti in tis.all())
 
-            if exclude_task_ids_and_map_indexes:
+            if exclude_task_ids is not None:
+                result = set(
+                    filter(
+                        lambda key: key.task_id not in exclude_task_ids,
+                        result,
+                    )
+                )
+
+            if exclude_task_ids_and_map_indexes is not None:
                 result = set(
                     filter(
                         lambda key: (key.task_id, key.map_index) not in 
exclude_task_ids_and_map_indexes,
@@ -1639,7 +1645,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_id: str,
-        map_indexes: Optional[Iterable[int]] = None,
+        map_index: Optional[int] = None,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         state: TaskInstanceState,
@@ -1655,7 +1661,8 @@ class DAG(LoggingMixin):
         in failed or upstream_failed state.
 
         :param task_id: Task ID of the TaskInstance
-        :param map_indexes: Task instance map_index to set the state of
+        :param map_index: The TaskInstance map_index, if None, would set state 
for all mapped
+            TaskInstances of the task
         :param execution_date: Execution date of the TaskInstance
         :param run_id: The run_id of the TaskInstance
         :param state: State to set the TaskInstance to
@@ -1680,10 +1687,8 @@ class DAG(LoggingMixin):
 
         task = self.get_task(task_id)
         task.dag = self
-        if not map_indexes:
-            map_indexes = [-1]
-        task_map_indexes = [(task, map_index) for map_index in map_indexes]
-        task_id_map_indexes = [(task_id, map_index) for map_index in 
map_indexes]
+        task_map_indexes = [(task, map_index)] if map_index else [task]
+        task_id_map_indexes = {(task_id, map_index)} if map_index else 
{task_id}
 
         altered = set_state(
             tasks=task_map_indexes,
@@ -1721,7 +1726,7 @@ class DAG(LoggingMixin):
             only_failed=True,
             session=session,
             # Exclude the task itself from being cleared
-            exclude_task_ids_and_map_indexes=task_id_map_indexes,
+            exclude_task_ids=task_id_map_indexes,
         )
 
         return altered
@@ -1779,8 +1784,7 @@ class DAG(LoggingMixin):
     @provide_session
     def clear(
         self,
-        task_ids=None,
-        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]] = None,
+        task_ids: Union[Iterable[str], Iterable[Tuple[str, int]], None] = None,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
         only_failed: bool = False,
@@ -1795,15 +1799,13 @@ class DAG(LoggingMixin):
         recursion_depth: int = 0,
         max_recursion_depth: Optional[int] = None,
         dag_bag: Optional["DagBag"] = None,
-        exclude_task_ids: FrozenSet[str] = frozenset(),
-        exclude_task_ids_and_map_indexes: FrozenSet[Tuple[str, int]] = 
frozenset({}),
+        exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], 
None] = frozenset(),
     ) -> Union[int, Iterable[TaskInstance]]:
         """
         Clears a set of task instances associated with the current dag for
         a specified date range.
 
-        :param task_ids: List of task ids to clear
-        :param task_ids_and_map_indexes: List of tuple of task_id, map_index 
to clear
+        :param task_ids: List of task ids or (``task_id``, ``map_index``) 
tuples to clear
         :param start_date: The minimum execution_date to clear
         :param end_date: The maximum execution_date to clear
         :param only_failed: Only clear failed tasks
@@ -1817,8 +1819,7 @@ class DAG(LoggingMixin):
         :param dry_run: Find the tasks to clear but don't clear them.
         :param session: The sqlalchemy session to use
         :param dag_bag: The DagBag used to find the dags subdags (Optional)
-        :param exclude_task_ids: A set of ``task_id`` that should not be 
cleared
-        :param exclude_task_ids_and_map_indexes: A set of 
``task_id``,``map_index``
+        :param exclude_task_ids: A set of ``task_id`` or (``task_id``, 
``map_index``)
             tuples that should not be cleared
         """
         if get_tis:
@@ -1848,12 +1849,9 @@ class DAG(LoggingMixin):
         if only_running:
             # Yes, having `+=` doesn't make sense, but this was the existing 
behaviour
             state += [State.RUNNING]
-        if task_ids:
-            task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
 
         tis = self._get_task_instances(
             task_ids=task_ids,
-            task_ids_and_map_indexes=task_ids_and_map_indexes,
             start_date=start_date,
             end_date=end_date,
             run_id=None,
@@ -1864,7 +1862,6 @@ class DAG(LoggingMixin):
             session=session,
             dag_bag=dag_bag,
             exclude_task_ids=exclude_task_ids,
-            exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
         )
 
         if dry_run:
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 2a3ad1e913..437a60cca0 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1962,7 +1962,7 @@ class Airflow(AirflowBaseView):
         start_date,
         end_date,
         origin,
-        task_id_map_index_list=None,
+        task_ids=None,
         recursive=False,
         confirmed=False,
         only_failed=False,
@@ -1971,7 +1971,7 @@ class Airflow(AirflowBaseView):
             count = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                task_ids_and_map_indexes=task_id_map_index_list,
+                task_ids=task_ids,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -1984,7 +1984,7 @@ class Airflow(AirflowBaseView):
             tis = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                task_ids_and_map_indexes=task_id_map_index_list,
+                task_ids=task_ids,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -2026,14 +2026,7 @@ class Airflow(AirflowBaseView):
         task_id = request.form.get('task_id')
         origin = get_safe_url(request.form.get('origin'))
         dag = current_app.dag_bag.get_dag(dag_id)
-
-        map_indexes = request.form.get('map_indexes')
-        if map_indexes:
-            if not isinstance(map_indexes, list):
-                map_indexes = list(map_indexes)
-        else:
-            map_indexes = [-1]
-        task_id_map_indexes = [(task_id, map_index) for map_index in 
map_indexes]
+        map_index = request.form.get('map_index')
 
         execution_date = request.form.get('execution_date')
         execution_date = timezone.parse(execution_date)
@@ -2052,13 +2045,13 @@ class Airflow(AirflowBaseView):
         )
         end_date = execution_date if not future else None
         start_date = execution_date if not past else None
-
+        task_ids = [(task_id, map_index)] if map_index else [task_id]
         return self._clear_dag_tis(
             dag,
             start_date,
             end_date,
             origin,
-            task_id_map_index_list=task_id_map_indexes,
+            task_ids=task_ids,
             recursive=recursive,
             confirmed=confirmed,
             only_failed=only_failed,
@@ -2300,7 +2293,7 @@ class Airflow(AirflowBaseView):
         future,
         past,
         state,
-        map_indexes=None,
+        map_index=None,
     ):
         dag = current_app.dag_bag.get_dag(dag_id)
         latest_execution_date = dag.get_latest_execution_date()
@@ -2311,7 +2304,7 @@ class Airflow(AirflowBaseView):
 
         altered = dag.set_task_instance_state(
             task_id=task_id,
-            map_indexes=map_indexes,
+            map_index=map_index,
             run_id=dag_run_id,
             state=state,
             upstream=upstream,
@@ -2339,12 +2332,7 @@ class Airflow(AirflowBaseView):
         dag_run_id = args.get('dag_run_id')
         state = args.get('state')
         origin = args.get('origin')
-        map_indexes = args.get('map_indexes')
-        if map_indexes:
-            if not isinstance(map_indexes, list):
-                map_indexes = list(map_indexes)
-        else:
-            map_indexes = [-1]
+        map_index = args.get('map_index')
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2379,8 +2367,10 @@ class Airflow(AirflowBaseView):
 
         from airflow.api.common.mark_tasks import set_state
 
+        tasks = [(task, map_index)] if map_index else [task]
+
         to_be_altered = set_state(
-            tasks=[(task, map_index) for map_index in map_indexes],
+            tasks=tasks,
             run_id=dag_run_id,
             upstream=upstream,
             downstream=downstream,
@@ -2420,12 +2410,7 @@ class Airflow(AirflowBaseView):
         task_id = args.get('task_id')
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
-        map_indexes = args.get('map_indexes')
-        if map_indexes:
-            if not isinstance(map_indexes, list):
-                map_indexes = list(map_indexes)
-        else:
-            map_indexes = [-1]
+        map_index = args.get('map_index')
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2442,7 +2427,7 @@ class Airflow(AirflowBaseView):
             future,
             past,
             State.FAILED,
-            map_indexes=map_indexes,
+            map_index=map_index,
         )
 
     @expose('/success', methods=['POST'])
@@ -2460,12 +2445,7 @@ class Airflow(AirflowBaseView):
         task_id = args.get('task_id')
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
-        map_indexes = args.get('map_indexes')
-        if map_indexes:
-            if not isinstance(map_indexes, list):
-                map_indexes = list(map_indexes)
-        else:
-            map_indexes = [-1]
+        map_index = args.get('map_index')
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2482,7 +2462,7 @@ class Airflow(AirflowBaseView):
             future,
             past,
             State.SUCCESS,
-            map_indexes=map_indexes,
+            map_index=map_index,
         )
 
     @expose('/dags/<string:dag_id>')
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index ed2119a490..6cd8ea660f 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1456,7 +1456,7 @@ class TestDag(unittest.TestCase):
         session.flush()
 
         dag.clear(
-            task_ids_and_map_indexes=[(task_id, 0)],
+            task_ids=[(task_id, 0)],
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=1),
             dag_run_state=dag_run_state,

Reply via email to