This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-5-test by this push:
     new 4f6512046a Adding an example dag for dynamic task mapping (#28325)
4f6512046a is described below

commit 4f6512046a0f18f1e522bb8d7bf1b8725eaa78b1
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 16 01:59:15 2022 +0530

    Adding an example dag for dynamic task mapping (#28325)
    
    (cherry picked from commit b263dbcb0f84fd9029591d1447a7c843cb970f15)
---
 .../example_dags/example_dynamic_task_mapping.py   | 38 +++++++++++
 docker_tests/test_docker_compose_quick_start.py    |  1 -
 .../concepts/dynamic-task-mapping.rst              | 23 +------
 docs/build_docs.py                                 |  5 +-
 docs/exts/docs_build/spelling_checks.py            |  2 +-
 tests/serialization/test_dag_serialization.py      | 77 +++++++++++++++-------
 6 files changed, 99 insertions(+), 47 deletions(-)

diff --git a/airflow/example_dags/example_dynamic_task_mapping.py 
b/airflow/example_dags/example_dynamic_task_mapping.py
new file mode 100644
index 0000000000..dce6cda209
--- /dev/null
+++ b/airflow/example_dags/example_dynamic_task_mapping.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAG demonstrating the usage of dynamic task mapping."""
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.decorators import task
+
+with DAG(dag_id="example_dynamic_task_mapping", start_date=datetime(2022, 3, 
4)) as dag:
+
+    @task
+    def add_one(x: int):
+        return x + 1
+
+    @task
+    def sum_it(values):
+        total = sum(values)
+        print(f"Total was {total}")
+
+    added_values = add_one.expand(x=[1, 2, 3])
+    sum_it(added_values)
diff --git a/docker_tests/test_docker_compose_quick_start.py 
b/docker_tests/test_docker_compose_quick_start.py
index 6f25f62578..fd553ed175 100644
--- a/docker_tests/test_docker_compose_quick_start.py
+++ b/docker_tests/test_docker_compose_quick_start.py
@@ -27,7 +27,6 @@ from time import monotonic, sleep
 from unittest import mock
 
 import requests
-
 from docker_tests.command_utils import run_command
 from docker_tests.constants import SOURCE_ROOT
 from docker_tests.docker_tests_utils import docker_image
diff --git a/docs/apache-airflow/concepts/dynamic-task-mapping.rst 
b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
index 5ae0e9fb82..d15c0ada77 100644
--- a/docs/apache-airflow/concepts/dynamic-task-mapping.rst
+++ b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
@@ -30,27 +30,10 @@ Simple mapping
 
 In its simplest form you can map over a list defined directly in your DAG file 
using the ``expand()`` function instead of calling your task directly.
 
-.. code-block:: python
-
-    from datetime import datetime
-
-    from airflow import DAG
-    from airflow.decorators import task
-
-
-    with DAG(dag_id="simple_mapping", start_date=datetime(2022, 3, 4)) as dag:
-
-        @task
-        def add_one(x: int):
-            return x + 1
-
-        @task
-        def sum_it(values):
-            total = sum(values)
-            print(f"Total was {total}")
+If you want to see a simple usage of Dynamic Task Mapping, you can look below:
 
-        added_values = add_one.expand(x=[1, 2, 3])
-        sum_it(added_values)
+.. exampleinclude:: /../../airflow/example_dags/example_dynamic_task_mapping.py
+    :language: python
 
 This will show ``Total was 9`` in the task logs when executed.
 
diff --git a/docs/build_docs.py b/docs/build_docs.py
index d1fb06ccac..cd6c83249d 100755
--- a/docs/build_docs.py
+++ b/docs/build_docs.py
@@ -25,9 +25,6 @@ from collections import defaultdict
 from itertools import filterfalse, tee
 from typing import Callable, Iterable, NamedTuple, TypeVar
 
-from rich.console import Console
-from tabulate import tabulate
-
 from docs.exts.docs_build import dev_index_generator, lint_checks
 from docs.exts.docs_build.code_utils import CONSOLE_WIDTH, PROVIDER_INIT_FILE
 from docs.exts.docs_build.docs_builder import DOCS_DIR, AirflowDocsBuilder, 
get_available_packages
@@ -36,6 +33,8 @@ from docs.exts.docs_build.fetch_inventories import 
fetch_inventories
 from docs.exts.docs_build.github_action_utils import with_group
 from docs.exts.docs_build.package_filter import process_package_filters
 from docs.exts.docs_build.spelling_checks import SpellingError, 
display_spelling_error_summary
+from rich.console import Console
+from tabulate import tabulate
 
 TEXT_RED = "\033[31m"
 TEXT_RESET = "\033[0m"
diff --git a/docs/exts/docs_build/spelling_checks.py 
b/docs/exts/docs_build/spelling_checks.py
index bbaa9fa5dd..f89bfa50dc 100644
--- a/docs/exts/docs_build/spelling_checks.py
+++ b/docs/exts/docs_build/spelling_checks.py
@@ -21,10 +21,10 @@ import re
 from functools import total_ordering
 from typing import NamedTuple
 
+from docs.exts.docs_build.code_utils import CONSOLE_WIDTH
 from rich.console import Console
 
 from airflow.utils.code_utils import prepare_code_snippet
-from docs.exts.docs_build.code_utils import CONSOLE_WIDTH
 
 CURRENT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))
 DOCS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, os.pardir, os.pardir))
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 44411f5c07..ec07d60954 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -30,6 +30,7 @@ from glob import glob
 from pathlib import Path
 from unittest import mock
 
+import attr
 import pendulum
 import pytest
 from dateutil.relativedelta import FR, relativedelta
@@ -42,6 +43,7 @@ from airflow.hooks.base import BaseHook
 from airflow.kubernetes.pod_generator import PodGenerator
 from airflow.models import DAG, Connection, DagBag, Operator
 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.param import Param, ParamsDict
 from airflow.models.xcom import XCOM_RETURN_KEY, XCom
@@ -534,32 +536,47 @@ class TestStringifiedDAGs:
         serialized_task,
         task,
     ):
-        """Verify non-airflow operators are casted to BaseOperator."""
-        assert isinstance(serialized_task, SerializedBaseOperator)
+        """Verify non-Airflow operators are casted to BaseOperator or 
MappedOperator."""
         assert not isinstance(task, SerializedBaseOperator)
-        assert isinstance(task, BaseOperator)
+        assert isinstance(task, (BaseOperator, MappedOperator))
 
         # Every task should have a task_group property -- even if it's the 
DAG's root task group
         assert serialized_task.task_group
 
-        fields_to_check = task.get_serialized_fields() - {
-            # Checked separately
-            "_task_type",
-            "_operator_name",
-            "subdag",
-            # Type is excluded, so don't check it
-            "_log",
-            # List vs tuple. Check separately
-            "template_ext",
-            "template_fields",
-            # We store the string, real dag has the actual code
-            "on_failure_callback",
-            "on_success_callback",
-            "on_retry_callback",
-            # Checked separately
-            "resources",
-            "params",
-        }
+        if isinstance(task, BaseOperator):
+            assert isinstance(serialized_task, SerializedBaseOperator)
+            fields_to_check = task.get_serialized_fields() - {
+                # Checked separately
+                "_task_type",
+                "_operator_name",
+                "subdag",
+                # Type is excluded, so don't check it
+                "_log",
+                # List vs tuple. Check separately
+                "template_ext",
+                "template_fields",
+                # We store the string, real dag has the actual code
+                "on_failure_callback",
+                "on_success_callback",
+                "on_retry_callback",
+                # Checked separately
+                "resources",
+            }
+        else:  # Promised to be mapped by the assert above.
+            assert isinstance(serialized_task, MappedOperator)
+            fields_to_check = {f.name for f in attr.fields(MappedOperator)}
+            fields_to_check -= {
+                # Matching logic in BaseOperator.get_serialized_fields().
+                "dag",
+                "task_group",
+                # List vs tuple. Check separately.
+                "operator_extra_links",
+                "template_ext",
+                "template_fields",
+                # Checked separately.
+                "operator_class",
+                "partial_kwargs",
+            }
 
         assert serialized_task.task_type == task.task_type
 
@@ -580,9 +597,25 @@ class TestStringifiedDAGs:
             assert serialized_task.resources == task.resources
 
         # Ugly hack as some operators override params var in their init
-        if isinstance(task.params, ParamsDict):
+        if isinstance(task.params, ParamsDict) and 
isinstance(serialized_task.params, ParamsDict):
             assert serialized_task.params.dump() == task.params.dump()
 
+        if isinstance(task, MappedOperator):
+            # MappedOperator.operator_class holds a backup of the serialized
+            # data; checking its entirety basically duplicates this validation
+            # function, so we just do some satiny checks.
+            serialized_task.operator_class["_task_type"] == type(task).__name__
+            serialized_task.operator_class["_operator_name"] == 
task._operator_name
+
+            # Serialization cleans up default values in partial_kwargs, this
+            # adds them back to both sides.
+            default_partial_kwargs = (
+                BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, 
strict=False).partial_kwargs
+            )
+            serialized_partial_kwargs = {**default_partial_kwargs, 
**serialized_task.partial_kwargs}
+            original_partial_kwargs = {**default_partial_kwargs, 
**task.partial_kwargs}
+            assert serialized_partial_kwargs == original_partial_kwargs
+
         # Check that for Deserialized task, task.subdag is None for all other 
Operators
         # except for the SubDagOperator where task.subdag is an instance of 
DAG object
         if task.task_type == "SubDagOperator":

Reply via email to