This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0ec9651f88 Split out and handle 'params' in mapped operator (#26100)
0ec9651f88 is described below

commit 0ec9651f88b7073f8de4df70b21e70be0bd08b2a
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Oct 6 12:47:32 2022 +0800

    Split out and handle 'params' in mapped operator (#26100)
---
 airflow/decorators/base.py                  |  5 ++--
 airflow/models/abstractoperator.py          |  6 ++---
 airflow/models/baseoperator.py              | 15 +++++++-----
 airflow/models/mappedoperator.py            | 32 ++++++++++++++++++++-----
 airflow/models/param.py                     | 36 +++++++++++++++++++++++++++++
 airflow/models/taskinstance.py              | 12 ++--------
 airflow/serialization/serialized_objects.py |  3 +++
 airflow/utils/context.py                    | 27 ++++++++++++----------
 airflow/utils/context.pyi                   |  9 ++++----
 tests/conftest.py                           |  3 ++-
 tests/models/test_dagrun.py                 | 27 ++++++++++++++++++++++
 tests/models/test_mappedoperator.py         | 20 ++++++++++++++++
 tests/models/test_taskinstance.py           | 16 ++++++-------
 13 files changed, 157 insertions(+), 54 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 474d90777e..715a5515cc 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -346,7 +346,7 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
         dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
         task_group = task_kwargs.pop("task_group", None) or 
TaskGroupContext.get_current_task_group(dag)
 
-        partial_kwargs, default_params = get_merged_defaults(
+        partial_kwargs, partial_params = get_merged_defaults(
             dag=dag,
             task_group=task_group,
             task_params=task_kwargs.pop("params", None),
@@ -357,7 +357,6 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
         task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, 
task_group)
         if task_group:
             task_id = task_group.child_id(task_id)
-        params = partial_kwargs.pop("params", None) or default_params
 
         # Logic here should be kept in sync with BaseOperatorMeta.partial().
         if "task_concurrency" in partial_kwargs:
@@ -397,7 +396,7 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
             expand_input=EXPAND_INPUT_EMPTY,  # Don't use this; mapped values 
go to op_kwargs_expand_input.
             partial_kwargs=partial_kwargs,
             task_id=task_id,
-            params=params,
+            params=partial_params,
             deps=MappedOperator.deps_for(self.operator_class),
             operator_extra_links=self.operator_class.operator_extra_links,
             template_ext=self.operator_class.template_ext,
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 6f3569ef3a..a7e3380dac 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -375,12 +375,12 @@ class AbstractOperator(LoggingMixin, DAGNode):
         context: Context,
         jinja_env: jinja2.Environment | None = None,
     ) -> None:
-        """Template all attributes listed in template_fields.
+        """Template all attributes listed in *self.template_fields*.
 
         If the operator is mapped, this should return the unmapped, fully
         rendered, and map-expanded operator. The mapped operator should not be
-        modified. However, ``context`` will be modified in-place to reference
-        the unmapped operator for template rendering.
+        modified. However, *context* may be modified in-place to reference the
+        unmapped operator for template rendering.
 
         If the operator is not mapped, this should modify the operator 
in-place.
         """
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 9d4489ade5..dc54ba6b90 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -237,7 +237,7 @@ def partial(
         task_id = task_group.child_id(task_id)
 
     # Merge DAG and task group level defaults into user-supplied values.
-    partial_kwargs, default_params = get_merged_defaults(
+    partial_kwargs, partial_params = get_merged_defaults(
         dag=dag,
         task_group=task_group,
         task_params=params,
@@ -253,7 +253,6 @@ def partial(
     partial_kwargs.setdefault("end_date", end_date)
     partial_kwargs.setdefault("owner", owner)
     partial_kwargs.setdefault("email", email)
-    partial_kwargs.setdefault("params", default_params)
     partial_kwargs.setdefault("trigger_rule", trigger_rule)
     partial_kwargs.setdefault("depends_on_past", depends_on_past)
     partial_kwargs.setdefault("ignore_first_depends_on_past", 
ignore_first_depends_on_past)
@@ -304,7 +303,11 @@ def partial(
     partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {}
     partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"])
 
-    return OperatorPartial(operator_class=operator_class, 
kwargs=partial_kwargs)
+    return OperatorPartial(
+        operator_class=operator_class,
+        kwargs=partial_kwargs,
+        params=partial_params,
+    )
 
 
 class BaseOperatorMeta(abc.ABCMeta):
@@ -1181,12 +1184,12 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         context: Context,
         jinja_env: jinja2.Environment | None = None,
     ) -> None:
-        """Template all attributes listed in template_fields.
+        """Template all attributes listed in *self.template_fields*.
 
         This mutates the attributes in-place and is irreversible.
 
-        :param context: Dict with values to apply on content
-        :param jinja_env: Jinja environment
+        :param context: Context dict with values to apply on content.
+        :param jinja_env: Jinja environment to use for rendering.
         """
         if not jinja_env:
             jinja_env = self.get_template_env()
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index bc276e73da..deba1e9d8c 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -19,6 +19,8 @@ from __future__ import annotations
 
 import collections
 import collections.abc
+import contextlib
+import copy
 import datetime
 import warnings
 from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, 
Iterator, Mapping, Sequence, Union
@@ -53,6 +55,7 @@ from airflow.models.expandinput import (
     OperatorExpandKwargsArgument,
     get_mappable_types,
 )
+from airflow.models.param import ParamsDict
 from airflow.models.pool import Pool
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
@@ -137,6 +140,7 @@ class OperatorPartial:
 
     operator_class: type[BaseOperator]
     kwargs: dict[str, Any]
+    params: ParamsDict | dict
 
     _expand_called: bool = False  # Set when expand() is called to ease user 
debugging.
 
@@ -187,7 +191,6 @@ class OperatorPartial:
 
         partial_kwargs = self.kwargs.copy()
         task_id = partial_kwargs.pop("task_id")
-        params = partial_kwargs.pop("params")
         dag = partial_kwargs.pop("dag")
         task_group = partial_kwargs.pop("task_group")
         start_date = partial_kwargs.pop("start_date")
@@ -203,7 +206,7 @@ class OperatorPartial:
             expand_input=expand_input,
             partial_kwargs=partial_kwargs,
             task_id=task_id,
-            params=params,
+            params=self.params,
             deps=MappedOperator.deps_for(self.operator_class),
             operator_extra_links=self.operator_class.operator_extra_links,
             template_ext=self.operator_class.template_ext,
@@ -253,7 +256,7 @@ class MappedOperator(AbstractOperator):
 
     # Needed for serialization.
     task_id: str
-    params: dict | None
+    params: ParamsDict | dict
     deps: frozenset[BaseTIDep]
     operator_extra_links: Collection[BaseOperatorLink]
     template_ext: Sequence[str]
@@ -539,16 +542,24 @@ class MappedOperator(AbstractOperator):
                 mapped_kwargs,
                 fail_reason="unmappable or already specified",
             )
-        # Ordering is significant; mapped kwargs should override partial ones.
+
+        # If params appears in the mapped kwargs, we need to merge it into the
+        # partial params, overriding existing keys.
+        params = copy.copy(self.params)
+        with contextlib.suppress(KeyError):
+            params.update(mapped_kwargs["params"])
+
+        # Ordering is significant; mapped kwargs should override partial ones,
+        # and the specially handled params should be respected.
         return {
             "task_id": self.task_id,
             "dag": self.dag,
             "task_group": self.task_group,
-            "params": self.params,
             "start_date": self.start_date,
             "end_date": self.end_date,
             **self.partial_kwargs,
             **mapped_kwargs,
+            "params": params,
         }
 
     def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, 
Session]) -> BaseOperator:
@@ -588,7 +599,7 @@ class MappedOperator(AbstractOperator):
         # mapped operator to a new SerializedBaseOperator instance.
         from airflow.serialization.serialized_objects import 
SerializedBaseOperator
 
-        op = SerializedBaseOperator(task_id=self.task_id, 
_airflow_from_mapped=True)
+        op = SerializedBaseOperator(task_id=self.task_id, params=self.params, 
_airflow_from_mapped=True)
         SerializedBaseOperator.populate_operator(op, self.operator_class)
         return op
 
@@ -739,6 +750,15 @@ class MappedOperator(AbstractOperator):
         context: Context,
         jinja_env: jinja2.Environment | None = None,
     ) -> None:
+        """Template all attributes listed in *self.template_fields*.
+
+        This updates *context* to reference the map-expanded task and relevant
+        information, without modifying the mapped operator. The expanded task
+        in *context* is then rendered in-place.
+
+        :param context: Context dict with values to apply on content.
+        :param jinja_env: Jinja environment to use for rendering.
+        """
         if not jinja_env:
             jinja_env = self.get_template_env()
 
diff --git a/airflow/models/param.py b/airflow/models/param.py
index a7d36c0aab..be625227c5 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 import contextlib
 import copy
 import json
+import logging
 import warnings
 from typing import TYPE_CHECKING, Any, ItemsView, MutableMapping, ValuesView
 
@@ -29,6 +30,10 @@ from airflow.utils.types import NOTSET, ArgNotSet
 
 if TYPE_CHECKING:
     from airflow.models.dag import DAG
+    from airflow.models.dagrun import DagRun
+    from airflow.models.operator import Operator
+
+logger = logging.getLogger(__name__)
 
 
 class Param:
@@ -132,6 +137,16 @@ class ParamsDict(MutableMapping[str, Any]):
         self.__dict = params_dict
         self.suppress_exception = suppress_exception
 
+    def __bool__(self) -> bool:
+        return bool(self.__dict)
+
+    def __eq__(self, other: Any) -> bool:
+        if isinstance(other, ParamsDict):
+            return self.dump() == other.dump()
+        if isinstance(other, dict):
+            return self.dump() == other
+        return NotImplemented
+
     def __copy__(self) -> ParamsDict:
         return ParamsDict(self.__dict, self.suppress_exception)
 
@@ -253,3 +268,24 @@ class DagParam(ResolveMixin):
         with contextlib.suppress(KeyError):
             return context['params'][self._name]
         raise AirflowException(f'No value could be resolved for parameter 
{self._name}')
+
+
+def process_params(
+    dag: DAG,
+    task: Operator,
+    dag_run: DagRun | None,
+    *,
+    suppress_exception: bool,
+) -> dict[str, Any]:
+    """Merge, validate params, and convert them into a simple dict."""
+    from airflow.configuration import conf
+
+    params = ParamsDict(suppress_exception=suppress_exception)
+    with contextlib.suppress(AttributeError):
+        params.update(dag.params)
+    if task.params:
+        params.update(task.params)
+    if conf.getboolean('core', 'dag_run_conf_overrides_params') and dag_run 
and dag_run.conf:
+        logger.debug("Updating task params (%s) with DagRun.conf (%s)", 
params, dag_run.conf)
+        params.update(dag_run.conf)
+    return params.validate()
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 190542063c..30fb9b5eb7 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -97,7 +97,7 @@ from airflow.exceptions import (
 )
 from airflow.models.base import Base, StringID
 from airflow.models.log import Log
-from airflow.models.param import ParamsDict
+from airflow.models.param import process_params
 from airflow.models.taskfail import TaskFail
 from airflow.models.taskmap import TaskMap
 from airflow.models.taskreschedule import TaskReschedule
@@ -1947,15 +1947,7 @@ class TaskInstance(Base, LoggingMixin):
         dag_run = self.get_dagrun(session)
         data_interval = dag.get_run_data_interval(dag_run)
 
-        # Validates Params and convert them into a simple dict.
-        params = ParamsDict(suppress_exception=ignore_param_exceptions)
-        with contextlib.suppress(AttributeError):
-            params.update(dag.params)
-        if task.params:
-            params.update(task.params)
-        if conf.getboolean('core', 'dag_run_conf_overrides_params'):
-            self.overwrite_params_with_dag_run_conf(params=params, 
dag_run=dag_run)
-        validated_params = params.validate()
+        validated_params = process_params(dag, task, dag_run, 
suppress_exception=ignore_param_exceptions)
 
         logical_date = timezone.coerce_datetime(self.execution_date)
         ds = logical_date.strftime('%Y-%m-%d')
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 8f5fcf8b26..ee4bc9c805 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -866,6 +866,9 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
                 v = cls._deserialize_deps(v)
             elif k == "params":
                 v = cls._deserialize_params_dict(v)
+                if op.params:  # Merge existing params if needed.
+                    v, new = op.params, v
+                    v.update(new)
             elif k == "partial_kwargs":
                 v = {arg: cls.deserialize(value) for arg, value in v.items()}
             elif k in {"expand_input", "op_kwargs_expand_input"}:
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index f356e7495f..9475abd5a4 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -258,6 +258,21 @@ def context_merge(context: Context, *args: Any, **kwargs: 
Any) -> None:
     context.update(*args, **kwargs)
 
 
+def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
+    """Update context after task unmapping.
+
+    Since ``get_template_context()`` is called before unmapping, the context
+    contains information about the mapped task. We need to do some in-place
+    updates to ensure the template context reflects the unmapped task instead.
+
+    :meta private:
+    """
+    from airflow.models.param import process_params
+
+    context["task"] = context["ti"].task = task
+    context["params"] = process_params(context["dag"], task, 
context["dag_run"], suppress_exception=False)
+
+
 def context_copy_partial(source: Context, keys: Container[str]) -> Context:
     """Create a context by copying items under selected keys in ``source``.
 
@@ -304,15 +319,3 @@ def lazy_mapping_from_context(source: Context) -> 
Mapping[str, Any]:
         return lazy_object_proxy.Proxy(factory)
 
     return {k: _create_value(k, v) for k, v in source._context.items()}
-
-
-def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
-    """Update context after task unmapping.
-
-    Since ``get_template_context()`` is called before unmapping, the context
-    contains information about the mapped task. We need to do some in-place
-    updates to ensure the template context reflects the unmapped task instead.
-
-    :meta private:
-    """
-    context["task"] = context["ti"].task = task
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index 7d58e1b3a1..bc82fdb334 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -25,7 +25,7 @@
 # undefined attribute errors from Mypy. Hopefully there will be a mechanism to
 # declare "these are defined, but don't error if others are accessed" someday.
 
-from typing import Any, Container, Iterable, Mapping, Optional, Set, Tuple, 
Union, overload
+from typing import Any, Collection, Container, Iterable, Mapping, Optional, 
Set, Tuple, Union, overload
 
 from pendulum import DateTime
 
@@ -99,10 +99,11 @@ class Context(TypedDict, total=False):
 class AirflowContextDeprecationWarning(DeprecationWarning): ...
 
 @overload
-def context_merge(source: Context, additions: Mapping[str, Any], **kwargs: 
Any) -> None: ...
+def context_merge(context: Context, additions: Mapping[str, Any], **kwargs: 
Any) -> None: ...
 @overload
-def context_merge(source: Context, additions: Iterable[Tuple[str, Any]], 
**kwargs: Any) -> None: ...
+def context_merge(context: Context, additions: Iterable[Tuple[str, Any]], 
**kwargs: Any) -> None: ...
 @overload
-def context_merge(source: Context, **kwargs: Any) -> None: ...
+def context_merge(context: Context, **kwargs: Any) -> None: ...
+def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: 
...
 def context_copy_partial(source: Context, keys: Container[str]) -> Context: ...
 def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ...
diff --git a/tests/conftest.py b/tests/conftest.py
index 1026dc3c01..1d73178f13 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -746,7 +746,7 @@ def create_task_instance(dag_maker, create_dummy_dag):
             from airflow.utils import timezone
 
             execution_date = timezone.utcnow()
-        create_dummy_dag(with_dagrun_type=None, **kwargs)
+        _, task = create_dummy_dag(with_dagrun_type=None, **kwargs)
 
         dagrun_kwargs = {"execution_date": execution_date, "state": 
dagrun_state}
         if run_id is not None:
@@ -757,6 +757,7 @@ def create_task_instance(dag_maker, create_dummy_dag):
             dagrun_kwargs["data_interval"] = data_interval
         dagrun = dag_maker.create_dagrun(**dagrun_kwargs)
         (ti,) = dagrun.task_instances
+        ti.task = task
         ti.state = state
 
         dag_maker.session.flush()
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 50e9e9a3d8..29d44a6527 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1945,3 +1945,30 @@ def 
test_schedulable_task_exist_when_rerun_removed_upstream_mapped_task(session,
     (tis, _) = dr.update_state()
     assert len(tis)
     assert dr.state != DagRunState.FAILED
+
+
[email protected](
+    "partial_params, mapped_params, expected",
+    [
+        pytest.param(None, [{"a": 1}], [[("a", 1)]], id="simple"),
+        pytest.param({"b": 2}, [{"a": 1}], [[("a", 1), ("b", 2)]], id="merge"),
+        pytest.param({"b": 2}, [{"a": 1, "b": 3}], [[("a", 1), ("b", 3)]], 
id="override"),
+    ],
+)
+def test_mapped_expand_against_params(dag_maker, partial_params, 
mapped_params, expected):
+    results = []
+
+    class PullOperator(BaseOperator):
+        def execute(self, context):
+            results.append(sorted(context["params"].items()))
+
+    with dag_maker():
+        PullOperator.partial(task_id="t", 
params=partial_params).expand(params=mapped_params)
+
+    dr: DagRun = dag_maker.create_dagrun()
+    decision = dr.task_instance_scheduling_decisions()
+
+    for ti in decision.schedulable_tis:
+        ti.run()
+
+    assert sorted(results) == expected
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 1faf42be3d..865e146985 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -25,6 +25,7 @@ import pytest
 from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.mappedoperator import MappedOperator
+from airflow.models.param import ParamsDict
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XCOM_RETURN_KEY
@@ -273,6 +274,25 @@ def 
test_mapped_task_applies_default_args_taskflow(dag_maker):
     assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30)
 
 
[email protected](
+    "dag_params, task_params, expected_partial_params",
+    [
+        pytest.param(None, None, ParamsDict(), id="none"),
+        pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"),
+        pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"),
+        pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}), 
id="merge"),
+    ],
+)
+def test_mapped_expand_against_params(dag_maker, dag_params, task_params, 
expected_partial_params):
+    with dag_maker(params=dag_params) as dag:
+        MockOperator.partial(task_id="t", 
params=task_params).expand(params=[{"c": "x"}, {"d": 1}])
+
+    t = dag.get_task("t")
+    assert isinstance(t, MappedOperator)
+    assert t.params == expected_partial_params
+    assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]}
+
+
 def test_mapped_render_template_fields_validating_operator(dag_maker, session):
     class MyOperator(MockOperator):
         def __init__(self, value, arg1, **kwargs):
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index b1e36dcbaa..2e02d77800 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -61,6 +61,7 @@ from airflow.models import (
 )
 from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, 
DatasetModel
 from airflow.models.expandinput import EXPAND_INPUT_EMPTY
+from airflow.models.param import process_params
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskfail import TaskFail
 from airflow.models.taskinstance import TaskInstance
@@ -1628,27 +1629,24 @@ class TestTaskInstance:
         ti = create_task_instance()
         dag_run = ti.dag_run
         dag_run.conf = {"override": True}
-        params = {"override": False}
-
-        ti.overwrite_params_with_dag_run_conf(params, dag_run)
+        ti.task.params = {"override": False}
 
+        params = process_params(ti.task.dag, ti.task, dag_run, 
suppress_exception=False)
         assert params["override"] is True
 
     def test_overwrite_params_with_dag_run_none(self, create_task_instance):
         ti = create_task_instance()
-        params = {"override": False}
-
-        ti.overwrite_params_with_dag_run_conf(params, None)
+        ti.task.params = {"override": False}
 
+        params = process_params(ti.task.dag, ti.task, None, 
suppress_exception=False)
         assert params["override"] is False
 
     def test_overwrite_params_with_dag_run_conf_none(self, 
create_task_instance):
         ti = create_task_instance()
-        params = {"override": False}
         dag_run = ti.dag_run
+        ti.task.params = {"override": False}
 
-        ti.overwrite_params_with_dag_run_conf(params, dag_run)
-
+        params = process_params(ti.task.dag, ti.task, dag_run, 
suppress_exception=False)
         assert params["override"] is False
 
     @pytest.mark.parametrize("use_native_obj", [True, False])

Reply via email to