This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new df49ad179b Ensure setup/teardown work on a previously decorated 
function (#30216)
df49ad179b is described below

commit df49ad179bddcdb098b3eccbf9bb6361cfbafc36
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Mar 24 18:01:34 2023 +0100

    Ensure setup/teardown work on a previously decorated function (#30216)
    
    * Ensure setup/teardown work on a previously decorated function (eg 
task.docker)
    
    * Apply suggestions from code review
---
 .../cncf/kubernetes/python_kubernetes_script.py    | 30 +++++++++-----
 airflow/utils/decorators.py                        | 30 +++++++++-----
 tests/decorators/test_external_python.py           | 48 +++++++++++++++++++++-
 tests/decorators/test_python_virtualenv.py         | 48 +++++++++++++++++++++-
 .../cncf/kubernetes/decorators/test_kubernetes.py  | 46 ++++++++++++++++++++-
 tests/providers/docker/decorators/test_docker.py   | 45 +++++++++++++++++++-
 6 files changed, 221 insertions(+), 26 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.py 
b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py
index 785daf6e56..fdf0336d09 100644
--- a/airflow/providers/cncf/kubernetes/python_kubernetes_script.py
+++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py
@@ -39,19 +39,27 @@ def _balance_parens(after_decorator):
 
 def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
     """
-    Removed @kubernetes_task
+    Removes @task.kubernetes or similar as well as @setup and @teardown
 
-    :param python_source:
+    :param python_source: python source code
+    :param task_decorator_name: the task decorator name
     """
-    if task_decorator_name not in python_source:
-        return python_source
-    split = python_source.split(task_decorator_name)
-    before_decorator, after_decorator = split[0], split[1]
-    if after_decorator[0] == "(":
-        after_decorator = _balance_parens(after_decorator)
-    if after_decorator[0] == "\n":
-        after_decorator = after_decorator[1:]
-    return before_decorator + after_decorator
+
+    def _remove_task_decorator(py_source, decorator_name):
+        if decorator_name not in py_source:
+            return python_source
+        split = python_source.split(decorator_name)
+        before_decorator, after_decorator = split[0], split[1]
+        if after_decorator[0] == "(":
+            after_decorator = _balance_parens(after_decorator)
+        if after_decorator[0] == "\n":
+            after_decorator = after_decorator[1:]
+        return before_decorator + after_decorator
+
+    decorators = ["@setup", "@teardown", task_decorator_name]
+    for decorator in decorators:
+        python_source = _remove_task_decorator(python_source, decorator)
+    return python_source
 
 
 def write_python_script(
diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py
index 35dc23e85a..645e65a637 100644
--- a/airflow/utils/decorators.py
+++ b/airflow/utils/decorators.py
@@ -58,19 +58,27 @@ def apply_defaults(func: T) -> T:
 
 def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
     """
-    Removed @task.
+    Removes @task or similar decorators as well as @setup and @teardown
 
-    :param python_source:
+    :param python_source: The python source code
+    :param task_decorator_name: the decorator name
     """
-    if task_decorator_name not in python_source:
-        return python_source
-    split = python_source.split(task_decorator_name)
-    before_decorator, after_decorator = split[0], split[1]
-    if after_decorator[0] == "(":
-        after_decorator = _balance_parens(after_decorator)
-    if after_decorator[0] == "\n":
-        after_decorator = after_decorator[1:]
-    return before_decorator + after_decorator
+
+    def _remove_task_decorator(py_source, decorator_name):
+        if decorator_name not in py_source:
+            return python_source
+        split = python_source.split(decorator_name)
+        before_decorator, after_decorator = split[0], split[1]
+        if after_decorator[0] == "(":
+            after_decorator = _balance_parens(after_decorator)
+        if after_decorator[0] == "\n":
+            after_decorator = after_decorator[1:]
+        return before_decorator + after_decorator
+
+    decorators = ["@setup", "@teardown", task_decorator_name]
+    for decorator in decorators:
+        python_source = _remove_task_decorator(python_source, decorator)
+    return python_source
 
 
 def _balance_parens(after_decorator):
diff --git a/tests/decorators/test_external_python.py 
b/tests/decorators/test_external_python.py
index a022a37d7f..84c46127c1 100644
--- a/tests/decorators/test_external_python.py
+++ b/tests/decorators/test_external_python.py
@@ -27,7 +27,7 @@ from tempfile import TemporaryDirectory
 
 import pytest
 
-from airflow.decorators import task
+from airflow.decorators import setup, task, teardown
 from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -125,3 +125,49 @@ class TestExternalPythonDecorator:
             ret = f(datetime.datetime.utcnow())
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    def test_marking_external_python_task_as_setup(self, dag_maker, 
venv_python):
+        @setup
+        @task.external_python(python=venv_python)
+        def f():
+            return 1
+
+        with dag_maker() as dag:
+            ret = f()
+
+        assert len(dag.task_group.children) == 1
+        setup_task = dag.task_group.children["f"]
+        assert setup_task._is_setup
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    def test_marking_external_python_task_as_teardown(self, dag_maker, 
venv_python):
+        @teardown
+        @task.external_python(python=venv_python)
+        def f():
+            return 1
+
+        with dag_maker() as dag:
+            ret = f()
+
+        assert len(dag.task_group.children) == 1
+        teardown_task = dag.task_group.children["f"]
+        assert teardown_task._is_teardown
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+    def test_marking_external_python_task_as_teardown_with_on_failure_fail(
+        self, dag_maker, on_failure_fail_dagrun, venv_python
+    ):
+        @teardown(on_failure_fail_dagrun=on_failure_fail_dagrun)
+        @task.external_python(python=venv_python)
+        def f():
+            return 1
+
+        with dag_maker() as dag:
+            ret = f()
+
+        assert len(dag.task_group.children) == 1
+        teardown_task = dag.task_group.children["f"]
+        assert teardown_task._is_teardown
+        assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
diff --git a/tests/decorators/test_python_virtualenv.py 
b/tests/decorators/test_python_virtualenv.py
index 88121c5db3..ec37c3586e 100644
--- a/tests/decorators/test_python_virtualenv.py
+++ b/tests/decorators/test_python_virtualenv.py
@@ -23,7 +23,7 @@ from subprocess import CalledProcessError
 
 import pytest
 
-from airflow.decorators import task
+from airflow.decorators import setup, task, teardown
 from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -176,3 +176,49 @@ class TestPythonVirtualenvDecorator:
             ret = f(datetime.datetime.utcnow())
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    def test_marking_virtualenv_python_task_as_setup(self, dag_maker):
+        @setup
+        @task.virtualenv
+        def f():
+            return 1
+
+        with dag_maker() as dag:
+            ret = f()
+
+        assert len(dag.task_group.children) == 1
+        setup_task = dag.task_group.children["f"]
+        assert setup_task._is_setup
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    def test_marking_virtualenv_python_task_as_teardown(self, dag_maker):
+        @teardown
+        @task.virtualenv
+        def f():
+            return 1
+
+        with dag_maker() as dag:
+            ret = f()
+
+        assert len(dag.task_group.children) == 1
+        teardown_task = dag.task_group.children["f"]
+        assert teardown_task._is_teardown
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+    def test_marking_virtualenv_python_task_as_teardown_with_on_failure_fail(
+        self, dag_maker, on_failure_fail_dagrun
+    ):
+        @teardown(on_failure_fail_dagrun=on_failure_fail_dagrun)
+        @task.virtualenv
+        def f():
+            return 1
+
+        with dag_maker() as dag:
+            ret = f()
+
+        assert len(dag.task_group.children) == 1
+        teardown_task = dag.task_group.children["f"]
+        assert teardown_task._is_teardown
+        assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py 
b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
index 853d056e23..dc1b9b2bc8 100644
--- a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
@@ -22,7 +22,7 @@ from unittest import mock
 
 import pytest
 
-from airflow.decorators import task
+from airflow.decorators import setup, task, teardown
 from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2021, 9, 1)
@@ -159,3 +159,47 @@ def test_kubernetes_with_input_output(
     # Second container is xcom image
     assert containers[1].image == XCOM_IMAGE
     assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"
+
+
+def test_kubernetes_with_marked_as_setup(
+    dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
+) -> None:
+    with dag_maker(session=session) as dag:
+
+        @setup
+        @task.kubernetes(
+            image="python:3.10-slim-buster",
+            in_cluster=False,
+            cluster_context="default",
+            config_file="/tmp/fake_file",
+        )
+        def f():
+            return {"key1": "value1", "key2": "value2"}
+
+        f()
+
+    assert len(dag.task_group.children) == 1
+    setup_task = dag.task_group.children["f"]
+    assert setup_task._is_setup
+
+
+def test_kubernetes_with_marked_as_teardown(
+    dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
+) -> None:
+    with dag_maker(session=session) as dag:
+
+        @teardown
+        @task.kubernetes(
+            image="python:3.10-slim-buster",
+            in_cluster=False,
+            cluster_context="default",
+            config_file="/tmp/fake_file",
+        )
+        def f():
+            return {"key1": "value1", "key2": "value2"}
+
+        f()
+
+    assert len(dag.task_group.children) == 1
+    teardown_task = dag.task_group.children["f"]
+    assert teardown_task._is_teardown
diff --git a/tests/providers/docker/decorators/test_docker.py 
b/tests/providers/docker/decorators/test_docker.py
index 2ed4dc512c..8d6de4d220 100644
--- a/tests/providers/docker/decorators/test_docker.py
+++ b/tests/providers/docker/decorators/test_docker.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import pytest
 
-from airflow.decorators import task
+from airflow.decorators import setup, task, teardown
 from airflow.exceptions import AirflowException
 from airflow.models import TaskInstance
 from airflow.models.dag import DAG
@@ -144,3 +144,46 @@ class TestDockerDecorator:
             ret.operator.run(start_date=dr.execution_date, 
end_date=dr.execution_date)
             ti = dr.get_task_instances()[0]
             assert ti.state == expected_state
+
+    def test_setup_decorator_with_decorated_docker_task(self, dag_maker):
+        @setup
+        @task.docker(image="python:3.9-slim", auto_remove="force")
+        def f():
+            pass
+
+        with dag_maker() as dag:
+            f()
+
+        assert len(dag.task_group.children) == 1
+        setup_task = dag.task_group.children["f"]
+        assert setup_task._is_setup
+
+    def test_teardown_decorator_with_decorated_docker_task(self, dag_maker):
+        @teardown
+        @task.docker(image="python:3.9-slim", auto_remove="force")
+        def f():
+            pass
+
+        with dag_maker() as dag:
+            f()
+
+        assert len(dag.task_group.children) == 1
+        teardown_task = dag.task_group.children["f"]
+        assert teardown_task._is_teardown
+
+    @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+    def 
test_teardown_decorator_with_decorated_docker_task_and_on_failure_fail_arg(
+        self, dag_maker, on_failure_fail_dagrun
+    ):
+        @teardown(on_failure_fail_dagrun=on_failure_fail_dagrun)
+        @task.docker(image="python:3.9-slim", auto_remove="force")
+        def f():
+            pass
+
+        with dag_maker() as dag:
+            f()
+
+        assert len(dag.task_group.children) == 1
+        teardown_task = dag.task_group.children["f"]
+        assert teardown_task._is_teardown
+        assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun

Reply via email to