BasPH closed pull request #4356: [AIRFLOW-3556] Add cross join set downstream function URL: https://github.com/apache/incubator-airflow/pull/4356
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 328147c1cf..5f8c88879c 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -169,6 +169,37 @@ def chain(*tasks): up_task.set_downstream(down_task) +def cross_downstream(from_tasks, to_tasks): + """ + Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks. + E.g.: cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6]) + Is equivalent to: + + t1 --> t4 + \ / + t2 -X> t5 + / \ + t3 --> t6 + + t1.set_downstream(t4) + t1.set_downstream(t5) + t1.set_downstream(t6) + t2.set_downstream(t4) + t2.set_downstream(t5) + t2.set_downstream(t6) + t3.set_downstream(t4) + t3.set_downstream(t5) + t3.set_downstream(t6) + + :param from_tasks: List of tasks to start from. + :type from_tasks: List[airflow.models.BaseOperator] + :param to_tasks: List of tasks to set as downstream dependencies. + :type to_tasks: List[airflow.models.BaseOperator] + """ + for task in from_tasks: + task.set_downstream(to_tasks) + + def pprinttable(rows): """Returns a pretty ascii table from tuples diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 4cb3e1a1fc..837a79acba 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -20,11 +20,16 @@ import logging import multiprocessing import os -import psutil import signal import time import unittest +from datetime import datetime + +import psutil +import six +from airflow import DAG +from airflow.operators.dummy_operator import DummyOperator from airflow.utils import helpers @@ -210,6 +215,16 @@ def test_is_container(self): # Pass an object that is not iter nor a string. self.assertFalse(helpers.is_container(10)) + def test_cross_downstream(self): + """Test if all dependencies between tasks are all set correctly.""" + dag = DAG(dag_id="test_dag", start_date=datetime.now()) + start_tasks = [DummyOperator(task_id="t{i}".format(i=i), dag=dag) for i in range(1, 4)] + end_tasks = [DummyOperator(task_id="t{i}".format(i=i), dag=dag) for i in range(4, 7)] + helpers.cross_downstream(from_tasks=start_tasks, to_tasks=end_tasks) + + for start_task in start_tasks: + six.assertCountEqual(self, start_task.get_direct_relatives(upstream=False), end_tasks) + if __name__ == '__main__': unittest.main() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services