This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d4e69b38c Add BranchPythonVirtualenvOperator (#33356)
9d4e69b38c is described below

commit 9d4e69b38c01219cd0b9bcc3c60b9c77f484d141
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Sep 11 11:41:32 2023 +0200

    Add BranchPythonVirtualenvOperator (#33356)
---
 airflow/operators/branch.py    | 16 ++++++++++++----
 airflow/operators/python.py    | 32 ++++++++++++++++++++++----------
 tests/operators/test_python.py | 27 +++++++++++++++++++--------
 3 files changed, 53 insertions(+), 22 deletions(-)

diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py
index 066ee52187..4288775f6b 100644
--- a/airflow/operators/branch.py
+++ b/airflow/operators/branch.py
@@ -27,7 +27,17 @@ if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class BaseBranchOperator(BaseOperator, SkipMixin):
+class BranchMixIn(SkipMixin):
+    """Utility helper which handles the branching as one-liner."""
+
+    def do_branch(self, context: Context, branches_to_execute: str | 
Iterable[str]) -> str | Iterable[str]:
+        """Implement the handling of branching including logging."""
+        self.log.info("Branch into %s", branches_to_execute)
+        self.skip_all_except(context["ti"], branches_to_execute)
+        return branches_to_execute
+
+
+class BaseBranchOperator(BaseOperator, BranchMixIn):
     """
     A base class for creating operators with branching functionality, like to 
BranchPythonOperator.
 
@@ -53,6 +63,4 @@ class BaseBranchOperator(BaseOperator, SkipMixin):
         raise NotImplementedError
 
     def execute(self, context: Context):
-        branches_to_execute = self.choose_branch(context)
-        self.skip_all_except(context["ti"], branches_to_execute)
-        return branches_to_execute
+        return self.do_branch(context, self.choose_branch(context))
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 4382884178..4dd8727519 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -46,6 +46,7 @@ from airflow.exceptions import (
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.skipmixin import SkipMixin
 from airflow.models.taskinstance import _CURRENT_CONTEXT
+from airflow.operators.branch import BranchMixIn
 from airflow.utils.context import context_copy_partial, context_merge
 from airflow.utils.operator_helpers import KeywordParameters
 from airflow.utils.process_utils import execute_in_subprocess
@@ -211,7 +212,7 @@ class PythonOperator(BaseOperator):
         return self.python_callable(*self.op_args, **self.op_kwargs)
 
 
-class BranchPythonOperator(PythonOperator, SkipMixin):
+class BranchPythonOperator(PythonOperator, BranchMixIn):
     """
     A workflow can "branch" or follow a path after the execution of this task.
 
@@ -225,10 +226,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
     """
 
     def execute(self, context: Context) -> Any:
-        branch = super().execute(context)
-        self.log.info("Branch callable return %s", branch)
-        self.skip_all_except(context["ti"], branch)
-        return branch
+        return self.do_branch(context, super().execute(context))
 
 
 class ShortCircuitOperator(PythonOperator, SkipMixin):
@@ -625,6 +623,23 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
             yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
 
 
+class BranchPythonVirtualenvOperator(PythonVirtualenvOperator, BranchMixIn):
+    """
+    A workflow can "branch" or follow a path after the execution of this task 
in a virtualenv.
+
+    It derives the PythonVirtualenvOperator and expects a Python function that 
returns
+    a single task_id or list of task_ids to follow. The task_id(s) returned
+    should point to a task directly downstream from {self}. All other 
"branches"
+    or directly downstream tasks are marked with a state of ``skipped`` so that
+    these paths can't move forward. The ``skipped`` states are propagated
+    downstream to allow for the DAG state to fill up and the DAG run's state
+    to be inferred.
+    """
+
+    def execute(self, context: Context) -> Any:
+        return self.do_branch(context, super().execute(context))
+
+
 class ExternalPythonOperator(_BasePythonVirtualenvOperator):
     """
     Run a function in a virtualenv that is not re-created.
@@ -792,7 +807,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator):
             return None
 
 
-class BranchExternalPythonOperator(ExternalPythonOperator, SkipMixin):
+class BranchExternalPythonOperator(ExternalPythonOperator, BranchMixIn):
     """
     A workflow can "branch" or follow a path after the execution of this task.
 
@@ -802,10 +817,7 @@ class BranchExternalPythonOperator(ExternalPythonOperator, 
SkipMixin):
     """
 
     def execute(self, context: Context) -> Any:
-        branch = super().execute(context)
-        self.log.info("Branch callable return %s", branch)
-        self.skip_all_except(context["ti"], branch)
-        return branch
+        return self.do_branch(context, super().execute(context))
 
 
 def get_current_context() -> Context:
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 85c2dc976f..d20949c618 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -42,6 +42,7 @@ from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import (
     BranchExternalPythonOperator,
     BranchPythonOperator,
+    BranchPythonVirtualenvOperator,
     ExternalPythonOperator,
     PythonOperator,
     PythonVirtualenvOperator,
@@ -1121,19 +1122,12 @@ class 
TestExternalPythonOperator(BaseTestPythonVirtualenvOperator):
             task._read_result(path=mock.Mock())
 
 
-class TestBranchExternalPythonOperator(BaseTestPythonVirtualenvOperator):
-    opcls = BranchExternalPythonOperator
-
+class BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
     @pytest.fixture(autouse=True)
     def setup_tests(self):
         self.branch_1 = EmptyOperator(task_id="branch_1")
         self.branch_2 = EmptyOperator(task_id="branch_2")
 
-    @staticmethod
-    def default_kwargs(*, python_version=sys.version_info[0], **kwargs):
-        kwargs["python"] = sys.executable
-        return kwargs
-
     def test_with_args(self):
         def f(a, b, c=False, d=False):
             if a == 0 and b == 1 and c and not d:
@@ -1280,6 +1274,23 @@ class 
TestBranchExternalPythonOperator(BaseTestPythonVirtualenvOperator):
             ti.run()
 
 
+class 
TestBranchPythonVirtualenvOperator(BaseTestBranchPythonVirtualenvOperator):
+    opcls = BranchPythonVirtualenvOperator
+
+    @staticmethod
+    def default_kwargs(*, python_version=sys.version_info[0], **kwargs):
+        return kwargs
+
+
+class TestBranchExternalPythonOperator(BaseTestBranchPythonVirtualenvOperator):
+    opcls = BranchExternalPythonOperator
+
+    @staticmethod
+    def default_kwargs(*, python_version=sys.version_info[0], **kwargs):
+        kwargs["python"] = sys.executable
+        return kwargs
+
+
 class TestCurrentContext:
     def test_current_context_no_context_raise(self):
         with pytest.raises(AirflowException):

Reply via email to