This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-6-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3b0d5a91a7f279ee40cc19d25c739927dde0e8c1
Author: Luiz Armesto <[email protected]>
AuthorDate: Sat Apr 29 18:27:22 2023 -0300

    Use the Task Group explicitly passed to 'partial' if any (#30933)
    
    (cherry picked from commit 4ee2de1e38a85abb89f9f313a3424c7368e12d1a)
---
 airflow/models/baseoperator.py      |  2 +-
 tests/models/test_mappedoperator.py | 43 +++++++++++++++++++++++++++++++++++++
 2 files changed, 44 insertions(+), 1 deletion(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 37106c580f..cec321b2b3 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -258,7 +258,7 @@ def partial(
 
     dag = dag or DagContext.get_current_dag()
     if dag:
-        task_group = TaskGroupContext.get_current_task_group(dag)
+        task_group = task_group or TaskGroupContext.get_current_task_group(dag)
     if task_group:
         task_id = task_group.child_id(task_id)
 
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 84ddd9fb66..cfd77b55ef 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -30,6 +30,7 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom_arg import XComArg
 from airflow.utils.state import TaskInstanceState
+from airflow.utils.task_group import TaskGroup
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.xcom import XCOM_RETURN_KEY
 from tests.models import DEFAULT_DATE
@@ -573,3 +574,45 @@ def 
test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session):
     tis = dr.get_task_instances(session=session)
     for ti in tis:
         ti.run()
+
+
+def test_task_mapping_with_task_group_context():
+    with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
+        task1 = BaseOperator(task_id="op1")
+        finish = MockOperator(task_id="finish")
+
+        with TaskGroup("test-group") as group:
+            literal = ["a", "b", "c"]
+            mapped = 
MockOperator.partial(task_id="task_2").expand(arg2=literal)
+
+            task1 >> group >> finish
+
+    assert task1.downstream_list == [mapped]
+    assert mapped.upstream_list == [task1]
+
+    assert mapped in dag.tasks
+    assert mapped.task_group == group
+
+    assert finish.upstream_list == [mapped]
+    assert mapped.downstream_list == [finish]
+
+
+def test_task_mapping_with_explicit_task_group():
+    with DAG("test-dag", start_date=DEFAULT_DATE) as dag:
+        task1 = BaseOperator(task_id="op1")
+        finish = MockOperator(task_id="finish")
+
+        group = TaskGroup("test-group")
+        literal = ["a", "b", "c"]
+        mapped = MockOperator.partial(task_id="task_2", 
task_group=group).expand(arg2=literal)
+
+        task1 >> group >> finish
+
+    assert task1.downstream_list == [mapped]
+    assert mapped.upstream_list == [task1]
+
+    assert mapped in dag.tasks
+    assert mapped.task_group == group
+
+    assert finish.upstream_list == [mapped]
+    assert mapped.downstream_list == [finish]

Reply via email to