[ https://issues.apache.org/jira/browse/AIRFLOW-3375?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16694810#comment-16694810 ]
ASF GitHub Bot commented on AIRFLOW-3375: ----------------------------------------- Fokko closed pull request #4215: [AIRFLOW-3375] Support returning multiple tasks with BranchPythonOperator URL: https://github.com/apache/incubator-airflow/pull/4215 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 9b31838b0c..a92cb86642 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -114,14 +114,14 @@ def execute_callable(self): class BranchPythonOperator(PythonOperator, SkipMixin): """ - Allows a workflow to "branch" or follow a single path following the - execution of this task. + Allows a workflow to "branch" or follow a path following the execution + of this task. It derives the PythonOperator and expects a Python function that returns - the task_id to follow. The task_id returned should point to a task - directly downstream from {self}. All other "branches" or - directly downstream tasks are marked with a state of ``skipped`` so that - these paths can't move forward. The ``skipped`` states are propageted + a single task_id or list of task_ids to follow. The task_id(s) returned + should point to a task directly downstream from {self}. All other "branches" + or directly downstream tasks are marked with a state of ``skipped`` so that + these paths can't move forward. The ``skipped`` states are propagated downstream to allow for the DAG state to fill up and the DAG run's state to be inferred. @@ -133,13 +133,15 @@ class BranchPythonOperator(PythonOperator, SkipMixin): """ def execute(self, context): branch = super(BranchPythonOperator, self).execute(context) + if isinstance(branch, str): + branch = [branch] self.log.info("Following branch %s", branch) self.log.info("Marking other directly downstream tasks as skipped") downstream_tasks = context['task'].downstream_list self.log.debug("Downstream task_ids %s", downstream_tasks) - skip_tasks = [t for t in downstream_tasks if t.task_id != branch] + skip_tasks = [t for t in downstream_tasks if t.task_id not in branch] if downstream_tasks: self.skip(context['dag_run'], context['ti'].execution_date, skip_tasks) diff --git a/docs/concepts.rst b/docs/concepts.rst index 2896010248..8753958af3 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -500,8 +500,8 @@ that happened in an upstream task. One way to do this is by using the ``BranchPythonOperator``. The ``BranchPythonOperator`` is much like the PythonOperator except that it -expects a python_callable that returns a task_id. The task_id returned -is followed, and all of the other paths are skipped. +expects a python_callable that returns a task_id (or list of task_ids). The +task_id returned is followed, and all of the other paths are skipped. The task_id returned by the Python function has to be referencing a task directly downstream from the BranchPythonOperator task. diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index afc2a1383a..dd830b899c 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -183,15 +183,9 @@ def setUp(self): 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) - self.branch_op = BranchPythonOperator(task_id='make_choice', - dag=self.dag, - python_callable=lambda: 'branch_1') self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) - self.branch_1.set_upstream(self.branch_op) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) - self.branch_2.set_upstream(self.branch_op) - self.dag.clear() def tearDown(self): super(BranchOperatorTest, self).tearDown() @@ -206,6 +200,13 @@ def tearDown(self): def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" + self.branch_op = BranchPythonOperator(task_id='make_choice', + dag=self.dag, + python_callable=lambda: 'branch_1') + self.branch_1.set_upstream(self.branch_op) + self.branch_2.set_upstream(self.branch_op) + self.dag.clear() + self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() @@ -226,7 +227,48 @@ def test_without_dag_run(self): else: raise + def test_branch_list_without_dag_run(self): + """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" + self.branch_op = BranchPythonOperator(task_id='make_choice', + dag=self.dag, + python_callable=lambda: ['branch_1', 'branch_2']) + self.branch_1.set_upstream(self.branch_op) + self.branch_2.set_upstream(self.branch_op) + self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) + self.branch_3.set_upstream(self.branch_op) + self.dag.clear() + + self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + session = Session() + tis = session.query(TI).filter( + TI.dag_id == self.dag.dag_id, + TI.execution_date == DEFAULT_DATE + ) + session.close() + + expected = { + "make_choice": State.SUCCESS, + "branch_1": State.NONE, + "branch_2": State.NONE, + "branch_3": State.SKIPPED, + } + + for ti in tis: + if ti.task_id in expected: + self.assertEquals(ti.state, expected[ti.task_id]) + else: + raise + def test_with_dag_run(self): + self.branch_op = BranchPythonOperator(task_id='make_choice', + dag=self.dag, + python_callable=lambda: 'branch_1') + + self.branch_1.set_upstream(self.branch_op) + self.branch_2.set_upstream(self.branch_op) + self.dag.clear() + dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Support returning multiple tasks with BranchPythonOperator > ---------------------------------------------------------- > > Key: AIRFLOW-3375 > URL: https://issues.apache.org/jira/browse/AIRFLOW-3375 > Project: Apache Airflow > Issue Type: Improvement > Reporter: Bas Harenslak > Assignee: Bas Harenslak > Priority: Major > Fix For: 2.0.0 > > > I hit a case where I'm using the BranchPythonOperator and want to branch to > multiple tasks, so I added support to returning a list of task ids. > Both a single task id (string type) and list of task ids are supported. > PR: https://github.com/apache/incubator-airflow/pull/4215 -- This message was sent by Atlassian JIRA (v7.6.3#76005)