[ 
https://issues.apache.org/jira/browse/AIRFLOW-3375?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16694810#comment-16694810
 ] 

ASF GitHub Bot commented on AIRFLOW-3375:
-----------------------------------------

Fokko closed pull request #4215: [AIRFLOW-3375] Support returning multiple 
tasks with BranchPythonOperator
URL: https://github.com/apache/incubator-airflow/pull/4215
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/operators/python_operator.py 
b/airflow/operators/python_operator.py
index 9b31838b0c..a92cb86642 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -114,14 +114,14 @@ def execute_callable(self):
 
 class BranchPythonOperator(PythonOperator, SkipMixin):
     """
-    Allows a workflow to "branch" or follow a single path following the
-    execution of this task.
+    Allows a workflow to "branch" or follow a path following the execution
+    of this task.
 
     It derives the PythonOperator and expects a Python function that returns
-    the task_id to follow. The task_id returned should point to a task
-    directly downstream from {self}. All other "branches" or
-    directly downstream tasks are marked with a state of ``skipped`` so that
-    these paths can't move forward. The ``skipped`` states are propageted
+    a single task_id or list of task_ids to follow. The task_id(s) returned
+    should point to a task directly downstream from {self}. All other 
"branches"
+    or directly downstream tasks are marked with a state of ``skipped`` so that
+    these paths can't move forward. The ``skipped`` states are propagated
     downstream to allow for the DAG state to fill up and the DAG run's state
     to be inferred.
 
@@ -133,13 +133,15 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
     """
     def execute(self, context):
         branch = super(BranchPythonOperator, self).execute(context)
+        if isinstance(branch, str):
+            branch = [branch]
         self.log.info("Following branch %s", branch)
         self.log.info("Marking other directly downstream tasks as skipped")
 
         downstream_tasks = context['task'].downstream_list
         self.log.debug("Downstream task_ids %s", downstream_tasks)
 
-        skip_tasks = [t for t in downstream_tasks if t.task_id != branch]
+        skip_tasks = [t for t in downstream_tasks if t.task_id not in branch]
         if downstream_tasks:
             self.skip(context['dag_run'], context['ti'].execution_date, 
skip_tasks)
 
diff --git a/docs/concepts.rst b/docs/concepts.rst
index 2896010248..8753958af3 100644
--- a/docs/concepts.rst
+++ b/docs/concepts.rst
@@ -500,8 +500,8 @@ that happened in an upstream task. One way to do this is by 
using the
 ``BranchPythonOperator``.
 
 The ``BranchPythonOperator`` is much like the PythonOperator except that it
-expects a python_callable that returns a task_id. The task_id returned
-is followed, and all of the other paths are skipped.
+expects a python_callable that returns a task_id (or list of task_ids). The
+task_id returned is followed, and all of the other paths are skipped.
 The task_id returned by the Python function has to be referencing a task
 directly downstream from the BranchPythonOperator task.
 
diff --git a/tests/operators/test_python_operator.py 
b/tests/operators/test_python_operator.py
index afc2a1383a..dd830b899c 100644
--- a/tests/operators/test_python_operator.py
+++ b/tests/operators/test_python_operator.py
@@ -183,15 +183,9 @@ def setUp(self):
                            'owner': 'airflow',
                            'start_date': DEFAULT_DATE},
                        schedule_interval=INTERVAL)
-        self.branch_op = BranchPythonOperator(task_id='make_choice',
-                                              dag=self.dag,
-                                              python_callable=lambda: 
'branch_1')
 
         self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
-        self.branch_1.set_upstream(self.branch_op)
         self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
-        self.branch_2.set_upstream(self.branch_op)
-        self.dag.clear()
 
     def tearDown(self):
         super(BranchOperatorTest, self).tearDown()
@@ -206,6 +200,13 @@ def tearDown(self):
 
     def test_without_dag_run(self):
         """This checks the defensive against non existent tasks in a dag run"""
+        self.branch_op = BranchPythonOperator(task_id='make_choice',
+                                              dag=self.dag,
+                                              python_callable=lambda: 
'branch_1')
+        self.branch_1.set_upstream(self.branch_op)
+        self.branch_2.set_upstream(self.branch_op)
+        self.dag.clear()
+
         self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
         session = Session()
@@ -226,7 +227,48 @@ def test_without_dag_run(self):
             else:
                 raise
 
+    def test_branch_list_without_dag_run(self):
+        """This checks if the BranchPythonOperator supports branching off to a 
list of tasks."""
+        self.branch_op = BranchPythonOperator(task_id='make_choice',
+                                              dag=self.dag,
+                                              python_callable=lambda: 
['branch_1', 'branch_2'])
+        self.branch_1.set_upstream(self.branch_op)
+        self.branch_2.set_upstream(self.branch_op)
+        self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag)
+        self.branch_3.set_upstream(self.branch_op)
+        self.dag.clear()
+
+        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        session = Session()
+        tis = session.query(TI).filter(
+            TI.dag_id == self.dag.dag_id,
+            TI.execution_date == DEFAULT_DATE
+        )
+        session.close()
+
+        expected = {
+            "make_choice": State.SUCCESS,
+            "branch_1": State.NONE,
+            "branch_2": State.NONE,
+            "branch_3": State.SKIPPED,
+        }
+
+        for ti in tis:
+            if ti.task_id in expected:
+                self.assertEquals(ti.state, expected[ti.task_id])
+            else:
+                raise
+
     def test_with_dag_run(self):
+        self.branch_op = BranchPythonOperator(task_id='make_choice',
+                                              dag=self.dag,
+                                              python_callable=lambda: 
'branch_1')
+
+        self.branch_1.set_upstream(self.branch_op)
+        self.branch_2.set_upstream(self.branch_op)
+        self.dag.clear()
+
         dr = self.dag.create_dagrun(
             run_id="manual__",
             start_date=timezone.utcnow(),


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


> Support returning multiple tasks with BranchPythonOperator
> ----------------------------------------------------------
>
>                 Key: AIRFLOW-3375
>                 URL: https://issues.apache.org/jira/browse/AIRFLOW-3375
>             Project: Apache Airflow
>          Issue Type: Improvement
>            Reporter: Bas Harenslak
>            Assignee: Bas Harenslak
>            Priority: Major
>             Fix For: 2.0.0
>
>
> I hit a case where I'm using the BranchPythonOperator and want to branch to 
> multiple tasks, so I added support to returning a list of task ids.
> Both a single task id (string type) and list of task ids are supported.
> PR: https://github.com/apache/incubator-airflow/pull/4215



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to