This is an automated email from the ASF dual-hosted git repository. bbovenzi pushed a commit to branch mapped-instance-actions in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 7a78a3efdb6820bcf4b3ee3fef27d8c1e4454b5d Author: Ephraim Anierobi <[email protected]> AuthorDate: Tue Apr 12 16:37:35 2022 +0100 Allow marking/clearing mapped taskinstances from the UI --- airflow/api/common/mark_tasks.py | 7 ++++++- airflow/models/dag.py | 18 ++++++++++++++++-- airflow/www/views.py | 21 +++++++++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index d11f490247..fe9fa0f490 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -79,6 +79,7 @@ def _create_dagruns( def set_state( *, tasks: Iterable[Operator], + map_indexes: Optional[Iterable[int]] = None, run_id: Optional[str] = None, execution_date: Optional[datetime] = None, upstream: bool = False, @@ -97,6 +98,7 @@ def set_state( on the schedule (but it will as for subdag dag runs if needed). :param tasks: the iterable of tasks from which to work. task.task.dag needs to be set + :param map_indexes: the map indexes of the tasks to set :param run_id: the run_id of the dagrun to start looking from :param execution_date: the execution date from which to start looking(deprecated) :param upstream: Mark all parents (upstream tasks) @@ -143,7 +145,7 @@ def set_state( # now look for the task instances that are affected - qry_dag = get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates) + qry_dag = get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates, map_indexes) if commit: tis_altered = qry_dag.with_for_update().all() @@ -181,6 +183,7 @@ def get_all_dag_task_query( state: TaskInstanceState, task_ids: List[str], confirmed_dates: Iterable[datetime], + map_indexes: Optional[Iterable[int]] = None, ): """Get all tasks of the main dag that will be affected by a state change""" qry_dag = ( @@ -194,6 +197,8 @@ def get_all_dag_task_query( .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)) .options(contains_eager(TaskInstance.dag_run)) ) + if map_indexes: + qry_dag = qry_dag.filter(TaskInstance.map_index.in_(map_indexes)) return qry_dag diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 8d5e8eacd6..8efedbef60 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1348,6 +1348,7 @@ class DAG(LoggingMixin): Query, self._get_task_instances( task_ids=None, + map_indexes=None, start_date=start_date, end_date=end_date, run_id=None, @@ -1368,6 +1369,7 @@ class DAG(LoggingMixin): self, *, task_ids, + map_indexes: Optional[Iterable[int]] = None, start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], @@ -1386,6 +1388,7 @@ class DAG(LoggingMixin): self, *, task_ids, + map_indexes: Optional[Iterable[int]] = None, as_pk_tuple: Literal[True], start_date: Optional[datetime], end_date: Optional[datetime], @@ -1407,6 +1410,7 @@ class DAG(LoggingMixin): self, *, task_ids, + map_indexes: Optional[Iterable[int]] = None, as_pk_tuple: Literal[True, None] = None, start_date: Optional[datetime], end_date: Optional[datetime], @@ -1434,7 +1438,7 @@ class DAG(LoggingMixin): # Do we want full objects, or just the primary columns? if as_pk_tuple: - tis = session.query(TI.dag_id, TI.task_id, TI.run_id) + tis = session.query(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) else: tis = session.query(TaskInstance) tis = tis.join(TaskInstance.dag_run) @@ -1455,6 +1459,8 @@ class DAG(LoggingMixin): tis = tis.filter(DagRun.execution_date >= start_date) if task_ids: tis = tis.filter(TaskInstance.task_id.in_(task_ids)) + if map_indexes: + tis = tis.filter(TaskInstance.map_index.in_(map_indexes)) # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC if end_date or not self.allow_future_exec_dates: @@ -1493,6 +1499,7 @@ class DAG(LoggingMixin): result.update( p_dag._get_task_instances( task_ids=task_ids, + map_indexes=map_indexes, start_date=start_date, end_date=end_date, run_id=None, @@ -1570,6 +1577,7 @@ class DAG(LoggingMixin): result.update( downstream._get_task_instances( task_ids=None, + map_indexes=None, run_id=tii.run_id, start_date=None, end_date=None, @@ -1606,7 +1614,7 @@ class DAG(LoggingMixin): return result elif result: # We've been asked for objects, lets combine it all back in to a result set - tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id) + tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) tis = session.query(TI).filter(TI.filter_for_tis(result)) elif exclude_task_ids: @@ -1619,6 +1627,7 @@ class DAG(LoggingMixin): self, *, task_id: str, + map_indexes: Optional[Iterable[int]] = None, execution_date: Optional[datetime] = None, run_id: Optional[str] = None, state: TaskInstanceState, @@ -1634,6 +1643,7 @@ class DAG(LoggingMixin): in failed or upstream_failed state. :param task_id: Task ID of the TaskInstance + :param map_indexes: Task instance map_index to set the state of :param execution_date: Execution date of the TaskInstance :param run_id: The run_id of the TaskInstance :param state: State to set the TaskInstance to @@ -1661,6 +1671,7 @@ class DAG(LoggingMixin): altered = set_state( tasks=[task], + map_indexes=map_indexes, execution_date=execution_date, run_id=run_id, upstream=upstream, @@ -1754,6 +1765,7 @@ class DAG(LoggingMixin): def clear( self, task_ids=None, + map_indexes: Optional[Iterable[int]] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, only_failed: bool = False, @@ -1775,6 +1787,7 @@ class DAG(LoggingMixin): a specified date range. :param task_ids: List of task ids to clear + :param map_indexes: List of map_indexes to clear :param start_date: The minimum execution_date to clear :param end_date: The maximum execution_date to clear :param only_failed: Only clear failed tasks @@ -1820,6 +1833,7 @@ class DAG(LoggingMixin): tis = self._get_task_instances( task_ids=task_ids, + map_indexes=map_indexes, start_date=start_date, end_date=end_date, run_id=None, diff --git a/airflow/www/views.py b/airflow/www/views.py index 953164e94a..ad9378f462 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1980,6 +1980,7 @@ class Airflow(AirflowBaseView): start_date, end_date, origin, + map_indexes=None, recursive=False, confirmed=False, only_failed=False, @@ -1988,6 +1989,7 @@ class Airflow(AirflowBaseView): count = dag.clear( start_date=start_date, end_date=end_date, + map_indexes=map_indexes, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -2000,6 +2002,7 @@ class Airflow(AirflowBaseView): tis = dag.clear( start_date=start_date, end_date=end_date, + map_indexes=map_indexes, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -2041,6 +2044,9 @@ class Airflow(AirflowBaseView): task_id = request.form.get('task_id') origin = get_safe_url(request.form.get('origin')) dag = current_app.dag_bag.get_dag(dag_id) + map_indexes = request.form.get('map_indexes') + if map_indexes and not isinstance(map_indexes, list): + map_indexes = list(map_indexes) execution_date = request.form.get('execution_date') execution_date = timezone.parse(execution_date) @@ -2065,6 +2071,7 @@ class Airflow(AirflowBaseView): start_date, end_date, origin, + map_indexes=map_indexes, recursive=recursive, confirmed=confirmed, only_failed=only_failed, @@ -2083,6 +2090,9 @@ class Airflow(AirflowBaseView): dag_id = request.form.get('dag_id') dag_run_id = request.form.get('dag_run_id') confirmed = request.form.get('confirmed') == "true" + map_indexes = request.form.get('map_indexes') + if map_indexes and not isinstance(map_indexes, list): + map_indexes = list(map_indexes) dag = current_app.dag_bag.get_dag(dag_id) dr = dag.get_dagrun(run_id=dag_run_id) @@ -2093,6 +2103,7 @@ class Airflow(AirflowBaseView): dag, start_date, end_date, + map_indexes=map_indexes, origin=None, recursive=True, confirmed=confirmed, @@ -2299,6 +2310,7 @@ class Airflow(AirflowBaseView): self, dag_id, task_id, + map_indexes, origin, dag_run_id, upstream, @@ -2316,6 +2328,7 @@ class Airflow(AirflowBaseView): altered = dag.set_task_instance_state( task_id=task_id, + map_index=map_indexes, run_id=dag_run_id, state=state, upstream=upstream, @@ -2418,6 +2431,9 @@ class Airflow(AirflowBaseView): task_id = args.get('task_id') origin = get_safe_url(args.get('origin')) dag_run_id = args.get('dag_run_id') + map_indexes = args.get('map_indexes') + if map_indexes and not isinstance(map_indexes, list): + map_indexes = list(map_indexes) upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) @@ -2427,6 +2443,7 @@ class Airflow(AirflowBaseView): return self._mark_task_instance_state( dag_id, task_id, + map_indexes, origin, dag_run_id, upstream, @@ -2451,6 +2468,9 @@ class Airflow(AirflowBaseView): task_id = args.get('task_id') origin = get_safe_url(args.get('origin')) dag_run_id = args.get('dag_run_id') + map_indexes = args.get('map_indexes') + if map_indexes and not isinstance(map_indexes, list): + map_indexes = list(map_indexes) upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) @@ -2460,6 +2480,7 @@ class Airflow(AirflowBaseView): return self._mark_task_instance_state( dag_id, task_id, + map_indexes, origin, dag_run_id, upstream,
