This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-7-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 22df7b111261c78fbeeb38191226f9694986bd05 Author: Ephraim Anierobi <[email protected]> AuthorDate: Tue Aug 29 17:48:43 2023 +0100 Fix MappedTaskGroup tasks not respecting upstream dependency (#33732) * Fix MappedTaskGroup tasks not respecting upstream dependency When a MappedTaskGroup has upstream dependencies, the tasks in the group don't wait for the upstream tasks before they start running, this causes the tasks to fail. From my investigation, the tasks inside the MappedTaskGroup don't have upstream tasks while the MappedTaskGroup has the upstream tasks properly set. Due to this, the task's dependencies are met even though the Group has upstreams that haven't finished. The Fix was to set upstreams after creating the task group with the factory Closes: https://github.com/apache/airflow/issues/33446 * set the relationship in __exit__ (cherry picked from commit fe27031382e2034b59a23db1c6b9bdbfef259137) --- airflow/utils/task_group.py | 7 ++++-- tests/decorators/test_task_group.py | 46 +++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 167eb53b71..1c0d1370d7 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -565,8 +565,6 @@ class MappedTaskGroup(TaskGroup): def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None: super().__init__(**kwargs) self._expand_input = expand_input - for op, _ in expand_input.iter_references(): - self.set_upstream(op) def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this mapped task group.""" @@ -619,6 +617,11 @@ class MappedTaskGroup(TaskGroup): (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), ) + def __exit__(self, exc_type, exc_val, exc_tb): + for op, _ in self._expand_input.iter_references(): + self.set_upstream(op) + super().__exit__(exc_type, exc_val, exc_tb) + class TaskGroupContext: """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index 3462c3a1d8..4c741ef1c1 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -191,6 +191,52 @@ def test_expand_kwargs_create_mapped(): assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input, key="b")} +def test_task_group_expand_kwargs_with_upstream(dag_maker, session, caplog): + with dag_maker() as dag: + + @dag.task + def t1(): + return [{"a": 1}, {"a": 2}] + + @task_group("tg1") + def tg1(a, b): + @dag.task() + def t2(): + return [a, b] + + t2() + + tg1.expand_kwargs(t1()) + + dr = dag_maker.create_dagrun() + dr.task_instance_scheduling_decisions() + assert "Cannot expand" not in caplog.text + assert "missing upstream values: ['expand_kwargs() argument']" not in caplog.text + + +def test_task_group_expand_with_upstream(dag_maker, session, caplog): + with dag_maker() as dag: + + @dag.task + def t1(): + return [1, 2, 3] + + @task_group("tg1") + def tg1(a, b): + @dag.task() + def t2(): + return [a, b] + + t2() + + tg1.partial(a=1).expand(b=t1()) + + dr = dag_maker.create_dagrun() + dr.task_instance_scheduling_decisions() + assert "Cannot expand" not in caplog.text + assert "missing upstream values: ['b']" not in caplog.text + + def test_override_dag_default_args(): @dag( dag_id="test_dag",
