This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0ef0a6beb4b Fix ``BranchPythonOperator`` failure when callable returns 
None (#54991)
0ef0a6beb4b is described below

commit 0ef0a6beb4bbd635c9094723dc93194efd0239b7
Author: Kaxil Naik <[email protected]>
AuthorDate: Wed Aug 27 19:56:02 2025 +0100

    Fix ``BranchPythonOperator`` failure when callable returns None (#54991)
    
    `BranchPythonOperator` now properly handles callables that return None
    by skipping all downstream tasks, instead of throwing an execution error.
    This restores the expected behavior for users who rely on None returns
    to skip branches conditionally.
    
    Fixes #54340
---
 .../airflow/providers/standard/operators/branch.py | 16 +++++++++----
 .../tests/unit/standard/operators/test_python.py   | 27 ++++++++++++++++++++++
 2 files changed, 38 insertions(+), 5 deletions(-)

diff --git 
a/providers/standard/src/airflow/providers/standard/operators/branch.py 
b/providers/standard/src/airflow/providers/standard/operators/branch.py
index 6d670c179cd..cc405b365c9 100644
--- a/providers/standard/src/airflow/providers/standard/operators/branch.py
+++ b/providers/standard/src/airflow/providers/standard/operators/branch.py
@@ -37,11 +37,17 @@ if TYPE_CHECKING:
 class BranchMixIn(SkipMixin):
     """Utility helper which handles the branching as one-liner."""
 
-    def do_branch(self, context: Context, branches_to_execute: str | 
Iterable[str]) -> str | Iterable[str]:
+    def do_branch(
+        self, context: Context, branches_to_execute: str | Iterable[str] | None
+    ) -> str | Iterable[str] | None:
         """Implement the handling of branching including logging."""
         self.log.info("Branch into %s", branches_to_execute)
-        branch_task_ids = self._expand_task_group_roots(context["ti"], 
branches_to_execute)
-        self.skip_all_except(context["ti"], branch_task_ids)
+        if branches_to_execute is None:
+            # When None is returned, skip all downstream tasks
+            self.skip_all_except(context["ti"], None)
+        else:
+            branch_task_ids = self._expand_task_group_roots(context["ti"], 
branches_to_execute)
+            self.skip_all_except(context["ti"], branch_task_ids)
         return branches_to_execute
 
     def _expand_task_group_roots(
@@ -86,13 +92,13 @@ class BaseBranchOperator(BaseOperator, BranchMixIn):
 
     inherits_from_skipmixin = True
 
-    def choose_branch(self, context: Context) -> str | Iterable[str]:
+    def choose_branch(self, context: Context) -> str | Iterable[str] | None:
         """
         Abstract method to choose which branch to run.
 
         Subclasses should implement this, running whatever logic is
         necessary to choose a branch and returning a task_id or list of
-        task_ids.
+        task_ids. If None is returned, all downstream tasks will be skipped.
 
         :param context: Context dictionary as passed to execute()
         """
diff --git a/providers/standard/tests/unit/standard/operators/test_python.py 
b/providers/standard/tests/unit/standard/operators/test_python.py
index 63bcd03aa31..e0c9d70245a 100644
--- a/providers/standard/tests/unit/standard/operators/test_python.py
+++ b/providers/standard/tests/unit/standard/operators/test_python.py
@@ -552,6 +552,33 @@ class TestBranchOperator(BasePythonTest):
         ):
             ti.run()
 
+    def test_none_return_value_should_skip_all_downstream(self):
+        """Test that returning None from callable should skip all downstream 
tasks."""
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, serialized=True):
+
+            def return_none():
+                return None
+
+            branch_op = self.opcls(task_id=self.task_id, 
python_callable=return_none, **self.default_kwargs())
+            branch_op >> [self.branch_1, self.branch_2]
+
+        dr = self.dag_maker.create_dagrun()
+        if AIRFLOW_V_3_0_1:
+            from airflow.exceptions import DownstreamTasksSkipped
+
+            with pytest.raises(DownstreamTasksSkipped) as dts:
+                self.dag_maker.run_ti(self.task_id, dr)
+
+            # When None is returned, all downstream tasks should be skipped
+            expected_skipped = {("branch_1", -1), ("branch_2", -1)}
+            assert set(dts.value.tasks) == expected_skipped
+        else:
+            self.dag_maker.run_ti(self.task_id, dr)
+            self.assert_expected_task_states(
+                dr, {self.task_id: State.SUCCESS, "branch_1": State.SKIPPED, 
"branch_2": State.SKIPPED}
+            )
+
     @pytest.mark.parametrize(
         "choice,expected_states",
         [

Reply via email to