This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new df49ad179b Ensure setup/teardown work on a previously decorated
function (#30216)
df49ad179b is described below
commit df49ad179bddcdb098b3eccbf9bb6361cfbafc36
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Mar 24 18:01:34 2023 +0100
Ensure setup/teardown work on a previously decorated function (#30216)
* Ensure setup/teardown work on a previously decorated function (eg
task.docker)
* Apply suggestions from code review
---
.../cncf/kubernetes/python_kubernetes_script.py | 30 +++++++++-----
airflow/utils/decorators.py | 30 +++++++++-----
tests/decorators/test_external_python.py | 48 +++++++++++++++++++++-
tests/decorators/test_python_virtualenv.py | 48 +++++++++++++++++++++-
.../cncf/kubernetes/decorators/test_kubernetes.py | 46 ++++++++++++++++++++-
tests/providers/docker/decorators/test_docker.py | 45 +++++++++++++++++++-
6 files changed, 221 insertions(+), 26 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.py
b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py
index 785daf6e56..fdf0336d09 100644
--- a/airflow/providers/cncf/kubernetes/python_kubernetes_script.py
+++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py
@@ -39,19 +39,27 @@ def _balance_parens(after_decorator):
def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
"""
- Removed @kubernetes_task
+ Removes @task.kubernetes or similar as well as @setup and @teardown
- :param python_source:
+ :param python_source: python source code
+ :param task_decorator_name: the task decorator name
"""
- if task_decorator_name not in python_source:
- return python_source
- split = python_source.split(task_decorator_name)
- before_decorator, after_decorator = split[0], split[1]
- if after_decorator[0] == "(":
- after_decorator = _balance_parens(after_decorator)
- if after_decorator[0] == "\n":
- after_decorator = after_decorator[1:]
- return before_decorator + after_decorator
+
+ def _remove_task_decorator(py_source, decorator_name):
+ if decorator_name not in py_source:
+ return python_source
+ split = python_source.split(decorator_name)
+ before_decorator, after_decorator = split[0], split[1]
+ if after_decorator[0] == "(":
+ after_decorator = _balance_parens(after_decorator)
+ if after_decorator[0] == "\n":
+ after_decorator = after_decorator[1:]
+ return before_decorator + after_decorator
+
+ decorators = ["@setup", "@teardown", task_decorator_name]
+ for decorator in decorators:
+ python_source = _remove_task_decorator(python_source, decorator)
+ return python_source
def write_python_script(
diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py
index 35dc23e85a..645e65a637 100644
--- a/airflow/utils/decorators.py
+++ b/airflow/utils/decorators.py
@@ -58,19 +58,27 @@ def apply_defaults(func: T) -> T:
def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
"""
- Removed @task.
+ Removes @task or similar decorators as well as @setup and @teardown
- :param python_source:
+ :param python_source: The python source code
+ :param task_decorator_name: the decorator name
"""
- if task_decorator_name not in python_source:
- return python_source
- split = python_source.split(task_decorator_name)
- before_decorator, after_decorator = split[0], split[1]
- if after_decorator[0] == "(":
- after_decorator = _balance_parens(after_decorator)
- if after_decorator[0] == "\n":
- after_decorator = after_decorator[1:]
- return before_decorator + after_decorator
+
+ def _remove_task_decorator(py_source, decorator_name):
+ if decorator_name not in py_source:
+ return python_source
+ split = python_source.split(decorator_name)
+ before_decorator, after_decorator = split[0], split[1]
+ if after_decorator[0] == "(":
+ after_decorator = _balance_parens(after_decorator)
+ if after_decorator[0] == "\n":
+ after_decorator = after_decorator[1:]
+ return before_decorator + after_decorator
+
+ decorators = ["@setup", "@teardown", task_decorator_name]
+ for decorator in decorators:
+ python_source = _remove_task_decorator(python_source, decorator)
+ return python_source
def _balance_parens(after_decorator):
diff --git a/tests/decorators/test_external_python.py
b/tests/decorators/test_external_python.py
index a022a37d7f..84c46127c1 100644
--- a/tests/decorators/test_external_python.py
+++ b/tests/decorators/test_external_python.py
@@ -27,7 +27,7 @@ from tempfile import TemporaryDirectory
import pytest
-from airflow.decorators import task
+from airflow.decorators import setup, task, teardown
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -125,3 +125,49 @@ class TestExternalPythonDecorator:
ret = f(datetime.datetime.utcnow())
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ def test_marking_external_python_task_as_setup(self, dag_maker,
venv_python):
+ @setup
+ @task.external_python(python=venv_python)
+ def f():
+ return 1
+
+ with dag_maker() as dag:
+ ret = f()
+
+ assert len(dag.task_group.children) == 1
+ setup_task = dag.task_group.children["f"]
+ assert setup_task._is_setup
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ def test_marking_external_python_task_as_teardown(self, dag_maker,
venv_python):
+ @teardown
+ @task.external_python(python=venv_python)
+ def f():
+ return 1
+
+ with dag_maker() as dag:
+ ret = f()
+
+ assert len(dag.task_group.children) == 1
+ teardown_task = dag.task_group.children["f"]
+ assert teardown_task._is_teardown
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+ def test_marking_external_python_task_as_teardown_with_on_failure_fail(
+ self, dag_maker, on_failure_fail_dagrun, venv_python
+ ):
+ @teardown(on_failure_fail_dagrun=on_failure_fail_dagrun)
+ @task.external_python(python=venv_python)
+ def f():
+ return 1
+
+ with dag_maker() as dag:
+ ret = f()
+
+ assert len(dag.task_group.children) == 1
+ teardown_task = dag.task_group.children["f"]
+ assert teardown_task._is_teardown
+ assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
diff --git a/tests/decorators/test_python_virtualenv.py
b/tests/decorators/test_python_virtualenv.py
index 88121c5db3..ec37c3586e 100644
--- a/tests/decorators/test_python_virtualenv.py
+++ b/tests/decorators/test_python_virtualenv.py
@@ -23,7 +23,7 @@ from subprocess import CalledProcessError
import pytest
-from airflow.decorators import task
+from airflow.decorators import setup, task, teardown
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -176,3 +176,49 @@ class TestPythonVirtualenvDecorator:
ret = f(datetime.datetime.utcnow())
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ def test_marking_virtualenv_python_task_as_setup(self, dag_maker):
+ @setup
+ @task.virtualenv
+ def f():
+ return 1
+
+ with dag_maker() as dag:
+ ret = f()
+
+ assert len(dag.task_group.children) == 1
+ setup_task = dag.task_group.children["f"]
+ assert setup_task._is_setup
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ def test_marking_virtualenv_python_task_as_teardown(self, dag_maker):
+ @teardown
+ @task.virtualenv
+ def f():
+ return 1
+
+ with dag_maker() as dag:
+ ret = f()
+
+ assert len(dag.task_group.children) == 1
+ teardown_task = dag.task_group.children["f"]
+ assert teardown_task._is_teardown
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+ def test_marking_virtualenv_python_task_as_teardown_with_on_failure_fail(
+ self, dag_maker, on_failure_fail_dagrun
+ ):
+ @teardown(on_failure_fail_dagrun=on_failure_fail_dagrun)
+ @task.virtualenv
+ def f():
+ return 1
+
+ with dag_maker() as dag:
+ ret = f()
+
+ assert len(dag.task_group.children) == 1
+ teardown_task = dag.task_group.children["f"]
+ assert teardown_task._is_teardown
+ assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
index 853d056e23..dc1b9b2bc8 100644
--- a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
@@ -22,7 +22,7 @@ from unittest import mock
import pytest
-from airflow.decorators import task
+from airflow.decorators import setup, task, teardown
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2021, 9, 1)
@@ -159,3 +159,47 @@ def test_kubernetes_with_input_output(
# Second container is xcom image
assert containers[1].image == XCOM_IMAGE
assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"
+
+
+def test_kubernetes_with_marked_as_setup(
+ dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
+) -> None:
+ with dag_maker(session=session) as dag:
+
+ @setup
+ @task.kubernetes(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ def f():
+ return {"key1": "value1", "key2": "value2"}
+
+ f()
+
+ assert len(dag.task_group.children) == 1
+ setup_task = dag.task_group.children["f"]
+ assert setup_task._is_setup
+
+
+def test_kubernetes_with_marked_as_teardown(
+ dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
+) -> None:
+ with dag_maker(session=session) as dag:
+
+ @teardown
+ @task.kubernetes(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ def f():
+ return {"key1": "value1", "key2": "value2"}
+
+ f()
+
+ assert len(dag.task_group.children) == 1
+ teardown_task = dag.task_group.children["f"]
+ assert teardown_task._is_teardown
diff --git a/tests/providers/docker/decorators/test_docker.py
b/tests/providers/docker/decorators/test_docker.py
index 2ed4dc512c..8d6de4d220 100644
--- a/tests/providers/docker/decorators/test_docker.py
+++ b/tests/providers/docker/decorators/test_docker.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import pytest
-from airflow.decorators import task
+from airflow.decorators import setup, task, teardown
from airflow.exceptions import AirflowException
from airflow.models import TaskInstance
from airflow.models.dag import DAG
@@ -144,3 +144,46 @@ class TestDockerDecorator:
ret.operator.run(start_date=dr.execution_date,
end_date=dr.execution_date)
ti = dr.get_task_instances()[0]
assert ti.state == expected_state
+
+ def test_setup_decorator_with_decorated_docker_task(self, dag_maker):
+ @setup
+ @task.docker(image="python:3.9-slim", auto_remove="force")
+ def f():
+ pass
+
+ with dag_maker() as dag:
+ f()
+
+ assert len(dag.task_group.children) == 1
+ setup_task = dag.task_group.children["f"]
+ assert setup_task._is_setup
+
+ def test_teardown_decorator_with_decorated_docker_task(self, dag_maker):
+ @teardown
+ @task.docker(image="python:3.9-slim", auto_remove="force")
+ def f():
+ pass
+
+ with dag_maker() as dag:
+ f()
+
+ assert len(dag.task_group.children) == 1
+ teardown_task = dag.task_group.children["f"]
+ assert teardown_task._is_teardown
+
+ @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+ def
test_teardown_decorator_with_decorated_docker_task_and_on_failure_fail_arg(
+ self, dag_maker, on_failure_fail_dagrun
+ ):
+ @teardown(on_failure_fail_dagrun=on_failure_fail_dagrun)
+ @task.docker(image="python:3.9-slim", auto_remove="force")
+ def f():
+ pass
+
+ with dag_maker() as dag:
+ f()
+
+ assert len(dag.task_group.children) == 1
+ teardown_task = dag.task_group.children["f"]
+ assert teardown_task._is_teardown
+ assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun