This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0ec9651f88 Split out and handle 'params' in mapped operator (#26100)
0ec9651f88 is described below
commit 0ec9651f88b7073f8de4df70b21e70be0bd08b2a
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Oct 6 12:47:32 2022 +0800
Split out and handle 'params' in mapped operator (#26100)
---
airflow/decorators/base.py | 5 ++--
airflow/models/abstractoperator.py | 6 ++---
airflow/models/baseoperator.py | 15 +++++++-----
airflow/models/mappedoperator.py | 32 ++++++++++++++++++++-----
airflow/models/param.py | 36 +++++++++++++++++++++++++++++
airflow/models/taskinstance.py | 12 ++--------
airflow/serialization/serialized_objects.py | 3 +++
airflow/utils/context.py | 27 ++++++++++++----------
airflow/utils/context.pyi | 9 ++++----
tests/conftest.py | 3 ++-
tests/models/test_dagrun.py | 27 ++++++++++++++++++++++
tests/models/test_mappedoperator.py | 20 ++++++++++++++++
tests/models/test_taskinstance.py | 16 ++++++-------
13 files changed, 157 insertions(+), 54 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 474d90777e..715a5515cc 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -346,7 +346,7 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
task_group = task_kwargs.pop("task_group", None) or
TaskGroupContext.get_current_task_group(dag)
- partial_kwargs, default_params = get_merged_defaults(
+ partial_kwargs, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=task_kwargs.pop("params", None),
@@ -357,7 +357,6 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag,
task_group)
if task_group:
task_id = task_group.child_id(task_id)
- params = partial_kwargs.pop("params", None) or default_params
# Logic here should be kept in sync with BaseOperatorMeta.partial().
if "task_concurrency" in partial_kwargs:
@@ -397,7 +396,7 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values
go to op_kwargs_expand_input.
partial_kwargs=partial_kwargs,
task_id=task_id,
- params=params,
+ params=partial_params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index 6f3569ef3a..a7e3380dac 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -375,12 +375,12 @@ class AbstractOperator(LoggingMixin, DAGNode):
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
- """Template all attributes listed in template_fields.
+ """Template all attributes listed in *self.template_fields*.
If the operator is mapped, this should return the unmapped, fully
rendered, and map-expanded operator. The mapped operator should not be
- modified. However, ``context`` will be modified in-place to reference
- the unmapped operator for template rendering.
+ modified. However, *context* may be modified in-place to reference the
+ unmapped operator for template rendering.
If the operator is not mapped, this should modify the operator
in-place.
"""
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 9d4489ade5..dc54ba6b90 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -237,7 +237,7 @@ def partial(
task_id = task_group.child_id(task_id)
# Merge DAG and task group level defaults into user-supplied values.
- partial_kwargs, default_params = get_merged_defaults(
+ partial_kwargs, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=params,
@@ -253,7 +253,6 @@ def partial(
partial_kwargs.setdefault("end_date", end_date)
partial_kwargs.setdefault("owner", owner)
partial_kwargs.setdefault("email", email)
- partial_kwargs.setdefault("params", default_params)
partial_kwargs.setdefault("trigger_rule", trigger_rule)
partial_kwargs.setdefault("depends_on_past", depends_on_past)
partial_kwargs.setdefault("ignore_first_depends_on_past",
ignore_first_depends_on_past)
@@ -304,7 +303,11 @@ def partial(
partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {}
partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"])
- return OperatorPartial(operator_class=operator_class,
kwargs=partial_kwargs)
+ return OperatorPartial(
+ operator_class=operator_class,
+ kwargs=partial_kwargs,
+ params=partial_params,
+ )
class BaseOperatorMeta(abc.ABCMeta):
@@ -1181,12 +1184,12 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
- """Template all attributes listed in template_fields.
+ """Template all attributes listed in *self.template_fields*.
This mutates the attributes in-place and is irreversible.
- :param context: Dict with values to apply on content
- :param jinja_env: Jinja environment
+ :param context: Context dict with values to apply on content.
+ :param jinja_env: Jinja environment to use for rendering.
"""
if not jinja_env:
jinja_env = self.get_template_env()
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index bc276e73da..deba1e9d8c 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -19,6 +19,8 @@ from __future__ import annotations
import collections
import collections.abc
+import contextlib
+import copy
import datetime
import warnings
from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable,
Iterator, Mapping, Sequence, Union
@@ -53,6 +55,7 @@ from airflow.models.expandinput import (
OperatorExpandKwargsArgument,
get_mappable_types,
)
+from airflow.models.param import ParamsDict
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
@@ -137,6 +140,7 @@ class OperatorPartial:
operator_class: type[BaseOperator]
kwargs: dict[str, Any]
+ params: ParamsDict | dict
_expand_called: bool = False # Set when expand() is called to ease user
debugging.
@@ -187,7 +191,6 @@ class OperatorPartial:
partial_kwargs = self.kwargs.copy()
task_id = partial_kwargs.pop("task_id")
- params = partial_kwargs.pop("params")
dag = partial_kwargs.pop("dag")
task_group = partial_kwargs.pop("task_group")
start_date = partial_kwargs.pop("start_date")
@@ -203,7 +206,7 @@ class OperatorPartial:
expand_input=expand_input,
partial_kwargs=partial_kwargs,
task_id=task_id,
- params=params,
+ params=self.params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
@@ -253,7 +256,7 @@ class MappedOperator(AbstractOperator):
# Needed for serialization.
task_id: str
- params: dict | None
+ params: ParamsDict | dict
deps: frozenset[BaseTIDep]
operator_extra_links: Collection[BaseOperatorLink]
template_ext: Sequence[str]
@@ -539,16 +542,24 @@ class MappedOperator(AbstractOperator):
mapped_kwargs,
fail_reason="unmappable or already specified",
)
- # Ordering is significant; mapped kwargs should override partial ones.
+
+ # If params appears in the mapped kwargs, we need to merge it into the
+ # partial params, overriding existing keys.
+ params = copy.copy(self.params)
+ with contextlib.suppress(KeyError):
+ params.update(mapped_kwargs["params"])
+
+ # Ordering is significant; mapped kwargs should override partial ones,
+ # and the specially handled params should be respected.
return {
"task_id": self.task_id,
"dag": self.dag,
"task_group": self.task_group,
- "params": self.params,
"start_date": self.start_date,
"end_date": self.end_date,
**self.partial_kwargs,
**mapped_kwargs,
+ "params": params,
}
def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context,
Session]) -> BaseOperator:
@@ -588,7 +599,7 @@ class MappedOperator(AbstractOperator):
# mapped operator to a new SerializedBaseOperator instance.
from airflow.serialization.serialized_objects import
SerializedBaseOperator
- op = SerializedBaseOperator(task_id=self.task_id,
_airflow_from_mapped=True)
+ op = SerializedBaseOperator(task_id=self.task_id, params=self.params,
_airflow_from_mapped=True)
SerializedBaseOperator.populate_operator(op, self.operator_class)
return op
@@ -739,6 +750,15 @@ class MappedOperator(AbstractOperator):
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
+ """Template all attributes listed in *self.template_fields*.
+
+ This updates *context* to reference the map-expanded task and relevant
+ information, without modifying the mapped operator. The expanded task
+ in *context* is then rendered in-place.
+
+ :param context: Context dict with values to apply on content.
+ :param jinja_env: Jinja environment to use for rendering.
+ """
if not jinja_env:
jinja_env = self.get_template_env()
diff --git a/airflow/models/param.py b/airflow/models/param.py
index a7d36c0aab..be625227c5 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import contextlib
import copy
import json
+import logging
import warnings
from typing import TYPE_CHECKING, Any, ItemsView, MutableMapping, ValuesView
@@ -29,6 +30,10 @@ from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from airflow.models.dag import DAG
+ from airflow.models.dagrun import DagRun
+ from airflow.models.operator import Operator
+
+logger = logging.getLogger(__name__)
class Param:
@@ -132,6 +137,16 @@ class ParamsDict(MutableMapping[str, Any]):
self.__dict = params_dict
self.suppress_exception = suppress_exception
+ def __bool__(self) -> bool:
+ return bool(self.__dict)
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, ParamsDict):
+ return self.dump() == other.dump()
+ if isinstance(other, dict):
+ return self.dump() == other
+ return NotImplemented
+
def __copy__(self) -> ParamsDict:
return ParamsDict(self.__dict, self.suppress_exception)
@@ -253,3 +268,24 @@ class DagParam(ResolveMixin):
with contextlib.suppress(KeyError):
return context['params'][self._name]
raise AirflowException(f'No value could be resolved for parameter
{self._name}')
+
+
+def process_params(
+ dag: DAG,
+ task: Operator,
+ dag_run: DagRun | None,
+ *,
+ suppress_exception: bool,
+) -> dict[str, Any]:
+ """Merge, validate params, and convert them into a simple dict."""
+ from airflow.configuration import conf
+
+ params = ParamsDict(suppress_exception=suppress_exception)
+ with contextlib.suppress(AttributeError):
+ params.update(dag.params)
+ if task.params:
+ params.update(task.params)
+ if conf.getboolean('core', 'dag_run_conf_overrides_params') and dag_run
and dag_run.conf:
+ logger.debug("Updating task params (%s) with DagRun.conf (%s)",
params, dag_run.conf)
+ params.update(dag_run.conf)
+ return params.validate()
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 190542063c..30fb9b5eb7 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -97,7 +97,7 @@ from airflow.exceptions import (
)
from airflow.models.base import Base, StringID
from airflow.models.log import Log
-from airflow.models.param import ParamsDict
+from airflow.models.param import process_params
from airflow.models.taskfail import TaskFail
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
@@ -1947,15 +1947,7 @@ class TaskInstance(Base, LoggingMixin):
dag_run = self.get_dagrun(session)
data_interval = dag.get_run_data_interval(dag_run)
- # Validates Params and convert them into a simple dict.
- params = ParamsDict(suppress_exception=ignore_param_exceptions)
- with contextlib.suppress(AttributeError):
- params.update(dag.params)
- if task.params:
- params.update(task.params)
- if conf.getboolean('core', 'dag_run_conf_overrides_params'):
- self.overwrite_params_with_dag_run_conf(params=params,
dag_run=dag_run)
- validated_params = params.validate()
+ validated_params = process_params(dag, task, dag_run,
suppress_exception=ignore_param_exceptions)
logical_date = timezone.coerce_datetime(self.execution_date)
ds = logical_date.strftime('%Y-%m-%d')
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 8f5fcf8b26..ee4bc9c805 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -866,6 +866,9 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
v = cls._deserialize_deps(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
+ if op.params: # Merge existing params if needed.
+ v, new = op.params, v
+ v.update(new)
elif k == "partial_kwargs":
v = {arg: cls.deserialize(value) for arg, value in v.items()}
elif k in {"expand_input", "op_kwargs_expand_input"}:
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index f356e7495f..9475abd5a4 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -258,6 +258,21 @@ def context_merge(context: Context, *args: Any, **kwargs:
Any) -> None:
context.update(*args, **kwargs)
+def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
+ """Update context after task unmapping.
+
+ Since ``get_template_context()`` is called before unmapping, the context
+ contains information about the mapped task. We need to do some in-place
+ updates to ensure the template context reflects the unmapped task instead.
+
+ :meta private:
+ """
+ from airflow.models.param import process_params
+
+ context["task"] = context["ti"].task = task
+ context["params"] = process_params(context["dag"], task,
context["dag_run"], suppress_exception=False)
+
+
def context_copy_partial(source: Context, keys: Container[str]) -> Context:
"""Create a context by copying items under selected keys in ``source``.
@@ -304,15 +319,3 @@ def lazy_mapping_from_context(source: Context) ->
Mapping[str, Any]:
return lazy_object_proxy.Proxy(factory)
return {k: _create_value(k, v) for k, v in source._context.items()}
-
-
-def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
- """Update context after task unmapping.
-
- Since ``get_template_context()`` is called before unmapping, the context
- contains information about the mapped task. We need to do some in-place
- updates to ensure the template context reflects the unmapped task instead.
-
- :meta private:
- """
- context["task"] = context["ti"].task = task
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index 7d58e1b3a1..bc82fdb334 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -25,7 +25,7 @@
# undefined attribute errors from Mypy. Hopefully there will be a mechanism to
# declare "these are defined, but don't error if others are accessed" someday.
-from typing import Any, Container, Iterable, Mapping, Optional, Set, Tuple,
Union, overload
+from typing import Any, Collection, Container, Iterable, Mapping, Optional,
Set, Tuple, Union, overload
from pendulum import DateTime
@@ -99,10 +99,11 @@ class Context(TypedDict, total=False):
class AirflowContextDeprecationWarning(DeprecationWarning): ...
@overload
-def context_merge(source: Context, additions: Mapping[str, Any], **kwargs:
Any) -> None: ...
+def context_merge(context: Context, additions: Mapping[str, Any], **kwargs:
Any) -> None: ...
@overload
-def context_merge(source: Context, additions: Iterable[Tuple[str, Any]],
**kwargs: Any) -> None: ...
+def context_merge(context: Context, additions: Iterable[Tuple[str, Any]],
**kwargs: Any) -> None: ...
@overload
-def context_merge(source: Context, **kwargs: Any) -> None: ...
+def context_merge(context: Context, **kwargs: Any) -> None: ...
+def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
...
def context_copy_partial(source: Context, keys: Container[str]) -> Context: ...
def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ...
diff --git a/tests/conftest.py b/tests/conftest.py
index 1026dc3c01..1d73178f13 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -746,7 +746,7 @@ def create_task_instance(dag_maker, create_dummy_dag):
from airflow.utils import timezone
execution_date = timezone.utcnow()
- create_dummy_dag(with_dagrun_type=None, **kwargs)
+ _, task = create_dummy_dag(with_dagrun_type=None, **kwargs)
dagrun_kwargs = {"execution_date": execution_date, "state":
dagrun_state}
if run_id is not None:
@@ -757,6 +757,7 @@ def create_task_instance(dag_maker, create_dummy_dag):
dagrun_kwargs["data_interval"] = data_interval
dagrun = dag_maker.create_dagrun(**dagrun_kwargs)
(ti,) = dagrun.task_instances
+ ti.task = task
ti.state = state
dag_maker.session.flush()
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 50e9e9a3d8..29d44a6527 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1945,3 +1945,30 @@ def
test_schedulable_task_exist_when_rerun_removed_upstream_mapped_task(session,
(tis, _) = dr.update_state()
assert len(tis)
assert dr.state != DagRunState.FAILED
+
+
[email protected](
+ "partial_params, mapped_params, expected",
+ [
+ pytest.param(None, [{"a": 1}], [[("a", 1)]], id="simple"),
+ pytest.param({"b": 2}, [{"a": 1}], [[("a", 1), ("b", 2)]], id="merge"),
+ pytest.param({"b": 2}, [{"a": 1, "b": 3}], [[("a", 1), ("b", 3)]],
id="override"),
+ ],
+)
+def test_mapped_expand_against_params(dag_maker, partial_params,
mapped_params, expected):
+ results = []
+
+ class PullOperator(BaseOperator):
+ def execute(self, context):
+ results.append(sorted(context["params"].items()))
+
+ with dag_maker():
+ PullOperator.partial(task_id="t",
params=partial_params).expand(params=mapped_params)
+
+ dr: DagRun = dag_maker.create_dagrun()
+ decision = dr.task_instance_scheduling_decisions()
+
+ for ti in decision.schedulable_tis:
+ ti.run()
+
+ assert sorted(results) == expected
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index 1faf42be3d..865e146985 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -25,6 +25,7 @@ import pytest
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
+from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
@@ -273,6 +274,25 @@ def
test_mapped_task_applies_default_args_taskflow(dag_maker):
assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
[email protected](
+ "dag_params, task_params, expected_partial_params",
+ [
+ pytest.param(None, None, ParamsDict(), id="none"),
+ pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"),
+ pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"),
+ pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}),
id="merge"),
+ ],
+)
+def test_mapped_expand_against_params(dag_maker, dag_params, task_params,
expected_partial_params):
+ with dag_maker(params=dag_params) as dag:
+ MockOperator.partial(task_id="t",
params=task_params).expand(params=[{"c": "x"}, {"d": 1}])
+
+ t = dag.get_task("t")
+ assert isinstance(t, MappedOperator)
+ assert t.params == expected_partial_params
+ assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]}
+
+
def test_mapped_render_template_fields_validating_operator(dag_maker, session):
class MyOperator(MockOperator):
def __init__(self, value, arg1, **kwargs):
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index b1e36dcbaa..2e02d77800 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -61,6 +61,7 @@ from airflow.models import (
)
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent,
DatasetModel
from airflow.models.expandinput import EXPAND_INPUT_EMPTY
+from airflow.models.param import process_params
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance
@@ -1628,27 +1629,24 @@ class TestTaskInstance:
ti = create_task_instance()
dag_run = ti.dag_run
dag_run.conf = {"override": True}
- params = {"override": False}
-
- ti.overwrite_params_with_dag_run_conf(params, dag_run)
+ ti.task.params = {"override": False}
+ params = process_params(ti.task.dag, ti.task, dag_run,
suppress_exception=False)
assert params["override"] is True
def test_overwrite_params_with_dag_run_none(self, create_task_instance):
ti = create_task_instance()
- params = {"override": False}
-
- ti.overwrite_params_with_dag_run_conf(params, None)
+ ti.task.params = {"override": False}
+ params = process_params(ti.task.dag, ti.task, None,
suppress_exception=False)
assert params["override"] is False
def test_overwrite_params_with_dag_run_conf_none(self,
create_task_instance):
ti = create_task_instance()
- params = {"override": False}
dag_run = ti.dag_run
+ ti.task.params = {"override": False}
- ti.overwrite_params_with_dag_run_conf(params, dag_run)
-
+ params = process_params(ti.task.dag, ti.task, dag_run,
suppress_exception=False)
assert params["override"] is False
@pytest.mark.parametrize("use_native_obj", [True, False])