This is an automated email from the ASF dual-hosted git repository. utkarsharma pushed a commit to branch sync_2-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 6b377104d0c2026e2a545c8ad339d3e6fd42e901 Author: Shahar Epstein <[email protected]> AuthorDate: Fri Nov 8 09:12:16 2024 +0200 Prevent using trigger_rule="always" in a dynamic mapped task (#43810) (cherry picked from commit c753ca295d72d4e3dd74b9131d3ca4c47899cd96) --- airflow/utils/task_group.py | 22 +++++++++++++++---- .../dynamic-task-mapping.rst | 5 +++++ tests/decorators/test_task_group.py | 25 +++++++++++++++++++++- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index d1dd9822be2..f5e95bde1a8 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -37,6 +37,7 @@ from airflow.exceptions import ( from airflow.models.taskmixin import DAGNode from airflow.serialization.enums import DagAttributeTypes from airflow.utils.helpers import validate_group_key, validate_instance_args +from airflow.utils.trigger_rule import TriggerRule if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -220,10 +221,15 @@ class TaskGroup(DAGNode): def __iter__(self): for child in self.children.values(): - if isinstance(child, TaskGroup): - yield from child - else: - yield child + yield from self._iter_child(child) + + @staticmethod + def _iter_child(child): + """Iterate over the children of this TaskGroup.""" + if isinstance(child, TaskGroup): + yield from child + else: + yield child def add(self, task: DAGNode) -> DAGNode: """ @@ -599,6 +605,14 @@ class MappedTaskGroup(TaskGroup): super().__init__(**kwargs) self._expand_input = expand_input + def __iter__(self): + from airflow.models.abstractoperator import AbstractOperator + + for child in self.children.values(): + if isinstance(child, AbstractOperator) and child.trigger_rule == TriggerRule.ALWAYS: + raise ValueError("Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'") + yield from self._iter_child(child) + def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this mapped task group.""" from airflow.models.xcom_arg import XComArg diff --git a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst index fd7d5707854..df74038fd2c 100644 --- a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst +++ b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst @@ -84,6 +84,11 @@ The grid view also provides visibility into your mapped tasks in the details pan Although we show a "reduce" task here (``sum_it``) you don't have to have one, the mapped tasks will still be executed even if they have no downstream tasks. +.. warning:: ``TriggerRule.ALWAYS`` cannot be utilized in expanded tasks + + Assigning ``trigger_rule=TriggerRule.ALWAYS`` in expanded tasks is forbidden, as expanded parameters will be undefined with the task's immediate execution. + This is enforced at the time of the DAG parsing, and will raise an error if you try to use it. + Task-generated Mapping ---------------------- diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index 6120f94af3a..2dab23ca38f 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -22,10 +22,11 @@ from datetime import timedelta import pendulum import pytest -from airflow.decorators import dag, task_group +from airflow.decorators import dag, task, task_group from airflow.models.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput, MappedArgument from airflow.operators.empty import EmptyOperator from airflow.utils.task_group import MappedTaskGroup +from airflow.utils.trigger_rule import TriggerRule def test_task_group_with_overridden_kwargs(): @@ -133,6 +134,28 @@ def test_expand_fail_empty(): assert str(ctx.value) == "no arguments to expand against" [email protected]_test +def test_expand_fail_trigger_rule_always(dag_maker, session): + @dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1)) + def pipeline(): + @task + def get_param(): + return ["a", "b", "c"] + + @task(trigger_rule=TriggerRule.ALWAYS) + def t1(param): + return param + + @task_group() + def tg(param): + t1(param) + + with pytest.raises( + ValueError, match="Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'" + ): + tg.expand(param=get_param()) + + def test_expand_create_mapped(): saved = {}
