This is an automated email from the ASF dual-hosted git repository. bbovenzi pushed a commit to branch mapped-instance-actions in repository https://gitbox.apache.org/repos/asf/airflow.git
commit c5cc48f9f9161b187368d99b551c652e1da03de5 Author: Ephraim Anierobi <[email protected]> AuthorDate: Wed Apr 13 20:16:56 2022 +0100 fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI --- airflow/api/common/mark_tasks.py | 38 ++++----- airflow/models/dag.py | 159 ++++++++++++++++++++---------------- airflow/www/views.py | 44 ++++++---- tests/api/common/test_mark_tasks.py | 6 +- tests/models/test_dag.py | 2 +- 5 files changed, 133 insertions(+), 116 deletions(-) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index 1d4709fb82..84fd48f4e4 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -18,9 +18,9 @@ """Marks tasks APIs.""" from datetime import datetime -from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union -from sqlalchemy import or_ +from sqlalchemy import or_, tuple_ from sqlalchemy.orm import contains_eager from sqlalchemy.orm.session import Session as SASession @@ -32,7 +32,6 @@ from airflow.operators.subdag import SubDagOperator from airflow.utils import timezone from airflow.utils.helpers import exactly_one from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import tuple_in_condition from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType @@ -79,7 +78,7 @@ def _create_dagruns( @provide_session def set_state( *, - tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]], + tasks: Union[Iterable[Operator], Iterable[Tuple[Operator, int]]], run_id: Optional[str] = None, execution_date: Optional[datetime] = None, upstream: bool = False, @@ -97,7 +96,7 @@ def set_state( tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). - :param tasks: the iterable of tasks or (task, map_index) tuples from which to work. + :param tasks: the iterable of tasks or task, map_index tuple from which to work. task.task.dag needs to be set :param run_id: the run_id of the dagrun to start looking from :param execution_date: the execution date from which to start looking(deprecated) @@ -120,7 +119,9 @@ def set_state( if execution_date and not timezone.is_localized(execution_date): raise ValueError(f"Received non-localized date {execution_date}") - task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks} + t_dags = {task.dag for task in tasks if not isinstance(task, tuple)} + t_dags_2 = {item[0].dag for item in tasks if isinstance(item, tuple)} + task_dags = t_dags | t_dags_2 if len(task_dags) > 1: raise ValueError(f"Received tasks from multiple DAGs: {task_dags}") dag = next(iter(task_dags)) @@ -135,12 +136,6 @@ def set_state( dag_run_ids = get_run_ids(dag, run_id, future, past) task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream)) task_ids = [task_id for task_id, _ in task_id_map_index_list] - # check if task_id_map_index_list contains map_index of None - # if it contains None, there was no map_index supplied for the task - for _, index in task_id_map_index_list: - if index is None: - task_id_map_index_list = [task_id for task_id, _ in task_id_map_index_list] - break confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids)) confirmed_dates = [info.logical_date for info in confirmed_infos] @@ -187,26 +182,20 @@ def get_all_dag_task_query( dag: DAG, session: SASession, state: TaskInstanceState, - task_ids: Union[List[str], List[Tuple[str, int]]], + task_id_map_index_list: List[Tuple[str, int]], confirmed_dates: Iterable[datetime], ): """Get all tasks of the main dag that will be affected by a state change""" - is_string_list = isinstance(task_ids[0], str) qry_dag = ( session.query(TaskInstance) .join(TaskInstance.dag_run) .filter( TaskInstance.dag_id == dag.dag_id, DagRun.execution_date.in_(confirmed_dates), + tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_id_map_index_list), ) - ) - - if is_string_list: - qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids)) - else: - qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids)) - qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options( - contains_eager(TaskInstance.dag_run) + .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)) + .options(contains_eager(TaskInstance.dag_run)) ) return qry_dag @@ -282,13 +271,14 @@ def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) -> Iterator[_DagR yield _DagRunInfo(dag_run.logical_date, dag.get_run_data_interval(dag_run)) -def find_task_relatives(tasks, downstream, upstream): +@provide_session +def find_task_relatives(tasks, downstream, upstream, session: SASession = NEW_SESSION): """Yield task ids and optionally ancestor and descendant ids.""" for item in tasks: if isinstance(item, tuple): task, map_index = item else: - task, map_index = item, None + task, map_index = item, -1 yield task.task_id, map_index if downstream: for relative in task.get_flat_relatives(upstream=False): diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e9c33acb72..931fd469d7 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -39,7 +39,6 @@ from typing import ( Iterable, List, Optional, - Sequence, Set, Tuple, Type, @@ -52,7 +51,7 @@ import jinja2 import pendulum from dateutil.relativedelta import relativedelta from pendulum.tz.timezone import Timezone -from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_ +from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_, tuple_ from sqlalchemy.orm import backref, joinedload, relationship from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session @@ -85,7 +84,7 @@ from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import exactly_one, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks +from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType @@ -1341,33 +1340,47 @@ class DAG(LoggingMixin): start_date = (timezone.utcnow() - timedelta(30)).replace( hour=0, minute=0, second=0, microsecond=0 ) - query = self._get_task_instances( - task_ids=None, - start_date=start_date, - end_date=end_date, - run_id=None, - state=state or (), - include_subdags=False, - include_parentdag=False, - include_dependent_dags=False, - exclude_task_ids=(), - session=session, + + if state is None: + state = [] + + return ( + cast( + Query, + self._get_task_instances( + task_ids=None, + task_ids_and_map_indexes=None, + start_date=start_date, + end_date=end_date, + run_id=None, + state=state, + include_subdags=False, + include_parentdag=False, + include_dependent_dags=False, + exclude_task_ids=cast(List[str], []), + exclude_task_ids_and_map_indexes=None, + session=session, + ), + ) + .order_by(DagRun.execution_date) + .all() ) - return cast(Query, query).order_by(DagRun.execution_date).all() @overload def _get_task_instances( self, *, - task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], + task_ids: Iterable[str], + task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]], start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], - state: Union[TaskInstanceState, Sequence[TaskInstanceState]], + state: Union[TaskInstanceState, List[TaskInstanceState]], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], + exclude_task_ids: Collection[str], + exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]], session: Session, dag_bag: Optional["DagBag"] = ..., ) -> Iterable[TaskInstance]: @@ -1377,16 +1390,18 @@ class DAG(LoggingMixin): def _get_task_instances( self, *, - task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], + task_ids: Iterable[str], + task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]], as_pk_tuple: Literal[True], start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], - state: Union[TaskInstanceState, Sequence[TaskInstanceState]], + state: Union[TaskInstanceState, List[TaskInstanceState]], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], + exclude_task_ids: Collection[str], + exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]], session: Session, dag_bag: Optional["DagBag"] = ..., recursion_depth: int = ..., @@ -1398,16 +1413,18 @@ class DAG(LoggingMixin): def _get_task_instances( self, *, - task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], + task_ids: Iterable[str], + task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]], as_pk_tuple: Literal[True, None] = None, start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], - state: Union[TaskInstanceState, Sequence[TaskInstanceState]], + state: Union[TaskInstanceState, List[TaskInstanceState]], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], + exclude_task_ids: Collection[str], + exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]], session: Session, dag_bag: Optional["DagBag"] = None, recursion_depth: int = 0, @@ -1431,6 +1448,11 @@ class DAG(LoggingMixin): tis = session.query(TaskInstance) tis = tis.join(TaskInstance.dag_run) + if task_ids is not None: # task not mapped + task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids] + if exclude_task_ids and len(exclude_task_ids) > 0: # task not mapped + exclude_task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids] + if include_subdags: # Crafting the right filter for dag_id and task_ids combo conditions = [] @@ -1445,13 +1467,10 @@ class DAG(LoggingMixin): tis = tis.filter(TaskInstance.run_id == run_id) if start_date: tis = tis.filter(DagRun.execution_date >= start_date) - - if task_ids is None: - pass # Disable filter if not set. - elif isinstance(next(iter(task_ids), None), str): - tis = tis.filter(TI.task_id.in_(task_ids)) - else: - tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids)) + if task_ids_and_map_indexes: + tis = tis.filter( + tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes) + ) # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC if end_date or not self.allow_future_exec_dates: @@ -1490,6 +1509,7 @@ class DAG(LoggingMixin): result.update( p_dag._get_task_instances( task_ids=task_ids, + task_ids_and_map_indexes=task_ids_and_map_indexes, start_date=start_date, end_date=end_date, run_id=None, @@ -1499,6 +1519,7 @@ class DAG(LoggingMixin): include_dependent_dags=include_dependent_dags, as_pk_tuple=True, exclude_task_ids=exclude_task_ids, + exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes, session=session, dag_bag=dag_bag, recursion_depth=recursion_depth, @@ -1566,7 +1587,7 @@ class DAG(LoggingMixin): ) result.update( downstream._get_task_instances( - task_ids=None, + task_ids_and_map_indexes=None, run_id=tii.run_id, start_date=None, end_date=None, @@ -1575,7 +1596,7 @@ class DAG(LoggingMixin): include_dependent_dags=include_dependent_dags, include_parentdag=False, as_pk_tuple=True, - exclude_task_ids=exclude_task_ids, + exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes, dag_bag=dag_bag, session=session, recursion_depth=recursion_depth + 1, @@ -1589,29 +1610,25 @@ class DAG(LoggingMixin): if as_pk_tuple: result.update(TaskInstanceKey(*cols) for cols in tis.all()) else: - result.update(ti.key for ti in tis) - - if exclude_task_ids is not None: - result = { - task - for task in result - if task.task_id not in exclude_task_ids - and (task.task_id, task.map_index) not in exclude_task_ids - } + result.update(ti.key for ti in tis.all()) + + if exclude_task_ids_and_map_indexes: + result = set( + filter( + lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes, + result, + ) + ) if as_pk_tuple: return result - if result: + elif result: # We've been asked for objects, lets combine it all back in to a result set - ti_filters = TI.filter_for_tis(result) - if ti_filters is not None: - tis = session.query(TI).filter(ti_filters) - elif exclude_task_ids is None: - pass # Disable filter if not set. - elif isinstance(next(iter(exclude_task_ids), None), str): - tis = tis.filter(TI.task_id.notin_(exclude_task_ids)) - else: - tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids))) + tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) + + tis = session.query(TI).filter(TI.filter_for_tis(result)) + elif exclude_task_ids_and_map_indexes: + tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids_and_map_indexes)) return tis @@ -1620,7 +1637,7 @@ class DAG(LoggingMixin): self, *, task_id: str, - map_indexes: Optional[Collection[int]] = None, + map_indexes: Optional[Iterable[int]] = None, execution_date: Optional[datetime] = None, run_id: Optional[str] = None, state: TaskInstanceState, @@ -1636,8 +1653,7 @@ class DAG(LoggingMixin): in failed or upstream_failed state. :param task_id: Task ID of the TaskInstance - :param map_indexes: Only set TaskInstance if its map_index matches. - If None (default), all mapped TaskInstances of the task are set. + :param map_indexes: Task instance map_index to set the state of :param execution_date: Execution date of the TaskInstance :param run_id: The run_id of the TaskInstance :param state: State to set the TaskInstance to @@ -1662,18 +1678,13 @@ class DAG(LoggingMixin): task = self.get_task(task_id) task.dag = self - - tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]] - task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]] - if map_indexes is None: - tasks_to_set_state = [task] - task_ids_to_exclude_from_clear = {task_id} - else: - tasks_to_set_state = [(task, map_index) for map_index in map_indexes] - task_ids_to_exclude_from_clear = {(task_id, map_index) for map_index in map_indexes} + if not map_indexes: + map_indexes = [-1] + task_map_indexes = [(task, map_index) for map_index in map_indexes] + task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes] altered = set_state( - tasks=tasks_to_set_state, + tasks=task_map_indexes, execution_date=execution_date, run_id=run_id, upstream=upstream, @@ -1703,13 +1714,12 @@ class DAG(LoggingMixin): subdag.clear( start_date=start_date, end_date=end_date, - map_indexes=map_indexes, include_subdags=True, include_parentdag=True, only_failed=True, session=session, # Exclude the task itself from being cleared - exclude_task_ids=task_ids_to_exclude_from_clear, + exclude_task_ids_and_map_indexes=task_id_map_indexes, ) return altered @@ -1767,7 +1777,8 @@ class DAG(LoggingMixin): @provide_session def clear( self, - task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None, + task_ids=None, + task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, only_failed: bool = False, @@ -1782,13 +1793,14 @@ class DAG(LoggingMixin): recursion_depth: int = 0, max_recursion_depth: Optional[int] = None, dag_bag: Optional["DagBag"] = None, - exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], None] = frozenset(), + exclude_task_ids: FrozenSet[str] = frozenset(), + exclude_task_ids_and_map_indexes: FrozenSet[Tuple[str, int]] = frozenset({}), ) -> Union[int, Iterable[TaskInstance]]: """ Clears a set of task instances associated with the current dag for a specified date range. - :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear + :param task_ids_and_map_indexes: List of tuple of task_id, map_index to clear :param start_date: The minimum execution_date to clear :param end_date: The maximum execution_date to clear :param only_failed: Only clear failed tasks @@ -1802,7 +1814,8 @@ class DAG(LoggingMixin): :param dry_run: Find the tasks to clear but don't clear them. :param session: The sqlalchemy session to use :param dag_bag: The DagBag used to find the dags subdags (Optional) - :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) + :param exclude_task_ids: A set of ``task_id`` that should not be cleared + :param exclude_task_ids_and_map_indexes: A set of ``task_id``,``map_index`` tuples that should not be cleared """ if get_tis: @@ -1832,9 +1845,12 @@ class DAG(LoggingMixin): if only_running: # Yes, having `+=` doesn't make sense, but this was the existing behaviour state += [State.RUNNING] + if task_ids: + task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids] tis = self._get_task_instances( task_ids=task_ids, + task_ids_and_map_indexes=task_ids_and_map_indexes, start_date=start_date, end_date=end_date, run_id=None, @@ -1845,6 +1861,7 @@ class DAG(LoggingMixin): session=session, dag_bag=dag_bag, exclude_task_ids=exclude_task_ids, + exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes, ) if dry_run: diff --git a/airflow/www/views.py b/airflow/www/views.py index 5e7f7a01e6..7b0f81af80 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1962,7 +1962,7 @@ class Airflow(AirflowBaseView): start_date, end_date, origin, - map_indexes=None, + task_id_map_index_list=None, recursive=False, confirmed=False, only_failed=False, @@ -1971,7 +1971,7 @@ class Airflow(AirflowBaseView): count = dag.clear( start_date=start_date, end_date=end_date, - map_indexes=map_indexes, + task_ids_and_map_indexes=task_id_map_index_list, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -1984,7 +1984,7 @@ class Airflow(AirflowBaseView): tis = dag.clear( start_date=start_date, end_date=end_date, - map_indexes=map_indexes, + task_ids_and_map_indexes=task_id_map_index_list, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -2026,9 +2026,14 @@ class Airflow(AirflowBaseView): task_id = request.form.get('task_id') origin = get_safe_url(request.form.get('origin')) dag = current_app.dag_bag.get_dag(dag_id) + map_indexes = request.form.get('map_indexes') - if map_indexes and not isinstance(map_indexes, list): - map_indexes = list(map_indexes) + if map_indexes: + if not isinstance(map_indexes, list): + map_indexes = list(map_indexes) + else: + map_indexes = [-1] + task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes] execution_date = request.form.get('execution_date') execution_date = timezone.parse(execution_date) @@ -2053,7 +2058,7 @@ class Airflow(AirflowBaseView): start_date, end_date, origin, - map_indexes=map_indexes, + task_id_map_index_list=task_id_map_indexes, recursive=recursive, confirmed=confirmed, only_failed=only_failed, @@ -2072,9 +2077,6 @@ class Airflow(AirflowBaseView): dag_id = request.form.get('dag_id') dag_run_id = request.form.get('dag_run_id') confirmed = request.form.get('confirmed') == "true" - map_indexes = request.form.get('map_indexes') - if map_indexes and not isinstance(map_indexes, list): - map_indexes = list(map_indexes) dag = current_app.dag_bag.get_dag(dag_id) dr = dag.get_dagrun(run_id=dag_run_id) @@ -2085,7 +2087,6 @@ class Airflow(AirflowBaseView): dag, start_date, end_date, - map_indexes=map_indexes, origin=None, recursive=True, confirmed=confirmed, @@ -2339,8 +2340,11 @@ class Airflow(AirflowBaseView): state = args.get('state') origin = args.get('origin') map_indexes = args.get('map_indexes') - if map_indexes and not isinstance(map_indexes, list): - map_indexes = list(map_indexes) + if map_indexes: + if not isinstance(map_indexes, list): + map_indexes = list(map_indexes) + else: + map_indexes = [-1] upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) @@ -2376,7 +2380,7 @@ class Airflow(AirflowBaseView): from airflow.api.common.mark_tasks import set_state to_be_altered = set_state( - tasks=[task], + tasks=[(task, map_index) for map_index in map_indexes], map_indexes=map_indexes, run_id=dag_run_id, upstream=upstream, @@ -2418,8 +2422,11 @@ class Airflow(AirflowBaseView): origin = get_safe_url(args.get('origin')) dag_run_id = args.get('dag_run_id') map_indexes = args.get('map_indexes') - if map_indexes and not isinstance(map_indexes, list): - map_indexes = list(map_indexes) + if map_indexes: + if not isinstance(map_indexes, list): + map_indexes = list(map_indexes) + else: + map_indexes = [-1] upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) @@ -2455,8 +2462,11 @@ class Airflow(AirflowBaseView): origin = get_safe_url(args.get('origin')) dag_run_id = args.get('dag_run_id') map_indexes = args.get('map_indexes') - if map_indexes and not isinstance(map_indexes, list): - map_indexes = list(map_indexes) + if map_indexes: + if not isinstance(map_indexes, list): + map_indexes = list(map_indexes) + else: + map_indexes = [-1] upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) diff --git a/tests/api/common/test_mark_tasks.py b/tests/api/common/test_mark_tasks.py index dedb624b1f..4c1d3c604b 100644 --- a/tests/api/common/test_mark_tasks.py +++ b/tests/api/common/test_mark_tasks.py @@ -439,12 +439,12 @@ class TestMarkTasks: def test_mark_mapped_task_instance_state(self): # set mapped task instance to success snapshot = TestMarkTasks.snapshot_state(self.dag4, self.execution_dates) - tasks = [self.dag4.get_task("consumer_literal")] + task = self.dag4.get_task("consumer_literal") + tasks = [(task, 0), (task, 1)] map_indexes = [0, 1] dr = DagRun.find(dag_id=self.dag4.dag_id, execution_date=self.execution_dates[0])[0] altered = set_state( tasks=tasks, - map_indexes=map_indexes, run_id=dr.run_id, upstream=False, downstream=False, @@ -456,7 +456,7 @@ class TestMarkTasks: assert len(altered) == 2 self.verify_state( self.dag4, - [task.task_id for task in tasks], + [task.task_id for task, _ in tasks], [self.execution_dates[0]], State.SUCCESS, snapshot, diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 41219ef87b..ed2119a490 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1456,7 +1456,7 @@ class TestDag(unittest.TestCase): session.flush() dag.clear( - map_indexes=[0], + task_ids_and_map_indexes=[(task_id, 0)], start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=1), dag_run_state=dag_run_state,
