This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5d1270c32b Resolve XComArgs before trying to unmap MappedOperators
(#22975)
5d1270c32b is described below
commit 5d1270c32b2739bcd91ed6b3e47fe8b5f8f75f13
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Apr 14 04:18:03 2022 +0100
Resolve XComArgs before trying to unmap MappedOperators (#22975)
Many operators do some type validation inside `__init__`
(DateTimeSensor for instance -- which requires a str or a datetime)
which then fail when mapped as they get an XComArg instead.
To fix this we have had to change the order we unmap and resolve
templates:
- first we get the unmapping kwargs, we resolve expansion/mapping args
in that
- Then we create the operator (this should fix the constructor getting
XComArg problem)
- Then we render templates, but only for values that _weren't_ expanded
already
Unmapping the task early in LocalTaskJob causes problems, and it's just
not needed as it is (correctly) unmapped inside
TaskInstance._execute_task_with_callbacks call to
`self.render_templates()`
---
airflow/decorators/base.py | 45 ++++++++++---
airflow/jobs/local_task_job.py | 5 --
airflow/models/abstractoperator.py | 22 +------
airflow/models/mappedoperator.py | 75 +++++++++++++++-------
airflow/models/xcom_arg.py | 10 ++-
.../concepts/dynamic-task-mapping.rst | 36 +++++++++++
tests/decorators/test_python.py | 40 ++++++++++++
tests/models/test_baseoperator.py | 43 +++++++++++++
8 files changed, 219 insertions(+), 57 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 3d66d01c91..9072439b2f 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import collections.abc
import functools
import inspect
import re
@@ -39,7 +38,6 @@ from typing import (
import attr
import typing_extensions
-from sqlalchemy.orm import Session
from airflow.compat.functools import cache, cached_property
from airflow.exceptions import AirflowException
@@ -68,6 +66,9 @@ from airflow.utils.task_group import TaskGroup,
TaskGroupContext
from airflow.utils.types import NOTSET
if TYPE_CHECKING:
+ import jinja2 # Slow import.
+ from sqlalchemy.orm import Session
+
from airflow.models.mappedoperator import Mappable
@@ -430,6 +431,7 @@ class DecoratedMappedOperator(MappedOperator):
self.mapped_op_kwargs,
fail_reason="mapping already partial",
)
+ self._combined_op_kwargs = op_kwargs
return {
"dag": self.dag,
"task_group": self.task_group,
@@ -441,13 +443,38 @@ class DecoratedMappedOperator(MappedOperator):
**self.mapped_kwargs,
}
- def _expand_mapped_field(self, key: str, content: Any, context: Context,
*, session: Session) -> Any:
- if key != "op_kwargs" or not isinstance(content,
collections.abc.Mapping):
- return content
- # The magic super() doesn't work here, so we use the explicit form.
- # Not using super(..., self) to work around pyupgrade bug.
- sup: Any = super(DecoratedMappedOperator, DecoratedMappedOperator)
- return {k: sup._expand_mapped_field(self, k, v, context,
session=session) for k, v in content.items()}
+ def _resolve_expansion_kwargs(
+ self, kwargs: Dict[str, Any], template_fields: Set[str], context:
Context, session: "Session"
+ ) -> None:
+ expansion_kwargs = self._get_expansion_kwargs()
+
+ self._already_resolved_op_kwargs = set()
+ for k, v in expansion_kwargs.items():
+ if isinstance(v, XComArg):
+ self._already_resolved_op_kwargs.add(k)
+ v = v.resolve(context, session=session)
+ v = self._expand_mapped_field(k, v, context, session=session)
+ kwargs['op_kwargs'][k] = v
+ template_fields.discard(k)
+
+ def render_template(
+ self,
+ value: Any,
+ context: Context,
+ jinja_env: Optional["jinja2.Environment"] = None,
+ seen_oids: Optional[Set] = None,
+ ) -> Any:
+ if hasattr(self, '_combined_op_kwargs') and value is
self._combined_op_kwargs:
+ # Avoid rendering values that came out of resolved XComArgs
+ return {
+ k: v
+ if k in self._already_resolved_op_kwargs
+ else super(DecoratedMappedOperator,
DecoratedMappedOperator).render_template(
+ self, v, context, jinja_env=jinja_env, seen_oids=seen_oids
+ )
+ for k, v in value.items()
+ }
+ return super().render_template(value, context, jinja_env=jinja_env,
seen_oids=seen_oids)
class Task(Generic[Function]):
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index d7c6fa2153..9b2c3510eb 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -104,11 +104,6 @@ class LocalTaskJob(BaseJob):
try:
self.task_runner.start()
- # Unmap the task _after_ it has forked/execed. (This is a bit of a
kludge, but if we unmap before
- # fork, then the "run_raw_task" command will see the mapping index
and an Non-mapped task and
- # fail)
- self.task_instance.task = self.task_instance.task.unmap()
-
heartbeat_time_limit = conf.getint('scheduler',
'scheduler_zombie_task_threshold')
# task callback invocation happens either here or in
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index c64b8554af..2d8d008dfc 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -35,8 +35,6 @@ from typing import (
Union,
)
-from sqlalchemy.orm import Session
-
from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -52,6 +50,7 @@ TaskStateChangeCallback = Callable[[Context], None]
if TYPE_CHECKING:
import jinja2 # Slow import.
+ from sqlalchemy.orm import Session
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.dag import DAG
@@ -330,7 +329,7 @@ class AbstractOperator(LoggingMixin, DAGNode):
jinja_env: "jinja2.Environment",
seen_oids: Set,
*,
- session: Session = NEW_SESSION,
+ session: "Session" = NEW_SESSION,
) -> None:
for attr_name in template_fields:
try:
@@ -342,29 +341,14 @@ class AbstractOperator(LoggingMixin, DAGNode):
)
if not value:
continue
- rendered_content = self._render_template_field(
- attr_name,
+ rendered_content = self.render_template(
value,
context,
jinja_env,
seen_oids,
- session=session,
)
setattr(parent, attr_name, rendered_content)
- def _render_template_field(
- self,
- key: str,
- value: Any,
- context: Context,
- jinja_env: Optional["jinja2.Environment"] = None,
- seen_oids: Optional[Set] = None,
- *,
- session: Session,
- ) -> Any:
- """Override point for MappedOperator to perform further resolution."""
- return self.render_template(value, context, jinja_env, seen_oids)
-
def render_template(
self,
content: Any,
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 15ecee8798..ddc32906df 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -44,6 +44,7 @@ import pendulum
from sqlalchemy import func, or_
from sqlalchemy.orm.session import Session
+from airflow import settings
from airflow.compat.functools import cache, cached_property
from airflow.exceptions import AirflowException, UnmappableOperator
from airflow.models.abstractoperator import (
@@ -473,6 +474,7 @@ class MappedOperator(AbstractOperator):
return DagAttributeTypes.OP, self.task_id
def _get_unmap_kwargs(self) -> Dict[str, Any]:
+
return {
"task_id": self.task_id,
"dag": self.dag,
@@ -484,14 +486,26 @@ class MappedOperator(AbstractOperator):
**self.mapped_kwargs,
}
- def unmap(self) -> "BaseOperator":
- """Get the "normal" Operator after applying the current mapping."""
+ def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) ->
"BaseOperator":
+ """
+ Get the "normal" Operator after applying the current mapping.
+
+ If ``operator_class`` is not a class (i.e. this DAG has been
deserialized) then this will return a
+ SerializedBaseOperator that aims to "look like" the real operator.
+
+ :param unmap_kwargs: Override the args to pass to the Operator
constructor. Only used when
+ ``operator_class`` is still an actual class.
+
+ :meta private:
+ """
if isinstance(self.operator_class, type):
# We can't simply specify task_id here because BaseOperator further
# mangles the task_id based on the task hierarchy (namely, group_id
# is prepended, and '__N' appended to deduplicate). Instead of
# recreating the whole logic here, we just overwrite task_id later.
- op = self.operator_class(**self._get_unmap_kwargs(),
_airflow_from_mapped=True)
+ if unmap_kwargs is None:
+ unmap_kwargs = self._get_unmap_kwargs()
+ op = self.operator_class(**unmap_kwargs, _airflow_from_mapped=True)
op.task_id = self.task_id
return op
@@ -569,6 +583,7 @@ class MappedOperator(AbstractOperator):
map_lengths[mapped_arg_name] += length
return map_lengths
+ @cache
def _resolve_map_lengths(self, run_id: str, *, session: Session) ->
Dict[str, int]:
"""Return dict of argument name to map length, or throw if some are
not resolvable"""
expansion_kwargs = self._get_expansion_kwargs()
@@ -686,33 +701,49 @@ class MappedOperator(AbstractOperator):
"""
if not jinja_env:
jinja_env = self.get_template_env()
- unmapped_task = self.unmap()
+ # Before we unmap we have to resolve the mapped arguments, otherwise
the real operator constructor
+ # could be called with an XComArg, rather than the value it resolves
to.
+ #
+ # We also need to resolve _all_ mapped arguments, even if they aren't
marked as templated
+ kwargs = self._get_unmap_kwargs()
+
+ template_fields = set(self.template_fields)
+
+ # Ideally we'd like to pass in session as an argument to this
function, but since operators _could_
+ # override this we can't easily change this function signature.
+ # We can't use @provide_session, as that closes and expunges
everything, which we don't want to do
+ # when we are so "deep" in the weeds here.
+ #
+ # Nor do we want to close the session -- that would expunge all the
things from the internal cache
+ # which we don't want to do either
+ session = settings.Session()
+ self._resolve_expansion_kwargs(kwargs, template_fields, context,
session)
+
+ unmapped_task = self.unmap(unmap_kwargs=kwargs)
self._do_render_template_fields(
parent=unmapped_task,
- template_fields=unmapped_task.template_fields,
+ template_fields=template_fields,
context=context,
jinja_env=jinja_env,
seen_oids=set(),
+ session=session,
)
return unmapped_task
- def _render_template_field(
- self,
- key: str,
- value: Any,
- context: Context,
- jinja_env: Optional["jinja2.Environment"] = None,
- seen_oids: Optional[Set] = None,
- *,
- session: Session,
- ) -> Any:
- """Override the ordinary template rendering to add more logic.
-
- Specifically, if we're rendering a mapped argument, we need to "unmap"
- the value as well to assign it to the unmapped operator.
- """
- value = super()._render_template_field(key, value, context, jinja_env,
seen_oids, session=session)
- return self._expand_mapped_field(key, value, context, session=session)
+ def _resolve_expansion_kwargs(
+ self, kwargs: Dict[str, Any], template_fields: Set[str], context:
Context, session: Session
+ ) -> None:
+ """Update mapped fields in place in kwargs dict"""
+ from airflow.models.xcom_arg import XComArg
+
+ expansion_kwargs = self._get_expansion_kwargs()
+
+ for k, v in expansion_kwargs.items():
+ if isinstance(v, XComArg):
+ v = v.resolve(context, session=session)
+ v = self._expand_mapped_field(k, v, context, session=session)
+ template_fields.discard(k)
+ kwargs[k] = v
def _expand_mapped_field(self, key: str, value: Any, context: Context, *,
session: Session) -> Any:
map_index = context["ti"].map_index
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 7e4b274626..449fab8af5 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -22,9 +22,12 @@ from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.types import NOTSET
if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
from airflow.models.operator import Operator
@@ -136,12 +139,15 @@ class XComArg(DependencyMixin):
"""Proxy to underlying operator set_downstream method. Required by
TaskMixin."""
self.operator.set_downstream(task_or_task_list, edge_modifier)
- def resolve(self, context: Context) -> Any:
+ @provide_session
+ def resolve(self, context: Context, session: "Session" = NEW_SESSION) ->
Any:
"""
Pull XCom value for the existing arg. This method is run during
``op.execute()``
in respectable context.
"""
- result = context["ti"].xcom_pull(task_ids=self.operator.task_id,
key=str(self.key), default=NOTSET)
+ result = context["ti"].xcom_pull(
+ task_ids=self.operator.task_id, key=str(self.key), default=NOTSET,
session=session
+ )
if result is NOTSET:
raise AirflowException(
f'XComArg result from {self.operator.task_id} at
{context["ti"].dag_id} '
diff --git a/docs/apache-airflow/concepts/dynamic-task-mapping.rst
b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
index b326a31672..d1152971dc 100644
--- a/docs/apache-airflow/concepts/dynamic-task-mapping.rst
+++ b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
@@ -224,6 +224,42 @@ Currently it is only possible to map against a dict, a
list, or one of those typ
If an upstream task returns an unmappable type, the mapped task will fail at
run-time with an ``UnmappableXComTypePushed`` exception. For instance, you
can't have the upstream task return a plain string – it must be a list or a
dict.
+How do templated fields and mapped arguments interact?
+======================================================
+
+All arguments to an operator can be mapped, even those that do not accept
templated parameters.
+
+If a field is marked as being templated and is mapped, it **will not be
templated**.
+
+For example, this will print ``{{ ds }}`` and not a date stamp:
+
+.. code-block:: python
+
+ @task
+ def make_list():
+ return ["{{ ds }}"]
+
+
+ @task
+ def printer(val):
+ print(val)
+
+
+ printer.expand(val=make_list())
+
+If you want to interpolate values either call ``task.render_template``
yourself, or use interpolation:
+
+.. code-block:: python
+
+ @task
+ def make_list(ds):
+ return [ds]
+
+
+ @task
+ def make_list(**context):
+ return [context["task"].render_template("{{ ds }}", context)]
+
Placing limits on mapped tasks
==============================
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 8406de45e6..418907fd78 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -27,6 +27,10 @@ from airflow.decorators import task as task_decorator
from airflow.decorators.base import DecoratedMappedOperator
from airflow.exceptions import AirflowException
from airflow.models import DAG
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskmap import TaskMap
+from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.models.xcom_arg import XComArg
from airflow.utils import timezone
from airflow.utils.state import State
@@ -641,3 +645,39 @@ def test_mapped_decorator_converts_partial_kwargs():
mapped_task1 = dag.get_task("task1")
assert mapped_task2.partial_kwargs["retry_delay"] == timedelta(seconds=30)
# Operator default.
mapped_task1.unmap().retry_delay == timedelta(seconds=300) # Operator
default.
+
+
+def test_mapped_render_template_fields(dag_maker, session):
+ @task_decorator
+ def fn(arg1, arg2):
+ ...
+
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ xcom_arg = XComArg(task1)
+ mapped = fn.partial(arg2='{{ ti.task_id }}').expand(arg1=xcom_arg)
+
+ dr = dag_maker.create_dagrun()
+ ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+ ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session)
+
+ session.add(
+ TaskMap(
+ dag_id=dr.dag_id,
+ task_id=task1.task_id,
+ run_id=dr.run_id,
+ map_index=-1,
+ length=1,
+ keys=None,
+ )
+ )
+ session.flush()
+
+ mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id,
session=session)
+ mapped_ti.map_index = 0
+ op =
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
+ assert op
+
+ assert op.op_kwargs['arg1'] == "{{ ds }}"
+ assert op.op_kwargs['arg2'] == "fn"
diff --git a/tests/models/test_baseoperator.py
b/tests/models/test_baseoperator.py
index 64a774a0da..3c90463c60 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -33,6 +33,7 @@ from airflow.models.baseoperator import BaseOperator,
BaseOperatorMeta, chain, c
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
+from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.models.xcom_arg import XComArg
from airflow.utils.context import Context
from airflow.utils.edgemodifier import Label
@@ -944,3 +945,45 @@ def
test_mapped_task_applies_default_args_taskflow(dag_maker):
assert dag.get_task("simple").execution_timeout == timedelta(minutes=30)
assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
+
+
+def test_mapped_render_template_fields_validating_operator(dag_maker, session):
+ class MyOperator(MockOperator):
+ def __init__(self, value, arg1, **kwargs):
+ assert isinstance(value, str), "value should have been resolved
before unmapping"
+ assert isinstance(arg1, str), "value should have been resolved
before unmapping"
+ super().__init__(arg1=arg1, **kwargs)
+ self.value = value
+
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ xcom_arg = XComArg(task1)
+ mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id
}}').expand(
+ value=xcom_arg, arg1=xcom_arg
+ )
+
+ dr = dag_maker.create_dagrun()
+ ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+ ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session)
+
+ session.add(
+ TaskMap(
+ dag_id=dr.dag_id,
+ task_id=task1.task_id,
+ run_id=dr.run_id,
+ map_index=-1,
+ length=1,
+ keys=None,
+ )
+ )
+ session.flush()
+
+ mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id,
session=session)
+ mapped_ti.map_index = 0
+ op =
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
+ assert isinstance(op, MyOperator)
+
+ assert op.value == "{{ ds }}", "Should not be templated!"
+ assert op.arg1 == "{{ ds }}"
+ assert op.arg2 == "a"