This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1c7a4acbb1 Move mapped kwargs introspection to separate type (#24971)
1c7a4acbb1 is described below
commit 1c7a4acbb1fe7882247b3329de87002ff938657d
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Jul 12 15:51:47 2022 +0800
Move mapped kwargs introspection to separate type (#24971)
---
airflow/decorators/base.py | 130 ++++---
airflow/models/baseoperator.py | 2 +-
airflow/models/expandinput.py | 201 +++++++++++
airflow/models/mappedoperator.py | 380 ++++++++-------------
airflow/models/taskinstance.py | 34 +-
airflow/serialization/schema.json | 8 +-
airflow/serialization/serialized_objects.py | 61 +++-
airflow/www/views.py | 2 +-
docs/spelling_wordlist.txt | 1 +
.../pre_commit_base_operator_partial_arguments.py | 3 +-
.../api_connexion/endpoints/test_task_endpoint.py | 3 +-
tests/decorators/test_python.py | 11 +-
tests/models/test_dagrun.py | 2 +-
tests/models/test_mappedoperator.py | 4 +-
tests/models/test_taskinstance.py | 11 +-
tests/serialization/test_dag_serialization.py | 74 ++--
16 files changed, 541 insertions(+), 386 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 2a2ce2da96..9598571240 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -25,11 +25,13 @@ from typing import (
Collection,
Dict,
Generic,
+ Iterable,
Iterator,
Mapping,
Optional,
Sequence,
Set,
+ Tuple,
Type,
TypeVar,
cast,
@@ -38,6 +40,7 @@ from typing import (
import attr
import typing_extensions
+from sqlalchemy.orm import Session
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
@@ -50,6 +53,7 @@ from airflow.models.baseoperator import (
parse_retries,
)
from airflow.models.dag import DAG, DagContext
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY,
DictOfListsExpandInput, ExpandInput
from airflow.models.mappedoperator import (
MappedOperator,
ValidationSource,
@@ -67,7 +71,6 @@ from airflow.utils.types import NOTSET
if TYPE_CHECKING:
import jinja2 # Slow import.
- from sqlalchemy.orm import Session
from airflow.models.mappedoperator import Mappable
@@ -314,10 +317,14 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
def expand(self, **map_kwargs: "Mappable") -> XComArg:
if not map_kwargs:
raise TypeError("no arguments to expand against")
-
self._validate_arg_names("expand", map_kwargs)
prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping
already partial")
- ensure_xcomarg_return_value(map_kwargs)
+ # Since the input is already checked at parse time, we can set strict
+ # to False to skip the checks on execution.
+ return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
+
+ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
+ ensure_xcomarg_return_value(expand_input.value)
task_kwargs = self.kwargs.copy()
dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
@@ -363,7 +370,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
_MappedOperator = cast(Any, DecoratedMappedOperator)
operator = _MappedOperator(
operator_class=self.operator_class,
- mapped_kwargs={},
+ expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values
go to op_kwargs_expand_input.
partial_kwargs=partial_kwargs,
task_id=task_id,
params=params,
@@ -383,36 +390,26 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
end_date=end_date,
multiple_outputs=self.multiple_outputs,
python_callable=self.function,
- mapped_op_kwargs=map_kwargs,
+ op_kwargs_expand_input=expand_input,
+ disallow_kwargs_override=strict,
# Different from classic operators, kwargs passed to a taskflow
# task's expand() contribute to the op_kwargs operator argument,
not
# the operator arguments themselves, and should expand against it.
- expansion_kwargs_attr="mapped_op_kwargs",
+ expand_input_attr="op_kwargs_expand_input",
)
return XComArg(operator=operator)
- def partial(self, **kwargs) -> "_TaskDecorator[Function,
OperatorSubclass]":
+ def partial(self, **kwargs: Any) -> "_TaskDecorator[Function,
OperatorSubclass]":
self._validate_arg_names("partial", kwargs)
+ old_kwargs = self.kwargs.get("op_kwargs", {})
+ prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
+ kwargs.update(old_kwargs)
+ return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
- op_kwargs = self.kwargs.get("op_kwargs", {})
- op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate
partial")
-
- return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs":
op_kwargs})
-
- def override(self, **kwargs) -> "_TaskDecorator[Function,
OperatorSubclass]":
+ def override(self, **kwargs: Any) -> "_TaskDecorator[Function,
OperatorSubclass]":
return attr.evolve(self, kwargs={**self.kwargs, **kwargs})
-def _merge_kwargs(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *,
fail_reason: str) -> Dict[str, Any]:
- duplicated_keys = set(kwargs1).intersection(kwargs2)
- if len(duplicated_keys) == 1:
- raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
- elif duplicated_keys:
- duplicated_keys_display = ", ".join(sorted(duplicated_keys))
- raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
- return {**kwargs1, **kwargs2}
-
-
@attr.define(kw_only=True, repr=False)
class DecoratedMappedOperator(MappedOperator):
"""MappedOperator implementation for @task-decorated task function."""
@@ -420,9 +417,9 @@ class DecoratedMappedOperator(MappedOperator):
multiple_outputs: bool
python_callable: Callable
- # We can't save these in mapped_kwargs because op_kwargs need to be present
+ # We can't save these in expand_input because op_kwargs need to be present
# in partial_kwargs, and MappedOperator prevents duplication.
- mapped_op_kwargs: Dict[str, "Mappable"]
+ op_kwargs_expand_input: ExpandInput
def __hash__(self):
return id(self)
@@ -431,40 +428,38 @@ class DecoratedMappedOperator(MappedOperator):
# The magic super() doesn't work here, so we use the explicit form.
# Not using super(..., self) to work around pyupgrade bug.
super(DecoratedMappedOperator,
DecoratedMappedOperator).__attrs_post_init__(self)
- XComArg.apply_upstream_relationship(self, self.mapped_op_kwargs)
-
- def _get_unmap_kwargs(self) -> Dict[str, Any]:
- partial_kwargs = self.partial_kwargs.copy()
- op_kwargs = _merge_kwargs(
- partial_kwargs.pop("op_kwargs"),
- self.mapped_op_kwargs,
- fail_reason="mapping already partial",
- )
- self._combined_op_kwargs = op_kwargs
- return {
- "dag": self.dag,
- "task_group": self.task_group,
- "task_id": self.task_id,
- "op_kwargs": op_kwargs,
+ XComArg.apply_upstream_relationship(self,
self.op_kwargs_expand_input.value)
+
+ def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context,
Session]]) -> Dict[str, Any]:
+ # We only use op_kwargs_expand_input so this must always be empty.
+ assert self.expand_input is EXPAND_INPUT_EMPTY
+ return {"op_kwargs": super()._expand_mapped_kwargs(resolve)}
+
+ def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict:
bool) -> Dict[str, Any]:
+ if strict:
+ prevent_duplicates(
+ self.partial_kwargs["op_kwargs"],
+ mapped_kwargs["op_kwargs"],
+ fail_reason="mapping already partial",
+ )
+ self._combined_op_kwargs = {**self.partial_kwargs["op_kwargs"],
**mapped_kwargs["op_kwargs"]}
+ self._already_resolved_op_kwargs = {
+ k for k, v in self.op_kwargs_expand_input.value.items() if
isinstance(v, XComArg)
+ }
+ kwargs = {
"multiple_outputs": self.multiple_outputs,
"python_callable": self.python_callable,
- **partial_kwargs,
- **self.mapped_kwargs,
+ "op_kwargs": self._combined_op_kwargs,
}
+ return super()._get_unmap_kwargs(kwargs, strict=False)
- def _resolve_expansion_kwargs(
- self, kwargs: Dict[str, Any], template_fields: Set[str], context:
Context, session: "Session"
- ) -> None:
- expansion_kwargs = self._get_expansion_kwargs()
-
- self._already_resolved_op_kwargs = set()
- for k, v in expansion_kwargs.items():
- if isinstance(v, XComArg):
- self._already_resolved_op_kwargs.add(k)
- v = v.resolve(context, session=session)
- v = self._expand_mapped_field(k, v, context, session=session)
- kwargs['op_kwargs'][k] = v
- template_fields.discard(k)
+ def _get_template_fields_to_render(self, expanded: Iterable[str]) ->
Iterable[str]:
+ # Different from a regular MappedOperator, we still want to render
+ # some fields in op_kwargs (those that are NOT passed as XComArg from
+ # upstream). Already-rendered op_kwargs keys are detected in a
different
+ # way (see render_template below and _get_unmap_kwargs above).
+ assert list(expanded) == ["op_kwargs"]
+ return self.template_fields
def render_template(
self,
@@ -473,17 +468,20 @@ class DecoratedMappedOperator(MappedOperator):
jinja_env: Optional["jinja2.Environment"] = None,
seen_oids: Optional[Set] = None,
) -> Any:
- if hasattr(self, '_combined_op_kwargs') and value is
self._combined_op_kwargs:
- # Avoid rendering values that came out of resolved XComArgs
- return {
- k: v
- if k in self._already_resolved_op_kwargs
- else super(DecoratedMappedOperator,
DecoratedMappedOperator).render_template(
- self, v, context, jinja_env=jinja_env, seen_oids=seen_oids
- )
- for k, v in value.items()
- }
- return super().render_template(value, context, jinja_env=jinja_env,
seen_oids=seen_oids)
+ if value is not getattr(self, "_combined_op_kwargs", object()):
+ return super().render_template(value, context,
jinja_env=jinja_env, seen_oids=seen_oids)
+
+ def _render_if_not_already_resolved(key: str, value: Any):
+ if key in self._already_resolved_op_kwargs:
+ return value
+ # The magic super() doesn't work here, so we use the explicit form.
+ # Not using super(..., self) to work around pyupgrade bug.
+ return super(DecoratedMappedOperator,
DecoratedMappedOperator).render_template(
+ self, value, context=context, jinja_env=jinja_env,
seen_oids=seen_oids
+ )
+
+ # Avoid rendering values that came out of resolved XComArgs.
+ return {k: _render_if_not_already_resolved(k, v) for k, v in
value.items()}
class Task(Generic[Function]):
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index ab9543c127..05ef24a124 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1511,7 +1511,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
if cls.mapped_arguments_validated_by_init:
cls(**kwargs, _airflow_from_mapped=True)
- def unmap(self) -> "BaseOperator":
+ def unmap(self, ctx: Union[None, Dict[str, Any], Tuple[Context, Session]])
-> "BaseOperator":
""":meta private:"""
return self
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
new file mode 100644
index 0000000000..86623b41a1
--- /dev/null
+++ b/airflow/models/expandinput.py
@@ -0,0 +1,201 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import collections
+import collections.abc
+import functools
+import operator
+from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session
+
+from airflow.exceptions import UnmappableXComTypePushed
+from airflow.utils.context import Context
+
+if TYPE_CHECKING:
+ from airflow.models.xcom_arg import XComArg
+
+# BaseOperator.expand() can be called on an XComArg, sequence, or dict (not any
+# mapping since we need the value to be ordered).
+Mappable = Union["XComArg", Sequence, dict]
+
+MAPPABLE_LITERAL_TYPES = (dict, list)
+
+
+class NotFullyPopulated(RuntimeError):
+ """Raise when ``get_map_lengths`` cannot populate all mapping metadata.
+ This is generally due to not all upstream tasks have finished when the
+ function is called.
+ """
+
+ def __init__(self, missing: set[str]) -> None:
+ self.missing = missing
+
+ def __str__(self) -> str:
+ keys = ", ".join(repr(k) for k in sorted(self.missing))
+ return f"Failed to populate all mapping metadata; missing: {keys}"
+
+
+class DictOfListsExpandInput(NamedTuple):
+ """Storage type of a mapped operator's mapped kwargs.
+
+ This is created from ``expand(**kwargs)``.
+ """
+
+ value: dict[str, Mappable]
+
+ @staticmethod
+ def validate_xcom(value: Any) -> None:
+ if not isinstance(value, collections.abc.Collection) or
isinstance(value, (bytes, str)):
+ raise UnmappableXComTypePushed(value)
+
+ def get_parse_time_mapped_ti_count(self) -> int | None:
+ if not self.value:
+ return 0
+ literal_values = [len(v) for v in self.value.values() if isinstance(v,
MAPPABLE_LITERAL_TYPES)]
+ if len(literal_values) != len(self.value):
+ return None # None-literal type encountered, so give up.
+ return functools.reduce(operator.mul, literal_values, 1)
+
+ def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str,
int]:
+ """Return dict of argument name to map length.
+
+ If any arguments are not known right now (upstream task not finished),
+ they will not be present in the dict.
+ """
+ from airflow.models.taskmap import TaskMap
+ from airflow.models.xcom import XCom
+ from airflow.models.xcom_arg import XComArg
+
+ # Populate literal mapped arguments first.
+ map_lengths: dict[str, int] = collections.defaultdict(int)
+ map_lengths.update((k, len(v)) for k, v in self.value.items() if not
isinstance(v, XComArg))
+
+ try:
+ dag_id = next(v.operator.dag_id for v in self.value.values() if
isinstance(v, XComArg))
+ except StopIteration: # All mapped arguments are literal. We're done.
+ return map_lengths
+
+ # Build a reverse mapping of what arguments each task contributes to.
+ mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
+ non_mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
+ for k, v in self.value.items():
+ if not isinstance(v, XComArg):
+ continue
+ assert v.operator.dag_id == dag_id
+ if v.operator.is_mapped:
+ mapped_dep_keys[v.operator.task_id].add(k)
+ else:
+ non_mapped_dep_keys[v.operator.task_id].add(k)
+ # TODO: It's not possible now, but in the future we may support
+ # depending on one single mapped task instance. When that happens,
+ # we need to further analyze the mapped case to contain only tasks
+ # we depend on "as a whole", and put those we only depend on
+ # individually to the non-mapped lookup.
+
+ # Collect lengths from unmapped upstreams.
+ taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
+ TaskMap.dag_id == dag_id,
+ TaskMap.run_id == run_id,
+ TaskMap.task_id.in_(non_mapped_dep_keys),
+ TaskMap.map_index < 0,
+ )
+ for task_id, length in taskmap_query:
+ for mapped_arg_name in non_mapped_dep_keys[task_id]:
+ map_lengths[mapped_arg_name] += length
+
+ # Collect lengths from mapped upstreams.
+ xcom_query = (
+ session.query(XCom.task_id, func.count(XCom.map_index))
+ .group_by(XCom.task_id)
+ .filter(
+ XCom.dag_id == dag_id,
+ XCom.run_id == run_id,
+ XCom.task_id.in_(mapped_dep_keys),
+ XCom.map_index >= 0,
+ )
+ )
+ for task_id, length in xcom_query:
+ for mapped_arg_name in mapped_dep_keys[task_id]:
+ map_lengths[mapped_arg_name] += length
+
+ if len(map_lengths) < len(self.value):
+ raise NotFullyPopulated(set(self.value).difference(map_lengths))
+ return map_lengths
+
+ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
+ if not self.value:
+ return 0
+ lengths = self._get_map_lengths(run_id, session=session)
+ return functools.reduce(operator.mul, (lengths[name] for name in
self.value), 1)
+
+ def _expand_mapped_field(self, key: str, value: Any, context: Context, *,
session: Session) -> Any:
+ from airflow.models.xcom_arg import XComArg
+
+ if isinstance(value, XComArg):
+ value = value.resolve(context, session=session)
+ map_index = context["ti"].map_index
+ if map_index < 0:
+ raise RuntimeError("can't resolve task-mapping argument without
expanding")
+ all_lengths = self._get_map_lengths(context["run_id"], session=session)
+
+ def _find_index_for_this_field(index: int) -> int:
+ # Need to use the original user input to retain argument order.
+ for mapped_key in reversed(list(self.value)):
+ mapped_length = all_lengths[mapped_key]
+ if mapped_length < 1:
+ raise RuntimeError(f"cannot expand field mapped to length
{mapped_length!r}")
+ if mapped_key == key:
+ return index % mapped_length
+ index //= mapped_length
+ return -1
+
+ found_index = _find_index_for_this_field(map_index)
+ if found_index < 0:
+ return value
+ if isinstance(value, collections.abc.Sequence):
+ return value[found_index]
+ if not isinstance(value, dict):
+ raise TypeError(f"can't map over value of type {type(value)}")
+ for i, (k, v) in enumerate(value.items()):
+ if i == found_index:
+ return k, v
+ raise IndexError(f"index {map_index} is over mapped length")
+
+ def resolve(self, context: Context, session: Session) -> dict[str, Any]:
+ return {k: self._expand_mapped_field(k, v, context, session=session)
for k, v in self.value.items()}
+
+
+ExpandInput = DictOfListsExpandInput
+
+EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
+
+_EXPAND_INPUT_TYPES = {
+ "dict-of-lists": DictOfListsExpandInput,
+}
+
+
+def get_map_type_key(expand_input: ExpandInput) -> str:
+ return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v ==
type(expand_input))
+
+
+def create_expand_input(kind: str, value: Any) -> ExpandInput:
+ return _EXPAND_INPUT_TYPES[kind](value)
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index e34b1501bc..9e75e9f4aa 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -19,12 +19,11 @@
import collections
import collections.abc
import datetime
-import functools
-import operator
import warnings
from typing import (
TYPE_CHECKING,
Any,
+ Callable,
ClassVar,
Collection,
Dict,
@@ -61,6 +60,13 @@ from airflow.models.abstractoperator import (
AbstractOperator,
TaskStateChangeCallback,
)
+from airflow.models.expandinput import (
+ MAPPABLE_LITERAL_TYPES,
+ DictOfListsExpandInput,
+ ExpandInput,
+ Mappable,
+ NotFullyPopulated,
+)
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
@@ -80,19 +86,11 @@ if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
- from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup
- # BaseOperator.expand() can be called on an XComArg, sequence, or dict (not
- # any mapping since we need the value to be ordered).
- Mappable = Union[XComArg, Sequence, dict]
-
ValidationSource = Union[Literal["expand"], Literal["partial"]]
-MAPPABLE_LITERAL_TYPES = (dict, list)
-
-
# For isinstance() check.
@cache
def get_mappable_types() -> Tuple[type, ...]:
@@ -194,16 +192,17 @@ class OperatorPartial:
def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
if not mapped_kwargs:
raise TypeError("no arguments to expand against")
- return self._expand(**mapped_kwargs)
-
- def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
- self._expand_called = True
+ validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
+ prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable
or already specified")
+ # Since the input is already checked at parse time, we can set strict
+ # to False to skip the checks on execution.
+ return self._expand(DictOfListsExpandInput(mapped_kwargs),
strict=False)
+ def _expand(self, expand_input: ExpandInput, *, strict: bool) ->
"MappedOperator":
from airflow.operators.empty import EmptyOperator
- validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
- prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable
or already specified")
- ensure_xcomarg_return_value(mapped_kwargs)
+ self._expand_called = True
+ ensure_xcomarg_return_value(expand_input.value)
partial_kwargs = self.kwargs.copy()
task_id = partial_kwargs.pop("task_id")
@@ -215,7 +214,7 @@ class OperatorPartial:
op = MappedOperator(
operator_class=self.operator_class,
- mapped_kwargs=mapped_kwargs,
+ expand_input=expand_input,
partial_kwargs=partial_kwargs,
task_id=task_id,
params=params,
@@ -233,9 +232,10 @@ class OperatorPartial:
task_group=task_group,
start_date=start_date,
end_date=end_date,
- # For classic operators, this points to mapped_kwargs because
kwargs
+ disallow_kwargs_override=strict,
+ # For classic operators, this points to expand_input because kwargs
# to BaseOperator.expand() contribute to operator arguments.
- expansion_kwargs_attr="mapped_kwargs",
+ expand_input_attr="expand_input",
)
return op
@@ -261,7 +261,7 @@ class MappedOperator(AbstractOperator):
# that can be used to unmap this into a SerializedBaseOperator.
operator_class: Union[Type["BaseOperator"], Dict[str, Any]]
- mapped_kwargs: Dict[str, "Mappable"]
+ expand_input: ExpandInput
partial_kwargs: Dict[str, Any]
# Needed for serialization.
@@ -285,7 +285,14 @@ class MappedOperator(AbstractOperator):
upstream_task_ids: Set[str] = attr.ib(factory=set, init=False)
downstream_task_ids: Set[str] = attr.ib(factory=set, init=False)
- _expansion_kwargs_attr: str
+ _disallow_kwargs_override: bool
+ """Whether execution fails if ``expand_input`` has duplicates to
``partial_kwargs``.
+
+ If *False*, values from ``expand_input`` under duplicate keys override
those
+ under corresponding keys in ``partial_kwargs``.
+ """
+
+ _expand_input_attr: str
"""Where to get kwargs to calculate expansion length against.
This should be a name to call ``getattr()`` on.
@@ -315,8 +322,7 @@ class MappedOperator(AbstractOperator):
self.task_group.add(self)
if self.dag:
self.dag.add_task(self)
- for k, v in self.mapped_kwargs.items():
- XComArg.apply_upstream_relationship(self, v)
+ XComArg.apply_upstream_relationship(self, self.expand_input.value)
for k, v in self.partial_kwargs.items():
if k in self.template_fields:
XComArg.apply_upstream_relationship(self, v)
@@ -334,7 +340,7 @@ class MappedOperator(AbstractOperator):
"dag",
"deps",
"is_mapped",
- "mapped_kwargs", # This is needed to be able to accept XComArg.
+ "expand_input", # This is needed to be able to accept XComArg.
"subdag",
"task_group",
"upstream_task_ids",
@@ -361,7 +367,9 @@ class MappedOperator(AbstractOperator):
"""
if not isinstance(self.operator_class, type):
return # No need to validate deserialized operator.
-
self.operator_class.validate_mapped_arguments(**self._get_unmap_kwargs())
+ kwargs = self._expand_mapped_kwargs(None)
+ kwargs = self._get_unmap_kwargs(kwargs,
strict=self._disallow_kwargs_override)
+ self.operator_class.validate_mapped_arguments(**kwargs)
@property
def task_type(self) -> str:
@@ -520,7 +528,38 @@ class MappedOperator(AbstractOperator):
"""Implementing DAGNode."""
return DagAttributeTypes.OP, self.task_id
- def _get_unmap_kwargs(self) -> Dict[str, Any]:
+ def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context,
Session]]) -> Dict[str, Any]:
+ """Get the kwargs to create the unmapped operator.
+
+ If *resolve* is not *None*, it must be a two-tuple to provide context
to
+ resolve XComArgs (a templating context, and a database session).
+
+ When resolving is not possible (e.g. to perform parse-time validation),
+ *resolve* can be set to *None*. This will cause the dict-of-lists
+ variant to simply return a dict of XComArgs corresponding to each
kwargs
+ to pass to the unmapped operator. Since it is impossible to perform any
+ operation on the list-of-dicts variant before execution time, an empty
+ dict will be returned for this case.
+ """
+ kwargs = self._get_specified_expand_input()
+ if resolve is not None:
+ return kwargs.resolve(*resolve)
+ if isinstance(kwargs, DictOfListsExpandInput):
+ return kwargs.value
+ return {}
+
+ def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict:
bool) -> Dict[str, Any]:
+ """Get init kwargs to unmap the underlying operator class.
+
+ :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
+ """
+ if strict:
+ prevent_duplicates(
+ self.partial_kwargs,
+ mapped_kwargs,
+ fail_reason="unmappable or already specified",
+ )
+ # Ordering is significant; mapped kwargs should override partial ones.
return {
"task_id": self.task_id,
"dag": self.dag,
@@ -529,29 +568,35 @@ class MappedOperator(AbstractOperator):
"start_date": self.start_date,
"end_date": self.end_date,
**self.partial_kwargs,
- **self.mapped_kwargs,
+ **mapped_kwargs,
}
- def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) ->
"BaseOperator":
- """
- Get the "normal" Operator after applying the current mapping.
+ def unmap(self, resolve: Union[None, Dict[str, Any], Tuple[Context,
Session]]) -> "BaseOperator":
+ """Get the "normal" Operator after applying the current mapping.
- If ``operator_class`` is not a class (i.e. this DAG has been
deserialized) then this will return a
- SerializedBaseOperator that aims to "look like" the real operator.
+ If ``operator_class`` is not a class (i.e. this DAG has been
+ deserialized), this returns a SerializedBaseOperator that aims to
+ "look like" the actual unmapping result.
- :param unmap_kwargs: Override the args to pass to the Operator
constructor. Only used when
- ``operator_class`` is still an actual class.
+ :param resolve: Only used if ``operator_class`` is a real class. If
this
+ is a two-tuple (context, session), the information is used to
+ resolve the mapped arguments into init arguments. If this is a
+ mapping, no resolving happens, the mapping directly provides those
+ init arguments resolved from mapped kwargs.
:meta private:
"""
if isinstance(self.operator_class, type):
- # We can't simply specify task_id here because BaseOperator further
+ if isinstance(resolve, collections.abc.Mapping):
+ kwargs = resolve
+ else:
+ kwargs = self._expand_mapped_kwargs(resolve)
+ kwargs = self._get_unmap_kwargs(kwargs,
strict=self._disallow_kwargs_override)
+ op = self.operator_class(**kwargs, _airflow_from_mapped=True)
+ # We need to overwrite task_id here because BaseOperator further
# mangles the task_id based on the task hierarchy (namely, group_id
- # is prepended, and '__N' appended to deduplicate). Instead of
- # recreating the whole logic here, we just overwrite task_id later.
- if unmap_kwargs is None:
- unmap_kwargs = self._get_unmap_kwargs()
- op = self.operator_class(**unmap_kwargs, _airflow_from_mapped=True)
+ # is prepended, and '__N' appended to deduplicate). This is hacky,
+ # but better than duplicating the whole mangling logic.
op.task_id = self.task_id
return op
@@ -565,80 +610,23 @@ class MappedOperator(AbstractOperator):
SerializedBaseOperator.populate_operator(op, self.operator_class)
return op
- def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]:
- """The kwargs to calculate expansion length against."""
- return getattr(self, self._expansion_kwargs_attr)
-
- def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str,
int]:
- """Return dict of argument name to map length.
-
- If any arguments are not known right now (upstream task not finished)
they will not be present in the
- dict.
- """
- # TODO: Find a way to cache this.
- from airflow.models.taskmap import TaskMap
- from airflow.models.xcom import XCom
- from airflow.models.xcom_arg import XComArg
+ def _get_specified_expand_input(self) -> ExpandInput:
+ """Input received from the expand call on the operator."""
+ return getattr(self, self._expand_input_attr)
- expansion_kwargs = self._get_expansion_kwargs()
-
- # Populate literal mapped arguments first.
- map_lengths: Dict[str, int] = collections.defaultdict(int)
- map_lengths.update((k, len(v)) for k, v in expansion_kwargs.items() if
not isinstance(v, XComArg))
-
- # Build a reverse mapping of what arguments each task contributes to.
- mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
- non_mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
- for k, v in expansion_kwargs.items():
- if not isinstance(v, XComArg):
- continue
- if v.operator.is_mapped:
- mapped_dep_keys[v.operator.task_id].add(k)
- else:
- non_mapped_dep_keys[v.operator.task_id].add(k)
- # TODO: It's not possible now, but in the future (AIP-42 Phase 2)
- # we will add support for depending on one single mapped task
- # instance. When that happens, we need to further analyze the
mapped
- # case to contain only tasks we depend on "as a whole", and put
- # those we only depend on individually to the non-mapped lookup.
-
- # Collect lengths from unmapped upstreams.
- taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
- TaskMap.dag_id == self.dag_id,
- TaskMap.run_id == run_id,
- TaskMap.task_id.in_(non_mapped_dep_keys),
- TaskMap.map_index < 0,
- )
- for task_id, length in taskmap_query:
- for mapped_arg_name in non_mapped_dep_keys[task_id]:
- map_lengths[mapped_arg_name] += length
-
- # Collect lengths from mapped upstreams.
- xcom_query = (
- session.query(XCom.task_id, func.count(XCom.map_index))
- .group_by(XCom.task_id)
- .filter(
- XCom.dag_id == self.dag_id,
- XCom.run_id == run_id,
- XCom.task_id.in_(mapped_dep_keys),
- XCom.map_index >= 0,
- )
- )
- for task_id, length in xcom_query:
- for mapped_arg_name in mapped_dep_keys[task_id]:
- map_lengths[mapped_arg_name] += length
- return map_lengths
+ @property
+ def validate_upstream_return_value(self) -> Callable[[Any], None]:
+ """Validate an upstream's return value satisfies this task's needs.
- @cache
- def _resolve_map_lengths(self, run_id: str, *, session: Session) ->
Dict[str, int]:
- """Return dict of argument name to map length, or throw if some are
not resolvable"""
- expansion_kwargs = self._get_expansion_kwargs()
- map_lengths = self._get_map_lengths(run_id, session=session)
- if len(map_lengths) < len(expansion_kwargs):
- keys = ", ".join(repr(k) for k in
sorted(set(expansion_kwargs).difference(map_lengths)))
- raise RuntimeError(f"Failed to populate all mapping metadata;
missing: {keys}")
+ This is implemented as a property (instead of a function calling
+ ``validate_xcom``) so the call site in TaskInstance can de-duplicate
+ validation functions. If this is an instance method, each
+ ``validate_upstream_return_value`` would be a different object (due to
+ how Python handles bounded functions), and de-duplication won't work.
- return map_lengths
+ :meta private:
+ """
+ return self._get_specified_expand_input().validate_xcom
def expand_mapped_task(self, run_id: str, *, session: Session) ->
Tuple[Sequence["TaskInstance"], int]:
"""Create the mapped task instances for mapped task.
@@ -649,9 +637,7 @@ class MappedOperator(AbstractOperator):
from airflow.models.taskinstance import TaskInstance
from airflow.settings import task_instance_mutation_hook
- total_length = functools.reduce(
- operator.mul, self._resolve_map_lengths(run_id,
session=session).values()
- )
+ total_length =
self._get_specified_expand_input().get_total_map_length(run_id, session=session)
state: Optional[TaskInstanceState] = None
unmapped_ti: Optional[TaskInstance] = (
@@ -728,143 +714,63 @@ class MappedOperator(AbstractOperator):
# we don't need to create a copy of the MappedOperator here.
return self
- def render_template_fields(
- self,
- context: Context,
- jinja_env: Optional["jinja2.Environment"] = None,
- ) -> Optional["BaseOperator"]:
- """Template all attributes listed in template_fields.
-
- Different from the BaseOperator implementation, this renders the
- template fields on the *unmapped* BaseOperator.
-
- :param context: Dict with values to apply on content
- :param jinja_env: Jinja environment
- :return: The unmapped, populated BaseOperator
- """
- if not jinja_env:
- jinja_env = self.get_template_env()
- # Before we unmap we have to resolve the mapped arguments, otherwise
the real operator constructor
- # could be called with an XComArg, rather than the value it resolves
to.
- #
- # We also need to resolve _all_ mapped arguments, even if they aren't
marked as templated
- kwargs = self._get_unmap_kwargs()
-
- template_fields = set(self.template_fields)
-
- # Ideally we'd like to pass in session as an argument to this
function, but since operators _could_
- # override this we can't easily change this function signature.
- # We can't use @provide_session, as that closes and expunges
everything, which we don't want to do
- # when we are so "deep" in the weeds here.
- #
- # Nor do we want to close the session -- that would expunge all the
things from the internal cache
- # which we don't want to do either
- session = settings.Session()
- self._resolve_expansion_kwargs(kwargs, template_fields, context,
session)
-
- unmapped_task = self.unmap(unmap_kwargs=kwargs)
- self._do_render_template_fields(
- parent=unmapped_task,
- template_fields=template_fields,
- context=context,
- jinja_env=jinja_env,
- seen_oids=set(),
- session=session,
- )
- return unmapped_task
-
- def _resolve_expansion_kwargs(
- self, kwargs: Dict[str, Any], template_fields: Set[str], context:
Context, session: Session
- ) -> None:
- """Update mapped fields in place in kwargs dict"""
- from airflow.models.xcom_arg import XComArg
-
- expansion_kwargs = self._get_expansion_kwargs()
-
- for k, v in expansion_kwargs.items():
- if isinstance(v, XComArg):
- v = v.resolve(context, session=session)
- v = self._expand_mapped_field(k, v, context, session=session)
- template_fields.discard(k)
- kwargs[k] = v
-
- def _expand_mapped_field(self, key: str, value: Any, context: Context, *,
session: Session) -> Any:
- map_index = context["ti"].map_index
- if map_index < 0:
- return value
- expansion_kwargs = self._get_expansion_kwargs()
- all_lengths = self._resolve_map_lengths(context["run_id"],
session=session)
-
- def _find_index_for_this_field(index: int) -> int:
- # Need to use self.mapped_kwargs for the original argument order.
- for mapped_key in reversed(list(expansion_kwargs)):
- mapped_length = all_lengths[mapped_key]
- if mapped_length < 1:
- raise RuntimeError(f"cannot expand field mapped to length
{mapped_length!r}")
- if mapped_key == key:
- return index % mapped_length
- index //= mapped_length
- return -1
-
- found_index = _find_index_for_this_field(map_index)
- if found_index < 0:
- return value
- if isinstance(value, collections.abc.Sequence):
- return value[found_index]
- if not isinstance(value, dict):
- raise TypeError(f"can't map over value of type {type(value)}")
- for i, (k, v) in enumerate(value.items()):
- if i == found_index:
- return k, v
- raise IndexError(f"index {map_index} is over mapped length")
-
def iter_mapped_dependencies(self) -> Iterator["Operator"]:
"""Upstream dependencies that provide XComs used by this task for task
mapping."""
from airflow.models.xcom_arg import XComArg
- for ref in XComArg.iter_xcom_args(self._get_expansion_kwargs()):
+ for ref in XComArg.iter_xcom_args(self._get_specified_expand_input()):
yield ref.operator
@cached_property
def parse_time_mapped_ti_count(self) -> Optional[int]:
- """
- Number of mapped TaskInstances that can be created at DagRun create
time.
+ """Number of mapped TaskInstances that can be created at DagRun create
time.
- :return: None if non-literal mapped arg encountered, or else total
number of mapped TIs this task
- should have
+ :return: None if non-literal mapped arg encountered, or the total
+ number of mapped TIs this task should have.
"""
- total = 0
-
- for value in self._get_expansion_kwargs().values():
- if not isinstance(value, MAPPABLE_LITERAL_TYPES):
- # None literal type encountered, so give up
- return None
- if total == 0:
- total = len(value)
- else:
- total *= len(value)
- return total
+ return
self._get_specified_expand_input().get_parse_time_mapped_ti_count()
@cache
def run_time_mapped_ti_count(self, run_id: str, *, session: Session) ->
Optional[int]:
- """
- Number of mapped TaskInstances that can be created at run time, or
None if upstream tasks are not
- complete yet.
+ """Number of mapped TaskInstances that can be created at run time.
- :return: None if upstream tasks are not complete yet, or else total
number of mapped TIs this task
- should have
+ :return: None if upstream tasks are not complete yet, or the total
+ number of mapped TIs this task should have.
"""
- lengths = self._get_map_lengths(run_id, session=session)
- expansion_kwargs = self._get_expansion_kwargs()
-
- if not lengths or not expansion_kwargs:
+ try:
+ return
self._get_specified_expand_input().get_total_map_length(run_id, session=session)
+ except NotFullyPopulated:
return None
- total = 1
- for name in expansion_kwargs:
- val = lengths.get(name)
- if val is None:
- return None
- total *= val
+ def _get_template_fields_to_render(self, expanded: Iterable[str]) ->
Iterable[str]:
+ # Since the mapped kwargs are already resolved during unmapping,
+ # they must be removed from the list of templated fields to avoid
+ # being rendered again (which breaks escaping).
+ return set(self.template_fields).difference(expanded)
+
+ def render_template_fields(
+ self,
+ context: Context,
+ jinja_env: Optional["jinja2.Environment"] = None,
+ ) -> Optional["BaseOperator"]:
+ if not jinja_env:
+ jinja_env = self.get_template_env()
+
+ # Ideally we'd like to pass in session as an argument to this function,
+ # but we can't easily change this function signature since operators
+ # could override this. We can't use @provide_session since it closes
and
+ # expunges everything, which we don't want to do when we are so "deep"
+ # in the weeds here. We don't close this session for the same reason.
+ session = settings.Session()
- return total
+ mapped_kwargs = self._expand_mapped_kwargs((context, session))
+ unmapped_task = self.unmap(mapped_kwargs)
+ self._do_render_template_fields(
+ parent=unmapped_task,
+ template_fields=self._get_template_fields_to_render(mapped_kwargs),
+ context=context,
+ jinja_env=jinja_env,
+ seen_oids=set(),
+ session=session,
+ )
+ return unmapped_task
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c198c6f09d..470c6360f0 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -91,7 +91,6 @@ from airflow.exceptions import (
TaskDeferralError,
TaskDeferred,
UnmappableXComLengthPushed,
- UnmappableXComTypePushed,
XComForMappingNotPushed,
)
from airflow.models.base import Base, StringID
@@ -1854,7 +1853,14 @@ class TaskInstance(Base, LoggingMixin):
return tb or error.__traceback__
@provide_session
- def handle_failure(self, error, test_mode=None, context=None,
force_fail=False, session=None) -> None:
+ def handle_failure(
+ self,
+ error: Any,
+ test_mode: Optional[bool] = None,
+ context: Optional[Context] = None,
+ force_fail: bool = False,
+ session: Session = NEW_SESSION,
+ ) -> None:
"""Handle Failure for the TaskInstance"""
if test_mode is None:
test_mode = self.test_mode
@@ -1898,11 +1904,11 @@ class TaskInstance(Base, LoggingMixin):
# only mark task instance as FAILED if the next task instance
# try_number exceeds the max_tries ... or if force_fail is truthy
- task = None
+ task: Optional[BaseOperator] = None
try:
- task = self.task.unmap()
+ task = self.task.unmap((context, session))
except Exception:
- self.log.error("Unable to unmap task, can't determine if we need
to send an alert email or not")
+ self.log.error("Unable to unmap task to determine if we need to
send an alert email")
if force_fail or not self.is_eligible_to_retry():
self.state = State.FAILED
@@ -2135,7 +2141,7 @@ class TaskInstance(Base, LoggingMixin):
rendered_task_instance_fields =
RenderedTaskInstanceFields.get_templated_fields(self, session=session)
if rendered_task_instance_fields:
- self.task = self.task.unmap()
+ self.task = self.task.unmap(None)
for field_name, rendered_value in
rendered_task_instance_fields.items():
setattr(self.task, field_name, rendered_value)
return
@@ -2311,18 +2317,20 @@ class TaskInstance(Base, LoggingMixin):
self.log.debug("Task Duration set to %s", self.duration)
def _record_task_map_for_downstreams(self, task: "Operator", value: Any,
*, session: Session) -> None:
- # TODO: We don't push TaskMap for mapped task instances because it's
not
- # currently possible for a downstream to depend on one individual
mapped
- # task instance, only a task as a whole. This will change in AIP-42
- # Phase 2, and we'll need to further analyze the mapped task case.
- if next(task.iter_mapped_dependants(), None) is None:
+ validators = {m.validate_upstream_return_value for m in
task.iter_mapped_dependants()}
+ if not validators: # No mapped dependants, no need to validate.
return
if value is None:
raise XComForMappingNotPushed()
+ # TODO: We don't push TaskMap for mapped task instances because it's
not
+ # currently possible for a downstream to depend on one individual
mapped
+ # task instance. This will change when we implement task group mapping,
+ # and we'll need to further analyze the mapped task case.
if task.is_mapped:
return
- if not isinstance(value, collections.abc.Collection) or
isinstance(value, (bytes, str)):
- raise UnmappableXComTypePushed(value)
+ for validator in validators:
+ validator(value)
+ assert isinstance(value, collections.abc.Collection) # The validators
type-guard this.
task_map = TaskMap.from_task_instance_xcom(self, value)
max_map_length = conf.getint("core", "max_map_length", fallback=1024)
if task_map.length > max_map_length:
diff --git a/airflow/serialization/schema.json
b/airflow/serialization/schema.json
index 1550387eed..f9df99eb58 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -251,13 +251,13 @@
"doc_yaml": { "type": "string" },
"doc_rst": { "type": "string" },
"_is_mapped": { "const": true, "$comment": "only present when True" },
- "mapped_kwargs": { "type": "object" },
+ "expand_input": { "type": "object" },
"partial_kwargs": { "type": "object" }
},
"dependencies": {
- "mapped_kwargs": ["partial_kwargs", "_is_mapped"],
- "partial_kwargs": ["mapped_kwargs", "_is_mapped"],
- "_is_mapped": ["mapped_kwargs", "partial_kwargs"]
+ "expand_input": ["partial_kwargs", "_is_mapped"],
+ "partial_kwargs": ["expand_input", "_is_mapped"],
+ "_is_mapped": ["expand_input", "partial_kwargs"]
},
"additionalProperties": true
},
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index f4a4257f57..201d27b1c8 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -37,6 +37,7 @@ from airflow.models import Dataset
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.connection import Connection
from airflow.models.dag import DAG, create_timetable
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY, ExpandInput,
create_expand_input, get_map_type_key
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param, ParamsDict
@@ -99,8 +100,8 @@ def _get_default_mapped_partial() -> Dict[str, Any]:
don't need to store them.
"""
# Use the private _expand() method to avoid the empty kwargs check.
- default_partial_kwargs =
BaseOperator.partial(task_id="_")._expand().partial_kwargs
- return BaseSerialization._serialize(default_partial_kwargs)[Encoding.VAR]
+ default = BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY,
strict=False).partial_kwargs
+ return BaseSerialization._serialize(default)[Encoding.VAR]
def encode_relativedelta(var: relativedelta.relativedelta) -> Dict[str, Any]:
@@ -193,16 +194,37 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable:
class _XComRef(NamedTuple):
- """
- Used to store info needed to create XComArg when deserializing
MappedOperator.
+ """Used to store info needed to create XComArg.
- We can't turn it in to a XComArg until we've loaded _all_ the tasks, so
when deserializing an operator we
- need to create _something_, and then post-process it in deserialize_dag
+ We can't turn it in to a XComArg until we've loaded _all_ the tasks, so
when
+ deserializing an operator, we need to create something in its place, and
+ post-process it in ``deserialize_dag``.
"""
task_id: str
key: str
+ def deref(self, dag: DAG) -> XComArg:
+ return XComArg(operator=dag.get_task(self.task_id), key=self.key)
+
+
+class _ExpandInputRef(NamedTuple):
+ """Used to store info needed to create a mapped operator's expand input.
+
+ This references a ``ExpandInput`` type, but replaces ``XComArg`` objects
+ with ``_XComRef`` (see documentation on the latter type for reasoning).
+ """
+
+ key: str
+ value: Union[_XComRef, Dict[str, Any]]
+
+ def deref(self, dag: DAG) -> ExpandInput:
+ if isinstance(self.value, _XComRef):
+ value: Any = self.value.deref(dag)
+ else:
+ value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k,
v in self.value.items()}
+ return create_expand_input(self.key, value)
+
class BaseSerialization:
"""BaseSerialization provides utils for serialization."""
@@ -598,8 +620,12 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
serialized_op = cls._serialize_node(op, include_deps=op.deps is
MappedOperator.deps_for(BaseOperator))
- # Handle mapped_kwargs and mapped_op_kwargs.
- serialized_op[op._expansion_kwargs_attr] =
cls._serialize(op._get_expansion_kwargs())
+ # Handle expand_input and op_kwargs_expand_input.
+ expansion_kwargs = op._get_specified_expand_input()
+ serialized_op[op._expand_input_attr] = {
+ "type": get_map_type_key(expansion_kwargs),
+ "value": cls._serialize(expansion_kwargs.value),
+ }
# Simplify partial_kwargs by comparing it to the most barebone object.
# Remove all entries that are simply default values.
@@ -749,6 +775,8 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
v = cls._deserialize_params_dict(v)
elif k == "partial_kwargs":
v = {arg: cls._deserialize(value) for arg, value in v.items()}
+ elif k in {"expand_input", "op_kwargs_expand_input"}:
+ v = _ExpandInputRef(v["type"], cls._deserialize(v["value"]))
elif k in cls._decorated_fields or k not in
op.get_serialized_fields():
v = cls._deserialize(v)
elif k in ("_outlets", "_inlets"):
@@ -781,7 +809,7 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
op_data = {k: v for k, v in encoded_op.items() if k in
BaseOperator.get_serialized_fields()}
op = MappedOperator(
operator_class=op_data,
- mapped_kwargs={},
+ expand_input=EXPAND_INPUT_EMPTY,
partial_kwargs={},
task_id=encoded_op["task_id"],
params={},
@@ -799,7 +827,8 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
task_group=None,
start_date=None,
end_date=None,
- expansion_kwargs_attr=encoded_op["_expansion_kwargs_attr"],
+
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
+ expand_input_attr=encoded_op["_expand_input_attr"],
)
else:
op = SerializedBaseOperator(task_id=encoded_op['task_id'])
@@ -1078,13 +1107,11 @@ class SerializedDAG(DAG, BaseSerialization):
if task.subdag is not None:
setattr(task.subdag, 'parent_dag', dag)
- if isinstance(task, MappedOperator):
- expansion_kwargs = task._get_expansion_kwargs()
- for k, v in expansion_kwargs.items():
- if not isinstance(v, _XComRef):
- continue
-
- expansion_kwargs[k] =
XComArg(operator=dag.get_task(v.task_id), key=v.key)
+ # Dereference expand_input and op_kwargs_expand_input.
+ for k in ("expand_input", "op_kwargs_expand_input"):
+ kwargs_ref = getattr(task, k, None)
+ if isinstance(kwargs_ref, _ExpandInputRef):
+ setattr(task, k, kwargs_ref.deref(dag))
for task_id in task.downstream_task_ids:
# Bypass set_upstream etc here - it does more than we want
diff --git a/airflow/www/views.py b/airflow/www/views.py
index c630d47af8..906d73e943 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1372,7 +1372,7 @@ class Airflow(AirflowBaseView):
# only matters if get_rendered_template_fields() raised an exception.
# The following rendering won't show useful values in this case anyway,
# but we'll display some quasi-meaingful field names.
- task = ti.task.unmap()
+ task = ti.task.unmap(None)
title = "Rendered Template"
html_dict = {}
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index e6779724bf..9f7be831d9 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1514,6 +1514,7 @@ unittests
unix
unmappable
unmapped
+unmapping
unpause
unpaused
unpausing
diff --git
a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
index c8c07c9e80..909693de5a 100755
--- a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
+++ b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
@@ -50,8 +50,9 @@ IGNORED = {
"partial",
"shallow_copy_attrs",
# Only on MappedOperator.
- "mapped_kwargs",
+ "expand_input",
"partial_kwargs",
+ "validate_upstream_return_value",
}
diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py
b/tests/api_connexion/endpoints/test_task_endpoint.py
index 7509a89032..4151a4b60f 100644
--- a/tests/api_connexion/endpoints/test_task_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_endpoint.py
@@ -22,6 +22,7 @@ import pytest
from airflow import DAG
from airflow.models import DagBag
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.security import permissions
@@ -70,7 +71,7 @@ class TestTaskEndpoint:
EmptyOperator(task_id=self.task_id3)
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its
presence.
- EmptyOperator.partial(task_id=self.mapped_task_id)._expand()
+
EmptyOperator.partial(task_id=self.mapped_task_id)._expand(EXPAND_INPUT_EMPTY,
strict=False)
task1 >> task2
dag_bag = DagBag(os.devnull, include_examples=False)
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 63514fcf25..58ae1c5f87 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -28,6 +28,7 @@ from airflow.decorators.base import DecoratedMappedOperator
from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator
+from airflow.models.expandinput import DictOfListsExpandInput
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
@@ -613,7 +614,7 @@ def test_mapped_decorator():
assert isinstance(t2, XComArg)
assert isinstance(t2.operator, DecoratedMappedOperator)
assert t2.operator.task_id == "print_everything"
- assert t2.operator.mapped_op_kwargs == {"any_key": [1, 2], "works": t1}
+ assert t2.operator.op_kwargs_expand_input ==
DictOfListsExpandInput({"any_key": [1, 2], "works": t1})
assert t0.operator.task_id == "print_info"
assert t1.operator.task_id == "print_info__1"
@@ -656,7 +657,7 @@ def test_partial_mapped_decorator() -> None:
assert isinstance(doubled, XComArg)
assert isinstance(doubled.operator, DecoratedMappedOperator)
- assert doubled.operator.mapped_op_kwargs == {"number": literal}
+ assert doubled.operator.op_kwargs_expand_input ==
DictOfListsExpandInput({"number": literal})
assert doubled.operator.partial_kwargs["op_kwargs"] == {"multiple": 2}
assert isinstance(trippled.operator, DecoratedMappedOperator) # For
type-checking on partial_kwargs.
@@ -678,7 +679,7 @@ def test_mapped_decorator_unmap_merge_op_kwargs():
task2.partial(arg1=1).expand(arg2=task1())
- unmapped = dag.get_task("task2").unmap()
+ unmapped = dag.get_task("task2").unmap(None)
assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
@@ -697,11 +698,11 @@ def test_mapped_decorator_converts_partial_kwargs():
mapped_task2 = dag.get_task("task2")
assert mapped_task2.partial_kwargs["retry_delay"] == timedelta(seconds=30)
- assert mapped_task2.unmap().retry_delay == timedelta(seconds=30)
+ assert mapped_task2.unmap(None).retry_delay == timedelta(seconds=30)
mapped_task1 = dag.get_task("task1")
assert mapped_task2.partial_kwargs["retry_delay"] == timedelta(seconds=30)
# Operator default.
- mapped_task1.unmap().retry_delay == timedelta(seconds=300) # Operator
default.
+ mapped_task1.unmap(None).retry_delay == timedelta(seconds=300) # Operator
default.
def test_mapped_render_template_fields(dag_maker, session):
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 397f519a50..17c339620b 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1165,7 +1165,7 @@ def
test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session):
]
-def
test_mapped_literal_length_increase_at_runtime_adds_additional_tis(dag_maker,
session):
+def test_mapped_length_increase_at_runtime_adds_additional_tis(dag_maker,
session):
"""Test that when the length of mapped literal increases at runtime,
additional ti is added"""
from airflow.models import Variable
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index c720fd96d9..92d6226097 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -102,9 +102,7 @@ def test_map_xcom_arg():
def test_partial_on_instance() -> None:
"""`.partial` on an instance should fail -- it's only designed to be
called on classes"""
with pytest.raises(TypeError):
- MockOperator(
- task_id='a',
- ).partial()
+ MockOperator(task_id='a').partial()
def test_partial_on_class() -> None:
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index c0868b0b7f..e17f34bd78 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -58,6 +58,7 @@ from airflow.models import (
XCom,
)
from airflow.models.dataset import Dataset, DatasetDagRunQueue, DatasetEvent,
DatasetTaskRef
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance
@@ -1058,7 +1059,7 @@ class TestTaskInstance:
with dag_maker(dag_id="test_xcom", session=session):
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its
presence.
- task_1 = EmptyOperator.partial(task_id="task_1")._expand()
+ task_1 =
EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY,
strict=False)
EmptyOperator(task_id="task_2")
dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6,
1, 0, 0, 0))
@@ -2477,9 +2478,9 @@ class TestTaskInstanceRecordTaskMapXComPush:
(None, XComForMappingNotPushed, "did not push XCom for task
mapping"),
],
)
- def test_error_if_unmappable_type(self, dag_maker, return_value,
exception_type, error_message):
- """If an unmappable return value is used to map, fail the task that
pushed the XCom."""
- with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
+ def test_expand_error_if_unmappable_type(self, dag_maker, return_value,
exception_type, error_message):
+ """If an unmappable return value is used for expand(), fail the task
that pushed the XCom."""
+ with dag_maker(dag_id="test_expand_error_if_unmappable_type") as dag:
@dag.task()
def push_something():
@@ -2802,7 +2803,7 @@ def
test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v
with dag_maker(dag_id="test_xcom", session=session):
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its presence.
- task_1 = EmptyOperator.partial(task_id="task_1")._expand()
+ task_1 =
EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY,
strict=False)
EmptyOperator(task_id="task_2")
dagrun = dag_maker.create_dagrun()
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index af5a701366..418cb4af89 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1674,7 +1674,7 @@ def test_kubernetes_optional():
module.SerializedDAG.to_dict(make_simple_dag()["simple_dag"])
-def test_mapped_operator_serde():
+def test_operator_expand_serde():
literal = [1, 2, {'a': 'b'}]
real_op = BashOperator.partial(task_id='a', executor_config={'dict':
{'sub': 'value'}}).expand(
bash_command=literal
@@ -1688,9 +1688,12 @@ def test_mapped_operator_serde():
'_task_module': 'airflow.operators.bash',
'_task_type': 'BashOperator',
'downstream_task_ids': [],
- 'mapped_kwargs': {
- "__type": "dict",
- "__var": {'bash_command': [1, 2, {"__type": "dict", "__var": {'a':
'b'}}]},
+ 'expand_input': {
+ "type": "dict-of-lists",
+ "value": {
+ "__type": "dict",
+ "__var": {'bash_command': [1, 2, {"__type": "dict", "__var":
{'a': 'b'}}]},
+ },
},
'partial_kwargs': {
'executor_config': {
@@ -1705,7 +1708,8 @@ def test_mapped_operator_serde():
'template_fields_renderers': {'bash_command': 'bash', 'env': 'json'},
'ui_color': '#f0ede4',
'ui_fgcolor': '#000',
- '_expansion_kwargs_attr': 'mapped_kwargs',
+ "_disallow_kwargs_override": False,
+ '_expand_input_attr': 'expand_input',
}
op = SerializedBaseOperator.deserialize_operator(serialized)
@@ -1722,11 +1726,11 @@ def test_mapped_operator_serde():
'ui_color': '#f0ede4',
'ui_fgcolor': '#000',
}
- assert op.mapped_kwargs['bash_command'] == literal
+ assert op.expand_input.value['bash_command'] == literal
assert op.partial_kwargs['executor_config'] == {'dict': {'sub': 'value'}}
-def test_mapped_operator_xcomarg_serde():
+def test_operator_expand_xcomarg_serde():
from airflow.models.xcom_arg import XComArg
with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
@@ -1740,9 +1744,12 @@ def test_mapped_operator_xcomarg_serde():
'_task_module': 'tests.test_utils.mock_operators',
'_task_type': 'MockOperator',
'downstream_task_ids': [],
- 'mapped_kwargs': {
- "__type": "dict",
- "__var": {'arg2': {'__type': 'xcomref', '__var': {'task_id':
'op1', 'key': 'return_value'}}},
+ 'expand_input': {
+ "type": "dict-of-lists",
+ "value": {
+ "__type": "dict",
+ "__var": {'arg2': {'__type': 'xcomref', '__var': {'task_id':
'op1', 'key': 'return_value'}}},
+ },
},
'partial_kwargs': {},
'task_id': 'task_2',
@@ -1752,31 +1759,32 @@ def test_mapped_operator_xcomarg_serde():
'operator_extra_links': [],
'ui_color': '#fff',
'ui_fgcolor': '#000',
- '_expansion_kwargs_attr': 'mapped_kwargs',
+ "_disallow_kwargs_override": False,
+ '_expand_input_attr': 'expand_input',
}
op = SerializedBaseOperator.deserialize_operator(serialized)
assert op.deps is MappedOperator.deps_for(BaseOperator)
- arg = op.mapped_kwargs['arg2']
+ arg = op.expand_input.value['arg2']
assert arg.task_id == 'op1'
assert arg.key == XCOM_RETURN_KEY
serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
- xcom_arg = serialized_dag.task_dict['task_2'].mapped_kwargs['arg2']
+ xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value['arg2']
assert isinstance(xcom_arg, XComArg)
assert xcom_arg.operator is serialized_dag.task_dict['op1']
-def test_mapped_operator_deserialized_unmap():
+def test_operator_expand_deserialized_unmap():
"""Unmap a deserialized mapped operator should be similar to deserializing
an non-mapped operator."""
normal = BashOperator(task_id='a', bash_command=[1, 2],
executor_config={"a": "b"})
mapped = BashOperator.partial(task_id='a', executor_config={"a":
"b"}).expand(bash_command=[1, 2])
serialize = SerializedBaseOperator._serialize
deserialize = SerializedBaseOperator.deserialize_operator
- assert deserialize(serialize(mapped)).unmap() ==
deserialize(serialize(normal))
+ assert deserialize(serialize(mapped)).unmap(None) ==
deserialize(serialize(normal))
def test_task_resources_serde():
@@ -1799,10 +1807,10 @@ def test_task_resources_serde():
}
-def test_mapped_decorator_serde():
+def test_taskflow_expand_serde():
from airflow.decorators import task
from airflow.models.xcom_arg import XComArg
- from airflow.serialization.serialized_objects import _XComRef
+ from airflow.serialization.serialized_objects import _ExpandInputRef,
_XComRef
with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
op1 = BaseOperator(task_id="op1")
@@ -1830,11 +1838,14 @@ def test_mapped_decorator_serde():
},
'retry_delay': {'__type': 'timedelta', '__var': 30.0},
},
- 'mapped_op_kwargs': {
- "__type": "dict",
- "__var": {
- 'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}},
- 'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1',
'key': 'return_value'}},
+ 'op_kwargs_expand_input': {
+ "type": "dict-of-lists",
+ "value": {
+ "__type": "dict",
+ "__var": {
+ 'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}},
+ 'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1',
'key': 'return_value'}},
+ },
},
},
'operator_extra_links': [],
@@ -1844,7 +1855,8 @@ def test_mapped_decorator_serde():
'template_ext': [],
'template_fields': ['op_args', 'op_kwargs'],
'template_fields_renderers': {"op_args": "py", "op_kwargs": "py"},
- '_expansion_kwargs_attr': 'mapped_op_kwargs',
+ "_disallow_kwargs_override": False,
+ '_expand_input_attr': 'op_kwargs_expand_input',
}
deserialized = SerializedBaseOperator.deserialize_operator(serialized)
@@ -1853,10 +1865,10 @@ def test_mapped_decorator_serde():
assert deserialized.upstream_task_ids == set()
assert deserialized.downstream_task_ids == set()
- assert deserialized.mapped_op_kwargs == {
- "arg2": {"a": 1, "b": 2},
- "arg3": _XComRef("op1", XCOM_RETURN_KEY),
- }
+ assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
+ key="dict-of-lists",
+ value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1",
XCOM_RETURN_KEY)},
+ )
assert deserialized.partial_kwargs == {
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
@@ -1868,10 +1880,10 @@ def test_mapped_decorator_serde():
# here so we don't need to duplicate tests between pickled and non-pickled
# DAGs everywhere else.
pickled = pickle.loads(pickle.dumps(deserialized))
- assert pickled.mapped_op_kwargs == {
- "arg2": {"a": 1, "b": 2},
- "arg3": _XComRef("op1", XCOM_RETURN_KEY),
- }
+ assert pickled.op_kwargs_expand_input == _ExpandInputRef(
+ key="dict-of-lists",
+ value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1",
XCOM_RETURN_KEY)},
+ )
assert pickled.partial_kwargs == {
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},