This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bb06c56403 Allow Task Group Ids to be passed as branches in 
BranchMixIn (#38883)
bb06c56403 is described below

commit bb06c56403896969ecf25b3eecdb2a00d5437dce
Author: Bora Berke Sahin <[email protected]>
AuthorDate: Mon Jun 3 10:32:27 2024 +0300

    Allow Task Group Ids to be passed as branches in BranchMixIn (#38883)
    
    * Allow `Task Group Id`s to be passed as branches in BranchMixIn
---
 airflow/models/skipmixin.py             |  5 ++--
 airflow/operators/branch.py             | 39 ++++++++++++++++++++++++++----
 airflow/operators/datetime.py           |  8 +++----
 airflow/operators/python.py             | 26 ++++++++++----------
 airflow/operators/weekday.py            |  6 +++--
 tests/operators/test_branch_operator.py | 42 +++++++++++++++++++++++++++++++++
 tests/operators/test_python.py          | 16 +++++++++----
 7 files changed, 113 insertions(+), 29 deletions(-)

diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index fc3097ce43..3c89deda12 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -175,13 +175,12 @@ class SkipMixin(LoggingMixin):
         branch_task_ids is stored to XCom so that NotPreviouslySkippedDep 
knows skipped tasks or
         newly added tasks should be skipped when they are cleared.
         """
-        self.log.info("Following branch %s", branch_task_ids)
         if isinstance(branch_task_ids, str):
             branch_task_id_set = {branch_task_ids}
         elif isinstance(branch_task_ids, Iterable):
             branch_task_id_set = set(branch_task_ids)
             invalid_task_ids_type = {
-                (bti, type(bti).__name__) for bti in branch_task_ids if not 
isinstance(bti, str)
+                (bti, type(bti).__name__) for bti in branch_task_id_set if not 
isinstance(bti, str)
             }
             if invalid_task_ids_type:
                 raise AirflowException(
@@ -196,6 +195,8 @@ class SkipMixin(LoggingMixin):
                 f"but got {type(branch_task_ids).__name__!r}."
             )
 
+        self.log.info("Following branch %s", branch_task_id_set)
+
         dag_run = ti.get_dagrun()
         if TYPE_CHECKING:
             assert isinstance(dag_run, DagRun)
diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py
index 74586590db..0085bfa5af 100644
--- a/airflow/operators/branch.py
+++ b/airflow/operators/branch.py
@@ -25,6 +25,8 @@ from airflow.models.baseoperator import BaseOperator
 from airflow.models.skipmixin import SkipMixin
 
 if TYPE_CHECKING:
+    from airflow.models import TaskInstance
+    from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
     from airflow.utils.context import Context
 
 
@@ -34,9 +36,36 @@ class BranchMixIn(SkipMixin):
     def do_branch(self, context: Context, branches_to_execute: str | 
Iterable[str]) -> str | Iterable[str]:
         """Implement the handling of branching including logging."""
         self.log.info("Branch into %s", branches_to_execute)
-        self.skip_all_except(context["ti"], branches_to_execute)
+        branch_task_ids = self._expand_task_group_roots(context["ti"], 
branches_to_execute)
+        self.skip_all_except(context["ti"], branch_task_ids)
         return branches_to_execute
 
+    def _expand_task_group_roots(
+        self, ti: TaskInstance | TaskInstancePydantic, branches_to_execute: 
str | Iterable[str]
+    ) -> Iterable[str]:
+        """Expand any task group into its root task ids."""
+        if TYPE_CHECKING:
+            assert ti.task
+
+        task = ti.task
+        dag = task.dag
+        if TYPE_CHECKING:
+            assert dag
+
+        if branches_to_execute is None:
+            return
+        elif isinstance(branches_to_execute, str) or not 
isinstance(branches_to_execute, Iterable):
+            branches_to_execute = [branches_to_execute]
+
+        for branch in branches_to_execute:
+            if branch in dag.task_group_dict:
+                tg = dag.task_group_dict[branch]
+                root_ids = [root.task_id for root in tg.roots]
+                self.log.info("Expanding task group %s into %s", tg.group_id, 
root_ids)
+                yield from root_ids
+            else:
+                yield branch
+
 
 class BaseBranchOperator(BaseOperator, BranchMixIn):
     """
@@ -44,10 +73,12 @@ class BaseBranchOperator(BaseOperator, BranchMixIn):
 
     Users should create a subclass from this operator and implement the 
function
     `choose_branch(self, context)`. This should run whatever business logic
-    is needed to determine the branch, and return either the task_id for
-    a single task (as a str) or a list of task_ids.
+    is needed to determine the branch, and return one of the following:
+    - A single task_id (as a str)
+    - A single task_group_id (as a str)
+    - A list containing a combination of task_ids and task_group_ids
 
-    The operator will continue with the returned task_id(s), and all other
+    The operator will continue with the returned task_id(s) and/or 
task_group_id(s), and all other
     tasks directly downstream of this operator will be skipped.
     """
 
diff --git a/airflow/operators/datetime.py b/airflow/operators/datetime.py
index 4e2638341a..b4bf061c37 100644
--- a/airflow/operators/datetime.py
+++ b/airflow/operators/datetime.py
@@ -37,10 +37,10 @@ class BranchDateTimeOperator(BaseBranchOperator):
     True branch will be returned when ``datetime.datetime.now()`` falls below
     ``target_upper`` and above ``target_lower``.
 
-    :param follow_task_ids_if_true: task id or task ids to follow if
-        ``datetime.datetime.now()`` falls above target_lower and below 
``target_upper``.
-    :param follow_task_ids_if_false: task id or task ids to follow if
-        ``datetime.datetime.now()`` falls below target_lower or above 
``target_upper``.
+    :param follow_task_ids_if_true: task_id, task_group_id, or a list of 
task_ids and/or task_group_ids
+        to follow if ``datetime.datetime.now()`` falls above target_lower and 
below target_upper.
+    :param follow_task_ids_if_false: task_id, task_group_id, or a list of 
task_ids and/or task_group_ids
+        to follow if ``datetime.datetime.now()`` falls below target_lower or 
above target_upper.
     :param target_lower: target lower bound.
     :param target_upper: target upper bound.
     :param use_task_logical_date: If ``True``, uses task's logical date to 
compare with targets.
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 554a27a444..fe9923529a 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -261,12 +261,13 @@ class BranchPythonOperator(PythonOperator, BranchMixIn):
     A workflow can "branch" or follow a path after the execution of this task.
 
     It derives the PythonOperator and expects a Python function that returns
-    a single task_id or list of task_ids to follow. The task_id(s) returned
-    should point to a task directly downstream from {self}. All other 
"branches"
-    or directly downstream tasks are marked with a state of ``skipped`` so that
-    these paths can't move forward. The ``skipped`` states are propagated
-    downstream to allow for the DAG state to fill up and the DAG run's state
-    to be inferred.
+    a single task_id, a single task_group_id, or a list of task_ids and/or
+    task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
+    should point to a task or task group directly downstream from {self}. All
+    other "branches" or directly downstream tasks are marked with a state of
+    ``skipped`` so that these paths can't move forward. The ``skipped`` states
+    are propagated downstream to allow for the DAG state to fill up and
+    the DAG run's state to be inferred.
     """
 
     def execute(self, context: Context) -> Any:
@@ -861,12 +862,13 @@ class 
BranchPythonVirtualenvOperator(PythonVirtualenvOperator, BranchMixIn):
     A workflow can "branch" or follow a path after the execution of this task 
in a virtual environment.
 
     It derives the PythonVirtualenvOperator and expects a Python function that 
returns
-    a single task_id or list of task_ids to follow. The task_id(s) returned
-    should point to a task directly downstream from {self}. All other 
"branches"
-    or directly downstream tasks are marked with a state of ``skipped`` so that
-    these paths can't move forward. The ``skipped`` states are propagated
-    downstream to allow for the DAG state to fill up and the DAG run's state
-    to be inferred.
+    a single task_id, a single task_group_id, or a list of task_ids and/or
+    task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
+    should point to a task or task group directly downstream from {self}. All
+    other "branches" or directly downstream tasks are marked with a state of
+    ``skipped`` so that these paths can't move forward. The ``skipped`` states
+    are propagated downstream to allow for the DAG state to fill up and
+    the DAG run's state to be inferred.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
diff --git a/airflow/operators/weekday.py b/airflow/operators/weekday.py
index 1a4a5f1f35..cd941d93af 100644
--- a/airflow/operators/weekday.py
+++ b/airflow/operators/weekday.py
@@ -73,8 +73,10 @@ class BranchDayOfWeekOperator(BaseBranchOperator):
         # add downstream dependencies as you would do with any branch operator
         weekend_check >> [workday, weekend]
 
-    :param follow_task_ids_if_true: task id or task ids to follow if criteria 
met
-    :param follow_task_ids_if_false: task id or task ids to follow if criteria 
does not met
+    :param follow_task_ids_if_true: task_id, task_group_id, or a list of 
task_ids and/or task_group_ids
+        to follow if criteria met.
+    :param follow_task_ids_if_false: task_id, task_group_id, or a list of 
task_ids and/or task_group_ids
+        to follow if criteria not met.
     :param week_day: Day of the week to check (full name). Optionally, a set
         of days can also be provided using a set. Example values:
 
diff --git a/tests/operators/test_branch_operator.py 
b/tests/operators/test_branch_operator.py
index 0bb5e318a7..4c127db8dd 100644
--- a/tests/operators/test_branch_operator.py
+++ b/tests/operators/test_branch_operator.py
@@ -29,6 +29,7 @@ from airflow.operators.empty import EmptyOperator
 from airflow.utils import timezone
 from airflow.utils.session import create_session
 from airflow.utils.state import State
+from airflow.utils.task_group import TaskGroup
 from airflow.utils.types import DagRunType
 
 pytestmark = pytest.mark.db_test
@@ -47,6 +48,11 @@ class ChooseBranchOneTwo(BaseBranchOperator):
         return ["branch_1", "branch_2"]
 
 
+class ChooseBranchThree(BaseBranchOperator):
+    def choose_branch(self, context):
+        return ["branch_3"]
+
+
 class TestBranchOperator:
     @classmethod
     def setup_class(cls):
@@ -191,3 +197,39 @@ class TestBranchOperator:
         for ti in tis:
             if ti.task_id == "make_choice":
                 assert ti.xcom_pull(task_ids="make_choice") == "branch_1"
+
+    def test_with_dag_run_task_groups(self):
+        self.branch_op = ChooseBranchThree(task_id="make_choice", dag=self.dag)
+        self.branch_3 = TaskGroup("branch_3", dag=self.dag)
+        _ = EmptyOperator(task_id="task_1", dag=self.dag, 
task_group=self.branch_3)
+        _ = EmptyOperator(task_id="task_2", dag=self.dag, 
task_group=self.branch_3)
+
+        self.branch_1.set_upstream(self.branch_op)
+        self.branch_2.set_upstream(self.branch_op)
+        self.branch_3.set_upstream(self.branch_op)
+
+        self.dag.clear()
+
+        dagrun = self.dag.create_dagrun(
+            run_type=DagRunType.MANUAL,
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dagrun.get_task_instances()
+        for ti in tis:
+            if ti.task_id == "make_choice":
+                assert ti.state == State.SUCCESS
+            elif ti.task_id == "branch_1":
+                assert ti.state == State.SKIPPED
+            elif ti.task_id == "branch_2":
+                assert ti.state == State.SKIPPED
+            elif ti.task_id == "branch_3.task_1":
+                assert ti.state == State.NONE
+            elif ti.task_id == "branch_3.task_2":
+                assert ti.state == State.NONE
+            else:
+                raise Exception
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 3d7b23c415..e419b7de19 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -463,7 +463,10 @@ class TestBranchOperator(BasePythonTest):
             return 5
 
         ti = self.create_ti(f)
-        with pytest.raises(AirflowException, match="must be either None, a 
task ID, or an Iterable of IDs"):
+        with pytest.raises(
+            AirflowException,
+            match="'branch_task_ids' expected all task IDs are strings.",
+        ):
             ti.run()
 
     def test_raise_exception_on_invalid_task_id(self):
@@ -1440,14 +1443,14 @@ class 
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
             else:
                 raise RuntimeError
 
-        with pytest.raises(AirflowException, match="but got 'bool'"):
+        with pytest.raises(AirflowException, match=r"Invalid tasks found: 
{\((True|False), 'bool'\)}"):
             self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True})
 
     def test_return_false(self):
         def f():
             return False
 
-        with pytest.raises(AirflowException, match="but got 'bool'"):
+        with pytest.raises(AirflowException, match=r"Invalid tasks found: 
{\(False, 'bool'\)}."):
             self.run_as_task(f)
 
     def test_context(self):
@@ -1468,7 +1471,7 @@ class 
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
         def f():
             return False
 
-        with pytest.raises(AirflowException, match="but got 'bool'"):
+        with pytest.raises(AirflowException, match=r"Invalid tasks found: 
{\(False, 'bool'\)}."):
             self.run_as_task(f, do_not_use_caching=True)
 
     def test_with_dag_run(self):
@@ -1581,7 +1584,10 @@ class 
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
             return 5
 
         ti = self.create_ti(f)
-        with pytest.raises(AirflowException, match="must be either None, a 
task ID, or an Iterable of IDs"):
+        with pytest.raises(
+            AirflowException,
+            match="'branch_task_ids' expected all task IDs are strings.",
+        ):
             ti.run()
 
     def test_raise_exception_on_invalid_task_id(self):

Reply via email to