This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f65aee8f3 Add version of `chain` which doesn't require matched lists 
(#31927)
8f65aee8f3 is described below

commit 8f65aee8f3c4dfdd6c4195d97f57f4267d37c209
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Jun 20 22:26:59 2023 -0700

    Add version of `chain` which doesn't require matched lists (#31927)
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/models/baseoperator.py    | 32 ++++++++++++++++++++++++
 tests/models/test_baseoperator.py | 51 ++++++++++++++++++++++++++++++++++++++-
 2 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 84aa9a17eb..79d4637387 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -85,6 +85,7 @@ from airflow.triggers.base import BaseTrigger
 from airflow.utils import timezone
 from airflow.utils.context import Context
 from airflow.utils.decorators import fixup_decorator_warning_stack
+from airflow.utils.edgemodifier import EdgeModifier
 from airflow.utils.helpers import validate_key
 from airflow.utils.operator_resources import Resources
 from airflow.utils.session import NEW_SESSION, provide_session
@@ -1838,6 +1839,37 @@ def cross_downstream(
         task.set_downstream(to_tasks)
 
 
+def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
+    """
+    Helper to simplify task dependency definition.
+
+    E.g.: suppose you want precedence like so::
+
+            ╭─op2─╮ ╭─op4─╮
+        op1─┤     ├─├─op5─┤─op7
+            ╰-op3─╯ ╰-op6─╯
+
+    Then you can accomplish like so::
+
+        chain_linear(
+            op1,
+            [op2, op3],
+            [op4, op5, op6],
+            op7
+        )
+
+    :param elements: a list of operators / lists of operators
+    """
+    prev_elem = None
+    for curr_elem in elements:
+        if isinstance(curr_elem, EdgeModifier):
+            raise ValueError("Labels are not supported by chain_linear")
+        if prev_elem is not None:
+            for task in prev_elem:
+                task >> curr_elem
+        prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else 
curr_elem
+
+
 # pyupgrade assumes all type annotations can be lazily evaluated, but this is
 # not the case for attrs-decorated classes, since cattrs needs to evaluate the
 # annotation expressions at runtime, and Python before 3.9.0 does not lazily
diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index 418ffdba72..82e9dd71c3 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -31,7 +31,13 @@ from airflow.decorators import task as task_decorator
 from airflow.exceptions import AirflowException, DagInvalidTriggerRule, 
RemovedInAirflow3Warning
 from airflow.lineage.entities import File
 from airflow.models import DAG
-from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain, 
cross_downstream
+from airflow.models.baseoperator import (
+    BaseOperator,
+    BaseOperatorMeta,
+    chain,
+    chain_linear,
+    cross_downstream,
+)
 from airflow.utils.context import Context
 from airflow.utils.edgemodifier import Label
 from airflow.utils.task_group import TaskGroup
@@ -515,6 +521,49 @@ class TestBaseOperator:
         assert [op2] == tgop3.get_direct_relatives(upstream=False)
         assert [op2] == tgop4.get_direct_relatives(upstream=False)
 
+    def test_chain_linear(self):
+        dag = DAG(dag_id="test_chain_linear", start_date=datetime.now())
+
+        t1, t2, t3, t4, t5, t6, t7 = (BaseOperator(task_id=f"t{i}", dag=dag) 
for i in range(1, 8))
+        chain_linear(t1, [t2, t3, t4], [t5, t6], t7)
+
+        assert set(t1.get_direct_relatives(upstream=False)) == {t2, t3, t4}
+        assert set(t2.get_direct_relatives(upstream=False)) == {t5, t6}
+        assert set(t3.get_direct_relatives(upstream=False)) == {t5, t6}
+        assert set(t7.get_direct_relatives(upstream=True)) == {t5, t6}
+
+        t1, t2, t3, t4, t5, t6 = (
+            task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: 
None, dag=dag)()
+            for i in range(1, 7)
+        )
+        chain_linear(t1, [t2, t3], [t4, t5], t6)
+
+        assert set(t1.operator.get_direct_relatives(upstream=False)) == 
{t2.operator, t3.operator}
+        assert set(t2.operator.get_direct_relatives(upstream=False)) == 
{t4.operator, t5.operator}
+        assert set(t3.operator.get_direct_relatives(upstream=False)) == 
{t4.operator, t5.operator}
+        assert set(t6.operator.get_direct_relatives(upstream=True)) == 
{t4.operator, t5.operator}
+
+        # Begin test for `TaskGroups`
+        tg1, tg2 = (TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 3))
+        op1, op2 = (BaseOperator(task_id=f"task{i}", dag=dag) for i in 
range(1, 3))
+        tgop1, tgop2 = (
+            BaseOperator(task_id=f"task_group_task{i}", task_group=tg1, 
dag=dag) for i in range(1, 3)
+        )
+        tgop3, tgop4 = (
+            BaseOperator(task_id=f"task_group_task{i}", task_group=tg2, 
dag=dag) for i in range(1, 3)
+        )
+        chain_linear(op1, tg1, tg2, op2)
+
+        assert set(op1.get_direct_relatives(upstream=False)) == {tgop1, tgop2}
+        assert set(tgop1.get_direct_relatives(upstream=False)) == {tgop3, 
tgop4}
+        assert set(tgop2.get_direct_relatives(upstream=False)) == {tgop3, 
tgop4}
+        assert set(tgop3.get_direct_relatives(upstream=False)) == {op2}
+        assert set(tgop4.get_direct_relatives(upstream=False)) == {op2}
+
+        t1, t2 = (BaseOperator(task_id=f"t-{i}", dag=dag) for i in range(1, 3))
+        with pytest.raises(ValueError, match="Labels are not supported"):
+            chain_linear(t1, Label("hi"), t2)
+
     def test_chain_not_support_type(self):
         dag = DAG(dag_id="test_chain", start_date=datetime.now())
         [op1, op2] = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 
3)]

Reply via email to