This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-5-test by this push:
new 4f6512046a Adding an example dag for dynamic task mapping (#28325)
4f6512046a is described below
commit 4f6512046a0f18f1e522bb8d7bf1b8725eaa78b1
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 16 01:59:15 2022 +0530
Adding an example dag for dynamic task mapping (#28325)
(cherry picked from commit b263dbcb0f84fd9029591d1447a7c843cb970f15)
---
.../example_dags/example_dynamic_task_mapping.py | 38 +++++++++++
docker_tests/test_docker_compose_quick_start.py | 1 -
.../concepts/dynamic-task-mapping.rst | 23 +------
docs/build_docs.py | 5 +-
docs/exts/docs_build/spelling_checks.py | 2 +-
tests/serialization/test_dag_serialization.py | 77 +++++++++++++++-------
6 files changed, 99 insertions(+), 47 deletions(-)
diff --git a/airflow/example_dags/example_dynamic_task_mapping.py
b/airflow/example_dags/example_dynamic_task_mapping.py
new file mode 100644
index 0000000000..dce6cda209
--- /dev/null
+++ b/airflow/example_dags/example_dynamic_task_mapping.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAG demonstrating the usage of dynamic task mapping."""
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.decorators import task
+
+with DAG(dag_id="example_dynamic_task_mapping", start_date=datetime(2022, 3,
4)) as dag:
+
+ @task
+ def add_one(x: int):
+ return x + 1
+
+ @task
+ def sum_it(values):
+ total = sum(values)
+ print(f"Total was {total}")
+
+ added_values = add_one.expand(x=[1, 2, 3])
+ sum_it(added_values)
diff --git a/docker_tests/test_docker_compose_quick_start.py
b/docker_tests/test_docker_compose_quick_start.py
index 6f25f62578..fd553ed175 100644
--- a/docker_tests/test_docker_compose_quick_start.py
+++ b/docker_tests/test_docker_compose_quick_start.py
@@ -27,7 +27,6 @@ from time import monotonic, sleep
from unittest import mock
import requests
-
from docker_tests.command_utils import run_command
from docker_tests.constants import SOURCE_ROOT
from docker_tests.docker_tests_utils import docker_image
diff --git a/docs/apache-airflow/concepts/dynamic-task-mapping.rst
b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
index 5ae0e9fb82..d15c0ada77 100644
--- a/docs/apache-airflow/concepts/dynamic-task-mapping.rst
+++ b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
@@ -30,27 +30,10 @@ Simple mapping
In its simplest form you can map over a list defined directly in your DAG file
using the ``expand()`` function instead of calling your task directly.
-.. code-block:: python
-
- from datetime import datetime
-
- from airflow import DAG
- from airflow.decorators import task
-
-
- with DAG(dag_id="simple_mapping", start_date=datetime(2022, 3, 4)) as dag:
-
- @task
- def add_one(x: int):
- return x + 1
-
- @task
- def sum_it(values):
- total = sum(values)
- print(f"Total was {total}")
+If you want to see a simple usage of Dynamic Task Mapping, you can look below:
- added_values = add_one.expand(x=[1, 2, 3])
- sum_it(added_values)
+.. exampleinclude:: /../../airflow/example_dags/example_dynamic_task_mapping.py
+ :language: python
This will show ``Total was 9`` in the task logs when executed.
diff --git a/docs/build_docs.py b/docs/build_docs.py
index d1fb06ccac..cd6c83249d 100755
--- a/docs/build_docs.py
+++ b/docs/build_docs.py
@@ -25,9 +25,6 @@ from collections import defaultdict
from itertools import filterfalse, tee
from typing import Callable, Iterable, NamedTuple, TypeVar
-from rich.console import Console
-from tabulate import tabulate
-
from docs.exts.docs_build import dev_index_generator, lint_checks
from docs.exts.docs_build.code_utils import CONSOLE_WIDTH, PROVIDER_INIT_FILE
from docs.exts.docs_build.docs_builder import DOCS_DIR, AirflowDocsBuilder,
get_available_packages
@@ -36,6 +33,8 @@ from docs.exts.docs_build.fetch_inventories import
fetch_inventories
from docs.exts.docs_build.github_action_utils import with_group
from docs.exts.docs_build.package_filter import process_package_filters
from docs.exts.docs_build.spelling_checks import SpellingError,
display_spelling_error_summary
+from rich.console import Console
+from tabulate import tabulate
TEXT_RED = "\033[31m"
TEXT_RESET = "\033[0m"
diff --git a/docs/exts/docs_build/spelling_checks.py
b/docs/exts/docs_build/spelling_checks.py
index bbaa9fa5dd..f89bfa50dc 100644
--- a/docs/exts/docs_build/spelling_checks.py
+++ b/docs/exts/docs_build/spelling_checks.py
@@ -21,10 +21,10 @@ import re
from functools import total_ordering
from typing import NamedTuple
+from docs.exts.docs_build.code_utils import CONSOLE_WIDTH
from rich.console import Console
from airflow.utils.code_utils import prepare_code_snippet
-from docs.exts.docs_build.code_utils import CONSOLE_WIDTH
CURRENT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))
DOCS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, os.pardir, os.pardir))
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 44411f5c07..ec07d60954 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -30,6 +30,7 @@ from glob import glob
from pathlib import Path
from unittest import mock
+import attr
import pendulum
import pytest
from dateutil.relativedelta import FR, relativedelta
@@ -42,6 +43,7 @@ from airflow.hooks.base import BaseHook
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.models import DAG, Connection, DagBag, Operator
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
+from airflow.models.expandinput import EXPAND_INPUT_EMPTY
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
from airflow.models.xcom import XCOM_RETURN_KEY, XCom
@@ -534,32 +536,47 @@ class TestStringifiedDAGs:
serialized_task,
task,
):
- """Verify non-airflow operators are casted to BaseOperator."""
- assert isinstance(serialized_task, SerializedBaseOperator)
+ """Verify non-Airflow operators are casted to BaseOperator or
MappedOperator."""
assert not isinstance(task, SerializedBaseOperator)
- assert isinstance(task, BaseOperator)
+ assert isinstance(task, (BaseOperator, MappedOperator))
# Every task should have a task_group property -- even if it's the
DAG's root task group
assert serialized_task.task_group
- fields_to_check = task.get_serialized_fields() - {
- # Checked separately
- "_task_type",
- "_operator_name",
- "subdag",
- # Type is excluded, so don't check it
- "_log",
- # List vs tuple. Check separately
- "template_ext",
- "template_fields",
- # We store the string, real dag has the actual code
- "on_failure_callback",
- "on_success_callback",
- "on_retry_callback",
- # Checked separately
- "resources",
- "params",
- }
+ if isinstance(task, BaseOperator):
+ assert isinstance(serialized_task, SerializedBaseOperator)
+ fields_to_check = task.get_serialized_fields() - {
+ # Checked separately
+ "_task_type",
+ "_operator_name",
+ "subdag",
+ # Type is excluded, so don't check it
+ "_log",
+ # List vs tuple. Check separately
+ "template_ext",
+ "template_fields",
+ # We store the string, real dag has the actual code
+ "on_failure_callback",
+ "on_success_callback",
+ "on_retry_callback",
+ # Checked separately
+ "resources",
+ }
+ else: # Promised to be mapped by the assert above.
+ assert isinstance(serialized_task, MappedOperator)
+ fields_to_check = {f.name for f in attr.fields(MappedOperator)}
+ fields_to_check -= {
+ # Matching logic in BaseOperator.get_serialized_fields().
+ "dag",
+ "task_group",
+ # List vs tuple. Check separately.
+ "operator_extra_links",
+ "template_ext",
+ "template_fields",
+ # Checked separately.
+ "operator_class",
+ "partial_kwargs",
+ }
assert serialized_task.task_type == task.task_type
@@ -580,9 +597,25 @@ class TestStringifiedDAGs:
assert serialized_task.resources == task.resources
# Ugly hack as some operators override params var in their init
- if isinstance(task.params, ParamsDict):
+ if isinstance(task.params, ParamsDict) and
isinstance(serialized_task.params, ParamsDict):
assert serialized_task.params.dump() == task.params.dump()
+ if isinstance(task, MappedOperator):
+ # MappedOperator.operator_class holds a backup of the serialized
+ # data; checking its entirety basically duplicates this validation
+ # function, so we just do some satiny checks.
+ serialized_task.operator_class["_task_type"] == type(task).__name__
+ serialized_task.operator_class["_operator_name"] ==
task._operator_name
+
+ # Serialization cleans up default values in partial_kwargs, this
+ # adds them back to both sides.
+ default_partial_kwargs = (
+ BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY,
strict=False).partial_kwargs
+ )
+ serialized_partial_kwargs = {**default_partial_kwargs,
**serialized_task.partial_kwargs}
+ original_partial_kwargs = {**default_partial_kwargs,
**task.partial_kwargs}
+ assert serialized_partial_kwargs == original_partial_kwargs
+
# Check that for Deserialized task, task.subdag is None for all other
Operators
# except for the SubDagOperator where task.subdag is an instance of
DAG object
if task.task_type == "SubDagOperator":