tquazi commented on code in PR #30478:
URL: https://github.com/apache/airflow/pull/30478#discussion_r1167617392
##########
airflow/models/dag.py:
##########
@@ -1915,9 +1914,118 @@ def set_task_instance_state(
only_failed=True,
session=session,
# Exclude the task itself from being cleared
- exclude_task_ids={task_id},
+ exclude_task_ids=frozenset({task_id}),
+ )
+
+ return altered
+
+ @provide_session
+ def set_task_group_state(
+ self,
+ *,
+ group_id: str,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
+ state: TaskInstanceState,
+ upstream: bool = False,
+ downstream: bool = False,
+ future: bool = False,
+ past: bool = False,
+ commit: bool = True,
+ session: Session = NEW_SESSION,
+ ) -> list[TaskInstance]:
+ """
+ Set the state of the TaskGroup to the given state, and clear its
downstream tasks that are
+ in failed or upstream_failed state.
+
+ :param group_id: The group_id of the TaskGroup
+ :param execution_date: Execution date of the TaskInstance
+ :param run_id: The run_id of the TaskInstance
+ :param state: State to set the TaskInstance to
+ :param upstream: Include all upstream tasks of the given task_id
+ :param downstream: Include all downstream tasks of the given task_id
+ :param future: Include all future TaskInstances of the given task_id
+ :param commit: Commit changes
+ :param past: Include all past TaskInstances of the given task_id
+ :param session: new session
+ """
+ from airflow.api.common.mark_tasks import set_state
+
+ if not exactly_one(execution_date, run_id):
+ raise ValueError("Exactly one of execution_date or run_id must be
provided")
+
+ tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = []
+ task_ids: list[str] = []
+ locked_dag_run_ids: list[int] = []
+
+ if execution_date is None:
+ dag_run = (
+ session.query(DagRun).filter(DagRun.run_id == run_id,
DagRun.dag_id == self.dag_id).one()
+ ) # Raises an error if not found
+ resolve_execution_date = dag_run.execution_date
+ else:
+ resolve_execution_date = execution_date
+
+ end_date = resolve_execution_date if not future else None
+ start_date = resolve_execution_date if not past else None
+
+ task_group_dict = self.task_group.get_task_group_dict()
+ task_group = task_group_dict.get(group_id)
+ if task_group is None:
+ raise ValueError("TaskGroup {group_id} could not be found")
+ tasks_to_set_state = [task for task in task_group.iter_tasks() if
isinstance(task, BaseOperator)]
+ task_ids = [task.task_id for task in task_group.iter_tasks()]
+ dag_runs_query = session.query(DagRun.id).filter(DagRun.dag_id ==
self.dag_id).with_for_update()
+
+ if start_date is None and end_date is None:
+ dag_runs_query = dag_runs_query.filter(DagRun.execution_date ==
start_date)
+ else:
+ if start_date is not None:
+ dag_runs_query = dag_runs_query.filter(DagRun.execution_date
>= start_date)
+
+ if end_date is not None:
+ dag_runs_query = dag_runs_query.filter(DagRun.execution_date
<= end_date)
+
+ locked_dag_run_ids = dag_runs_query.all()
+
+ altered = set_state(
+ tasks=tasks_to_set_state,
+ execution_date=execution_date,
+ run_id=run_id,
+ upstream=upstream,
+ downstream=downstream,
+ future=future,
+ past=past,
+ state=state,
+ commit=commit,
+ session=session,
+ )
+
+ if not commit:
+ del locked_dag_run_ids
+ return altered
+
+ # Clear downstream tasks that are in failed/upstream_failed state to
resume them.
+ # Flush the session so that the tasks marked success are reflected in
the db.
+ session.flush()
+ subdag = self.partial_subset(
Review Comment:
Changed the name of the variable :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]