This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 6fc6edf Make `airflow dags test` be able to execute Mapped Tasks (#21210) 6fc6edf is described below commit 6fc6edf6af7f676bfa54ff3a2e6e6d2edb938f2e Author: Ash Berlin-Taylor <ash_git...@firemirror.com> AuthorDate: Fri Feb 4 14:24:32 2022 +0000 Make `airflow dags test` be able to execute Mapped Tasks (#21210) * Make `airflow dags test` be able to execute Mapped Tasks In order to do this there were two steps required: - The BackfillJob needs to know about mapped tasks, both to expand them, and in order to update it's TI tracking - The DebugExecutor needed to "unmap" the mapped task to get the real operator back I was testing this with the following dag: ``` from airflow import DAG from airflow.decorators import task from airflow.operators.python import PythonOperator import pendulum @task def make_list(): return list(map(lambda a: f'echo "{a!r}"', [1, 2, {'a': 'b'}])) def consumer(*args): print(repr(args)) with DAG(dag_id='maptest', start_date=pendulum.DateTime(2022, 1, 18)) as dag: PythonOperator(task_id='consumer', python_callable=consumer).map(op_args=make_list()) ``` It can't "unmap" decorated operators successfully yet, so we're using old-school PythonOperator We also just pass the whole value to the operator, not just the current mapping value(s) * Always have a `task_group` property on DAGNodes And since TaskGroup is a DAGNode, we don't need to store parent group directly anymore -- it'll already be stored * Add "integation" tests for running mapped tasks via BackfillJob * Only show "Map Index" in Backfill report when relevant Co-authored-by: Tzu-ping Chung <uranu...@gmail.com> --- airflow/cli/commands/task_command.py | 2 + airflow/executors/debug_executor.py | 2 + airflow/executors/kubernetes_executor.py | 2 +- airflow/jobs/backfill_job.py | 117 ++++++++++-------- airflow/jobs/local_task_job.py | 6 + airflow/jobs/scheduler_job.py | 2 +- airflow/models/baseoperator.py | 134 ++++++++++++--------- airflow/models/taskinstance.py | 51 +++++--- airflow/models/taskmixin.py | 52 +++++++- airflow/serialization/serialized_objects.py | 22 ++-- .../ti_deps/deps/mapped_task_expanded.py | 16 ++- airflow/utils/task_group.py | 33 ++--- .../__init__.py => dags/test_mapped_classic.py} | 20 ++- tests/executors/test_kubernetes_executor.py | 5 +- tests/jobs/test_backfill_job.py | 42 +++++-- tests/models/__init__.py | 4 +- tests/models/test_baseoperator.py | 20 ++- tests/models/test_dag.py | 2 +- tests/models/test_taskinstance.py | 2 +- tests/serialization/test_dag_serialization.py | 5 + tests/test_utils/mock_executor.py | 4 +- 21 files changed, 366 insertions(+), 177 deletions(-) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 537fab0..1b5208f 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -224,6 +224,8 @@ RAW_TASK_UNSUPPORTED_OPTION = [ def _run_raw_task(args, ti: TaskInstance) -> None: """Runs the main task handling code""" + if ti.task.is_mapped: + ti.task = ti.task.unmap() ti._run_raw_task( mark_success=args.mark_success, job_id=args.job_id, diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index 865186d..0ab5f35 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -76,6 +76,8 @@ class DebugExecutor(BaseExecutor): key = ti.key try: params = self.tasks_params.pop(ti.key, {}) + if ti.task.is_mapped: + ti.task = ti.task.unmap() ti._run_raw_task(job_id=ti.job_id, **params) self.change_state(key, State.SUCCESS) ti._run_finished_callback() diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 1071a3a..ef671eb 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -296,7 +296,7 @@ class AirflowKubernetesScheduler(LoggingMixin): """ self.log.info('Kubernetes job is %s', str(next_job).replace("\n", " ")) key, command, kube_executor_config, pod_template_file = next_job - dag_id, task_id, run_id, try_number = key + dag_id, task_id, run_id, try_number, _ = key if command[0:3] != ["airflow", "tasks", "run"]: raise ValueError('The command must start with ["airflow", "tasks", "run"].') diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index 406c2ea..10a5d08 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -18,9 +18,9 @@ # import time -from collections import OrderedDict -from typing import Optional, Set +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple +import attr import pendulum from sqlalchemy.orm.session import Session, make_transient from tabulate import tabulate @@ -48,6 +48,9 @@ from airflow.utils.session import provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType +if TYPE_CHECKING: + from airflow.models.baseoperator import MappedOperator + class BackfillJob(BaseJob): """ @@ -60,6 +63,7 @@ class BackfillJob(BaseJob): __mapper_args__ = {'polymorphic_identity': 'BackfillJob'} + @attr.define class _DagRunTaskStatus: """ Internal status of the backfill job. This class is intended to be instantiated @@ -83,32 +87,17 @@ class BackfillJob(BaseJob): :param total_runs: Number of total dag runs able to run """ - # TODO(edgarRd): AIRFLOW-1444: Add consistency check on counts - def __init__( - self, - to_run=None, - running=None, - skipped=None, - succeeded=None, - failed=None, - not_ready=None, - deadlocked=None, - active_runs=None, - executed_dag_run_dates=None, - finished_runs=0, - total_runs=0, - ): - self.to_run = to_run or OrderedDict() - self.running = running or {} - self.skipped = skipped or set() - self.succeeded = succeeded or set() - self.failed = failed or set() - self.not_ready = not_ready or set() - self.deadlocked = deadlocked or set() - self.active_runs = active_runs or [] - self.executed_dag_run_dates = executed_dag_run_dates or set() - self.finished_runs = finished_runs - self.total_runs = total_runs + to_run: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict) + running: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict) + skipped: Set[TaskInstanceKey] = attr.ib(factory=set) + succeeded: Set[TaskInstanceKey] = attr.ib(factory=set) + failed: Set[TaskInstanceKey] = attr.ib(factory=set) + not_ready: Set[TaskInstanceKey] = attr.ib(factory=set) + deadlocked: Set[TaskInstance] = attr.ib(factory=set) + active_runs: List[DagRun] = attr.ib(factory=list) + executed_dag_run_dates: Set[pendulum.DateTime] = attr.ib(factory=set) + finished_runs: int = 0 + total_runs: int = 0 def __init__( self, @@ -167,7 +156,6 @@ class BackfillJob(BaseJob): self.run_at_least_once = run_at_least_once super().__init__(*args, **kwargs) - @provide_session def _update_counters(self, ti_status, session=None): """ Updates the counters per state of the tasks that were running. Can re-add @@ -234,14 +222,22 @@ class BackfillJob(BaseJob): session.query(TI).filter(filter_for_tis).update( values={TI.state: TaskInstanceState.SCHEDULED}, synchronize_session=False ) + session.flush() - def _manage_executor_state(self, running): + def _manage_executor_state( + self, running, session + ) -> Iterator[Tuple["MappedOperator", str, Sequence[TaskInstance]]]: """ Checks if the executor agrees with the state of task instances - that are running + that are running. + + Expands downstream mapped tasks when necessary :param running: dict of key, task to verify + :return: An iterable of expanded TaskInstance per MappedTask """ + from airflow.models.baseoperator import MappedOperator + executor = self.executor # TODO: query all instead of refresh from db @@ -266,6 +262,11 @@ class BackfillJob(BaseJob): ) self.log.error(msg) ti.handle_failure_with_callback(error=msg) + continue + if ti.state not in self.STATES_COUNT_AS_RUNNING: + for node in ti.task.mapped_dependants(): + assert isinstance(node, MappedOperator) + yield node, ti.run_id, node.expand_mapped_task(ti, session) @provide_session def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = None): @@ -409,7 +410,6 @@ class BackfillJob(BaseJob): # or leaf to root, as otherwise tasks might be # determined deadlocked while they are actually # waiting for their upstream to finish - @provide_session def _per_task_process(key, ti: TaskInstance, session=None): ti.refresh_from_db(lock_for_update=True, session=session) @@ -577,7 +577,8 @@ class BackfillJob(BaseJob): "Not scheduling since Task concurrency limit is reached." ) - _per_task_process(key, ti) + _per_task_process(key, ti, session) + session.commit() except (NoAvailablePoolSlot, DagConcurrencyLimitReached, TaskConcurrencyLimitReached) as e: self.log.debug(e) @@ -597,11 +598,23 @@ class BackfillJob(BaseJob): ti_status.deadlocked.update(ti_status.to_run.values()) ti_status.to_run.clear() - # check executor state - self._manage_executor_state(ti_status.running) + # check executor state -- and expand any mapped TIs + for node, run_id, mapped_tis in self._manage_executor_state(ti_status.running, session): + + def to_keep(key: TaskInstanceKey) -> bool: + if key.dag_id != node.dag_id or key.task_id != node.task_id or key.run_id != run_id: + # For another Dag/Task/Run -- don't remove + return True + return False + + # remove the old unmapped TIs for node -- they have been replaced with the mapped TIs + ti_status.to_run = {key: ti for (key, ti) in ti_status.to_run.items() if to_keep(key)} + + ti_status.to_run.update({ti.key: ti for ti in mapped_tis}) # update the task counters - self._update_counters(ti_status=ti_status) + self._update_counters(ti_status=ti_status, session=session) + session.commit() # update dag run state _dag_runs = ti_status.active_runs[:] @@ -613,25 +626,33 @@ class BackfillJob(BaseJob): executed_run_dates.append(run.execution_date) self._log_progress(ti_status) + session.commit() # return updated status return executed_run_dates @provide_session - def _collect_errors(self, ti_status, session=None): - def tabulate_ti_keys_set(set_ti_keys: Set[TaskInstanceKey]) -> str: + def _collect_errors(self, ti_status: _DagRunTaskStatus, session=None): + def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str: # Sorting by execution date first - sorted_ti_keys = sorted( - set_ti_keys, - key=lambda ti_key: (ti_key.run_id, ti_key.dag_id, ti_key.task_id, ti_key.try_number), + sorted_ti_keys: Any = sorted( + ti_keys, + key=lambda ti_key: ( + ti_key.run_id, + ti_key.dag_id, + ti_key.task_id, + ti_key.map_index, + ti_key.try_number, + ), ) - return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Run ID", "Try number"]) - def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str: - # Sorting by execution date first - sorted_tis = sorted(set_tis, key=lambda ti: (ti.run_id, ti.dag_id, ti.task_id, ti.try_number)) - tis_values = ((ti.dag_id, ti.task_id, ti.run_id, ti.try_number) for ti in sorted_tis) - return tabulate(tis_values, headers=["DAG ID", "Task ID", "Run ID", "Try number"]) + if all(key.map_index == -1 for key in ti_keys): + headers = ["DAG ID", "Task ID", "Run ID", "Try number"] + sorted_ti_keys = map(lambda k: k[0:4], sorted_ti_keys) + else: + headers = ["DAG ID", "Task ID", "Run ID", "Map Index", "Try number"] + + return tabulate(sorted_ti_keys, headers=headers) err = '' if ti_status.failed: @@ -667,7 +688,7 @@ class BackfillJob(BaseJob): err += '\n\nThese tasks are skipped:\n' err += tabulate_ti_keys_set(ti_status.skipped) err += '\n\nThese tasks are deadlocked:\n' - err += tabulate_tis_set(ti_status.deadlocked) + err += tabulate_ti_keys_set([ti.key for ti in ti_status.deadlocked]) return err diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index c0255d7..05ee533 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -104,6 +104,12 @@ class LocalTaskJob(BaseJob): try: self.task_runner.start() + # Unmap the task _after_ it has forked/execed. (This is a bit of a kludge, but if we unmap before + # fork, then the "run_raw_task" command will see the mapping index and an Non-mapped task and + # fail) + if self.task_instance.task.is_mapped: + self.task_instance.task = self.task_instance.task.unmap() + heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold') # task callback invocation happens either here or in diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index cbda16e..7a6e3ef 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -534,7 +534,7 @@ class SchedulerJob(BaseJob): """Respond to executor events.""" if not self.processor_agent: raise ValueError("Processor agent is not started.") - ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str], int] = {} + ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str, int], int] = {} event_buffer = self.executor.get_event_buffer() tis_with_right_state: List[TaskInstanceKey] = [] diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index d51dda8..34c8412 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -50,7 +50,7 @@ import attr import jinja2 import pendulum from dateutil.relativedelta import relativedelta -from sqlalchemy import or_ +from sqlalchemy import func, or_ from sqlalchemy.orm import Session from sqlalchemy.orm.exc import NoResultFound @@ -66,6 +66,7 @@ from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.models.xcom import XCOM_RETURN_KEY from airflow.serialization.enums import DagAttributeTypes from airflow.ti_deps.deps.base_ti_dep import BaseTIDep +from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep @@ -247,7 +248,12 @@ class BaseOperatorMeta(abc.ABCMeta): # Validate that the args we passed are known -- at call/DAG parse time, not run time! _validate_kwarg_names_for_mapping(operator_class, "partial", kwargs) return MappedOperator( - task_id=task_id, operator_class=operator_class, dag=dag, partial_kwargs=kwargs, mapped_kwargs={} + task_id=task_id, + operator_class=operator_class, + dag=dag, + partial_kwargs=kwargs, + mapped_kwargs={}, + deps=MappedOperator._deps(operator_class.deps), ) @@ -1459,9 +1465,7 @@ class BaseOperator(Operator, LoggingMixin, DAGNode, metaclass=BaseOperatorMeta): """Return if this operator can use smart service. Default False.""" return False - @property - def is_mapped(self) -> bool: - return False + is_mapped: ClassVar[bool] = False @property def inherits_from_dummy_operator(self): @@ -1491,38 +1495,10 @@ class BaseOperator(Operator, LoggingMixin, DAGNode, metaclass=BaseOperatorMeta): def map(self, **kwargs) -> "MappedOperator": return MappedOperator.from_operator(self, kwargs) - def has_mapped_dependants(self) -> bool: - """Whether any downstream dependencies depend on this task for mapping. - - For now, this walks the entire DAG to find mapped nodes that has this - current task as an upstream. We cannot use ``downstream_list`` since it - only contains operators, not task groups. In the future, we should - provide a way to record an DAG node's all downstream nodes instead. - """ - from airflow.utils.task_group import MappedTaskGroup, TaskGroup - - if not self.has_dag(): - return False - - def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]: - """Recursively walk children in a task group. - - This yields all direct children (including both tasks and task - groups), and all children of any task groups. - """ - for key, child in group.children.items(): - yield key, child - if isinstance(child, TaskGroup): - yield from _walk_group(child) - - for key, child in _walk_group(self.dag.task_group): - if key == self.task_id: - continue - if not isinstance(child, (MappedOperator, MappedTaskGroup)): - continue - if self.task_id in child.upstream_task_ids: - return True - return False + def unmap(self) -> "BaseOperator": + """:meta private:""" + # Exists to make typing easier + raise TypeError("Internal code error: Do not call unmap on BaseOperator!") def _validate_kwarg_names_for_mapping( @@ -1591,7 +1567,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): # Needed for SerializedBaseOperator _is_dummy: bool = attr.ib() - deps: Iterable[BaseTIDep] = attr.ib() + deps: Iterable[BaseTIDep] operator_extra_links: Iterable['BaseOperatorLink'] = () template_fields: Collection[str] = attr.ib() template_ext: Collection[str] = attr.ib() @@ -1602,16 +1578,16 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): subdag: None = attr.ib(init=False) + DEFAULT_DEPS: ClassVar[FrozenSet[BaseTIDep]] = frozenset(BaseOperator.deps) | frozenset( + [MappedTaskIsExpanded()] + ) + @_is_dummy.default def _is_dummy_from_operator_class(self): from airflow.operators.dummy import DummyOperator return issubclass(self.operator_class, DummyOperator) - @deps.default - def _deps_from_operator_class(self): - return self.operator_class.deps - @template_fields.default def _template_fields_from_operator_class(self): return self.operator_class.template_fields @@ -1648,7 +1624,8 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): @classmethod def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> "MappedOperator": - dag: Optional["DAG"] = getattr(operator, '_dag', None) + dag = operator.get_dag() + task_group = operator.task_group if dag: # When BaseOperator() was called within a DAG, it would have been added straight away, but now we # are mapped, we want to _remove_ that task from the dag @@ -1658,7 +1635,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): return MappedOperator( operator_class=type(operator), task_id=operator.task_id, - task_group=operator.task_group, + task_group=task_group, dag=dag, upstream_task_ids=operator.upstream_task_ids, downstream_task_ids=operator.downstream_task_ids, @@ -1668,7 +1645,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): mapped_kwargs=mapped_kwargs, owner=operator.owner, max_active_tis_per_dag=operator.max_active_tis_per_dag, - deps=operator.deps, + deps=cls._deps(operator.deps), ) @classmethod @@ -1695,12 +1672,20 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): task_id=task_id, dag=dag, task_group=task_group, + deps=cls._deps(decorator.operator_class.deps), ) operator.mapped_kwargs.update(mapped_kwargs) for arg in mapped_kwargs.values(): XComArg.apply_upstream_relationship(operator, arg) return operator + @classmethod + def _deps(cls, deps: Iterable[BaseTIDep]): + if deps is BaseOperator.deps: + return cls.DEFAULT_DEPS + else: + return frozenset(deps) | {MappedTaskIsExpanded()} + def __attrs_post_init__(self): from airflow.models.xcom_arg import XComArg @@ -1756,9 +1741,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): """Used to determine if an Operator is inherited from DummyOperator""" return self._is_dummy - @property - def is_mapped(self) -> bool: - return True + is_mapped: ClassVar[bool] = True # The _serialized_fields are lazily loaded when get_serialized_fields() method is called __serialized_fields: ClassVar[Optional[FrozenSet[str]]] = None @@ -1777,6 +1760,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): 'operator_extra_links', 'upstream_task_ids', 'task_type', + 'task_group', # These are automatically populated from partial_kwargs. In # a perfect world, they should be properties like other # partial_kwargs-populated values e.g. 'queue' below, but we @@ -1826,8 +1810,14 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): def depends_on_past(self) -> bool: return self.partial_kwargs.get("depends_on_past") or self.wait_for_downstream - def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION) -> None: - """Create the mapped TaskInstances for mapped task.""" + def expand_mapped_task( + self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION + ) -> Sequence[TaskInstance]: + """ + Create the mapped TaskInstances for mapped task. + + :return: The mapped TaskInstances + """ # TODO: support having multiuple mapped upstreams? from airflow.models.taskmap import TaskMap from airflow.settings import task_instance_mutation_hook @@ -1846,6 +1836,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): # TODO: What would lead to this? How can this be better handled? raise RuntimeError("mapped operator cannot be expanded; upstream not found") + state = None unmapped_ti: Optional[TaskInstance] = ( session.query(TaskInstance) .filter( @@ -1858,6 +1849,8 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): .one_or_none() ) + ret: List[TaskInstance] = [] + if unmapped_ti: # The unmapped task instance still exists and is unfinished, i.e. we # haven't tried to run it before. @@ -1867,20 +1860,34 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): self.log.info("Marking %s as SKIPPED since the map has 0 values to expand", unmapped_ti) unmapped_ti.state = TaskInstanceState.SKIPPED session.flush() - return + return ret # Otherwise convert this into the first mapped index, and create # TaskInstance for other indexes. unmapped_ti.map_index = 0 + state = unmapped_ti.state + self.log.debug("Updated in place to become %s", unmapped_ti) + ret.append(unmapped_ti) indexes_to_map = range(1, task_map_info_length) else: - indexes_to_map = range(task_map_info_length) + # Only create "missing" ones. + current_max_mapping = ( + session.query(func.max(TaskInstance.map_index)) + .filter( + TaskInstance.dag_id == upstream_ti.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == upstream_ti.run_id, + ) + .scalar() + ) + indexes_to_map = range(current_max_mapping + 1, task_map_info_length) for index in indexes_to_map: # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator - ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index) # type: ignore + ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index, state=state) # type: ignore + self.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) - session.merge(ti) + ret.append(session.merge(ti)) # Set to "REMOVED" any (old) TaskInstances with map indices greater # than the current map value @@ -1893,6 +1900,25 @@ class MappedOperator(Operator, LoggingMixin, DAGNode): session.flush() + return ret + + def unmap(self) -> BaseOperator: + """Get the "normal" Operator after applying the current mapping""" + assert not isinstance(self.operator_class, str) + + dag = self.get_dag() + if not dag: + raise RuntimeError("Cannot unmapp a task unless it has a dag") + + args = { + **self.partial_kwargs, + **self.mapped_kwargs, + } + dag._remove_task(self.task_id) + task = self.operator_class(task_id=self.task_id, dag=self.dag, **args) + + return task + # TODO: Deprecate for Airflow 3.0 Chainable = Union[DependencyMixin, Sequence[DependencyMixin]] diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 0528dbb..7f151f4d 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -303,20 +303,23 @@ class TaskInstanceKey(NamedTuple): task_id: str run_id: str try_number: int = 1 + map_index: int = -1 @property - def primary(self) -> Tuple[str, str, str]: + def primary(self) -> Tuple[str, str, str, int]: """Return task instance primary key part of the key""" - return self.dag_id, self.task_id, self.run_id + return self.dag_id, self.task_id, self.run_id, self.map_index @property def reduced(self) -> 'TaskInstanceKey': """Remake the key by subtracting 1 from try number to match in memory information""" - return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1)) + return TaskInstanceKey( + self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1), self.map_index + ) def with_try_number(self, try_number: int) -> 'TaskInstanceKey': """Returns TaskInstanceKey with provided ``try_number``""" - return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number) + return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index) @property def key(self) -> "TaskInstanceKey": @@ -795,8 +798,6 @@ class TaskInstance(Base, LoggingMixin): else: self.state = None - self.log.debug("Refreshed TaskInstance %s", self) - def refresh_from_task(self, task: "BaseOperator", pool_override=None): """ Copy common attributes from the given task. @@ -829,12 +830,11 @@ class TaskInstance(Base, LoggingMixin): execution_date=self.execution_date, session=session, ) - self.log.debug("XCom data cleared") @property def key(self) -> TaskInstanceKey: """Returns a tuple that identifies the task instance uniquely""" - return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number) + return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) @provide_session def set_state(self, state: Optional[str], session=NEW_SESSION): @@ -1068,7 +1068,10 @@ class TaskInstance(Base, LoggingMixin): yield dep_status def __repr__(self): - return f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} [{self.state}]>" + prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} " + if self.map_index != -1: + prefix += f"map_index={self.map_index} " + return prefix + f"[{self.state}]>" def next_retry_datetime(self): """ @@ -1312,6 +1315,11 @@ class TaskInstance(Base, LoggingMixin): :param pool: specifies the pool to use to run the task instance :param session: SQLAlchemy ORM Session """ + if self.task.is_mapped: + raise RuntimeError( + f'task property of {self.task_id!r} was still a MappedOperator -- it should have been ' + 'expanded already!' + ) self.test_mode = test_mode self.refresh_from_task(self.task, pool_override=pool) self.refresh_from_db(session=session) @@ -1719,6 +1727,8 @@ class TaskInstance(Base, LoggingMixin): self.refresh_from_db(session) task = self.task + if task.is_mapped: + task = task.unmap() self.end_date = timezone.utcnow() self.set_duration() Stats.incr(f'operator_failures_{task.task_type}', 1, 1) @@ -2252,19 +2262,29 @@ class TaskInstance(Base, LoggingMixin): dag_id = first.dag_id run_id = first.run_id + map_index = first.map_index first_task_id = first.task_id # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id - # and task_id -- this can be over 150x for huge numbers of TIs (20k+) - if all(t.dag_id == dag_id and t.run_id == run_id for t in tis): + # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+) + if all(t.dag_id == dag_id and t.run_id == run_id and t.map_index == map_index for t in tis): return and_( TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id, + TaskInstance.map_index == map_index, TaskInstance.task_id.in_(t.task_id for t in tis), ) - if all(t.dag_id == dag_id and t.task_id == first_task_id for t in tis): + if all(t.dag_id == dag_id and t.task_id == first_task_id and t.map_index == map_index for t in tis): return and_( TaskInstance.dag_id == dag_id, TaskInstance.run_id.in_(t.run_id for t in tis), + TaskInstance.map_index == map_index, + TaskInstance.task_id == first_task_id, + ) + if all(t.dag_id == dag_id and t.run_id == run_id and t.task_id == first_task_id for t in tis): + return and_( + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index.in_(t.map_index for t in tis), TaskInstance.task_id == first_task_id, ) @@ -2274,13 +2294,14 @@ class TaskInstance(Base, LoggingMixin): TaskInstance.dag_id == ti.dag_id, TaskInstance.task_id == ti.task_id, TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, ) for ti in tis ) else: - return tuple_(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id).in_( - [ti.key.primary for ti in tis] - ) + return tuple_( + TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index + ).in_([ti.key.primary for ti in tis]) # State of the task instance. diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 4fc9566..7c06155 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -17,7 +17,7 @@ import warnings from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union import pendulum @@ -109,6 +109,8 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta): """ dag: Optional["DAG"] = None + task_group: Optional["TaskGroup"] = None + """The task_group that contains this node""" @property @abstractmethod @@ -117,15 +119,12 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta): @property def label(self) -> Optional[str]: - tg: Optional["TaskGroup"] = getattr(self, 'task_group', None) + tg = self.task_group if tg and tg.node_id and tg.prefix_group_id: # "task_group_id.task_id" -> "task_id" return self.node_id[len(tg.node_id) + 1 :] return self.node_id - task_group: Optional["TaskGroup"] - """The task_group that contains this node""" - start_date: Optional[pendulum.DateTime] end_date: Optional[pendulum.DateTime] upstream_task_ids: Set[str] @@ -268,3 +267,46 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta): def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: """This is used by SerializedTaskGroup to serialize a task group's content.""" raise NotImplementedError() + + def mapped_dependants(self) -> Iterator["DAGNode"]: + """Return any mapped nodes that are direct dependencies of the current task + + For now, this walks the entire DAG to find mapped nodes that has this + current task as an upstream. We cannot use ``downstream_list`` since it + only contains operators, not task groups. In the future, we should + provide a way to record an DAG node's all downstream nodes instead. + """ + from airflow.models.baseoperator import MappedOperator + from airflow.utils.task_group import MappedTaskGroup, TaskGroup + + def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]: + """Recursively walk children in a task group. + + This yields all direct children (including both tasks and task + groups), and all children of any task groups. + """ + for key, child in group.children.items(): + yield key, child + if isinstance(child, TaskGroup): + yield from _walk_group(child) + + tg = self.task_group + if not tg: + raise RuntimeError("Cannot check for mapped_dependants when not attached to a DAG") + for key, child in _walk_group(tg): + if key == self.node_id: + continue + if not isinstance(child, (MappedOperator, MappedTaskGroup)): + continue + if self.node_id in child.upstream_task_ids: + yield child + + def has_mapped_dependants(self) -> bool: + """Whether any downstream dependencies depend on this task for mapping. + + For now, this walks the entire DAG to find mapped nodes that has this + current task as an upstream. We cannot use ``downstream_list`` since it + only contains operators, not task groups. In the future, we should + provide a way to record an DAG node's all downstream nodes instead. + """ + return any(self.mapped_dependants()) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 63820ff..42fa314 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -19,6 +19,7 @@ import datetime import enum import logging +import weakref from dataclasses import dataclass from inspect import Parameter, signature from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Type, Union @@ -567,7 +568,9 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): @classmethod def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: - serialize_op = cls._serialize_node(op) + + stock_deps = op.deps is MappedOperator.DEFAULT_DEPS + serialize_op = cls._serialize_node(op, include_deps=not stock_deps) # It must be a class at this point for it to work, not a string assert isinstance(op.operator_class, type) serialize_op['_task_type'] = op.operator_class.__name__ @@ -577,10 +580,10 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): @classmethod def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]: - return cls._serialize_node(op) + return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps) @classmethod - def _serialize_node(cls, op: Union[BaseOperator, MappedOperator]) -> Dict[str, Any]: + def _serialize_node(cls, op: Union[BaseOperator, MappedOperator], include_deps: bool) -> Dict[str, Any]: """Serializes operator into a JSON object.""" serialize_op = cls.serialize_to_json(op, cls._decorated_fields) serialize_op['_task_type'] = type(op).__name__ @@ -594,8 +597,8 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): op.operator_extra_links ) - if op.deps is not BaseOperator.deps: - # Are the deps different to BaseOperator, if so serialize the class names! + if include_deps: + # Are the deps different to "stock", if so serialize the class names! # For Airflow 2.0 expediency we _only_ allow built in Dep classes. # Fix this for 2.0.x or 2.1 deps = [] @@ -641,7 +644,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): # These are all re-set later partial_kwargs={}, mapped_kwargs={}, - deps=tuple(), + deps=MappedOperator.DEFAULT_DEPS, is_dummy=False, template_fields=(), template_ext=(), @@ -1084,8 +1087,13 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization): for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"] } group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, **kwargs) + + def set_ref(task: BaseOperator) -> BaseOperator: + task.task_group = weakref.proxy(group) + return task + group.children = { - label: task_dict[val] # type: ignore + label: set_ref(task_dict[val]) # type: ignore if _type == DAT.OP # type: ignore else SerializedTaskGroup.deserialize_task_group(val, group, task_dict) for label, (_type, val) in encoded_group["children"].items() diff --git a/tests/models/__init__.py b/airflow/ti_deps/deps/mapped_task_expanded.py similarity index 60% copy from tests/models/__init__.py copy to airflow/ti_deps/deps/mapped_task_expanded.py index 2d4a0d9..03cf07d 100644 --- a/tests/models/__init__.py +++ b/airflow/ti_deps/deps/mapped_task_expanded.py @@ -15,10 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from airflow.ti_deps.deps.base_ti_dep import BaseTIDep -import os -from airflow.utils import timezone +class MappedTaskIsExpanded(BaseTIDep): + """Checks that a mapped task has been expanded before it's TaskInstance can run.""" -DEFAULT_DATE = timezone.datetime(2016, 1, 1) -TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags')) + NAME = "Task has been mapped" + IGNORABLE = False + IS_TASK_DEP = False + + def _get_dep_statuses(self, ti, session, dep_context): + if ti.map_index == -1: + yield self._failing_status(reason="The task has yet to be mapped!") + return + yield self._passing_status(reason="The task has been mapped") diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 8f193f4..88b956e 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -92,7 +92,6 @@ class TaskGroup(DAGNode): # used_group_ids is shared across all TaskGroups in the same DAG to keep track # of used group_id to avoid duplication. self.used_group_ids = set() - self._parent_group = None self.dag = dag else: if prefix_group_id: @@ -108,28 +107,29 @@ class TaskGroup(DAGNode): if not parent_group and not dag: raise AirflowException("TaskGroup can only be used inside a dag") - self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) - if not self._parent_group: + parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) + if not parent_group: raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup") - if dag is not self._parent_group.dag: + if dag is not parent_group.dag: raise RuntimeError( - "Cannot mix TaskGroups from different DAGs: %s and %s", dag, self._parent_group.dag + "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag ) - self.used_group_ids = self._parent_group.used_group_ids + self.used_group_ids = parent_group.used_group_ids # if given group_id already used assign suffix by incrementing largest used suffix integer # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 self._group_id = group_id self._check_for_group_id_collisions(add_suffix_on_collision) + self.children: Dict[str, DAGNode] = {} + if parent_group: + parent_group.add(self) + self.used_group_ids.add(self.group_id) if self.group_id: self.used_group_ids.add(self.downstream_join_id) self.used_group_ids.add(self.upstream_join_id) - self.children: Dict[str, DAGNode] = {} - if self._parent_group: - self._parent_group.add(self) self.tooltip = tooltip self.ui_color = ui_color @@ -175,6 +175,10 @@ class TaskGroup(DAGNode): """Returns True if this TaskGroup is the root TaskGroup. Otherwise False""" return not self.group_id + @property + def parent_group(self) -> Optional["TaskGroup"]: + return self.task_group + def __iter__(self): for child in self.children.values(): if isinstance(child, TaskGroup): @@ -184,6 +188,8 @@ class TaskGroup(DAGNode): def add(self, task: DAGNode) -> None: """Add a task to this TaskGroup.""" + # Set the TG first, as setting it might change the return value of node_id! + task.task_group = weakref.proxy(self) key = task.node_id if key in self.children: @@ -201,7 +207,6 @@ class TaskGroup(DAGNode): raise AirflowException("Cannot add a non-empty TaskGroup") self.children[key] = task - task.task_group = weakref.proxy(self) def _remove(self, task: DAGNode) -> None: key = task.node_id @@ -216,8 +221,8 @@ class TaskGroup(DAGNode): @property def group_id(self) -> Optional[str]: """group_id of this TaskGroup.""" - if self._parent_group and self._parent_group.prefix_group_id and self._parent_group.group_id: - return self._parent_group.child_id(self._group_id) + if self.task_group and self.task_group.prefix_group_id and self.task_group.group_id: + return self.task_group.child_id(self._group_id) return self._group_id @@ -380,8 +385,8 @@ class TaskGroup(DAGNode): raise RuntimeError("Cannot map a TaskGroup that already has children") if not self.group_id: raise RuntimeError("Cannot map a TaskGroup before it has a group_id") - if self._parent_group: - self._parent_group._remove(self) + if self.task_group: + self.task_group._remove(self) return MappedTaskGroup(group_id=self._group_id, dag=self.dag, mapped_arg=arg) diff --git a/tests/models/__init__.py b/tests/dags/test_mapped_classic.py similarity index 65% copy from tests/models/__init__.py copy to tests/dags/test_mapped_classic.py index 2d4a0d9..14f2c1f 100644 --- a/tests/models/__init__.py +++ b/tests/dags/test_mapped_classic.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,9 +15,20 @@ # specific language governing permissions and limitations # under the License. -import os +from airflow import DAG +from airflow.decorators import task +from airflow.operators.python import PythonOperator +from airflow.utils.dates import days_ago + + +@task +def make_list(): + return [1, 2, {'a': 'b'}] + + +def consumer(*args): + print(repr(args)) -from airflow.utils import timezone -DEFAULT_DATE = timezone.datetime(2016, 1, 1) -TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags')) +with DAG(dag_id='test_mapped_classic', start_date=days_ago(2)) as dag: + PythonOperator(task_id='consumer', python_callable=consumer).map(op_args=make_list()) diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 340e67b..af79b16 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -29,6 +29,7 @@ from kubernetes.client.rest import ApiException from urllib3 import HTTPResponse from airflow import AirflowException +from airflow.models.taskinstance import TaskInstanceKey from airflow.utils import timezone from tests.test_utils.config import conf_vars @@ -244,7 +245,7 @@ class TestKubernetesExecutor: kubernetes_executor.start() # Execute a task while the Api Throws errors try_number = 1 - task_instance_key = ('dag', 'task', 'run_id', try_number) + task_instance_key = TaskInstanceKey('dag', 'task', 'run_id', try_number) kubernetes_executor.execute_async( key=task_instance_key, queue=None, @@ -326,7 +327,7 @@ class TestKubernetesExecutor: assert executor.task_queue.empty() executor.execute_async( - key=('dag', 'task', 'run_id', 1), + key=TaskInstanceKey('dag', 'task', 'run_id', 1), queue=None, command=['airflow', 'tasks', 'run', 'true', 'some_parameter'], executor_config={ diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 37b5acc..0878f63 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -41,10 +41,12 @@ from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstanceKey from airflow.operators.dummy import DummyOperator from airflow.utils import timezone +from airflow.utils.dates import days_ago from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.types import DagRunType +from tests.models import TEST_DAGS_FOLDER from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots from tests.test_utils.mock_executor import MockExecutor from tests.test_utils.timetables import cron_timetable @@ -190,7 +192,7 @@ class TestBackfillJob: ("run_this_last", end_date), ] assert [ - ((dag.dag_id, task_id, f'backfill__{when.isoformat()}', 1), (State.SUCCESS, None)) + ((dag.dag_id, task_id, f'backfill__{when.isoformat()}', 1, -1), (State.SUCCESS, None)) for (task_id, when) in expected_execution_order ] == executor.sorted_tasks @@ -267,7 +269,7 @@ class TestBackfillJob: job.run() assert [ - ((dag_id, task_id, f'backfill__{DEFAULT_DATE.isoformat()}', 1), (State.SUCCESS, None)) + ((dag_id, task_id, f'backfill__{DEFAULT_DATE.isoformat()}', 1, -1), (State.SUCCESS, None)) for task_id in expected_execution_order ] == executor.sorted_tasks @@ -1230,12 +1232,11 @@ class TestBackfillJob: subdag.clear() dag.clear() - def test_update_counters(self, dag_maker): - with dag_maker(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE) as dag: + def test_update_counters(self, dag_maker, session): + with dag_maker(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE, session=session) as dag: task1 = DummyOperator(task_id='dummy', owner='airflow') dr = dag_maker.create_dagrun() job = BackfillJob(dag=dag) - session = settings.Session() ti = TI(task1, dr.execution_date) ti.refresh_from_db() @@ -1245,7 +1246,7 @@ class TestBackfillJob: # test for success ti.set_state(State.SUCCESS, session) ti_status.running[ti.key] = ti - job._update_counters(ti_status=ti_status) + job._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 1 assert len(ti_status.skipped) == 0 @@ -1257,7 +1258,7 @@ class TestBackfillJob: # test for skipped ti.set_state(State.SKIPPED, session) ti_status.running[ti.key] = ti - job._update_counters(ti_status=ti_status) + job._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 1 @@ -1269,7 +1270,7 @@ class TestBackfillJob: # test for failed ti.set_state(State.FAILED, session) ti_status.running[ti.key] = ti - job._update_counters(ti_status=ti_status) + job._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1281,7 +1282,7 @@ class TestBackfillJob: # test for retry ti.set_state(State.UP_FOR_RETRY, session) ti_status.running[ti.key] = ti - job._update_counters(ti_status=ti_status) + job._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1297,7 +1298,7 @@ class TestBackfillJob: ti.set_state(State.UP_FOR_RESCHEDULE, session) assert ti.try_number == 3 # see ti.try_number property in taskinstance module ti_status.running[ti.key] = ti - job._update_counters(ti_status=ti_status) + job._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1315,7 +1316,7 @@ class TestBackfillJob: session.merge(ti) session.commit() ti_status.running[ti.key] = ti - job._update_counters(ti_status=ti_status) + job._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1510,3 +1511,22 @@ class TestBackfillJob: ) job.run() assert executor.job_id is not None + + def test_mapped_dag(self, dag_maker): + """End-to-end test of a simple mapped dag""" + # Use SequentialExecutor for more predictable test behaviour + from airflow.executors.sequential_executor import SequentialExecutor + + self.dagbag.process_file(str(TEST_DAGS_FOLDER / 'test_mapped_classic.py')) + dag = self.dagbag.get_dag('test_mapped_classic') + + # This needs a real executor to run, so that the `make_list` task can write out the TaskMap + + job = BackfillJob( + dag=dag, + start_date=days_ago(1), + end_date=days_ago(1), + donot_pickle=True, + executor=SequentialExecutor(), + ) + job.run() diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 2d4a0d9..c1cbabd 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -16,9 +16,9 @@ # specific language governing permissions and limitations # under the License. -import os +import pathlib from airflow.utils import timezone DEFAULT_DATE = timezone.datetime(2016, 1, 1) -TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags')) +TEST_DAGS_FOLDER = pathlib.Path(__file__).parent.with_name('dags') diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index ffde6ee..9f3e8ad 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -718,6 +718,7 @@ def test_task_mapping_with_dag(): assert task1.downstream_list == [mapped] assert mapped in dag.tasks + assert mapped.task_group == dag.task_group # At parse time there should only be three tasks! assert len(dag.tasks) == 3 @@ -799,11 +800,21 @@ def test_partial_on_class_invalid_ctor_args() -> None: ["num_existing_tis", "expected"], ( pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'), - pytest.param(3, [(0, None), (1, None), (2, None)], id='all-tis-exist'), + pytest.param( + 3, + [(0, 'success'), (1, 'success'), (2, 'success')], + id='all-tis-exist', + ), pytest.param( 5, - [(0, None), (1, None), (2, None), (3, TaskInstanceState.REMOVED), (4, TaskInstanceState.REMOVED)], - id="tis-to-be-remove", + [ + (0, 'success'), + (1, 'success'), + (2, 'success'), + (3, TaskInstanceState.REMOVED), + (4, TaskInstanceState.REMOVED), + ], + id="tis-to-be-removed", ), ), ) @@ -836,7 +847,8 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec ).delete() for index in range(num_existing_tis): - ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index) # type: ignore + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) session.add(ti) session.flush() diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index f7e40ef..1eb3328 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -2255,7 +2255,7 @@ def test_set_task_instance_state(run_id, execution_date, session, dag_maker): # dagrun should be set to QUEUED assert dagrun.get_state() == State.QUEUED - assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', dagrun.run_id, 1)} + assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', dagrun.run_id, 1, -1)} @pytest.mark.parametrize( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index ff4207a..bbf05c3 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1947,7 +1947,7 @@ class TestTaskInstance: with dag_maker('test-dag', session=session) as dag: task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}") - dag.fileloc = TEST_DAGS_FOLDER + '/test_get_k8s_pod_yaml.py' + dag.fileloc = TEST_DAGS_FOLDER / 'test_get_k8s_pod_yaml.py' ti = dag_maker.create_dagrun().task_instances[0] ti.task = task diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 35d0d68..447b173 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -488,6 +488,9 @@ class TestStringifiedDAGs: assert not isinstance(task, SerializedBaseOperator) assert isinstance(task, BaseOperator) + # Every task should have a task_group property -- even if it's the DAG's root task group + assert serialized_task.task_group + fields_to_check = task.get_serialized_fields() - { # Checked separately '_task_type', @@ -1608,6 +1611,7 @@ def test_mapped_operator_serde(): op = SerializedBaseOperator.deserialize_operator(serialized) assert isinstance(op, MappedOperator) + assert op.deps is MappedOperator.DEFAULT_DEPS assert op.operator_class == "airflow.operators.bash.BashOperator" assert op.mapped_kwargs['bash_command'] == literal @@ -1637,6 +1641,7 @@ def test_mapped_operator_xcomarg_serde(): } op = SerializedBaseOperator.deserialize_operator(serialized) + assert op.deps is MappedOperator.DEFAULT_DEPS arg = op.mapped_kwargs['arg2'] assert arg.task_id == 'op1' diff --git a/tests/test_utils/mock_executor.py b/tests/test_utils/mock_executor.py index 37f49cf..23d32b6 100644 --- a/tests/test_utils/mock_executor.py +++ b/tests/test_utils/mock_executor.py @@ -59,10 +59,10 @@ class MockExecutor(BaseExecutor): # for tests! def sort_by(item): key, val = item - (dag_id, task_id, date, try_number) = key + (dag_id, task_id, date, try_number, map_index) = key (_, prio, _, _) = val # Sort by priority (DESC), then date,task, try - return -prio, date, dag_id, task_id, try_number + return -prio, date, dag_id, task_id, map_index, try_number open_slots = self.parallelism - len(self.running) sorted_queue = sorted(self.queued_tasks.items(), key=sort_by)