This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 12638d2310 Refactor `DagRun.verify_integrity` (#24114)
12638d2310 is described below
commit 12638d2310d962986b43af8f1584a405e280badf
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Jun 10 14:44:19 2022 +0100
Refactor `DagRun.verify_integrity` (#24114)
This refactoring became necessary as there's a necessity to add additional
code
to the already exisiting code to handle mapped task immutability during
run. The additional
code would make this method difficult to read. Refactoring the code will
aid understanding and
help in debugging.
---
airflow/models/dagrun.py | 102 ++++++++++++++++++++++++++++++++++++++++-------
1 file changed, 88 insertions(+), 14 deletions(-)
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index c66e24c536..216272c79a 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -23,6 +23,7 @@ from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
+ Callable,
Dict,
Generator,
Iterable,
@@ -30,6 +31,7 @@ from typing import (
NamedTuple,
Optional,
Sequence,
+ Set,
Tuple,
Union,
cast,
@@ -818,13 +820,50 @@ class DagRun(Base, LoggingMixin):
"""
from airflow.settings import task_instance_mutation_hook
+ # Set for the empty default in airflow.settings -- if it's not set
this means it has been changed
+ hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)
+
dag = self.get_dag()
+ task_ids = self._check_for_removed_or_restored_tasks(
+ dag, task_instance_mutation_hook, session=session
+ )
+
+ def task_filter(task: "Operator") -> bool:
+ return task.task_id not in task_ids and (
+ self.is_backfill
+ or task.start_date <= self.execution_date
+ and (task.end_date is None or self.execution_date <=
task.end_date)
+ )
+
+ created_counts: Dict[str, int] = defaultdict(int)
+
+ # Get task creator function
+ task_creator = self._get_task_creator(created_counts,
task_instance_mutation_hook, hook_is_noop)
+
+ # Create the missing tasks, including mapped tasks
+ tasks = self._create_missing_tasks(dag, task_creator, task_filter,
session=session)
+
+ self._create_task_instances(dag.dag_id, tasks, created_counts,
hook_is_noop, session=session)
+
+ def _check_for_removed_or_restored_tasks(
+ self, dag: "DAG", ti_mutation_hook, *, session: Session
+ ) -> Set[str]:
+ """
+ Check for removed tasks/restored tasks.
+
+ :param dag: DAG object corresponding to the dagrun
+ :param ti_mutation_hook: task_instance_mutation_hook function
+ :param session: Sqlalchemy ORM Session
+
+ :return: List of task_ids in the dagrun
+
+ """
tis = self.get_task_instances(session=session)
# check for removed or restored tasks
task_ids = set()
for ti in tis:
- task_instance_mutation_hook(ti)
+ ti_mutation_hook(ti)
task_ids.add(ti.task_id)
task = None
try:
@@ -885,19 +924,21 @@ class DagRun(Base, LoggingMixin):
)
ti.state = State.REMOVED
...
+ return task_ids
- def task_filter(task: "Operator") -> bool:
- return task.task_id not in task_ids and (
- self.is_backfill
- or task.start_date <= self.execution_date
- and (task.end_date is None or self.execution_date <=
task.end_date)
- )
+ def _get_task_creator(
+ self, created_counts: Dict[str, int], ti_mutation_hook: Callable,
hook_is_noop: bool
+ ) -> Callable:
+ """
+ Get the task creator function.
- created_counts: Dict[str, int] = defaultdict(int)
+ This function also updates the created_counts dictionary with the
number of tasks created.
- # Set for the empty default in airflow.settings -- if it's not set
this means it has been changed
- hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)
+ :param created_counts: Dictionary of task_type -> count of created TIs
+ :param ti_mutation_hook: task_instance_mutation_hook function
+ :param hook_is_noop: Whether the task_instance_mutation_hook is a noop
+ """
if hook_is_noop:
def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...])
-> Generator:
@@ -912,13 +953,25 @@ class DagRun(Base, LoggingMixin):
def create_ti(task: "Operator", indexes: Tuple[int, ...]) ->
Generator:
for map_index in indexes:
ti = TI(task, run_id=self.run_id, map_index=map_index)
- task_instance_mutation_hook(ti)
+ ti_mutation_hook(ti)
created_counts[ti.operator] += 1
yield ti
creator = create_ti
+ return creator
+
+ def _create_missing_tasks(
+ self, dag: "DAG", task_creator: Callable, task_filter: Callable, *,
session: Session
+ ) -> Iterable["Operator"]:
+ """
+ Create missing tasks -- and expand any MappedOperator that _only_ have
literals as input
+
+ :param dag: DAG object corresponding to the dagrun
+ :param task_creator: a function that creates tasks
+ :param task_filter: a function that filters tasks to create
+ :param session: the session to use
+ """
- # Create missing tasks -- and expand any MappedOperator that _only_
have literals as input
def expand_mapped_literals(task: "Operator") -> Tuple["Operator",
Sequence[int]]:
if not task.is_mapped:
return (task, (-1,))
@@ -931,8 +984,29 @@ class DagRun(Base, LoggingMixin):
return (task, range(count))
tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter,
dag.task_dict.values()))
- tasks = itertools.chain.from_iterable(itertools.starmap(creator,
tasks_and_map_idxs))
+ tasks = itertools.chain.from_iterable(itertools.starmap(task_creator,
tasks_and_map_idxs))
+ return tasks
+
+ def _create_task_instances(
+ self,
+ dag_id: str,
+ tasks: Iterable["Operator"],
+ created_counts: Dict[str, int],
+ hook_is_noop: bool,
+ *,
+ session: Session,
+ ) -> None:
+ """
+ Create the necessary task instances from the given tasks.
+
+ :param dag_id: DAG ID associated with the dagrun
+ :param tasks: the tasks to create the task instances from
+ :param created_counts: a dictionary of number of tasks -> total ti
created by the task creator
+ :param hook_is_noop: whether the task_instance_mutation_hook is noop
+ :param session: the session to use
+
+ """
try:
if hook_is_noop:
session.bulk_insert_mappings(TI, tasks)
@@ -945,7 +1019,7 @@ class DagRun(Base, LoggingMixin):
except IntegrityError:
self.log.info(
'Hit IntegrityError while creating the TIs for %s- %s',
- dag.dag_id,
+ dag_id,
self.run_id,
exc_info=True,
)