This is an automated email from the ASF dual-hosted git repository. bbovenzi pushed a commit to branch mapped-instance-actions in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4e81a780308a52e16f7fdd2ca8de8e2053636b7c Author: Ephraim Anierobi <[email protected]> AuthorDate: Fri Apr 15 12:26:53 2022 +0100 fixup! Apply suggestions from code review --- airflow/api/common/mark_tasks.py | 27 +++++++++----- airflow/models/dag.py | 79 +++++++++++++++++++--------------------- airflow/www/views.py | 52 ++++++++------------------ tests/models/test_dag.py | 2 +- 4 files changed, 73 insertions(+), 87 deletions(-) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index a9e4f4812e..594423305c 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -119,10 +119,7 @@ def set_state( if execution_date and not timezone.is_localized(execution_date): raise ValueError(f"Received non-localized date {execution_date}") - task_dags = { - task[0].dag if isinstance(task, tuple) else task.dag - for task in tasks - } + task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks} if len(task_dags) > 1: raise ValueError(f"Received tasks from multiple DAGs: {task_dags}") dag = next(iter(task_dags)) @@ -137,6 +134,12 @@ def set_state( dag_run_ids = get_run_ids(dag, run_id, future, past) task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream)) task_ids = [task_id for task_id, _ in task_id_map_index_list] + # check if task_id_map_index_list contains map_index of None + # if it contains None, there was no map_index supplied for the task + for _, index in task_id_map_index_list: + if index is None: + task_id_map_index_list = [task_id for task_id, _ in task_id_map_index_list] + break confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids)) confirmed_dates = [info.logical_date for info in confirmed_infos] @@ -183,20 +186,26 @@ def get_all_dag_task_query( dag: DAG, session: SASession, state: TaskInstanceState, - task_id_map_index_list: List[Tuple[str, int]], + task_ids: Union[List[str], List[Tuple[str, int]]], confirmed_dates: Iterable[datetime], ): """Get all tasks of the main dag that will be affected by a state change""" + is_string_list = isinstance(task_ids[0], str) qry_dag = ( session.query(TaskInstance) .join(TaskInstance.dag_run) .filter( TaskInstance.dag_id == dag.dag_id, DagRun.execution_date.in_(confirmed_dates), - tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_id_map_index_list), ) - .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)) - .options(contains_eager(TaskInstance.dag_run)) + ) + + if is_string_list: + qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids)) + else: + qry_dag = qry_dag.filter(tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids)) + qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options( + contains_eager(TaskInstance.dag_run) ) return qry_dag @@ -278,7 +287,7 @@ def find_task_relatives(tasks, downstream, upstream): if isinstance(item, tuple): task, map_index = item else: - task, map_index = item, -1 + task, map_index = item, None yield task.task_id, map_index if downstream: for relative in task.get_flat_relatives(upstream=False): diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 8856a841d3..0694f37550 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1349,7 +1349,6 @@ class DAG(LoggingMixin): Query, self._get_task_instances( task_ids=None, - task_ids_and_map_indexes=None, start_date=start_date, end_date=end_date, run_id=None, @@ -1358,7 +1357,6 @@ class DAG(LoggingMixin): include_parentdag=False, include_dependent_dags=False, exclude_task_ids=cast(List[str], []), - exclude_task_ids_and_map_indexes=None, session=session, ), ) @@ -1370,7 +1368,7 @@ class DAG(LoggingMixin): def _get_task_instances( self, *, - task_ids, + task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], task_ids_and_map_indexes, start_date: Optional[datetime], end_date: Optional[datetime], @@ -1379,8 +1377,7 @@ class DAG(LoggingMixin): include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Collection[str], - exclude_task_ids_and_map_indexes, + exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], session: Session, dag_bag: Optional["DagBag"] = ..., ) -> Iterable[TaskInstance]: @@ -1390,8 +1387,7 @@ class DAG(LoggingMixin): def _get_task_instances( self, *, - task_ids, - task_ids_and_map_indexes, + task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], as_pk_tuple: Literal[True], start_date: Optional[datetime], end_date: Optional[datetime], @@ -1400,8 +1396,7 @@ class DAG(LoggingMixin): include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Collection[str], - exclude_task_ids_and_map_indexes, + exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], session: Session, dag_bag: Optional["DagBag"] = ..., recursion_depth: int = ..., @@ -1413,8 +1408,7 @@ class DAG(LoggingMixin): def _get_task_instances( self, *, - task_ids, - task_ids_and_map_indexes, + task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], as_pk_tuple: Literal[True, None] = None, start_date: Optional[datetime], end_date: Optional[datetime], @@ -1423,8 +1417,7 @@ class DAG(LoggingMixin): include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Collection[str], - exclude_task_ids_and_map_indexes, + exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], session: Session, dag_bag: Optional["DagBag"] = None, recursion_depth: int = 0, @@ -1448,10 +1441,17 @@ class DAG(LoggingMixin): tis = session.query(TaskInstance) tis = tis.join(TaskInstance.dag_run) - if task_ids is not None: # task not mapped - task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids] - if exclude_task_ids and len(exclude_task_ids) > 0: # task not mapped - exclude_task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids] + task_ids_and_map_indexes = None + if task_ids is not None: + task_ids_and_map_indexes = [item for item in task_ids if isinstance(item, tuple)] + if task_ids_and_map_indexes: + task_ids = None # nullify since we have indexes + + exclude_task_ids_and_map_indexes = None + if exclude_task_ids is not None: + exclude_task_ids_and_map_indexes = [item for item in exclude_task_ids if isinstance(item, tuple)] + if exclude_task_ids_and_map_indexes: + exclude_task_ids = None if include_subdags: # Crafting the right filter for dag_id and task_ids combo @@ -1467,6 +1467,8 @@ class DAG(LoggingMixin): tis = tis.filter(TaskInstance.run_id == run_id) if start_date: tis = tis.filter(DagRun.execution_date >= start_date) + if task_ids: + tis = tis.filter(TaskInstance.task_id.in_(task_ids)) if task_ids_and_map_indexes: tis = tis.filter( tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes) @@ -1509,7 +1511,6 @@ class DAG(LoggingMixin): result.update( p_dag._get_task_instances( task_ids=task_ids, - task_ids_and_map_indexes=task_ids_and_map_indexes, start_date=start_date, end_date=end_date, run_id=None, @@ -1519,7 +1520,6 @@ class DAG(LoggingMixin): include_dependent_dags=include_dependent_dags, as_pk_tuple=True, exclude_task_ids=exclude_task_ids, - exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes, session=session, dag_bag=dag_bag, recursion_depth=recursion_depth, @@ -1588,7 +1588,6 @@ class DAG(LoggingMixin): result.update( downstream._get_task_instances( task_ids=None, - task_ids_and_map_indexes=None, run_id=tii.run_id, start_date=None, end_date=None, @@ -1598,7 +1597,6 @@ class DAG(LoggingMixin): include_parentdag=False, as_pk_tuple=True, exclude_task_ids=exclude_task_ids, - exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes, dag_bag=dag_bag, session=session, recursion_depth=recursion_depth + 1, @@ -1614,7 +1612,15 @@ class DAG(LoggingMixin): else: result.update(ti.key for ti in tis.all()) - if exclude_task_ids_and_map_indexes: + if exclude_task_ids is not None: + result = set( + filter( + lambda key: key.task_id not in exclude_task_ids, + result, + ) + ) + + if exclude_task_ids_and_map_indexes is not None: result = set( filter( lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes, @@ -1639,7 +1645,7 @@ class DAG(LoggingMixin): self, *, task_id: str, - map_indexes: Optional[Iterable[int]] = None, + map_index: Optional[int] = None, execution_date: Optional[datetime] = None, run_id: Optional[str] = None, state: TaskInstanceState, @@ -1655,7 +1661,8 @@ class DAG(LoggingMixin): in failed or upstream_failed state. :param task_id: Task ID of the TaskInstance - :param map_indexes: Task instance map_index to set the state of + :param map_index: The TaskInstance map_index, if None, would set state for all mapped + TaskInstances of the task :param execution_date: Execution date of the TaskInstance :param run_id: The run_id of the TaskInstance :param state: State to set the TaskInstance to @@ -1680,10 +1687,8 @@ class DAG(LoggingMixin): task = self.get_task(task_id) task.dag = self - if not map_indexes: - map_indexes = [-1] - task_map_indexes = [(task, map_index) for map_index in map_indexes] - task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes] + task_map_indexes = [(task, map_index)] if map_index else [task] + task_id_map_indexes = {(task_id, map_index)} if map_index else {task_id} altered = set_state( tasks=task_map_indexes, @@ -1721,7 +1726,7 @@ class DAG(LoggingMixin): only_failed=True, session=session, # Exclude the task itself from being cleared - exclude_task_ids_and_map_indexes=task_id_map_indexes, + exclude_task_ids=task_id_map_indexes, ) return altered @@ -1779,8 +1784,7 @@ class DAG(LoggingMixin): @provide_session def clear( self, - task_ids=None, - task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]] = None, + task_ids: Union[Iterable[str], Iterable[Tuple[str, int]], None] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, only_failed: bool = False, @@ -1795,15 +1799,13 @@ class DAG(LoggingMixin): recursion_depth: int = 0, max_recursion_depth: Optional[int] = None, dag_bag: Optional["DagBag"] = None, - exclude_task_ids: FrozenSet[str] = frozenset(), - exclude_task_ids_and_map_indexes: FrozenSet[Tuple[str, int]] = frozenset({}), + exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], None] = frozenset(), ) -> Union[int, Iterable[TaskInstance]]: """ Clears a set of task instances associated with the current dag for a specified date range. - :param task_ids: List of task ids to clear - :param task_ids_and_map_indexes: List of tuple of task_id, map_index to clear + :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear :param start_date: The minimum execution_date to clear :param end_date: The maximum execution_date to clear :param only_failed: Only clear failed tasks @@ -1817,8 +1819,7 @@ class DAG(LoggingMixin): :param dry_run: Find the tasks to clear but don't clear them. :param session: The sqlalchemy session to use :param dag_bag: The DagBag used to find the dags subdags (Optional) - :param exclude_task_ids: A set of ``task_id`` that should not be cleared - :param exclude_task_ids_and_map_indexes: A set of ``task_id``,``map_index`` + :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) tuples that should not be cleared """ if get_tis: @@ -1848,12 +1849,9 @@ class DAG(LoggingMixin): if only_running: # Yes, having `+=` doesn't make sense, but this was the existing behaviour state += [State.RUNNING] - if task_ids: - task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids] tis = self._get_task_instances( task_ids=task_ids, - task_ids_and_map_indexes=task_ids_and_map_indexes, start_date=start_date, end_date=end_date, run_id=None, @@ -1864,7 +1862,6 @@ class DAG(LoggingMixin): session=session, dag_bag=dag_bag, exclude_task_ids=exclude_task_ids, - exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes, ) if dry_run: diff --git a/airflow/www/views.py b/airflow/www/views.py index 2a3ad1e913..437a60cca0 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1962,7 +1962,7 @@ class Airflow(AirflowBaseView): start_date, end_date, origin, - task_id_map_index_list=None, + task_ids=None, recursive=False, confirmed=False, only_failed=False, @@ -1971,7 +1971,7 @@ class Airflow(AirflowBaseView): count = dag.clear( start_date=start_date, end_date=end_date, - task_ids_and_map_indexes=task_id_map_index_list, + task_ids=task_ids, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -1984,7 +1984,7 @@ class Airflow(AirflowBaseView): tis = dag.clear( start_date=start_date, end_date=end_date, - task_ids_and_map_indexes=task_id_map_index_list, + task_ids=task_ids, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -2026,14 +2026,7 @@ class Airflow(AirflowBaseView): task_id = request.form.get('task_id') origin = get_safe_url(request.form.get('origin')) dag = current_app.dag_bag.get_dag(dag_id) - - map_indexes = request.form.get('map_indexes') - if map_indexes: - if not isinstance(map_indexes, list): - map_indexes = list(map_indexes) - else: - map_indexes = [-1] - task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes] + map_index = request.form.get('map_index') execution_date = request.form.get('execution_date') execution_date = timezone.parse(execution_date) @@ -2052,13 +2045,13 @@ class Airflow(AirflowBaseView): ) end_date = execution_date if not future else None start_date = execution_date if not past else None - + task_ids = [(task_id, map_index)] if map_index else [task_id] return self._clear_dag_tis( dag, start_date, end_date, origin, - task_id_map_index_list=task_id_map_indexes, + task_ids=task_ids, recursive=recursive, confirmed=confirmed, only_failed=only_failed, @@ -2300,7 +2293,7 @@ class Airflow(AirflowBaseView): future, past, state, - map_indexes=None, + map_index=None, ): dag = current_app.dag_bag.get_dag(dag_id) latest_execution_date = dag.get_latest_execution_date() @@ -2311,7 +2304,7 @@ class Airflow(AirflowBaseView): altered = dag.set_task_instance_state( task_id=task_id, - map_indexes=map_indexes, + map_index=map_index, run_id=dag_run_id, state=state, upstream=upstream, @@ -2339,12 +2332,7 @@ class Airflow(AirflowBaseView): dag_run_id = args.get('dag_run_id') state = args.get('state') origin = args.get('origin') - map_indexes = args.get('map_indexes') - if map_indexes: - if not isinstance(map_indexes, list): - map_indexes = list(map_indexes) - else: - map_indexes = [-1] + map_index = args.get('map_index') upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) @@ -2379,8 +2367,10 @@ class Airflow(AirflowBaseView): from airflow.api.common.mark_tasks import set_state + tasks = [(task, map_index)] if map_index else [task] + to_be_altered = set_state( - tasks=[(task, map_index) for map_index in map_indexes], + tasks=tasks, run_id=dag_run_id, upstream=upstream, downstream=downstream, @@ -2420,12 +2410,7 @@ class Airflow(AirflowBaseView): task_id = args.get('task_id') origin = get_safe_url(args.get('origin')) dag_run_id = args.get('dag_run_id') - map_indexes = args.get('map_indexes') - if map_indexes: - if not isinstance(map_indexes, list): - map_indexes = list(map_indexes) - else: - map_indexes = [-1] + map_index = args.get('map_index') upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) @@ -2442,7 +2427,7 @@ class Airflow(AirflowBaseView): future, past, State.FAILED, - map_indexes=map_indexes, + map_index=map_index, ) @expose('/success', methods=['POST']) @@ -2460,12 +2445,7 @@ class Airflow(AirflowBaseView): task_id = args.get('task_id') origin = get_safe_url(args.get('origin')) dag_run_id = args.get('dag_run_id') - map_indexes = args.get('map_indexes') - if map_indexes: - if not isinstance(map_indexes, list): - map_indexes = list(map_indexes) - else: - map_indexes = [-1] + map_index = args.get('map_index') upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) @@ -2482,7 +2462,7 @@ class Airflow(AirflowBaseView): future, past, State.SUCCESS, - map_indexes=map_indexes, + map_index=map_index, ) @expose('/dags/<string:dag_id>') diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index ed2119a490..6cd8ea660f 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1456,7 +1456,7 @@ class TestDag(unittest.TestCase): session.flush() dag.clear( - task_ids_and_map_indexes=[(task_id, 0)], + task_ids=[(task_id, 0)], start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=1), dag_run_state=dag_run_state,
