This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5d1270c32b Resolve XComArgs before trying to unmap MappedOperators 
(#22975)
5d1270c32b is described below

commit 5d1270c32b2739bcd91ed6b3e47fe8b5f8f75f13
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Apr 14 04:18:03 2022 +0100

    Resolve XComArgs before trying to unmap MappedOperators (#22975)
    
    Many operators do some type validation inside `__init__`
    (DateTimeSensor for instance -- which requires a str or a datetime)
    which then fail when mapped as they get an XComArg instead.
    
    To fix this we have had to change the order we unmap and resolve
    templates:
    
    - first we get the unmapping kwargs, we resolve expansion/mapping args
      in that
    - Then we create the operator (this should fix the constructor getting
      XComArg problem)
    - Then we render templates, but only for values that _weren't_ expanded
      already
    
    Unmapping the task early in LocalTaskJob causes problems, and it's just
    not needed as it is (correctly) unmapped inside
    TaskInstance._execute_task_with_callbacks call to
    `self.render_templates()`
---
 airflow/decorators/base.py                         | 45 ++++++++++---
 airflow/jobs/local_task_job.py                     |  5 --
 airflow/models/abstractoperator.py                 | 22 +------
 airflow/models/mappedoperator.py                   | 75 +++++++++++++++-------
 airflow/models/xcom_arg.py                         | 10 ++-
 .../concepts/dynamic-task-mapping.rst              | 36 +++++++++++
 tests/decorators/test_python.py                    | 40 ++++++++++++
 tests/models/test_baseoperator.py                  | 43 +++++++++++++
 8 files changed, 219 insertions(+), 57 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 3d66d01c91..9072439b2f 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import collections.abc
 import functools
 import inspect
 import re
@@ -39,7 +38,6 @@ from typing import (
 
 import attr
 import typing_extensions
-from sqlalchemy.orm import Session
 
 from airflow.compat.functools import cache, cached_property
 from airflow.exceptions import AirflowException
@@ -68,6 +66,9 @@ from airflow.utils.task_group import TaskGroup, 
TaskGroupContext
 from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
+    import jinja2  # Slow import.
+    from sqlalchemy.orm import Session
+
     from airflow.models.mappedoperator import Mappable
 
 
@@ -430,6 +431,7 @@ class DecoratedMappedOperator(MappedOperator):
             self.mapped_op_kwargs,
             fail_reason="mapping already partial",
         )
+        self._combined_op_kwargs = op_kwargs
         return {
             "dag": self.dag,
             "task_group": self.task_group,
@@ -441,13 +443,38 @@ class DecoratedMappedOperator(MappedOperator):
             **self.mapped_kwargs,
         }
 
-    def _expand_mapped_field(self, key: str, content: Any, context: Context, 
*, session: Session) -> Any:
-        if key != "op_kwargs" or not isinstance(content, 
collections.abc.Mapping):
-            return content
-        # The magic super() doesn't work here, so we use the explicit form.
-        # Not using super(..., self) to work around pyupgrade bug.
-        sup: Any = super(DecoratedMappedOperator, DecoratedMappedOperator)
-        return {k: sup._expand_mapped_field(self, k, v, context, 
session=session) for k, v in content.items()}
+    def _resolve_expansion_kwargs(
+        self, kwargs: Dict[str, Any], template_fields: Set[str], context: 
Context, session: "Session"
+    ) -> None:
+        expansion_kwargs = self._get_expansion_kwargs()
+
+        self._already_resolved_op_kwargs = set()
+        for k, v in expansion_kwargs.items():
+            if isinstance(v, XComArg):
+                self._already_resolved_op_kwargs.add(k)
+                v = v.resolve(context, session=session)
+            v = self._expand_mapped_field(k, v, context, session=session)
+            kwargs['op_kwargs'][k] = v
+            template_fields.discard(k)
+
+    def render_template(
+        self,
+        value: Any,
+        context: Context,
+        jinja_env: Optional["jinja2.Environment"] = None,
+        seen_oids: Optional[Set] = None,
+    ) -> Any:
+        if hasattr(self, '_combined_op_kwargs') and value is 
self._combined_op_kwargs:
+            # Avoid rendering values that came out of resolved XComArgs
+            return {
+                k: v
+                if k in self._already_resolved_op_kwargs
+                else super(DecoratedMappedOperator, 
DecoratedMappedOperator).render_template(
+                    self, v, context, jinja_env=jinja_env, seen_oids=seen_oids
+                )
+                for k, v in value.items()
+            }
+        return super().render_template(value, context, jinja_env=jinja_env, 
seen_oids=seen_oids)
 
 
 class Task(Generic[Function]):
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index d7c6fa2153..9b2c3510eb 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -104,11 +104,6 @@ class LocalTaskJob(BaseJob):
         try:
             self.task_runner.start()
 
-            # Unmap the task _after_ it has forked/execed. (This is a bit of a 
kludge, but if we unmap before
-            # fork, then the "run_raw_task" command will see the mapping index 
and an Non-mapped task and
-            # fail)
-            self.task_instance.task = self.task_instance.task.unmap()
-
             heartbeat_time_limit = conf.getint('scheduler', 
'scheduler_zombie_task_threshold')
 
             # task callback invocation happens either here or in
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index c64b8554af..2d8d008dfc 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -35,8 +35,6 @@ from typing import (
     Union,
 )
 
-from sqlalchemy.orm import Session
-
 from airflow.compat.functools import cached_property
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
@@ -52,6 +50,7 @@ TaskStateChangeCallback = Callable[[Context], None]
 
 if TYPE_CHECKING:
     import jinja2  # Slow import.
+    from sqlalchemy.orm import Session
 
     from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
     from airflow.models.dag import DAG
@@ -330,7 +329,7 @@ class AbstractOperator(LoggingMixin, DAGNode):
         jinja_env: "jinja2.Environment",
         seen_oids: Set,
         *,
-        session: Session = NEW_SESSION,
+        session: "Session" = NEW_SESSION,
     ) -> None:
         for attr_name in template_fields:
             try:
@@ -342,29 +341,14 @@ class AbstractOperator(LoggingMixin, DAGNode):
                 )
             if not value:
                 continue
-            rendered_content = self._render_template_field(
-                attr_name,
+            rendered_content = self.render_template(
                 value,
                 context,
                 jinja_env,
                 seen_oids,
-                session=session,
             )
             setattr(parent, attr_name, rendered_content)
 
-    def _render_template_field(
-        self,
-        key: str,
-        value: Any,
-        context: Context,
-        jinja_env: Optional["jinja2.Environment"] = None,
-        seen_oids: Optional[Set] = None,
-        *,
-        session: Session,
-    ) -> Any:
-        """Override point for MappedOperator to perform further resolution."""
-        return self.render_template(value, context, jinja_env, seen_oids)
-
     def render_template(
         self,
         content: Any,
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 15ecee8798..ddc32906df 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -44,6 +44,7 @@ import pendulum
 from sqlalchemy import func, or_
 from sqlalchemy.orm.session import Session
 
+from airflow import settings
 from airflow.compat.functools import cache, cached_property
 from airflow.exceptions import AirflowException, UnmappableOperator
 from airflow.models.abstractoperator import (
@@ -473,6 +474,7 @@ class MappedOperator(AbstractOperator):
         return DagAttributeTypes.OP, self.task_id
 
     def _get_unmap_kwargs(self) -> Dict[str, Any]:
+
         return {
             "task_id": self.task_id,
             "dag": self.dag,
@@ -484,14 +486,26 @@ class MappedOperator(AbstractOperator):
             **self.mapped_kwargs,
         }
 
-    def unmap(self) -> "BaseOperator":
-        """Get the "normal" Operator after applying the current mapping."""
+    def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) -> 
"BaseOperator":
+        """
+        Get the "normal" Operator after applying the current mapping.
+
+        If ``operator_class`` is not a class (i.e. this DAG has been 
deserialized) then this will return a
+        SerializedBaseOperator that aims to "look like" the real operator.
+
+        :param unmap_kwargs: Override the args to pass to the Operator 
constructor. Only used when
+            ``operator_class`` is still an actual class.
+
+        :meta private:
+        """
         if isinstance(self.operator_class, type):
             # We can't simply specify task_id here because BaseOperator further
             # mangles the task_id based on the task hierarchy (namely, group_id
             # is prepended, and '__N' appended to deduplicate). Instead of
             # recreating the whole logic here, we just overwrite task_id later.
-            op = self.operator_class(**self._get_unmap_kwargs(), 
_airflow_from_mapped=True)
+            if unmap_kwargs is None:
+                unmap_kwargs = self._get_unmap_kwargs()
+            op = self.operator_class(**unmap_kwargs, _airflow_from_mapped=True)
             op.task_id = self.task_id
             return op
 
@@ -569,6 +583,7 @@ class MappedOperator(AbstractOperator):
                 map_lengths[mapped_arg_name] += length
         return map_lengths
 
+    @cache
     def _resolve_map_lengths(self, run_id: str, *, session: Session) -> 
Dict[str, int]:
         """Return dict of argument name to map length, or throw if some are 
not resolvable"""
         expansion_kwargs = self._get_expansion_kwargs()
@@ -686,33 +701,49 @@ class MappedOperator(AbstractOperator):
         """
         if not jinja_env:
             jinja_env = self.get_template_env()
-        unmapped_task = self.unmap()
+        # Before we unmap we have to resolve the mapped arguments, otherwise 
the real operator constructor
+        # could be called with an XComArg, rather than the value it resolves 
to.
+        #
+        # We also need to resolve _all_ mapped arguments, even if they aren't 
marked as templated
+        kwargs = self._get_unmap_kwargs()
+
+        template_fields = set(self.template_fields)
+
+        # Ideally we'd like to pass in session as an argument to this 
function, but since operators _could_
+        # override this we can't easily change this function signature.
+        # We can't use @provide_session, as that closes and expunges 
everything, which we don't want to do
+        # when we are so "deep" in the weeds here.
+        #
+        # Nor do we want to close the session -- that would expunge all the 
things from the internal cache
+        # which we don't want to do either
+        session = settings.Session()
+        self._resolve_expansion_kwargs(kwargs, template_fields, context, 
session)
+
+        unmapped_task = self.unmap(unmap_kwargs=kwargs)
         self._do_render_template_fields(
             parent=unmapped_task,
-            template_fields=unmapped_task.template_fields,
+            template_fields=template_fields,
             context=context,
             jinja_env=jinja_env,
             seen_oids=set(),
+            session=session,
         )
         return unmapped_task
 
-    def _render_template_field(
-        self,
-        key: str,
-        value: Any,
-        context: Context,
-        jinja_env: Optional["jinja2.Environment"] = None,
-        seen_oids: Optional[Set] = None,
-        *,
-        session: Session,
-    ) -> Any:
-        """Override the ordinary template rendering to add more logic.
-
-        Specifically, if we're rendering a mapped argument, we need to "unmap"
-        the value as well to assign it to the unmapped operator.
-        """
-        value = super()._render_template_field(key, value, context, jinja_env, 
seen_oids, session=session)
-        return self._expand_mapped_field(key, value, context, session=session)
+    def _resolve_expansion_kwargs(
+        self, kwargs: Dict[str, Any], template_fields: Set[str], context: 
Context, session: Session
+    ) -> None:
+        """Update mapped fields in place in kwargs dict"""
+        from airflow.models.xcom_arg import XComArg
+
+        expansion_kwargs = self._get_expansion_kwargs()
+
+        for k, v in expansion_kwargs.items():
+            if isinstance(v, XComArg):
+                v = v.resolve(context, session=session)
+            v = self._expand_mapped_field(k, v, context, session=session)
+            template_fields.discard(k)
+            kwargs[k] = v
 
     def _expand_mapped_field(self, key: str, value: Any, context: Context, *, 
session: Session) -> Any:
         map_index = context["ti"].map_index
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 7e4b274626..449fab8af5 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -22,9 +22,12 @@ from airflow.models.taskmixin import DAGNode, DependencyMixin
 from airflow.models.xcom import XCOM_RETURN_KEY
 from airflow.utils.context import Context
 from airflow.utils.edgemodifier import EdgeModifier
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
     from airflow.models.operator import Operator
 
 
@@ -136,12 +139,15 @@ class XComArg(DependencyMixin):
         """Proxy to underlying operator set_downstream method. Required by 
TaskMixin."""
         self.operator.set_downstream(task_or_task_list, edge_modifier)
 
-    def resolve(self, context: Context) -> Any:
+    @provide_session
+    def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> 
Any:
         """
         Pull XCom value for the existing arg. This method is run during 
``op.execute()``
         in respectable context.
         """
-        result = context["ti"].xcom_pull(task_ids=self.operator.task_id, 
key=str(self.key), default=NOTSET)
+        result = context["ti"].xcom_pull(
+            task_ids=self.operator.task_id, key=str(self.key), default=NOTSET, 
session=session
+        )
         if result is NOTSET:
             raise AirflowException(
                 f'XComArg result from {self.operator.task_id} at 
{context["ti"].dag_id} '
diff --git a/docs/apache-airflow/concepts/dynamic-task-mapping.rst 
b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
index b326a31672..d1152971dc 100644
--- a/docs/apache-airflow/concepts/dynamic-task-mapping.rst
+++ b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
@@ -224,6 +224,42 @@ Currently it is only possible to map against a dict, a 
list, or one of those typ
 
 If an upstream task returns an unmappable type, the mapped task will fail at 
run-time with an ``UnmappableXComTypePushed`` exception. For instance, you 
can't have the upstream task return a plain string – it must be a list or a 
dict.
 
+How do templated fields and mapped arguments interact?
+======================================================
+
+All arguments to an operator can be mapped, even those that do not accept 
templated parameters.
+
+If a field is marked as being templated and is mapped, it **will not be 
templated**.
+
+For example, this will print ``{{ ds }}`` and not a date stamp:
+
+.. code-block:: python
+
+    @task
+    def make_list():
+        return ["{{ ds }}"]
+
+
+    @task
+    def printer(val):
+        print(val)
+
+
+    printer.expand(val=make_list())
+
+If you want to interpolate values either call ``task.render_template`` 
yourself, or use interpolation:
+
+.. code-block:: python
+
+    @task
+    def make_list(ds):
+        return [ds]
+
+
+    @task
+    def make_list(**context):
+        return [context["task"].render_template("{{ ds }}", context)]
+
 Placing limits on mapped tasks
 ==============================
 
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 8406de45e6..418907fd78 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -27,6 +27,10 @@ from airflow.decorators import task as task_decorator
 from airflow.decorators.base import DecoratedMappedOperator
 from airflow.exceptions import AirflowException
 from airflow.models import DAG
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskmap import TaskMap
+from airflow.models.xcom import XCOM_RETURN_KEY
 from airflow.models.xcom_arg import XComArg
 from airflow.utils import timezone
 from airflow.utils.state import State
@@ -641,3 +645,39 @@ def test_mapped_decorator_converts_partial_kwargs():
     mapped_task1 = dag.get_task("task1")
     assert mapped_task2.partial_kwargs["retry_delay"] == timedelta(seconds=30) 
 # Operator default.
     mapped_task1.unmap().retry_delay == timedelta(seconds=300)  # Operator 
default.
+
+
+def test_mapped_render_template_fields(dag_maker, session):
+    @task_decorator
+    def fn(arg1, arg2):
+        ...
+
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        xcom_arg = XComArg(task1)
+        mapped = fn.partial(arg2='{{ ti.task_id }}').expand(arg1=xcom_arg)
+
+    dr = dag_maker.create_dagrun()
+    ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+    ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session)
+
+    session.add(
+        TaskMap(
+            dag_id=dr.dag_id,
+            task_id=task1.task_id,
+            run_id=dr.run_id,
+            map_index=-1,
+            length=1,
+            keys=None,
+        )
+    )
+    session.flush()
+
+    mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, 
session=session)
+    mapped_ti.map_index = 0
+    op = 
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
+    assert op
+
+    assert op.op_kwargs['arg1'] == "{{ ds }}"
+    assert op.op_kwargs['arg2'] == "fn"
diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index 64a774a0da..3c90463c60 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -33,6 +33,7 @@ from airflow.models.baseoperator import BaseOperator, 
BaseOperatorMeta, chain, c
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
+from airflow.models.xcom import XCOM_RETURN_KEY
 from airflow.models.xcom_arg import XComArg
 from airflow.utils.context import Context
 from airflow.utils.edgemodifier import Label
@@ -944,3 +945,45 @@ def 
test_mapped_task_applies_default_args_taskflow(dag_maker):
 
     assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
     assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
+
+
+def test_mapped_render_template_fields_validating_operator(dag_maker, session):
+    class MyOperator(MockOperator):
+        def __init__(self, value, arg1, **kwargs):
+            assert isinstance(value, str), "value should have been resolved 
before unmapping"
+            assert isinstance(arg1, str), "value should have been resolved 
before unmapping"
+            super().__init__(arg1=arg1, **kwargs)
+            self.value = value
+
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        xcom_arg = XComArg(task1)
+        mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id 
}}').expand(
+            value=xcom_arg, arg1=xcom_arg
+        )
+
+    dr = dag_maker.create_dagrun()
+    ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+    ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session)
+
+    session.add(
+        TaskMap(
+            dag_id=dr.dag_id,
+            task_id=task1.task_id,
+            run_id=dr.run_id,
+            map_index=-1,
+            length=1,
+            keys=None,
+        )
+    )
+    session.flush()
+
+    mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, 
session=session)
+    mapped_ti.map_index = 0
+    op = 
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
+    assert isinstance(op, MyOperator)
+
+    assert op.value == "{{ ds }}", "Should not be templated!"
+    assert op.arg1 == "{{ ds }}"
+    assert op.arg2 == "a"

Reply via email to