This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8f65aee8f3 Add version of `chain` which doesn't require matched lists
(#31927)
8f65aee8f3 is described below
commit 8f65aee8f3c4dfdd6c4195d97f57f4267d37c209
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Jun 20 22:26:59 2023 -0700
Add version of `chain` which doesn't require matched lists (#31927)
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/models/baseoperator.py | 32 ++++++++++++++++++++++++
tests/models/test_baseoperator.py | 51 ++++++++++++++++++++++++++++++++++++++-
2 files changed, 82 insertions(+), 1 deletion(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 84aa9a17eb..79d4637387 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -85,6 +85,7 @@ from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.decorators import fixup_decorator_warning_stack
+from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.helpers import validate_key
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
@@ -1838,6 +1839,37 @@ def cross_downstream(
task.set_downstream(to_tasks)
+def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
+ """
+ Helper to simplify task dependency definition.
+
+ E.g.: suppose you want precedence like so::
+
+ ╭─op2─╮ ╭─op4─╮
+ op1─┤ ├─├─op5─┤─op7
+ ╰-op3─╯ ╰-op6─╯
+
+ Then you can accomplish like so::
+
+ chain_linear(
+ op1,
+ [op2, op3],
+ [op4, op5, op6],
+ op7
+ )
+
+ :param elements: a list of operators / lists of operators
+ """
+ prev_elem = None
+ for curr_elem in elements:
+ if isinstance(curr_elem, EdgeModifier):
+ raise ValueError("Labels are not supported by chain_linear")
+ if prev_elem is not None:
+ for task in prev_elem:
+ task >> curr_elem
+ prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else
curr_elem
+
+
# pyupgrade assumes all type annotations can be lazily evaluated, but this is
# not the case for attrs-decorated classes, since cattrs needs to evaluate the
# annotation expressions at runtime, and Python before 3.9.0 does not lazily
diff --git a/tests/models/test_baseoperator.py
b/tests/models/test_baseoperator.py
index 418ffdba72..82e9dd71c3 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -31,7 +31,13 @@ from airflow.decorators import task as task_decorator
from airflow.exceptions import AirflowException, DagInvalidTriggerRule,
RemovedInAirflow3Warning
from airflow.lineage.entities import File
from airflow.models import DAG
-from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain,
cross_downstream
+from airflow.models.baseoperator import (
+ BaseOperator,
+ BaseOperatorMeta,
+ chain,
+ chain_linear,
+ cross_downstream,
+)
from airflow.utils.context import Context
from airflow.utils.edgemodifier import Label
from airflow.utils.task_group import TaskGroup
@@ -515,6 +521,49 @@ class TestBaseOperator:
assert [op2] == tgop3.get_direct_relatives(upstream=False)
assert [op2] == tgop4.get_direct_relatives(upstream=False)
+ def test_chain_linear(self):
+ dag = DAG(dag_id="test_chain_linear", start_date=datetime.now())
+
+ t1, t2, t3, t4, t5, t6, t7 = (BaseOperator(task_id=f"t{i}", dag=dag)
for i in range(1, 8))
+ chain_linear(t1, [t2, t3, t4], [t5, t6], t7)
+
+ assert set(t1.get_direct_relatives(upstream=False)) == {t2, t3, t4}
+ assert set(t2.get_direct_relatives(upstream=False)) == {t5, t6}
+ assert set(t3.get_direct_relatives(upstream=False)) == {t5, t6}
+ assert set(t7.get_direct_relatives(upstream=True)) == {t5, t6}
+
+ t1, t2, t3, t4, t5, t6 = (
+ task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda:
None, dag=dag)()
+ for i in range(1, 7)
+ )
+ chain_linear(t1, [t2, t3], [t4, t5], t6)
+
+ assert set(t1.operator.get_direct_relatives(upstream=False)) ==
{t2.operator, t3.operator}
+ assert set(t2.operator.get_direct_relatives(upstream=False)) ==
{t4.operator, t5.operator}
+ assert set(t3.operator.get_direct_relatives(upstream=False)) ==
{t4.operator, t5.operator}
+ assert set(t6.operator.get_direct_relatives(upstream=True)) ==
{t4.operator, t5.operator}
+
+ # Begin test for `TaskGroups`
+ tg1, tg2 = (TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 3))
+ op1, op2 = (BaseOperator(task_id=f"task{i}", dag=dag) for i in
range(1, 3))
+ tgop1, tgop2 = (
+ BaseOperator(task_id=f"task_group_task{i}", task_group=tg1,
dag=dag) for i in range(1, 3)
+ )
+ tgop3, tgop4 = (
+ BaseOperator(task_id=f"task_group_task{i}", task_group=tg2,
dag=dag) for i in range(1, 3)
+ )
+ chain_linear(op1, tg1, tg2, op2)
+
+ assert set(op1.get_direct_relatives(upstream=False)) == {tgop1, tgop2}
+ assert set(tgop1.get_direct_relatives(upstream=False)) == {tgop3,
tgop4}
+ assert set(tgop2.get_direct_relatives(upstream=False)) == {tgop3,
tgop4}
+ assert set(tgop3.get_direct_relatives(upstream=False)) == {op2}
+ assert set(tgop4.get_direct_relatives(upstream=False)) == {op2}
+
+ t1, t2 = (BaseOperator(task_id=f"t-{i}", dag=dag) for i in range(1, 3))
+ with pytest.raises(ValueError, match="Labels are not supported"):
+ chain_linear(t1, Label("hi"), t2)
+
def test_chain_not_support_type(self):
dag = DAG(dag_id="test_chain", start_date=datetime.now())
[op1, op2] = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1,
3)]