BasPH closed pull request #4356: [AIRFLOW-3556] Add cross join set downstream 
function
URL: https://github.com/apache/incubator-airflow/pull/4356
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 328147c1cf..5f8c88879c 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -169,6 +169,37 @@ def chain(*tasks):
         up_task.set_downstream(down_task)
 
 
+def cross_downstream(from_tasks, to_tasks):
+    """
+    Set downstream dependencies for all tasks in from_tasks to all tasks in 
to_tasks.
+    E.g.: cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
+    Is equivalent to:
+
+    t1 --> t4
+       \ /
+    t2 -X> t5
+       / \
+    t3 --> t6
+
+    t1.set_downstream(t4)
+    t1.set_downstream(t5)
+    t1.set_downstream(t6)
+    t2.set_downstream(t4)
+    t2.set_downstream(t5)
+    t2.set_downstream(t6)
+    t3.set_downstream(t4)
+    t3.set_downstream(t5)
+    t3.set_downstream(t6)
+
+    :param from_tasks: List of tasks to start from.
+    :type from_tasks: List[airflow.models.BaseOperator]
+    :param to_tasks: List of tasks to set as downstream dependencies.
+    :type to_tasks: List[airflow.models.BaseOperator]
+    """
+    for task in from_tasks:
+        task.set_downstream(to_tasks)
+
+
 def pprinttable(rows):
     """Returns a pretty ascii table from tuples
 
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
index 4cb3e1a1fc..837a79acba 100644
--- a/tests/utils/test_helpers.py
+++ b/tests/utils/test_helpers.py
@@ -20,11 +20,16 @@
 import logging
 import multiprocessing
 import os
-import psutil
 import signal
 import time
 import unittest
+from datetime import datetime
+
+import psutil
+import six
 
+from airflow import DAG
+from airflow.operators.dummy_operator import DummyOperator
 from airflow.utils import helpers
 
 
@@ -210,6 +215,16 @@ def test_is_container(self):
         # Pass an object that is not iter nor a string.
         self.assertFalse(helpers.is_container(10))
 
+    def test_cross_downstream(self):
+        """Test if all dependencies between tasks are all set correctly."""
+        dag = DAG(dag_id="test_dag", start_date=datetime.now())
+        start_tasks = [DummyOperator(task_id="t{i}".format(i=i), dag=dag) for 
i in range(1, 4)]
+        end_tasks = [DummyOperator(task_id="t{i}".format(i=i), dag=dag) for i 
in range(4, 7)]
+        helpers.cross_downstream(from_tasks=start_tasks, to_tasks=end_tasks)
+
+        for start_task in start_tasks:
+            six.assertCountEqual(self, 
start_task.get_direct_relatives(upstream=False), end_tasks)
+
 
 if __name__ == '__main__':
     unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to