This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-3-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4081a65b77c5105d2b3d892d5f25adc3bdb3b450 Author: Tzu-ping Chung <[email protected]> AuthorDate: Wed Jun 22 15:48:50 2022 +0800 Remove special serde logic for mapped op_kwargs (#23860) Co-authored-by: Daniel Standish <[email protected]> (cherry picked from commit 5877f45d65d5aa864941efebd2040661b6f89cb1) --- airflow/decorators/base.py | 10 +-------- airflow/models/mappedoperator.py | 1 + airflow/serialization/serialized_objects.py | 29 +++++---------------------- tests/serialization/test_dag_serialization.py | 29 +++++++++++++++------------ 4 files changed, 23 insertions(+), 46 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 1b14cd0668..92cf0691e4 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -39,7 +39,7 @@ from typing import ( import attr import typing_extensions -from airflow.compat.functools import cache, cached_property +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY from airflow.models.baseoperator import ( @@ -420,14 +420,6 @@ class DecoratedMappedOperator(MappedOperator): def __hash__(self): return id(self) - @classmethod - @cache - def get_serialized_fields(cls): - # The magic super() doesn't work here, so we use the explicit form. - # Not using super(..., cls) to work around pyupgrade bug. - sup = super(DecoratedMappedOperator, DecoratedMappedOperator) - return sup.get_serialized_fields() | {"mapped_op_kwargs"} - def __attrs_post_init__(self): # The magic super() doesn't work here, so we use the explicit form. # Not using super(..., self) to work around pyupgrade bug. diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 663ceeece1..6b202d2cc6 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -323,6 +323,7 @@ class MappedOperator(AbstractOperator): "dag", "deps", "is_mapped", + "mapped_kwargs", # This is needed to be able to accept XComArg. "subdag", "task_group", "upstream_task_ids", diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 3e674b2f8d..dd3dc4404e 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -16,7 +16,7 @@ # under the License. """Serialized DAG and BaseOperator""" -import contextlib + import datetime import enum import logging @@ -592,6 +592,9 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: serialized_op = cls._serialize_node(op, include_deps=op.deps is MappedOperator.deps_for(BaseOperator)) + # Handle mapped_kwargs and mapped_op_kwargs. + serialized_op[op._expansion_kwargs_attr] = cls._serialize(op._get_expansion_kwargs()) + # Simplify partial_kwargs by comparing it to the most barebone object. # Remove all entries that are simply default values. serialized_partial = serialized_op["partial_kwargs"] @@ -603,20 +606,6 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): if v == default: del serialized_partial[k] - # Simplify op_kwargs format. It must be a dict, so we flatten it. - with contextlib.suppress(KeyError): - op_kwargs = serialized_op["mapped_kwargs"]["op_kwargs"] - assert op_kwargs[Encoding.TYPE] == DAT.DICT - serialized_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] - with contextlib.suppress(KeyError): - op_kwargs = serialized_op["partial_kwargs"]["op_kwargs"] - assert op_kwargs[Encoding.TYPE] == DAT.DICT - serialized_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] - with contextlib.suppress(KeyError): - op_kwargs = serialized_op["mapped_op_kwargs"] - assert op_kwargs[Encoding.TYPE] == DAT.DICT - serialized_op["mapped_op_kwargs"] = op_kwargs[Encoding.VAR] - serialized_op["_is_mapped"] = True return serialized_op @@ -752,15 +741,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): v = cls._deserialize_deps(v) elif k == "params": v = cls._deserialize_params_dict(v) - elif k in ("mapped_kwargs", "partial_kwargs"): - if "op_kwargs" not in v: - op_kwargs: Optional[dict] = None - else: - op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()} - v = {arg: cls._deserialize(value) for arg, value in v.items()} - if op_kwargs is not None: - v["op_kwargs"] = op_kwargs - elif k == "mapped_op_kwargs": + elif k == "partial_kwargs": v = {arg: cls._deserialize(value) for arg, value in v.items()} elif k in cls._decorated_fields or k not in op.get_serialized_fields(): v = cls._deserialize(v) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index fe9fc7c7e5..7d6a43e933 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1688,18 +1688,13 @@ def test_mapped_operator_serde(): '_task_type': 'BashOperator', 'downstream_task_ids': [], 'mapped_kwargs': { - 'bash_command': [ - 1, - 2, - {"__type": "dict", "__var": {'a': 'b'}}, - ] + "__type": "dict", + "__var": {'bash_command': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]}, }, 'partial_kwargs': { 'executor_config': { '__type': 'dict', - '__var': { - 'dict': {"__type": "dict", "__var": {'sub': 'value'}}, - }, + '__var': {'dict': {"__type": "dict", "__var": {'sub': 'value'}}}, }, }, 'task_id': 'a', @@ -1744,7 +1739,10 @@ def test_mapped_operator_xcomarg_serde(): '_task_module': 'tests.test_utils.mock_operators', '_task_type': 'MockOperator', 'downstream_task_ids': [], - 'mapped_kwargs': {'arg2': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}}, + 'mapped_kwargs': { + "__type": "dict", + "__var": {'arg2': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}}, + }, 'partial_kwargs': {}, 'task_id': 'task_2', 'template_fields': ['arg1', 'arg2'], @@ -1825,13 +1823,18 @@ def test_mapped_decorator_serde(): 'downstream_task_ids': [], 'partial_kwargs': { 'op_args': [], - 'op_kwargs': {'arg1': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]}, + 'op_kwargs': { + '__type': 'dict', + '__var': {'arg1': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]}, + }, 'retry_delay': {'__type': 'timedelta', '__var': 30.0}, }, - 'mapped_kwargs': {}, 'mapped_op_kwargs': { - 'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}}, - 'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}, + "__type": "dict", + "__var": { + 'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}}, + 'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}, + }, }, 'operator_extra_links': [], 'ui_color': '#ffefeb',
