This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-6-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 35624862c620b89dac0e887191f2938d467f8f7c Author: Luiz Armesto <[email protected]> AuthorDate: Sat Apr 29 18:27:22 2023 -0300 Use the Task Group explicitly passed to 'partial' if any (#30933) (cherry picked from commit 4ee2de1e38a85abb89f9f313a3424c7368e12d1a) --- airflow/models/baseoperator.py | 2 +- tests/models/test_mappedoperator.py | 43 +++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 37106c580f..cec321b2b3 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -258,7 +258,7 @@ def partial( dag = dag or DagContext.get_current_dag() if dag: - task_group = TaskGroupContext.get_current_task_group(dag) + task_group = task_group or TaskGroupContext.get_current_task_group(dag) if task_group: task_id = task_group.child_id(task_id) diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 84ddd9fb66..cfd77b55ef 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -30,6 +30,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom_arg import XComArg from airflow.utils.state import TaskInstanceState +from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE @@ -573,3 +574,45 @@ def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): tis = dr.get_task_instances(session=session) for ti in tis: ti.run() + + +def test_task_mapping_with_task_group_context(): + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + finish = MockOperator(task_id="finish") + + with TaskGroup("test-group") as group: + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) + + task1 >> group >> finish + + assert task1.downstream_list == [mapped] + assert mapped.upstream_list == [task1] + + assert mapped in dag.tasks + assert mapped.task_group == group + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +def test_task_mapping_with_explicit_task_group(): + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + finish = MockOperator(task_id="finish") + + group = TaskGroup("test-group") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2", task_group=group).expand(arg2=literal) + + task1 >> group >> finish + + assert task1.downstream_list == [mapped] + assert mapped.upstream_list == [task1] + + assert mapped in dag.tasks + assert mapped.task_group == group + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish]
