This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9d4e69b38c Add BranchPythonVirtualenvOperator (#33356)
9d4e69b38c is described below
commit 9d4e69b38c01219cd0b9bcc3c60b9c77f484d141
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Sep 11 11:41:32 2023 +0200
Add BranchPythonVirtualenvOperator (#33356)
---
airflow/operators/branch.py | 16 ++++++++++++----
airflow/operators/python.py | 32 ++++++++++++++++++++++----------
tests/operators/test_python.py | 27 +++++++++++++++++++--------
3 files changed, 53 insertions(+), 22 deletions(-)
diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py
index 066ee52187..4288775f6b 100644
--- a/airflow/operators/branch.py
+++ b/airflow/operators/branch.py
@@ -27,7 +27,17 @@ if TYPE_CHECKING:
from airflow.utils.context import Context
-class BaseBranchOperator(BaseOperator, SkipMixin):
+class BranchMixIn(SkipMixin):
+ """Utility helper which handles the branching as one-liner."""
+
+ def do_branch(self, context: Context, branches_to_execute: str |
Iterable[str]) -> str | Iterable[str]:
+ """Implement the handling of branching including logging."""
+ self.log.info("Branch into %s", branches_to_execute)
+ self.skip_all_except(context["ti"], branches_to_execute)
+ return branches_to_execute
+
+
+class BaseBranchOperator(BaseOperator, BranchMixIn):
"""
A base class for creating operators with branching functionality, like to
BranchPythonOperator.
@@ -53,6 +63,4 @@ class BaseBranchOperator(BaseOperator, SkipMixin):
raise NotImplementedError
def execute(self, context: Context):
- branches_to_execute = self.choose_branch(context)
- self.skip_all_except(context["ti"], branches_to_execute)
- return branches_to_execute
+ return self.do_branch(context, self.choose_branch(context))
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 4382884178..4dd8727519 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -46,6 +46,7 @@ from airflow.exceptions import (
from airflow.models.baseoperator import BaseOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import _CURRENT_CONTEXT
+from airflow.operators.branch import BranchMixIn
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
@@ -211,7 +212,7 @@ class PythonOperator(BaseOperator):
return self.python_callable(*self.op_args, **self.op_kwargs)
-class BranchPythonOperator(PythonOperator, SkipMixin):
+class BranchPythonOperator(PythonOperator, BranchMixIn):
"""
A workflow can "branch" or follow a path after the execution of this task.
@@ -225,10 +226,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
"""
def execute(self, context: Context) -> Any:
- branch = super().execute(context)
- self.log.info("Branch callable return %s", branch)
- self.skip_all_except(context["ti"], branch)
- return branch
+ return self.do_branch(context, super().execute(context))
class ShortCircuitOperator(PythonOperator, SkipMixin):
@@ -625,6 +623,23 @@ class
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
+class BranchPythonVirtualenvOperator(PythonVirtualenvOperator, BranchMixIn):
+ """
+ A workflow can "branch" or follow a path after the execution of this task
in a virtualenv.
+
+ It derives the PythonVirtualenvOperator and expects a Python function that
returns
+ a single task_id or list of task_ids to follow. The task_id(s) returned
+ should point to a task directly downstream from {self}. All other
"branches"
+ or directly downstream tasks are marked with a state of ``skipped`` so that
+ these paths can't move forward. The ``skipped`` states are propagated
+ downstream to allow for the DAG state to fill up and the DAG run's state
+ to be inferred.
+ """
+
+ def execute(self, context: Context) -> Any:
+ return self.do_branch(context, super().execute(context))
+
+
class ExternalPythonOperator(_BasePythonVirtualenvOperator):
"""
Run a function in a virtualenv that is not re-created.
@@ -792,7 +807,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator):
return None
-class BranchExternalPythonOperator(ExternalPythonOperator, SkipMixin):
+class BranchExternalPythonOperator(ExternalPythonOperator, BranchMixIn):
"""
A workflow can "branch" or follow a path after the execution of this task.
@@ -802,10 +817,7 @@ class BranchExternalPythonOperator(ExternalPythonOperator,
SkipMixin):
"""
def execute(self, context: Context) -> Any:
- branch = super().execute(context)
- self.log.info("Branch callable return %s", branch)
- self.skip_all_except(context["ti"], branch)
- return branch
+ return self.do_branch(context, super().execute(context))
def get_current_context() -> Context:
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 85c2dc976f..d20949c618 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -42,6 +42,7 @@ from airflow.operators.empty import EmptyOperator
from airflow.operators.python import (
BranchExternalPythonOperator,
BranchPythonOperator,
+ BranchPythonVirtualenvOperator,
ExternalPythonOperator,
PythonOperator,
PythonVirtualenvOperator,
@@ -1121,19 +1122,12 @@ class
TestExternalPythonOperator(BaseTestPythonVirtualenvOperator):
task._read_result(path=mock.Mock())
-class TestBranchExternalPythonOperator(BaseTestPythonVirtualenvOperator):
- opcls = BranchExternalPythonOperator
-
+class BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
@pytest.fixture(autouse=True)
def setup_tests(self):
self.branch_1 = EmptyOperator(task_id="branch_1")
self.branch_2 = EmptyOperator(task_id="branch_2")
- @staticmethod
- def default_kwargs(*, python_version=sys.version_info[0], **kwargs):
- kwargs["python"] = sys.executable
- return kwargs
-
def test_with_args(self):
def f(a, b, c=False, d=False):
if a == 0 and b == 1 and c and not d:
@@ -1280,6 +1274,23 @@ class
TestBranchExternalPythonOperator(BaseTestPythonVirtualenvOperator):
ti.run()
+class
TestBranchPythonVirtualenvOperator(BaseTestBranchPythonVirtualenvOperator):
+ opcls = BranchPythonVirtualenvOperator
+
+ @staticmethod
+ def default_kwargs(*, python_version=sys.version_info[0], **kwargs):
+ return kwargs
+
+
+class TestBranchExternalPythonOperator(BaseTestBranchPythonVirtualenvOperator):
+ opcls = BranchExternalPythonOperator
+
+ @staticmethod
+ def default_kwargs(*, python_version=sys.version_info[0], **kwargs):
+ kwargs["python"] = sys.executable
+ return kwargs
+
+
class TestCurrentContext:
def test_current_context_no_context_raise(self):
with pytest.raises(AirflowException):