This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bb06c56403 Allow Task Group Ids to be passed as branches in
BranchMixIn (#38883)
bb06c56403 is described below
commit bb06c56403896969ecf25b3eecdb2a00d5437dce
Author: Bora Berke Sahin <[email protected]>
AuthorDate: Mon Jun 3 10:32:27 2024 +0300
Allow Task Group Ids to be passed as branches in BranchMixIn (#38883)
* Allow `Task Group Id`s to be passed as branches in BranchMixIn
---
airflow/models/skipmixin.py | 5 ++--
airflow/operators/branch.py | 39 ++++++++++++++++++++++++++----
airflow/operators/datetime.py | 8 +++----
airflow/operators/python.py | 26 ++++++++++----------
airflow/operators/weekday.py | 6 +++--
tests/operators/test_branch_operator.py | 42 +++++++++++++++++++++++++++++++++
tests/operators/test_python.py | 16 +++++++++----
7 files changed, 113 insertions(+), 29 deletions(-)
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index fc3097ce43..3c89deda12 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -175,13 +175,12 @@ class SkipMixin(LoggingMixin):
branch_task_ids is stored to XCom so that NotPreviouslySkippedDep
knows skipped tasks or
newly added tasks should be skipped when they are cleared.
"""
- self.log.info("Following branch %s", branch_task_ids)
if isinstance(branch_task_ids, str):
branch_task_id_set = {branch_task_ids}
elif isinstance(branch_task_ids, Iterable):
branch_task_id_set = set(branch_task_ids)
invalid_task_ids_type = {
- (bti, type(bti).__name__) for bti in branch_task_ids if not
isinstance(bti, str)
+ (bti, type(bti).__name__) for bti in branch_task_id_set if not
isinstance(bti, str)
}
if invalid_task_ids_type:
raise AirflowException(
@@ -196,6 +195,8 @@ class SkipMixin(LoggingMixin):
f"but got {type(branch_task_ids).__name__!r}."
)
+ self.log.info("Following branch %s", branch_task_id_set)
+
dag_run = ti.get_dagrun()
if TYPE_CHECKING:
assert isinstance(dag_run, DagRun)
diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py
index 74586590db..0085bfa5af 100644
--- a/airflow/operators/branch.py
+++ b/airflow/operators/branch.py
@@ -25,6 +25,8 @@ from airflow.models.baseoperator import BaseOperator
from airflow.models.skipmixin import SkipMixin
if TYPE_CHECKING:
+ from airflow.models import TaskInstance
+ from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
from airflow.utils.context import Context
@@ -34,9 +36,36 @@ class BranchMixIn(SkipMixin):
def do_branch(self, context: Context, branches_to_execute: str |
Iterable[str]) -> str | Iterable[str]:
"""Implement the handling of branching including logging."""
self.log.info("Branch into %s", branches_to_execute)
- self.skip_all_except(context["ti"], branches_to_execute)
+ branch_task_ids = self._expand_task_group_roots(context["ti"],
branches_to_execute)
+ self.skip_all_except(context["ti"], branch_task_ids)
return branches_to_execute
+ def _expand_task_group_roots(
+ self, ti: TaskInstance | TaskInstancePydantic, branches_to_execute:
str | Iterable[str]
+ ) -> Iterable[str]:
+ """Expand any task group into its root task ids."""
+ if TYPE_CHECKING:
+ assert ti.task
+
+ task = ti.task
+ dag = task.dag
+ if TYPE_CHECKING:
+ assert dag
+
+ if branches_to_execute is None:
+ return
+ elif isinstance(branches_to_execute, str) or not
isinstance(branches_to_execute, Iterable):
+ branches_to_execute = [branches_to_execute]
+
+ for branch in branches_to_execute:
+ if branch in dag.task_group_dict:
+ tg = dag.task_group_dict[branch]
+ root_ids = [root.task_id for root in tg.roots]
+ self.log.info("Expanding task group %s into %s", tg.group_id,
root_ids)
+ yield from root_ids
+ else:
+ yield branch
+
class BaseBranchOperator(BaseOperator, BranchMixIn):
"""
@@ -44,10 +73,12 @@ class BaseBranchOperator(BaseOperator, BranchMixIn):
Users should create a subclass from this operator and implement the
function
`choose_branch(self, context)`. This should run whatever business logic
- is needed to determine the branch, and return either the task_id for
- a single task (as a str) or a list of task_ids.
+ is needed to determine the branch, and return one of the following:
+ - A single task_id (as a str)
+ - A single task_group_id (as a str)
+ - A list containing a combination of task_ids and task_group_ids
- The operator will continue with the returned task_id(s), and all other
+ The operator will continue with the returned task_id(s) and/or
task_group_id(s), and all other
tasks directly downstream of this operator will be skipped.
"""
diff --git a/airflow/operators/datetime.py b/airflow/operators/datetime.py
index 4e2638341a..b4bf061c37 100644
--- a/airflow/operators/datetime.py
+++ b/airflow/operators/datetime.py
@@ -37,10 +37,10 @@ class BranchDateTimeOperator(BaseBranchOperator):
True branch will be returned when ``datetime.datetime.now()`` falls below
``target_upper`` and above ``target_lower``.
- :param follow_task_ids_if_true: task id or task ids to follow if
- ``datetime.datetime.now()`` falls above target_lower and below
``target_upper``.
- :param follow_task_ids_if_false: task id or task ids to follow if
- ``datetime.datetime.now()`` falls below target_lower or above
``target_upper``.
+ :param follow_task_ids_if_true: task_id, task_group_id, or a list of
task_ids and/or task_group_ids
+ to follow if ``datetime.datetime.now()`` falls above target_lower and
below target_upper.
+ :param follow_task_ids_if_false: task_id, task_group_id, or a list of
task_ids and/or task_group_ids
+ to follow if ``datetime.datetime.now()`` falls below target_lower or
above target_upper.
:param target_lower: target lower bound.
:param target_upper: target upper bound.
:param use_task_logical_date: If ``True``, uses task's logical date to
compare with targets.
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 554a27a444..fe9923529a 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -261,12 +261,13 @@ class BranchPythonOperator(PythonOperator, BranchMixIn):
A workflow can "branch" or follow a path after the execution of this task.
It derives the PythonOperator and expects a Python function that returns
- a single task_id or list of task_ids to follow. The task_id(s) returned
- should point to a task directly downstream from {self}. All other
"branches"
- or directly downstream tasks are marked with a state of ``skipped`` so that
- these paths can't move forward. The ``skipped`` states are propagated
- downstream to allow for the DAG state to fill up and the DAG run's state
- to be inferred.
+ a single task_id, a single task_group_id, or a list of task_ids and/or
+ task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
+ should point to a task or task group directly downstream from {self}. All
+ other "branches" or directly downstream tasks are marked with a state of
+ ``skipped`` so that these paths can't move forward. The ``skipped`` states
+ are propagated downstream to allow for the DAG state to fill up and
+ the DAG run's state to be inferred.
"""
def execute(self, context: Context) -> Any:
@@ -861,12 +862,13 @@ class
BranchPythonVirtualenvOperator(PythonVirtualenvOperator, BranchMixIn):
A workflow can "branch" or follow a path after the execution of this task
in a virtual environment.
It derives the PythonVirtualenvOperator and expects a Python function that
returns
- a single task_id or list of task_ids to follow. The task_id(s) returned
- should point to a task directly downstream from {self}. All other
"branches"
- or directly downstream tasks are marked with a state of ``skipped`` so that
- these paths can't move forward. The ``skipped`` states are propagated
- downstream to allow for the DAG state to fill up and the DAG run's state
- to be inferred.
+ a single task_id, a single task_group_id, or a list of task_ids and/or
+ task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
+ should point to a task or task group directly downstream from {self}. All
+ other "branches" or directly downstream tasks are marked with a state of
+ ``skipped`` so that these paths can't move forward. The ``skipped`` states
+ are propagated downstream to allow for the DAG state to fill up and
+ the DAG run's state to be inferred.
.. seealso::
For more information on how to use this operator, take a look at the
guide:
diff --git a/airflow/operators/weekday.py b/airflow/operators/weekday.py
index 1a4a5f1f35..cd941d93af 100644
--- a/airflow/operators/weekday.py
+++ b/airflow/operators/weekday.py
@@ -73,8 +73,10 @@ class BranchDayOfWeekOperator(BaseBranchOperator):
# add downstream dependencies as you would do with any branch operator
weekend_check >> [workday, weekend]
- :param follow_task_ids_if_true: task id or task ids to follow if criteria
met
- :param follow_task_ids_if_false: task id or task ids to follow if criteria
does not met
+ :param follow_task_ids_if_true: task_id, task_group_id, or a list of
task_ids and/or task_group_ids
+ to follow if criteria met.
+ :param follow_task_ids_if_false: task_id, task_group_id, or a list of
task_ids and/or task_group_ids
+ to follow if criteria not met.
:param week_day: Day of the week to check (full name). Optionally, a set
of days can also be provided using a set. Example values:
diff --git a/tests/operators/test_branch_operator.py
b/tests/operators/test_branch_operator.py
index 0bb5e318a7..4c127db8dd 100644
--- a/tests/operators/test_branch_operator.py
+++ b/tests/operators/test_branch_operator.py
@@ -29,6 +29,7 @@ from airflow.operators.empty import EmptyOperator
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
+from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
pytestmark = pytest.mark.db_test
@@ -47,6 +48,11 @@ class ChooseBranchOneTwo(BaseBranchOperator):
return ["branch_1", "branch_2"]
+class ChooseBranchThree(BaseBranchOperator):
+ def choose_branch(self, context):
+ return ["branch_3"]
+
+
class TestBranchOperator:
@classmethod
def setup_class(cls):
@@ -191,3 +197,39 @@ class TestBranchOperator:
for ti in tis:
if ti.task_id == "make_choice":
assert ti.xcom_pull(task_ids="make_choice") == "branch_1"
+
+ def test_with_dag_run_task_groups(self):
+ self.branch_op = ChooseBranchThree(task_id="make_choice", dag=self.dag)
+ self.branch_3 = TaskGroup("branch_3", dag=self.dag)
+ _ = EmptyOperator(task_id="task_1", dag=self.dag,
task_group=self.branch_3)
+ _ = EmptyOperator(task_id="task_2", dag=self.dag,
task_group=self.branch_3)
+
+ self.branch_1.set_upstream(self.branch_op)
+ self.branch_2.set_upstream(self.branch_op)
+ self.branch_3.set_upstream(self.branch_op)
+
+ self.dag.clear()
+
+ dagrun = self.dag.create_dagrun(
+ run_type=DagRunType.MANUAL,
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dagrun.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ assert ti.state == State.SUCCESS
+ elif ti.task_id == "branch_1":
+ assert ti.state == State.SKIPPED
+ elif ti.task_id == "branch_2":
+ assert ti.state == State.SKIPPED
+ elif ti.task_id == "branch_3.task_1":
+ assert ti.state == State.NONE
+ elif ti.task_id == "branch_3.task_2":
+ assert ti.state == State.NONE
+ else:
+ raise Exception
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 3d7b23c415..e419b7de19 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -463,7 +463,10 @@ class TestBranchOperator(BasePythonTest):
return 5
ti = self.create_ti(f)
- with pytest.raises(AirflowException, match="must be either None, a
task ID, or an Iterable of IDs"):
+ with pytest.raises(
+ AirflowException,
+ match="'branch_task_ids' expected all task IDs are strings.",
+ ):
ti.run()
def test_raise_exception_on_invalid_task_id(self):
@@ -1440,14 +1443,14 @@ class
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
else:
raise RuntimeError
- with pytest.raises(AirflowException, match="but got 'bool'"):
+ with pytest.raises(AirflowException, match=r"Invalid tasks found:
{\((True|False), 'bool'\)}"):
self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True})
def test_return_false(self):
def f():
return False
- with pytest.raises(AirflowException, match="but got 'bool'"):
+ with pytest.raises(AirflowException, match=r"Invalid tasks found:
{\(False, 'bool'\)}."):
self.run_as_task(f)
def test_context(self):
@@ -1468,7 +1471,7 @@ class
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
def f():
return False
- with pytest.raises(AirflowException, match="but got 'bool'"):
+ with pytest.raises(AirflowException, match=r"Invalid tasks found:
{\(False, 'bool'\)}."):
self.run_as_task(f, do_not_use_caching=True)
def test_with_dag_run(self):
@@ -1581,7 +1584,10 @@ class
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
return 5
ti = self.create_ti(f)
- with pytest.raises(AirflowException, match="must be either None, a
task ID, or an Iterable of IDs"):
+ with pytest.raises(
+ AirflowException,
+ match="'branch_task_ids' expected all task IDs are strings.",
+ ):
ti.run()
def test_raise_exception_on_invalid_task_id(self):