This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c7a4acbb1 Move mapped kwargs introspection to separate type (#24971)
1c7a4acbb1 is described below

commit 1c7a4acbb1fe7882247b3329de87002ff938657d
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Jul 12 15:51:47 2022 +0800

    Move mapped kwargs introspection to separate type (#24971)
---
 airflow/decorators/base.py                         | 130 ++++---
 airflow/models/baseoperator.py                     |   2 +-
 airflow/models/expandinput.py                      | 201 +++++++++++
 airflow/models/mappedoperator.py                   | 380 ++++++++-------------
 airflow/models/taskinstance.py                     |  34 +-
 airflow/serialization/schema.json                  |   8 +-
 airflow/serialization/serialized_objects.py        |  61 +++-
 airflow/www/views.py                               |   2 +-
 docs/spelling_wordlist.txt                         |   1 +
 .../pre_commit_base_operator_partial_arguments.py  |   3 +-
 .../api_connexion/endpoints/test_task_endpoint.py  |   3 +-
 tests/decorators/test_python.py                    |  11 +-
 tests/models/test_dagrun.py                        |   2 +-
 tests/models/test_mappedoperator.py                |   4 +-
 tests/models/test_taskinstance.py                  |  11 +-
 tests/serialization/test_dag_serialization.py      |  74 ++--
 16 files changed, 541 insertions(+), 386 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 2a2ce2da96..9598571240 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -25,11 +25,13 @@ from typing import (
     Collection,
     Dict,
     Generic,
+    Iterable,
     Iterator,
     Mapping,
     Optional,
     Sequence,
     Set,
+    Tuple,
     Type,
     TypeVar,
     cast,
@@ -38,6 +40,7 @@ from typing import (
 
 import attr
 import typing_extensions
+from sqlalchemy.orm import Session
 
 from airflow.compat.functools import cached_property
 from airflow.exceptions import AirflowException
@@ -50,6 +53,7 @@ from airflow.models.baseoperator import (
     parse_retries,
 )
 from airflow.models.dag import DAG, DagContext
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY, 
DictOfListsExpandInput, ExpandInput
 from airflow.models.mappedoperator import (
     MappedOperator,
     ValidationSource,
@@ -67,7 +71,6 @@ from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
     import jinja2  # Slow import.
-    from sqlalchemy.orm import Session
 
     from airflow.models.mappedoperator import Mappable
 
@@ -314,10 +317,14 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
     def expand(self, **map_kwargs: "Mappable") -> XComArg:
         if not map_kwargs:
             raise TypeError("no arguments to expand against")
-
         self._validate_arg_names("expand", map_kwargs)
         prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping 
already partial")
-        ensure_xcomarg_return_value(map_kwargs)
+        # Since the input is already checked at parse time, we can set strict
+        # to False to skip the checks on execution.
+        return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
+
+    def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
+        ensure_xcomarg_return_value(expand_input.value)
 
         task_kwargs = self.kwargs.copy()
         dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
@@ -363,7 +370,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
         _MappedOperator = cast(Any, DecoratedMappedOperator)
         operator = _MappedOperator(
             operator_class=self.operator_class,
-            mapped_kwargs={},
+            expand_input=EXPAND_INPUT_EMPTY,  # Don't use this; mapped values 
go to op_kwargs_expand_input.
             partial_kwargs=partial_kwargs,
             task_id=task_id,
             params=params,
@@ -383,36 +390,26 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
             end_date=end_date,
             multiple_outputs=self.multiple_outputs,
             python_callable=self.function,
-            mapped_op_kwargs=map_kwargs,
+            op_kwargs_expand_input=expand_input,
+            disallow_kwargs_override=strict,
             # Different from classic operators, kwargs passed to a taskflow
             # task's expand() contribute to the op_kwargs operator argument, 
not
             # the operator arguments themselves, and should expand against it.
-            expansion_kwargs_attr="mapped_op_kwargs",
+            expand_input_attr="op_kwargs_expand_input",
         )
         return XComArg(operator=operator)
 
-    def partial(self, **kwargs) -> "_TaskDecorator[Function, 
OperatorSubclass]":
+    def partial(self, **kwargs: Any) -> "_TaskDecorator[Function, 
OperatorSubclass]":
         self._validate_arg_names("partial", kwargs)
+        old_kwargs = self.kwargs.get("op_kwargs", {})
+        prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
+        kwargs.update(old_kwargs)
+        return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
 
-        op_kwargs = self.kwargs.get("op_kwargs", {})
-        op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate 
partial")
-
-        return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": 
op_kwargs})
-
-    def override(self, **kwargs) -> "_TaskDecorator[Function, 
OperatorSubclass]":
+    def override(self, **kwargs: Any) -> "_TaskDecorator[Function, 
OperatorSubclass]":
         return attr.evolve(self, kwargs={**self.kwargs, **kwargs})
 
 
-def _merge_kwargs(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, 
fail_reason: str) -> Dict[str, Any]:
-    duplicated_keys = set(kwargs1).intersection(kwargs2)
-    if len(duplicated_keys) == 1:
-        raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
-    elif duplicated_keys:
-        duplicated_keys_display = ", ".join(sorted(duplicated_keys))
-        raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
-    return {**kwargs1, **kwargs2}
-
-
 @attr.define(kw_only=True, repr=False)
 class DecoratedMappedOperator(MappedOperator):
     """MappedOperator implementation for @task-decorated task function."""
@@ -420,9 +417,9 @@ class DecoratedMappedOperator(MappedOperator):
     multiple_outputs: bool
     python_callable: Callable
 
-    # We can't save these in mapped_kwargs because op_kwargs need to be present
+    # We can't save these in expand_input because op_kwargs need to be present
     # in partial_kwargs, and MappedOperator prevents duplication.
-    mapped_op_kwargs: Dict[str, "Mappable"]
+    op_kwargs_expand_input: ExpandInput
 
     def __hash__(self):
         return id(self)
@@ -431,40 +428,38 @@ class DecoratedMappedOperator(MappedOperator):
         # The magic super() doesn't work here, so we use the explicit form.
         # Not using super(..., self) to work around pyupgrade bug.
         super(DecoratedMappedOperator, 
DecoratedMappedOperator).__attrs_post_init__(self)
-        XComArg.apply_upstream_relationship(self, self.mapped_op_kwargs)
-
-    def _get_unmap_kwargs(self) -> Dict[str, Any]:
-        partial_kwargs = self.partial_kwargs.copy()
-        op_kwargs = _merge_kwargs(
-            partial_kwargs.pop("op_kwargs"),
-            self.mapped_op_kwargs,
-            fail_reason="mapping already partial",
-        )
-        self._combined_op_kwargs = op_kwargs
-        return {
-            "dag": self.dag,
-            "task_group": self.task_group,
-            "task_id": self.task_id,
-            "op_kwargs": op_kwargs,
+        XComArg.apply_upstream_relationship(self, 
self.op_kwargs_expand_input.value)
+
+    def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context, 
Session]]) -> Dict[str, Any]:
+        # We only use op_kwargs_expand_input so this must always be empty.
+        assert self.expand_input is EXPAND_INPUT_EMPTY
+        return {"op_kwargs": super()._expand_mapped_kwargs(resolve)}
+
+    def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: 
bool) -> Dict[str, Any]:
+        if strict:
+            prevent_duplicates(
+                self.partial_kwargs["op_kwargs"],
+                mapped_kwargs["op_kwargs"],
+                fail_reason="mapping already partial",
+            )
+        self._combined_op_kwargs = {**self.partial_kwargs["op_kwargs"], 
**mapped_kwargs["op_kwargs"]}
+        self._already_resolved_op_kwargs = {
+            k for k, v in self.op_kwargs_expand_input.value.items() if 
isinstance(v, XComArg)
+        }
+        kwargs = {
             "multiple_outputs": self.multiple_outputs,
             "python_callable": self.python_callable,
-            **partial_kwargs,
-            **self.mapped_kwargs,
+            "op_kwargs": self._combined_op_kwargs,
         }
+        return super()._get_unmap_kwargs(kwargs, strict=False)
 
-    def _resolve_expansion_kwargs(
-        self, kwargs: Dict[str, Any], template_fields: Set[str], context: 
Context, session: "Session"
-    ) -> None:
-        expansion_kwargs = self._get_expansion_kwargs()
-
-        self._already_resolved_op_kwargs = set()
-        for k, v in expansion_kwargs.items():
-            if isinstance(v, XComArg):
-                self._already_resolved_op_kwargs.add(k)
-                v = v.resolve(context, session=session)
-            v = self._expand_mapped_field(k, v, context, session=session)
-            kwargs['op_kwargs'][k] = v
-            template_fields.discard(k)
+    def _get_template_fields_to_render(self, expanded: Iterable[str]) -> 
Iterable[str]:
+        # Different from a regular MappedOperator, we still want to render
+        # some fields in op_kwargs (those that are NOT passed as XComArg from
+        # upstream). Already-rendered op_kwargs keys are detected in a 
different
+        # way (see render_template below and _get_unmap_kwargs above).
+        assert list(expanded) == ["op_kwargs"]
+        return self.template_fields
 
     def render_template(
         self,
@@ -473,17 +468,20 @@ class DecoratedMappedOperator(MappedOperator):
         jinja_env: Optional["jinja2.Environment"] = None,
         seen_oids: Optional[Set] = None,
     ) -> Any:
-        if hasattr(self, '_combined_op_kwargs') and value is 
self._combined_op_kwargs:
-            # Avoid rendering values that came out of resolved XComArgs
-            return {
-                k: v
-                if k in self._already_resolved_op_kwargs
-                else super(DecoratedMappedOperator, 
DecoratedMappedOperator).render_template(
-                    self, v, context, jinja_env=jinja_env, seen_oids=seen_oids
-                )
-                for k, v in value.items()
-            }
-        return super().render_template(value, context, jinja_env=jinja_env, 
seen_oids=seen_oids)
+        if value is not getattr(self, "_combined_op_kwargs", object()):
+            return super().render_template(value, context, 
jinja_env=jinja_env, seen_oids=seen_oids)
+
+        def _render_if_not_already_resolved(key: str, value: Any):
+            if key in self._already_resolved_op_kwargs:
+                return value
+            # The magic super() doesn't work here, so we use the explicit form.
+            # Not using super(..., self) to work around pyupgrade bug.
+            return super(DecoratedMappedOperator, 
DecoratedMappedOperator).render_template(
+                self, value, context=context, jinja_env=jinja_env, 
seen_oids=seen_oids
+            )
+
+        # Avoid rendering values that came out of resolved XComArgs.
+        return {k: _render_if_not_already_resolved(k, v) for k, v in 
value.items()}
 
 
 class Task(Generic[Function]):
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index ab9543c127..05ef24a124 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1511,7 +1511,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         if cls.mapped_arguments_validated_by_init:
             cls(**kwargs, _airflow_from_mapped=True)
 
-    def unmap(self) -> "BaseOperator":
+    def unmap(self, ctx: Union[None, Dict[str, Any], Tuple[Context, Session]]) 
-> "BaseOperator":
         """:meta private:"""
         return self
 
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
new file mode 100644
index 0000000000..86623b41a1
--- /dev/null
+++ b/airflow/models/expandinput.py
@@ -0,0 +1,201 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import collections
+import collections.abc
+import functools
+import operator
+from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session
+
+from airflow.exceptions import UnmappableXComTypePushed
+from airflow.utils.context import Context
+
+if TYPE_CHECKING:
+    from airflow.models.xcom_arg import XComArg
+
+# BaseOperator.expand() can be called on an XComArg, sequence, or dict (not any
+# mapping since we need the value to be ordered).
+Mappable = Union["XComArg", Sequence, dict]
+
+MAPPABLE_LITERAL_TYPES = (dict, list)
+
+
+class NotFullyPopulated(RuntimeError):
+    """Raise when ``get_map_lengths`` cannot populate all mapping metadata.
+    This is generally due to not all upstream tasks have finished when the
+    function is called.
+    """
+
+    def __init__(self, missing: set[str]) -> None:
+        self.missing = missing
+
+    def __str__(self) -> str:
+        keys = ", ".join(repr(k) for k in sorted(self.missing))
+        return f"Failed to populate all mapping metadata; missing: {keys}"
+
+
+class DictOfListsExpandInput(NamedTuple):
+    """Storage type of a mapped operator's mapped kwargs.
+
+    This is created from ``expand(**kwargs)``.
+    """
+
+    value: dict[str, Mappable]
+
+    @staticmethod
+    def validate_xcom(value: Any) -> None:
+        if not isinstance(value, collections.abc.Collection) or 
isinstance(value, (bytes, str)):
+            raise UnmappableXComTypePushed(value)
+
+    def get_parse_time_mapped_ti_count(self) -> int | None:
+        if not self.value:
+            return 0
+        literal_values = [len(v) for v in self.value.values() if isinstance(v, 
MAPPABLE_LITERAL_TYPES)]
+        if len(literal_values) != len(self.value):
+            return None  # None-literal type encountered, so give up.
+        return functools.reduce(operator.mul, literal_values, 1)
+
+    def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, 
int]:
+        """Return dict of argument name to map length.
+
+        If any arguments are not known right now (upstream task not finished),
+        they will not be present in the dict.
+        """
+        from airflow.models.taskmap import TaskMap
+        from airflow.models.xcom import XCom
+        from airflow.models.xcom_arg import XComArg
+
+        # Populate literal mapped arguments first.
+        map_lengths: dict[str, int] = collections.defaultdict(int)
+        map_lengths.update((k, len(v)) for k, v in self.value.items() if not 
isinstance(v, XComArg))
+
+        try:
+            dag_id = next(v.operator.dag_id for v in self.value.values() if 
isinstance(v, XComArg))
+        except StopIteration:  # All mapped arguments are literal. We're done.
+            return map_lengths
+
+        # Build a reverse mapping of what arguments each task contributes to.
+        mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
+        non_mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
+        for k, v in self.value.items():
+            if not isinstance(v, XComArg):
+                continue
+            assert v.operator.dag_id == dag_id
+            if v.operator.is_mapped:
+                mapped_dep_keys[v.operator.task_id].add(k)
+            else:
+                non_mapped_dep_keys[v.operator.task_id].add(k)
+            # TODO: It's not possible now, but in the future we may support
+            # depending on one single mapped task instance. When that happens,
+            # we need to further analyze the mapped case to contain only tasks
+            # we depend on "as a whole", and put those we only depend on
+            # individually to the non-mapped lookup.
+
+        # Collect lengths from unmapped upstreams.
+        taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
+            TaskMap.dag_id == dag_id,
+            TaskMap.run_id == run_id,
+            TaskMap.task_id.in_(non_mapped_dep_keys),
+            TaskMap.map_index < 0,
+        )
+        for task_id, length in taskmap_query:
+            for mapped_arg_name in non_mapped_dep_keys[task_id]:
+                map_lengths[mapped_arg_name] += length
+
+        # Collect lengths from mapped upstreams.
+        xcom_query = (
+            session.query(XCom.task_id, func.count(XCom.map_index))
+            .group_by(XCom.task_id)
+            .filter(
+                XCom.dag_id == dag_id,
+                XCom.run_id == run_id,
+                XCom.task_id.in_(mapped_dep_keys),
+                XCom.map_index >= 0,
+            )
+        )
+        for task_id, length in xcom_query:
+            for mapped_arg_name in mapped_dep_keys[task_id]:
+                map_lengths[mapped_arg_name] += length
+
+        if len(map_lengths) < len(self.value):
+            raise NotFullyPopulated(set(self.value).difference(map_lengths))
+        return map_lengths
+
+    def get_total_map_length(self, run_id: str, *, session: Session) -> int:
+        if not self.value:
+            return 0
+        lengths = self._get_map_lengths(run_id, session=session)
+        return functools.reduce(operator.mul, (lengths[name] for name in 
self.value), 1)
+
+    def _expand_mapped_field(self, key: str, value: Any, context: Context, *, 
session: Session) -> Any:
+        from airflow.models.xcom_arg import XComArg
+
+        if isinstance(value, XComArg):
+            value = value.resolve(context, session=session)
+        map_index = context["ti"].map_index
+        if map_index < 0:
+            raise RuntimeError("can't resolve task-mapping argument without 
expanding")
+        all_lengths = self._get_map_lengths(context["run_id"], session=session)
+
+        def _find_index_for_this_field(index: int) -> int:
+            # Need to use the original user input to retain argument order.
+            for mapped_key in reversed(list(self.value)):
+                mapped_length = all_lengths[mapped_key]
+                if mapped_length < 1:
+                    raise RuntimeError(f"cannot expand field mapped to length 
{mapped_length!r}")
+                if mapped_key == key:
+                    return index % mapped_length
+                index //= mapped_length
+            return -1
+
+        found_index = _find_index_for_this_field(map_index)
+        if found_index < 0:
+            return value
+        if isinstance(value, collections.abc.Sequence):
+            return value[found_index]
+        if not isinstance(value, dict):
+            raise TypeError(f"can't map over value of type {type(value)}")
+        for i, (k, v) in enumerate(value.items()):
+            if i == found_index:
+                return k, v
+        raise IndexError(f"index {map_index} is over mapped length")
+
+    def resolve(self, context: Context, session: Session) -> dict[str, Any]:
+        return {k: self._expand_mapped_field(k, v, context, session=session) 
for k, v in self.value.items()}
+
+
+ExpandInput = DictOfListsExpandInput
+
+EXPAND_INPUT_EMPTY = DictOfListsExpandInput({})  # Sentinel value.
+
+_EXPAND_INPUT_TYPES = {
+    "dict-of-lists": DictOfListsExpandInput,
+}
+
+
+def get_map_type_key(expand_input: ExpandInput) -> str:
+    return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == 
type(expand_input))
+
+
+def create_expand_input(kind: str, value: Any) -> ExpandInput:
+    return _EXPAND_INPUT_TYPES[kind](value)
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index e34b1501bc..9e75e9f4aa 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -19,12 +19,11 @@
 import collections
 import collections.abc
 import datetime
-import functools
-import operator
 import warnings
 from typing import (
     TYPE_CHECKING,
     Any,
+    Callable,
     ClassVar,
     Collection,
     Dict,
@@ -61,6 +60,13 @@ from airflow.models.abstractoperator import (
     AbstractOperator,
     TaskStateChangeCallback,
 )
+from airflow.models.expandinput import (
+    MAPPABLE_LITERAL_TYPES,
+    DictOfListsExpandInput,
+    ExpandInput,
+    Mappable,
+    NotFullyPopulated,
+)
 from airflow.models.pool import Pool
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
@@ -80,19 +86,11 @@ if TYPE_CHECKING:
     from airflow.models.dag import DAG
     from airflow.models.operator import Operator
     from airflow.models.taskinstance import TaskInstance
-    from airflow.models.xcom_arg import XComArg
     from airflow.utils.task_group import TaskGroup
 
-    # BaseOperator.expand() can be called on an XComArg, sequence, or dict (not
-    # any mapping since we need the value to be ordered).
-    Mappable = Union[XComArg, Sequence, dict]
-
 ValidationSource = Union[Literal["expand"], Literal["partial"]]
 
 
-MAPPABLE_LITERAL_TYPES = (dict, list)
-
-
 # For isinstance() check.
 @cache
 def get_mappable_types() -> Tuple[type, ...]:
@@ -194,16 +192,17 @@ class OperatorPartial:
     def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
         if not mapped_kwargs:
             raise TypeError("no arguments to expand against")
-        return self._expand(**mapped_kwargs)
-
-    def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
-        self._expand_called = True
+        validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
+        prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable 
or already specified")
+        # Since the input is already checked at parse time, we can set strict
+        # to False to skip the checks on execution.
+        return self._expand(DictOfListsExpandInput(mapped_kwargs), 
strict=False)
 
+    def _expand(self, expand_input: ExpandInput, *, strict: bool) -> 
"MappedOperator":
         from airflow.operators.empty import EmptyOperator
 
-        validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
-        prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable 
or already specified")
-        ensure_xcomarg_return_value(mapped_kwargs)
+        self._expand_called = True
+        ensure_xcomarg_return_value(expand_input.value)
 
         partial_kwargs = self.kwargs.copy()
         task_id = partial_kwargs.pop("task_id")
@@ -215,7 +214,7 @@ class OperatorPartial:
 
         op = MappedOperator(
             operator_class=self.operator_class,
-            mapped_kwargs=mapped_kwargs,
+            expand_input=expand_input,
             partial_kwargs=partial_kwargs,
             task_id=task_id,
             params=params,
@@ -233,9 +232,10 @@ class OperatorPartial:
             task_group=task_group,
             start_date=start_date,
             end_date=end_date,
-            # For classic operators, this points to mapped_kwargs because 
kwargs
+            disallow_kwargs_override=strict,
+            # For classic operators, this points to expand_input because kwargs
             # to BaseOperator.expand() contribute to operator arguments.
-            expansion_kwargs_attr="mapped_kwargs",
+            expand_input_attr="expand_input",
         )
         return op
 
@@ -261,7 +261,7 @@ class MappedOperator(AbstractOperator):
     # that can be used to unmap this into a SerializedBaseOperator.
     operator_class: Union[Type["BaseOperator"], Dict[str, Any]]
 
-    mapped_kwargs: Dict[str, "Mappable"]
+    expand_input: ExpandInput
     partial_kwargs: Dict[str, Any]
 
     # Needed for serialization.
@@ -285,7 +285,14 @@ class MappedOperator(AbstractOperator):
     upstream_task_ids: Set[str] = attr.ib(factory=set, init=False)
     downstream_task_ids: Set[str] = attr.ib(factory=set, init=False)
 
-    _expansion_kwargs_attr: str
+    _disallow_kwargs_override: bool
+    """Whether execution fails if ``expand_input`` has duplicates to 
``partial_kwargs``.
+
+    If *False*, values from ``expand_input`` under duplicate keys override 
those
+    under corresponding keys in ``partial_kwargs``.
+    """
+
+    _expand_input_attr: str
     """Where to get kwargs to calculate expansion length against.
 
     This should be a name to call ``getattr()`` on.
@@ -315,8 +322,7 @@ class MappedOperator(AbstractOperator):
             self.task_group.add(self)
         if self.dag:
             self.dag.add_task(self)
-        for k, v in self.mapped_kwargs.items():
-            XComArg.apply_upstream_relationship(self, v)
+        XComArg.apply_upstream_relationship(self, self.expand_input.value)
         for k, v in self.partial_kwargs.items():
             if k in self.template_fields:
                 XComArg.apply_upstream_relationship(self, v)
@@ -334,7 +340,7 @@ class MappedOperator(AbstractOperator):
             "dag",
             "deps",
             "is_mapped",
-            "mapped_kwargs",  # This is needed to be able to accept XComArg.
+            "expand_input",  # This is needed to be able to accept XComArg.
             "subdag",
             "task_group",
             "upstream_task_ids",
@@ -361,7 +367,9 @@ class MappedOperator(AbstractOperator):
         """
         if not isinstance(self.operator_class, type):
             return  # No need to validate deserialized operator.
-        
self.operator_class.validate_mapped_arguments(**self._get_unmap_kwargs())
+        kwargs = self._expand_mapped_kwargs(None)
+        kwargs = self._get_unmap_kwargs(kwargs, 
strict=self._disallow_kwargs_override)
+        self.operator_class.validate_mapped_arguments(**kwargs)
 
     @property
     def task_type(self) -> str:
@@ -520,7 +528,38 @@ class MappedOperator(AbstractOperator):
         """Implementing DAGNode."""
         return DagAttributeTypes.OP, self.task_id
 
-    def _get_unmap_kwargs(self) -> Dict[str, Any]:
+    def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context, 
Session]]) -> Dict[str, Any]:
+        """Get the kwargs to create the unmapped operator.
+
+        If *resolve* is not *None*, it must be a two-tuple to provide context 
to
+        resolve XComArgs (a templating context, and a database session).
+
+        When resolving is not possible (e.g. to perform parse-time validation),
+        *resolve* can be set to *None*. This will cause the dict-of-lists
+        variant to simply return a dict of XComArgs corresponding to each 
kwargs
+        to pass to the unmapped operator. Since it is impossible to perform any
+        operation on the list-of-dicts variant before execution time, an empty
+        dict will be returned for this case.
+        """
+        kwargs = self._get_specified_expand_input()
+        if resolve is not None:
+            return kwargs.resolve(*resolve)
+        if isinstance(kwargs, DictOfListsExpandInput):
+            return kwargs.value
+        return {}
+
+    def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: 
bool) -> Dict[str, Any]:
+        """Get init kwargs to unmap the underlying operator class.
+
+        :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
+        """
+        if strict:
+            prevent_duplicates(
+                self.partial_kwargs,
+                mapped_kwargs,
+                fail_reason="unmappable or already specified",
+            )
+        # Ordering is significant; mapped kwargs should override partial ones.
         return {
             "task_id": self.task_id,
             "dag": self.dag,
@@ -529,29 +568,35 @@ class MappedOperator(AbstractOperator):
             "start_date": self.start_date,
             "end_date": self.end_date,
             **self.partial_kwargs,
-            **self.mapped_kwargs,
+            **mapped_kwargs,
         }
 
-    def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) -> 
"BaseOperator":
-        """
-        Get the "normal" Operator after applying the current mapping.
+    def unmap(self, resolve: Union[None, Dict[str, Any], Tuple[Context, 
Session]]) -> "BaseOperator":
+        """Get the "normal" Operator after applying the current mapping.
 
-        If ``operator_class`` is not a class (i.e. this DAG has been 
deserialized) then this will return a
-        SerializedBaseOperator that aims to "look like" the real operator.
+        If ``operator_class`` is not a class (i.e. this DAG has been
+        deserialized), this returns a SerializedBaseOperator that aims to
+        "look like" the actual unmapping result.
 
-        :param unmap_kwargs: Override the args to pass to the Operator 
constructor. Only used when
-            ``operator_class`` is still an actual class.
+        :param resolve: Only used if ``operator_class`` is a real class. If 
this
+            is a two-tuple (context, session), the information is used to
+            resolve the mapped arguments into init arguments. If this is a
+            mapping, no resolving happens, the mapping directly provides those
+            init arguments resolved from mapped kwargs.
 
         :meta private:
         """
         if isinstance(self.operator_class, type):
-            # We can't simply specify task_id here because BaseOperator further
+            if isinstance(resolve, collections.abc.Mapping):
+                kwargs = resolve
+            else:
+                kwargs = self._expand_mapped_kwargs(resolve)
+            kwargs = self._get_unmap_kwargs(kwargs, 
strict=self._disallow_kwargs_override)
+            op = self.operator_class(**kwargs, _airflow_from_mapped=True)
+            # We need to overwrite task_id here because BaseOperator further
             # mangles the task_id based on the task hierarchy (namely, group_id
-            # is prepended, and '__N' appended to deduplicate). Instead of
-            # recreating the whole logic here, we just overwrite task_id later.
-            if unmap_kwargs is None:
-                unmap_kwargs = self._get_unmap_kwargs()
-            op = self.operator_class(**unmap_kwargs, _airflow_from_mapped=True)
+            # is prepended, and '__N' appended to deduplicate). This is hacky,
+            # but better than duplicating the whole mangling logic.
             op.task_id = self.task_id
             return op
 
@@ -565,80 +610,23 @@ class MappedOperator(AbstractOperator):
         SerializedBaseOperator.populate_operator(op, self.operator_class)
         return op
 
-    def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]:
-        """The kwargs to calculate expansion length against."""
-        return getattr(self, self._expansion_kwargs_attr)
-
-    def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, 
int]:
-        """Return dict of argument name to map length.
-
-        If any arguments are not known right now (upstream task not finished) 
they will not be present in the
-        dict.
-        """
-        # TODO: Find a way to cache this.
-        from airflow.models.taskmap import TaskMap
-        from airflow.models.xcom import XCom
-        from airflow.models.xcom_arg import XComArg
+    def _get_specified_expand_input(self) -> ExpandInput:
+        """Input received from the expand call on the operator."""
+        return getattr(self, self._expand_input_attr)
 
-        expansion_kwargs = self._get_expansion_kwargs()
-
-        # Populate literal mapped arguments first.
-        map_lengths: Dict[str, int] = collections.defaultdict(int)
-        map_lengths.update((k, len(v)) for k, v in expansion_kwargs.items() if 
not isinstance(v, XComArg))
-
-        # Build a reverse mapping of what arguments each task contributes to.
-        mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
-        non_mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
-        for k, v in expansion_kwargs.items():
-            if not isinstance(v, XComArg):
-                continue
-            if v.operator.is_mapped:
-                mapped_dep_keys[v.operator.task_id].add(k)
-            else:
-                non_mapped_dep_keys[v.operator.task_id].add(k)
-            # TODO: It's not possible now, but in the future (AIP-42 Phase 2)
-            # we will add support for depending on one single mapped task
-            # instance. When that happens, we need to further analyze the 
mapped
-            # case to contain only tasks we depend on "as a whole", and put
-            # those we only depend on individually to the non-mapped lookup.
-
-        # Collect lengths from unmapped upstreams.
-        taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
-            TaskMap.dag_id == self.dag_id,
-            TaskMap.run_id == run_id,
-            TaskMap.task_id.in_(non_mapped_dep_keys),
-            TaskMap.map_index < 0,
-        )
-        for task_id, length in taskmap_query:
-            for mapped_arg_name in non_mapped_dep_keys[task_id]:
-                map_lengths[mapped_arg_name] += length
-
-        # Collect lengths from mapped upstreams.
-        xcom_query = (
-            session.query(XCom.task_id, func.count(XCom.map_index))
-            .group_by(XCom.task_id)
-            .filter(
-                XCom.dag_id == self.dag_id,
-                XCom.run_id == run_id,
-                XCom.task_id.in_(mapped_dep_keys),
-                XCom.map_index >= 0,
-            )
-        )
-        for task_id, length in xcom_query:
-            for mapped_arg_name in mapped_dep_keys[task_id]:
-                map_lengths[mapped_arg_name] += length
-        return map_lengths
+    @property
+    def validate_upstream_return_value(self) -> Callable[[Any], None]:
+        """Validate an upstream's return value satisfies this task's needs.
 
-    @cache
-    def _resolve_map_lengths(self, run_id: str, *, session: Session) -> 
Dict[str, int]:
-        """Return dict of argument name to map length, or throw if some are 
not resolvable"""
-        expansion_kwargs = self._get_expansion_kwargs()
-        map_lengths = self._get_map_lengths(run_id, session=session)
-        if len(map_lengths) < len(expansion_kwargs):
-            keys = ", ".join(repr(k) for k in 
sorted(set(expansion_kwargs).difference(map_lengths)))
-            raise RuntimeError(f"Failed to populate all mapping metadata; 
missing: {keys}")
+        This is implemented as a property (instead of a function calling
+        ``validate_xcom``) so the call site in TaskInstance can de-duplicate
+        validation functions. If this is an instance method, each
+        ``validate_upstream_return_value`` would be a different object (due to
+        how Python handles bounded functions), and de-duplication won't work.
 
-        return map_lengths
+        :meta private:
+        """
+        return self._get_specified_expand_input().validate_xcom
 
     def expand_mapped_task(self, run_id: str, *, session: Session) -> 
Tuple[Sequence["TaskInstance"], int]:
         """Create the mapped task instances for mapped task.
@@ -649,9 +637,7 @@ class MappedOperator(AbstractOperator):
         from airflow.models.taskinstance import TaskInstance
         from airflow.settings import task_instance_mutation_hook
 
-        total_length = functools.reduce(
-            operator.mul, self._resolve_map_lengths(run_id, 
session=session).values()
-        )
+        total_length = 
self._get_specified_expand_input().get_total_map_length(run_id, session=session)
 
         state: Optional[TaskInstanceState] = None
         unmapped_ti: Optional[TaskInstance] = (
@@ -728,143 +714,63 @@ class MappedOperator(AbstractOperator):
         # we don't need to create a copy of the MappedOperator here.
         return self
 
-    def render_template_fields(
-        self,
-        context: Context,
-        jinja_env: Optional["jinja2.Environment"] = None,
-    ) -> Optional["BaseOperator"]:
-        """Template all attributes listed in template_fields.
-
-        Different from the BaseOperator implementation, this renders the
-        template fields on the *unmapped* BaseOperator.
-
-        :param context: Dict with values to apply on content
-        :param jinja_env: Jinja environment
-        :return: The unmapped, populated BaseOperator
-        """
-        if not jinja_env:
-            jinja_env = self.get_template_env()
-        # Before we unmap we have to resolve the mapped arguments, otherwise 
the real operator constructor
-        # could be called with an XComArg, rather than the value it resolves 
to.
-        #
-        # We also need to resolve _all_ mapped arguments, even if they aren't 
marked as templated
-        kwargs = self._get_unmap_kwargs()
-
-        template_fields = set(self.template_fields)
-
-        # Ideally we'd like to pass in session as an argument to this 
function, but since operators _could_
-        # override this we can't easily change this function signature.
-        # We can't use @provide_session, as that closes and expunges 
everything, which we don't want to do
-        # when we are so "deep" in the weeds here.
-        #
-        # Nor do we want to close the session -- that would expunge all the 
things from the internal cache
-        # which we don't want to do either
-        session = settings.Session()
-        self._resolve_expansion_kwargs(kwargs, template_fields, context, 
session)
-
-        unmapped_task = self.unmap(unmap_kwargs=kwargs)
-        self._do_render_template_fields(
-            parent=unmapped_task,
-            template_fields=template_fields,
-            context=context,
-            jinja_env=jinja_env,
-            seen_oids=set(),
-            session=session,
-        )
-        return unmapped_task
-
-    def _resolve_expansion_kwargs(
-        self, kwargs: Dict[str, Any], template_fields: Set[str], context: 
Context, session: Session
-    ) -> None:
-        """Update mapped fields in place in kwargs dict"""
-        from airflow.models.xcom_arg import XComArg
-
-        expansion_kwargs = self._get_expansion_kwargs()
-
-        for k, v in expansion_kwargs.items():
-            if isinstance(v, XComArg):
-                v = v.resolve(context, session=session)
-            v = self._expand_mapped_field(k, v, context, session=session)
-            template_fields.discard(k)
-            kwargs[k] = v
-
-    def _expand_mapped_field(self, key: str, value: Any, context: Context, *, 
session: Session) -> Any:
-        map_index = context["ti"].map_index
-        if map_index < 0:
-            return value
-        expansion_kwargs = self._get_expansion_kwargs()
-        all_lengths = self._resolve_map_lengths(context["run_id"], 
session=session)
-
-        def _find_index_for_this_field(index: int) -> int:
-            # Need to use self.mapped_kwargs for the original argument order.
-            for mapped_key in reversed(list(expansion_kwargs)):
-                mapped_length = all_lengths[mapped_key]
-                if mapped_length < 1:
-                    raise RuntimeError(f"cannot expand field mapped to length 
{mapped_length!r}")
-                if mapped_key == key:
-                    return index % mapped_length
-                index //= mapped_length
-            return -1
-
-        found_index = _find_index_for_this_field(map_index)
-        if found_index < 0:
-            return value
-        if isinstance(value, collections.abc.Sequence):
-            return value[found_index]
-        if not isinstance(value, dict):
-            raise TypeError(f"can't map over value of type {type(value)}")
-        for i, (k, v) in enumerate(value.items()):
-            if i == found_index:
-                return k, v
-        raise IndexError(f"index {map_index} is over mapped length")
-
     def iter_mapped_dependencies(self) -> Iterator["Operator"]:
         """Upstream dependencies that provide XComs used by this task for task 
mapping."""
         from airflow.models.xcom_arg import XComArg
 
-        for ref in XComArg.iter_xcom_args(self._get_expansion_kwargs()):
+        for ref in XComArg.iter_xcom_args(self._get_specified_expand_input()):
             yield ref.operator
 
     @cached_property
     def parse_time_mapped_ti_count(self) -> Optional[int]:
-        """
-        Number of mapped TaskInstances that can be created at DagRun create 
time.
+        """Number of mapped TaskInstances that can be created at DagRun create 
time.
 
-        :return: None if non-literal mapped arg encountered, or else total 
number of mapped TIs this task
-            should have
+        :return: None if non-literal mapped arg encountered, or the total
+            number of mapped TIs this task should have.
         """
-        total = 0
-
-        for value in self._get_expansion_kwargs().values():
-            if not isinstance(value, MAPPABLE_LITERAL_TYPES):
-                # None literal type encountered, so give up
-                return None
-            if total == 0:
-                total = len(value)
-            else:
-                total *= len(value)
-        return total
+        return 
self._get_specified_expand_input().get_parse_time_mapped_ti_count()
 
     @cache
     def run_time_mapped_ti_count(self, run_id: str, *, session: Session) -> 
Optional[int]:
-        """
-        Number of mapped TaskInstances that can be created at run time, or 
None if upstream tasks are not
-        complete yet.
+        """Number of mapped TaskInstances that can be created at run time.
 
-        :return: None if upstream tasks are not complete yet, or else total 
number of mapped TIs this task
-            should have
+        :return: None if upstream tasks are not complete yet, or the total
+            number of mapped TIs this task should have.
         """
-        lengths = self._get_map_lengths(run_id, session=session)
-        expansion_kwargs = self._get_expansion_kwargs()
-
-        if not lengths or not expansion_kwargs:
+        try:
+            return 
self._get_specified_expand_input().get_total_map_length(run_id, session=session)
+        except NotFullyPopulated:
             return None
 
-        total = 1
-        for name in expansion_kwargs:
-            val = lengths.get(name)
-            if val is None:
-                return None
-            total *= val
+    def _get_template_fields_to_render(self, expanded: Iterable[str]) -> 
Iterable[str]:
+        # Since the mapped kwargs are already resolved during unmapping,
+        # they must be removed from the list of templated fields to avoid
+        # being rendered again (which breaks escaping).
+        return set(self.template_fields).difference(expanded)
+
+    def render_template_fields(
+        self,
+        context: Context,
+        jinja_env: Optional["jinja2.Environment"] = None,
+    ) -> Optional["BaseOperator"]:
+        if not jinja_env:
+            jinja_env = self.get_template_env()
+
+        # Ideally we'd like to pass in session as an argument to this function,
+        # but we can't easily change this function signature since operators
+        # could override this. We can't use @provide_session since it closes 
and
+        # expunges everything, which we don't want to do when we are so "deep"
+        # in the weeds here. We don't close this session for the same reason.
+        session = settings.Session()
 
-        return total
+        mapped_kwargs = self._expand_mapped_kwargs((context, session))
+        unmapped_task = self.unmap(mapped_kwargs)
+        self._do_render_template_fields(
+            parent=unmapped_task,
+            template_fields=self._get_template_fields_to_render(mapped_kwargs),
+            context=context,
+            jinja_env=jinja_env,
+            seen_oids=set(),
+            session=session,
+        )
+        return unmapped_task
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c198c6f09d..470c6360f0 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -91,7 +91,6 @@ from airflow.exceptions import (
     TaskDeferralError,
     TaskDeferred,
     UnmappableXComLengthPushed,
-    UnmappableXComTypePushed,
     XComForMappingNotPushed,
 )
 from airflow.models.base import Base, StringID
@@ -1854,7 +1853,14 @@ class TaskInstance(Base, LoggingMixin):
         return tb or error.__traceback__
 
     @provide_session
-    def handle_failure(self, error, test_mode=None, context=None, 
force_fail=False, session=None) -> None:
+    def handle_failure(
+        self,
+        error: Any,
+        test_mode: Optional[bool] = None,
+        context: Optional[Context] = None,
+        force_fail: bool = False,
+        session: Session = NEW_SESSION,
+    ) -> None:
         """Handle Failure for the TaskInstance"""
         if test_mode is None:
             test_mode = self.test_mode
@@ -1898,11 +1904,11 @@ class TaskInstance(Base, LoggingMixin):
         # only mark task instance as FAILED if the next task instance
         # try_number exceeds the max_tries ... or if force_fail is truthy
 
-        task = None
+        task: Optional[BaseOperator] = None
         try:
-            task = self.task.unmap()
+            task = self.task.unmap((context, session))
         except Exception:
-            self.log.error("Unable to unmap task, can't determine if we need 
to send an alert email or not")
+            self.log.error("Unable to unmap task to determine if we need to 
send an alert email")
 
         if force_fail or not self.is_eligible_to_retry():
             self.state = State.FAILED
@@ -2135,7 +2141,7 @@ class TaskInstance(Base, LoggingMixin):
 
         rendered_task_instance_fields = 
RenderedTaskInstanceFields.get_templated_fields(self, session=session)
         if rendered_task_instance_fields:
-            self.task = self.task.unmap()
+            self.task = self.task.unmap(None)
             for field_name, rendered_value in 
rendered_task_instance_fields.items():
                 setattr(self.task, field_name, rendered_value)
             return
@@ -2311,18 +2317,20 @@ class TaskInstance(Base, LoggingMixin):
         self.log.debug("Task Duration set to %s", self.duration)
 
     def _record_task_map_for_downstreams(self, task: "Operator", value: Any, 
*, session: Session) -> None:
-        # TODO: We don't push TaskMap for mapped task instances because it's 
not
-        # currently possible for a downstream to depend on one individual 
mapped
-        # task instance, only a task as a whole. This will change in AIP-42
-        # Phase 2, and we'll need to further analyze the mapped task case.
-        if next(task.iter_mapped_dependants(), None) is None:
+        validators = {m.validate_upstream_return_value for m in 
task.iter_mapped_dependants()}
+        if not validators:  # No mapped dependants, no need to validate.
             return
         if value is None:
             raise XComForMappingNotPushed()
+        # TODO: We don't push TaskMap for mapped task instances because it's 
not
+        # currently possible for a downstream to depend on one individual 
mapped
+        # task instance. This will change when we implement task group mapping,
+        # and we'll need to further analyze the mapped task case.
         if task.is_mapped:
             return
-        if not isinstance(value, collections.abc.Collection) or 
isinstance(value, (bytes, str)):
-            raise UnmappableXComTypePushed(value)
+        for validator in validators:
+            validator(value)
+        assert isinstance(value, collections.abc.Collection)  # The validators 
type-guard this.
         task_map = TaskMap.from_task_instance_xcom(self, value)
         max_map_length = conf.getint("core", "max_map_length", fallback=1024)
         if task_map.length > max_map_length:
diff --git a/airflow/serialization/schema.json 
b/airflow/serialization/schema.json
index 1550387eed..f9df99eb58 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -251,13 +251,13 @@
         "doc_yaml":  { "type": "string" },
         "doc_rst":  { "type": "string" },
         "_is_mapped": { "const": true, "$comment": "only present when True" },
-        "mapped_kwargs": { "type": "object" },
+        "expand_input": { "type": "object" },
         "partial_kwargs": { "type": "object" }
       },
       "dependencies": {
-        "mapped_kwargs": ["partial_kwargs", "_is_mapped"],
-        "partial_kwargs": ["mapped_kwargs", "_is_mapped"],
-        "_is_mapped": ["mapped_kwargs", "partial_kwargs"]
+        "expand_input": ["partial_kwargs", "_is_mapped"],
+        "partial_kwargs": ["expand_input", "_is_mapped"],
+        "_is_mapped": ["expand_input", "partial_kwargs"]
       },
       "additionalProperties": true
     },
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index f4a4257f57..201d27b1c8 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -37,6 +37,7 @@ from airflow.models import Dataset
 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
 from airflow.models.connection import Connection
 from airflow.models.dag import DAG, create_timetable
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY, ExpandInput, 
create_expand_input, get_map_type_key
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.operator import Operator
 from airflow.models.param import Param, ParamsDict
@@ -99,8 +100,8 @@ def _get_default_mapped_partial() -> Dict[str, Any]:
     don't need to store them.
     """
     # Use the private _expand() method to avoid the empty kwargs check.
-    default_partial_kwargs = 
BaseOperator.partial(task_id="_")._expand().partial_kwargs
-    return BaseSerialization._serialize(default_partial_kwargs)[Encoding.VAR]
+    default = BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, 
strict=False).partial_kwargs
+    return BaseSerialization._serialize(default)[Encoding.VAR]
 
 
 def encode_relativedelta(var: relativedelta.relativedelta) -> Dict[str, Any]:
@@ -193,16 +194,37 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable:
 
 
 class _XComRef(NamedTuple):
-    """
-    Used to store info needed to create XComArg when deserializing 
MappedOperator.
+    """Used to store info needed to create XComArg.
 
-    We can't turn it in to a XComArg until we've loaded _all_ the tasks, so 
when deserializing an operator we
-    need to create _something_, and then post-process it in deserialize_dag
+    We can't turn it in to a XComArg until we've loaded _all_ the tasks, so 
when
+    deserializing an operator, we need to create something in its place, and
+    post-process it in ``deserialize_dag``.
     """
 
     task_id: str
     key: str
 
+    def deref(self, dag: DAG) -> XComArg:
+        return XComArg(operator=dag.get_task(self.task_id), key=self.key)
+
+
+class _ExpandInputRef(NamedTuple):
+    """Used to store info needed to create a mapped operator's expand input.
+
+    This references a ``ExpandInput`` type, but replaces ``XComArg`` objects
+    with ``_XComRef`` (see documentation on the latter type for reasoning).
+    """
+
+    key: str
+    value: Union[_XComRef, Dict[str, Any]]
+
+    def deref(self, dag: DAG) -> ExpandInput:
+        if isinstance(self.value, _XComRef):
+            value: Any = self.value.deref(dag)
+        else:
+            value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, 
v in self.value.items()}
+        return create_expand_input(self.key, value)
+
 
 class BaseSerialization:
     """BaseSerialization provides utils for serialization."""
@@ -598,8 +620,12 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
     def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
         serialized_op = cls._serialize_node(op, include_deps=op.deps is 
MappedOperator.deps_for(BaseOperator))
 
-        # Handle mapped_kwargs and mapped_op_kwargs.
-        serialized_op[op._expansion_kwargs_attr] = 
cls._serialize(op._get_expansion_kwargs())
+        # Handle expand_input and op_kwargs_expand_input.
+        expansion_kwargs = op._get_specified_expand_input()
+        serialized_op[op._expand_input_attr] = {
+            "type": get_map_type_key(expansion_kwargs),
+            "value": cls._serialize(expansion_kwargs.value),
+        }
 
         # Simplify partial_kwargs by comparing it to the most barebone object.
         # Remove all entries that are simply default values.
@@ -749,6 +775,8 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
                 v = cls._deserialize_params_dict(v)
             elif k == "partial_kwargs":
                 v = {arg: cls._deserialize(value) for arg, value in v.items()}
+            elif k in {"expand_input", "op_kwargs_expand_input"}:
+                v = _ExpandInputRef(v["type"], cls._deserialize(v["value"]))
             elif k in cls._decorated_fields or k not in 
op.get_serialized_fields():
                 v = cls._deserialize(v)
             elif k in ("_outlets", "_inlets"):
@@ -781,7 +809,7 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
             op_data = {k: v for k, v in encoded_op.items() if k in 
BaseOperator.get_serialized_fields()}
             op = MappedOperator(
                 operator_class=op_data,
-                mapped_kwargs={},
+                expand_input=EXPAND_INPUT_EMPTY,
                 partial_kwargs={},
                 task_id=encoded_op["task_id"],
                 params={},
@@ -799,7 +827,8 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
                 task_group=None,
                 start_date=None,
                 end_date=None,
-                expansion_kwargs_attr=encoded_op["_expansion_kwargs_attr"],
+                
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
+                expand_input_attr=encoded_op["_expand_input_attr"],
             )
         else:
             op = SerializedBaseOperator(task_id=encoded_op['task_id'])
@@ -1078,13 +1107,11 @@ class SerializedDAG(DAG, BaseSerialization):
             if task.subdag is not None:
                 setattr(task.subdag, 'parent_dag', dag)
 
-            if isinstance(task, MappedOperator):
-                expansion_kwargs = task._get_expansion_kwargs()
-                for k, v in expansion_kwargs.items():
-                    if not isinstance(v, _XComRef):
-                        continue
-
-                    expansion_kwargs[k] = 
XComArg(operator=dag.get_task(v.task_id), key=v.key)
+            # Dereference expand_input and op_kwargs_expand_input.
+            for k in ("expand_input", "op_kwargs_expand_input"):
+                kwargs_ref = getattr(task, k, None)
+                if isinstance(kwargs_ref, _ExpandInputRef):
+                    setattr(task, k, kwargs_ref.deref(dag))
 
             for task_id in task.downstream_task_ids:
                 # Bypass set_upstream etc here - it does more than we want
diff --git a/airflow/www/views.py b/airflow/www/views.py
index c630d47af8..906d73e943 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1372,7 +1372,7 @@ class Airflow(AirflowBaseView):
         # only matters if get_rendered_template_fields() raised an exception.
         # The following rendering won't show useful values in this case anyway,
         # but we'll display some quasi-meaingful field names.
-        task = ti.task.unmap()
+        task = ti.task.unmap(None)
 
         title = "Rendered Template"
         html_dict = {}
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index e6779724bf..9f7be831d9 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1514,6 +1514,7 @@ unittests
 unix
 unmappable
 unmapped
+unmapping
 unpause
 unpaused
 unpausing
diff --git 
a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py 
b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
index c8c07c9e80..909693de5a 100755
--- a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
+++ b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
@@ -50,8 +50,9 @@ IGNORED = {
     "partial",
     "shallow_copy_attrs",
     # Only on MappedOperator.
-    "mapped_kwargs",
+    "expand_input",
     "partial_kwargs",
+    "validate_upstream_return_value",
 }
 
 
diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py 
b/tests/api_connexion/endpoints/test_task_endpoint.py
index 7509a89032..4151a4b60f 100644
--- a/tests/api_connexion/endpoints/test_task_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_endpoint.py
@@ -22,6 +22,7 @@ import pytest
 
 from airflow import DAG
 from airflow.models import DagBag
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.operators.empty import EmptyOperator
 from airflow.security import permissions
@@ -70,7 +71,7 @@ class TestTaskEndpoint:
             EmptyOperator(task_id=self.task_id3)
             # Use the private _expand() method to avoid the empty kwargs check.
             # We don't care about how the operator runs here, only its 
presence.
-            EmptyOperator.partial(task_id=self.mapped_task_id)._expand()
+            
EmptyOperator.partial(task_id=self.mapped_task_id)._expand(EXPAND_INPUT_EMPTY, 
strict=False)
 
         task1 >> task2
         dag_bag = DagBag(os.devnull, include_examples=False)
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 63514fcf25..58ae1c5f87 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -28,6 +28,7 @@ from airflow.decorators.base import DecoratedMappedOperator
 from airflow.exceptions import AirflowException
 from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperator
+from airflow.models.expandinput import DictOfListsExpandInput
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XCOM_RETURN_KEY
@@ -613,7 +614,7 @@ def test_mapped_decorator():
     assert isinstance(t2, XComArg)
     assert isinstance(t2.operator, DecoratedMappedOperator)
     assert t2.operator.task_id == "print_everything"
-    assert t2.operator.mapped_op_kwargs == {"any_key": [1, 2], "works": t1}
+    assert t2.operator.op_kwargs_expand_input == 
DictOfListsExpandInput({"any_key": [1, 2], "works": t1})
 
     assert t0.operator.task_id == "print_info"
     assert t1.operator.task_id == "print_info__1"
@@ -656,7 +657,7 @@ def test_partial_mapped_decorator() -> None:
 
     assert isinstance(doubled, XComArg)
     assert isinstance(doubled.operator, DecoratedMappedOperator)
-    assert doubled.operator.mapped_op_kwargs == {"number": literal}
+    assert doubled.operator.op_kwargs_expand_input == 
DictOfListsExpandInput({"number": literal})
     assert doubled.operator.partial_kwargs["op_kwargs"] == {"multiple": 2}
 
     assert isinstance(trippled.operator, DecoratedMappedOperator)  # For 
type-checking on partial_kwargs.
@@ -678,7 +679,7 @@ def test_mapped_decorator_unmap_merge_op_kwargs():
 
         task2.partial(arg1=1).expand(arg2=task1())
 
-    unmapped = dag.get_task("task2").unmap()
+    unmapped = dag.get_task("task2").unmap(None)
     assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
 
 
@@ -697,11 +698,11 @@ def test_mapped_decorator_converts_partial_kwargs():
 
     mapped_task2 = dag.get_task("task2")
     assert mapped_task2.partial_kwargs["retry_delay"] == timedelta(seconds=30)
-    assert mapped_task2.unmap().retry_delay == timedelta(seconds=30)
+    assert mapped_task2.unmap(None).retry_delay == timedelta(seconds=30)
 
     mapped_task1 = dag.get_task("task1")
     assert mapped_task2.partial_kwargs["retry_delay"] == timedelta(seconds=30) 
 # Operator default.
-    mapped_task1.unmap().retry_delay == timedelta(seconds=300)  # Operator 
default.
+    mapped_task1.unmap(None).retry_delay == timedelta(seconds=300)  # Operator 
default.
 
 
 def test_mapped_render_template_fields(dag_maker, session):
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 397f519a50..17c339620b 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1165,7 +1165,7 @@ def 
test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session):
     ]
 
 
-def 
test_mapped_literal_length_increase_at_runtime_adds_additional_tis(dag_maker, 
session):
+def test_mapped_length_increase_at_runtime_adds_additional_tis(dag_maker, 
session):
     """Test that when the length of mapped literal increases at runtime, 
additional ti is added"""
     from airflow.models import Variable
 
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index c720fd96d9..92d6226097 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -102,9 +102,7 @@ def test_map_xcom_arg():
 def test_partial_on_instance() -> None:
     """`.partial` on an instance should fail -- it's only designed to be 
called on classes"""
     with pytest.raises(TypeError):
-        MockOperator(
-            task_id='a',
-        ).partial()
+        MockOperator(task_id='a').partial()
 
 
 def test_partial_on_class() -> None:
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index c0868b0b7f..e17f34bd78 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -58,6 +58,7 @@ from airflow.models import (
     XCom,
 )
 from airflow.models.dataset import Dataset, DatasetDagRunQueue, DatasetEvent, 
DatasetTaskRef
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskfail import TaskFail
 from airflow.models.taskinstance import TaskInstance
@@ -1058,7 +1059,7 @@ class TestTaskInstance:
         with dag_maker(dag_id="test_xcom", session=session):
             # Use the private _expand() method to avoid the empty kwargs check.
             # We don't care about how the operator runs here, only its 
presence.
-            task_1 = EmptyOperator.partial(task_id="task_1")._expand()
+            task_1 = 
EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY, 
strict=False)
             EmptyOperator(task_id="task_2")
 
         dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 
1, 0, 0, 0))
@@ -2477,9 +2478,9 @@ class TestTaskInstanceRecordTaskMapXComPush:
             (None, XComForMappingNotPushed, "did not push XCom for task 
mapping"),
         ],
     )
-    def test_error_if_unmappable_type(self, dag_maker, return_value, 
exception_type, error_message):
-        """If an unmappable return value is used to map, fail the task that 
pushed the XCom."""
-        with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
+    def test_expand_error_if_unmappable_type(self, dag_maker, return_value, 
exception_type, error_message):
+        """If an unmappable return value is used for expand(), fail the task 
that pushed the XCom."""
+        with dag_maker(dag_id="test_expand_error_if_unmappable_type") as dag:
 
             @dag.task()
             def push_something():
@@ -2802,7 +2803,7 @@ def 
test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v
     with dag_maker(dag_id="test_xcom", session=session):
         # Use the private _expand() method to avoid the empty kwargs check.
         # We don't care about how the operator runs here, only its presence.
-        task_1 = EmptyOperator.partial(task_id="task_1")._expand()
+        task_1 = 
EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY, 
strict=False)
         EmptyOperator(task_id="task_2")
 
     dagrun = dag_maker.create_dagrun()
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index af5a701366..418cb4af89 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1674,7 +1674,7 @@ def test_kubernetes_optional():
         module.SerializedDAG.to_dict(make_simple_dag()["simple_dag"])
 
 
-def test_mapped_operator_serde():
+def test_operator_expand_serde():
     literal = [1, 2, {'a': 'b'}]
     real_op = BashOperator.partial(task_id='a', executor_config={'dict': 
{'sub': 'value'}}).expand(
         bash_command=literal
@@ -1688,9 +1688,12 @@ def test_mapped_operator_serde():
         '_task_module': 'airflow.operators.bash',
         '_task_type': 'BashOperator',
         'downstream_task_ids': [],
-        'mapped_kwargs': {
-            "__type": "dict",
-            "__var": {'bash_command': [1, 2, {"__type": "dict", "__var": {'a': 
'b'}}]},
+        'expand_input': {
+            "type": "dict-of-lists",
+            "value": {
+                "__type": "dict",
+                "__var": {'bash_command': [1, 2, {"__type": "dict", "__var": 
{'a': 'b'}}]},
+            },
         },
         'partial_kwargs': {
             'executor_config': {
@@ -1705,7 +1708,8 @@ def test_mapped_operator_serde():
         'template_fields_renderers': {'bash_command': 'bash', 'env': 'json'},
         'ui_color': '#f0ede4',
         'ui_fgcolor': '#000',
-        '_expansion_kwargs_attr': 'mapped_kwargs',
+        "_disallow_kwargs_override": False,
+        '_expand_input_attr': 'expand_input',
     }
 
     op = SerializedBaseOperator.deserialize_operator(serialized)
@@ -1722,11 +1726,11 @@ def test_mapped_operator_serde():
         'ui_color': '#f0ede4',
         'ui_fgcolor': '#000',
     }
-    assert op.mapped_kwargs['bash_command'] == literal
+    assert op.expand_input.value['bash_command'] == literal
     assert op.partial_kwargs['executor_config'] == {'dict': {'sub': 'value'}}
 
 
-def test_mapped_operator_xcomarg_serde():
+def test_operator_expand_xcomarg_serde():
     from airflow.models.xcom_arg import XComArg
 
     with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
@@ -1740,9 +1744,12 @@ def test_mapped_operator_xcomarg_serde():
         '_task_module': 'tests.test_utils.mock_operators',
         '_task_type': 'MockOperator',
         'downstream_task_ids': [],
-        'mapped_kwargs': {
-            "__type": "dict",
-            "__var": {'arg2': {'__type': 'xcomref', '__var': {'task_id': 
'op1', 'key': 'return_value'}}},
+        'expand_input': {
+            "type": "dict-of-lists",
+            "value": {
+                "__type": "dict",
+                "__var": {'arg2': {'__type': 'xcomref', '__var': {'task_id': 
'op1', 'key': 'return_value'}}},
+            },
         },
         'partial_kwargs': {},
         'task_id': 'task_2',
@@ -1752,31 +1759,32 @@ def test_mapped_operator_xcomarg_serde():
         'operator_extra_links': [],
         'ui_color': '#fff',
         'ui_fgcolor': '#000',
-        '_expansion_kwargs_attr': 'mapped_kwargs',
+        "_disallow_kwargs_override": False,
+        '_expand_input_attr': 'expand_input',
     }
 
     op = SerializedBaseOperator.deserialize_operator(serialized)
     assert op.deps is MappedOperator.deps_for(BaseOperator)
 
-    arg = op.mapped_kwargs['arg2']
+    arg = op.expand_input.value['arg2']
     assert arg.task_id == 'op1'
     assert arg.key == XCOM_RETURN_KEY
 
     serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
 
-    xcom_arg = serialized_dag.task_dict['task_2'].mapped_kwargs['arg2']
+    xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value['arg2']
     assert isinstance(xcom_arg, XComArg)
     assert xcom_arg.operator is serialized_dag.task_dict['op1']
 
 
-def test_mapped_operator_deserialized_unmap():
+def test_operator_expand_deserialized_unmap():
     """Unmap a deserialized mapped operator should be similar to deserializing 
an non-mapped operator."""
     normal = BashOperator(task_id='a', bash_command=[1, 2], 
executor_config={"a": "b"})
     mapped = BashOperator.partial(task_id='a', executor_config={"a": 
"b"}).expand(bash_command=[1, 2])
 
     serialize = SerializedBaseOperator._serialize
     deserialize = SerializedBaseOperator.deserialize_operator
-    assert deserialize(serialize(mapped)).unmap() == 
deserialize(serialize(normal))
+    assert deserialize(serialize(mapped)).unmap(None) == 
deserialize(serialize(normal))
 
 
 def test_task_resources_serde():
@@ -1799,10 +1807,10 @@ def test_task_resources_serde():
     }
 
 
-def test_mapped_decorator_serde():
+def test_taskflow_expand_serde():
     from airflow.decorators import task
     from airflow.models.xcom_arg import XComArg
-    from airflow.serialization.serialized_objects import _XComRef
+    from airflow.serialization.serialized_objects import _ExpandInputRef, 
_XComRef
 
     with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
         op1 = BaseOperator(task_id="op1")
@@ -1830,11 +1838,14 @@ def test_mapped_decorator_serde():
             },
             'retry_delay': {'__type': 'timedelta', '__var': 30.0},
         },
-        'mapped_op_kwargs': {
-            "__type": "dict",
-            "__var": {
-                'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}},
-                'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1', 
'key': 'return_value'}},
+        'op_kwargs_expand_input': {
+            "type": "dict-of-lists",
+            "value": {
+                "__type": "dict",
+                "__var": {
+                    'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}},
+                    'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1', 
'key': 'return_value'}},
+                },
             },
         },
         'operator_extra_links': [],
@@ -1844,7 +1855,8 @@ def test_mapped_decorator_serde():
         'template_ext': [],
         'template_fields': ['op_args', 'op_kwargs'],
         'template_fields_renderers': {"op_args": "py", "op_kwargs": "py"},
-        '_expansion_kwargs_attr': 'mapped_op_kwargs',
+        "_disallow_kwargs_override": False,
+        '_expand_input_attr': 'op_kwargs_expand_input',
     }
 
     deserialized = SerializedBaseOperator.deserialize_operator(serialized)
@@ -1853,10 +1865,10 @@ def test_mapped_decorator_serde():
     assert deserialized.upstream_task_ids == set()
     assert deserialized.downstream_task_ids == set()
 
-    assert deserialized.mapped_op_kwargs == {
-        "arg2": {"a": 1, "b": 2},
-        "arg3": _XComRef("op1", XCOM_RETURN_KEY),
-    }
+    assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
+        key="dict-of-lists",
+        value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1", 
XCOM_RETURN_KEY)},
+    )
     assert deserialized.partial_kwargs == {
         "op_args": [],
         "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
@@ -1868,10 +1880,10 @@ def test_mapped_decorator_serde():
     # here so we don't need to duplicate tests between pickled and non-pickled
     # DAGs everywhere else.
     pickled = pickle.loads(pickle.dumps(deserialized))
-    assert pickled.mapped_op_kwargs == {
-        "arg2": {"a": 1, "b": 2},
-        "arg3": _XComRef("op1", XCOM_RETURN_KEY),
-    }
+    assert pickled.op_kwargs_expand_input == _ExpandInputRef(
+        key="dict-of-lists",
+        value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1", 
XCOM_RETURN_KEY)},
+    )
     assert pickled.partial_kwargs == {
         "op_args": [],
         "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},

Reply via email to