This is an automated email from the ASF dual-hosted git repository. bbovenzi pushed a commit to branch mapped-instance-actions in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 3b12a914db3962b40141498f9eb43af91ecd80b2 Author: Tzu-ping Chung <[email protected]> AuthorDate: Mon Apr 18 22:13:10 2022 +0800 Refactor to straighten up types --- airflow/models/dag.py | 122 +++++++++++++++++++++----------------------------- 1 file changed, 52 insertions(+), 70 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 0694f37550..755505b5d0 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -39,6 +39,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, Type, @@ -1340,40 +1341,29 @@ class DAG(LoggingMixin): start_date = (timezone.utcnow() - timedelta(30)).replace( hour=0, minute=0, second=0, microsecond=0 ) - - if state is None: - state = [] - - return ( - cast( - Query, - self._get_task_instances( - task_ids=None, - start_date=start_date, - end_date=end_date, - run_id=None, - state=state, - include_subdags=False, - include_parentdag=False, - include_dependent_dags=False, - exclude_task_ids=cast(List[str], []), - session=session, - ), - ) - .order_by(DagRun.execution_date) - .all() + query = self._get_task_instances( + task_ids=None, + start_date=start_date, + end_date=end_date, + run_id=None, + state=state or (), + include_subdags=False, + include_parentdag=False, + include_dependent_dags=False, + exclude_task_ids=(), + session=session, ) + return cast(Query, query).order_by(DagRun.execution_date).all() @overload def _get_task_instances( self, *, task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], - task_ids_and_map_indexes, start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], - state: Union[TaskInstanceState, List[TaskInstanceState]], + state: Union[TaskInstanceState, Sequence[TaskInstanceState]], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, @@ -1392,7 +1382,7 @@ class DAG(LoggingMixin): start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], - state: Union[TaskInstanceState, List[TaskInstanceState]], + state: Union[TaskInstanceState, Sequence[TaskInstanceState]], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, @@ -1413,7 +1403,7 @@ class DAG(LoggingMixin): start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], - state: Union[TaskInstanceState, List[TaskInstanceState]], + state: Union[TaskInstanceState, Sequence[TaskInstanceState]], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, @@ -1441,18 +1431,6 @@ class DAG(LoggingMixin): tis = session.query(TaskInstance) tis = tis.join(TaskInstance.dag_run) - task_ids_and_map_indexes = None - if task_ids is not None: - task_ids_and_map_indexes = [item for item in task_ids if isinstance(item, tuple)] - if task_ids_and_map_indexes: - task_ids = None # nullify since we have indexes - - exclude_task_ids_and_map_indexes = None - if exclude_task_ids is not None: - exclude_task_ids_and_map_indexes = [item for item in exclude_task_ids if isinstance(item, tuple)] - if exclude_task_ids_and_map_indexes: - exclude_task_ids = None - if include_subdags: # Crafting the right filter for dag_id and task_ids combo conditions = [] @@ -1467,12 +1445,13 @@ class DAG(LoggingMixin): tis = tis.filter(TaskInstance.run_id == run_id) if start_date: tis = tis.filter(DagRun.execution_date >= start_date) - if task_ids: - tis = tis.filter(TaskInstance.task_id.in_(task_ids)) - if task_ids_and_map_indexes: - tis = tis.filter( - tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes) - ) + + if task_ids is None: + pass # Disable filter if not set. + elif isinstance(next(iter(task_ids), None), str): + tis = tis.filter(TI.task_id.in_(task_ids)) + else: + tis = tis.filter(tuple_(TI.task_id, TI.map_index).in_(task_ids)) # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC if end_date or not self.allow_future_exec_dates: @@ -1610,33 +1589,29 @@ class DAG(LoggingMixin): if as_pk_tuple: result.update(TaskInstanceKey(*cols) for cols in tis.all()) else: - result.update(ti.key for ti in tis.all()) + result.update(ti.key for ti in tis) if exclude_task_ids is not None: - result = set( - filter( - lambda key: key.task_id not in exclude_task_ids, - result, - ) - ) - - if exclude_task_ids_and_map_indexes is not None: - result = set( - filter( - lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes, - result, - ) - ) + result = { + task + for task in result + if task.task_id not in exclude_task_ids + and (task.task_id, task.map_index) not in exclude_task_ids + } if as_pk_tuple: return result - elif result: + if result: # We've been asked for objects, lets combine it all back in to a result set - tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) - - tis = session.query(TI).filter(TI.filter_for_tis(result)) - elif exclude_task_ids_and_map_indexes: - tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids_and_map_indexes)) + ti_filters = TI.filter_for_tis(result) + if ti_filters is not None: + tis = session.query(TI).filter(ti_filters) + elif exclude_task_ids is None: + pass # Disable filter if not set. + elif isinstance(next(iter(exclude_task_ids), None), str): + tis = tis.filter(TI.task_id.notin_(exclude_task_ids)) + else: + tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids)) return tis @@ -1687,11 +1662,18 @@ class DAG(LoggingMixin): task = self.get_task(task_id) task.dag = self - task_map_indexes = [(task, map_index)] if map_index else [task] - task_id_map_indexes = {(task_id, map_index)} if map_index else {task_id} + + tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]] + task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]] + if map_index is None: + tasks_to_set_state = [task] + task_ids_to_exclude_from_clear = {task_id} + else: + tasks_to_set_state = [(task, map_index)] + task_ids_to_exclude_from_clear = {(task_id, map_index)} altered = set_state( - tasks=task_map_indexes, + tasks=tasks_to_set_state, execution_date=execution_date, run_id=run_id, upstream=upstream, @@ -1726,7 +1708,7 @@ class DAG(LoggingMixin): only_failed=True, session=session, # Exclude the task itself from being cleared - exclude_task_ids=task_id_map_indexes, + exclude_task_ids=task_ids_to_exclude_from_clear, ) return altered @@ -1784,7 +1766,7 @@ class DAG(LoggingMixin): @provide_session def clear( self, - task_ids: Union[Iterable[str], Iterable[Tuple[str, int]], None] = None, + task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, only_failed: bool = False,
