This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c76555930a Refactor setup/teardown decorator so they return
_TaskDecorator (#30342)
c76555930a is described below
commit c76555930aee9692d2a839b9c7b9e2220717b8a0
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Mar 28 19:15:07 2023 +0100
Refactor setup/teardown decorator so they return _TaskDecorator (#30342)
* Refactor setup/teardown decorator so they return _TaskDecorator
Prior to this, the setup/teardown decorators do return ordinary functions.
Returning _TaskDecorator like @task will be useful for future works
* fixup! Refactor setup/teardown decorator so they return _TaskDecorator
* Fix tests
* Update tests/decorators/test_setup_teardown.py
Co-authored-by: Daniel Standish
<[email protected]>
* Fix on_failure_fail_dagrun default
---------
Co-authored-by: Daniel Standish
<[email protected]>
Co-authored-by: Jed Cunningham <[email protected]>
---
airflow/decorators/base.py | 12 ++-
airflow/decorators/setup_teardown.py | 41 ++++------
airflow/models/baseoperator.py | 23 ++----
airflow/utils/setup_teardown.py | 55 -------------
airflow/utils/task_group.py | 5 --
tests/decorators/test_setup_teardown.py | 92 +++++++++++++++++----
tests/serialization/test_dag_serialization.py | 1 +
tests/utils/test_setup_teardown.py | 113 --------------------------
8 files changed, 112 insertions(+), 230 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index bc83c7786b..3ece688dd9 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -301,6 +301,9 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
decorator_name: str = attr.ib(repr=False, default="task")
_airflow_is_task_decorator: ClassVar[bool] = True
+ _is_setup: ClassVar[bool] = False
+ _is_teardown: ClassVar[bool] = False
+ _on_failure_fail_dagrun: ClassVar[bool] = False
@multiple_outputs.default
def _infer_multiple_outputs(self):
@@ -341,6 +344,9 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
multiple_outputs=self.multiple_outputs,
**self.kwargs,
)
+ op._is_setup = self._is_setup
+ op._is_teardown = self._is_teardown
+ op._on_failure_fail_dagrun = self._on_failure_fail_dagrun
op_doc_attrs = [op.doc, op.doc_json, op.doc_md, op.doc_rst,
op.doc_yaml]
# Set the task's doc_md to the function's docstring if it exists and
no other doc* args are set.
if self.function.__doc__ and not any(op_doc_attrs):
@@ -473,7 +479,11 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
def override(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn,
OperatorSubclass]:
- return attr.evolve(self, kwargs={**self.kwargs, **kwargs})
+ result = attr.evolve(self, kwargs={**self.kwargs, **kwargs})
+ setattr(result, "_is_setup", self._is_setup)
+ setattr(result, "_is_teardown", self._is_teardown)
+ setattr(result, "_on_failure_fail_dagrun",
self._on_failure_fail_dagrun)
+ return result
@attr.define(kw_only=True, repr=False)
diff --git a/airflow/decorators/setup_teardown.py
b/airflow/decorators/setup_teardown.py
index 8480bde636..8fccdb9fe5 100644
--- a/airflow/decorators/setup_teardown.py
+++ b/airflow/decorators/setup_teardown.py
@@ -16,39 +16,34 @@
# under the License.
from __future__ import annotations
-import functools
import types
from typing import Callable
+from airflow import AirflowException
from airflow.decorators import python_task
-from airflow.utils.setup_teardown import SetupTeardownContext
+from airflow.decorators.task_group import _TaskGroupFactory
-def setup_task(python_callable: Callable) -> Callable:
+def setup_task(func: Callable) -> Callable:
# Using FunctionType here since _TaskDecorator is also a callable
- if isinstance(python_callable, types.FunctionType):
- python_callable = python_task(python_callable)
+ if isinstance(func, types.FunctionType):
+ func = python_task(func)
+ if isinstance(func, _TaskGroupFactory):
+ raise AirflowException("Task groups cannot be marked as setup or
teardown.")
+ func._is_setup = True # type: ignore[attr-defined]
+ return func
- @functools.wraps(python_callable)
- def wrapper(*args, **kwargs):
- with SetupTeardownContext.setup():
- return python_callable(*args, **kwargs)
- return wrapper
-
-
-def teardown_task(_func=None, *, on_failure_fail_dagrun: bool | None = None)
-> Callable:
- def teardown(python_callable: Callable) -> Callable:
+def teardown_task(_func=None, *, on_failure_fail_dagrun: bool = False) ->
Callable:
+ def teardown(func: Callable) -> Callable:
# Using FunctionType here since _TaskDecorator is also a callable
- if isinstance(python_callable, types.FunctionType):
- python_callable = python_task(python_callable)
-
- @functools.wraps(python_callable)
- def wrapper(*args, **kwargs) -> Callable:
- with
SetupTeardownContext.teardown(on_failure_fail_dagrun=on_failure_fail_dagrun):
- return python_callable(*args, **kwargs)
-
- return wrapper
+ if isinstance(func, types.FunctionType):
+ func = python_task(func)
+ if isinstance(func, _TaskGroupFactory):
+ raise AirflowException("Task groups cannot be marked as setup or
teardown.")
+ func._is_teardown = True # type: ignore[attr-defined]
+ func._on_failure_fail_dagrun = on_failure_fail_dagrun # type:
ignore[attr-defined]
+ return func
if _func is None:
return teardown
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index d78152d990..557dc634d8 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -88,7 +88,6 @@ from airflow.utils.decorators import
fixup_decorator_warning_stack
from airflow.utils.helpers import validate_key
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
from airflow.utils.xcom import XCOM_RETURN_KEY
@@ -920,27 +919,19 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
)
self.template_fields = [self.template_fields]
- if SetupTeardownContext.is_setup:
- self._is_setup = True
- elif SetupTeardownContext.is_teardown:
- self._is_teardown = True
- if SetupTeardownContext.on_failure_fail_dagrun:
- self._on_failure_fail_dagrun = True
-
@classmethod
def as_setup(cls, *args, **kwargs):
- from airflow.utils.setup_teardown import SetupTeardownContext
-
- with SetupTeardownContext.setup():
- return cls(*args, **kwargs)
+ op = cls(*args, **kwargs)
+ op._is_setup = True
+ return op
@classmethod
def as_teardown(cls, *args, **kwargs):
- from airflow.utils.setup_teardown import SetupTeardownContext
-
on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
- with
SetupTeardownContext.teardown(on_failure_fail_dagrun=on_failure_fail_dagrun):
- return cls(*args, **kwargs)
+ op = cls(*args, **kwargs)
+ op._is_teardown = True
+ op._on_failure_fail_dagrun = on_failure_fail_dagrun
+ return op
def __eq__(self, other):
if type(self) is type(other):
diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py
deleted file mode 100644
index e421678e46..0000000000
--- a/airflow/utils/setup_teardown.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from contextlib import contextmanager
-
-from airflow.exceptions import AirflowException
-
-
-class SetupTeardownContext:
- """Track whether the next added task is a setup or teardown task"""
-
- is_setup: bool = False
- is_teardown: bool = False
- on_failure_fail_dagrun: bool = False
-
- @classmethod
- @contextmanager
- def setup(cls):
- if cls.is_setup or cls.is_teardown:
- raise AirflowException("You cannot mark a setup or teardown task
as setup or teardown again.")
-
- cls.is_setup = True
- try:
- yield
- finally:
- cls.is_setup = False
-
- @classmethod
- @contextmanager
- def teardown(cls, *, on_failure_fail_dagrun=False):
- if cls.is_setup or cls.is_teardown:
- raise AirflowException("You cannot mark a setup or teardown task
as setup or teardown again.")
-
- cls.is_teardown = True
- cls.on_failure_fail_dagrun = on_failure_fail_dagrun
- try:
- yield
- finally:
- cls.is_teardown = False
- cls.on_failure_fail_dagrun = False
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index b0611305d4..9b94a55952 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -38,7 +38,6 @@ from airflow.exceptions import (
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils.helpers import validate_group_key
-from airflow.utils.setup_teardown import SetupTeardownContext
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -158,10 +157,6 @@ class TaskGroup(DAGNode):
self.upstream_task_ids = set()
self.downstream_task_ids = set()
- if SetupTeardownContext.is_setup or SetupTeardownContext.is_teardown:
- # TODO: This might not be the ideal place to check this.
- raise AirflowException("Task groups cannot be marked as setup or
teardown.")
-
def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
if self._group_id is None:
return
diff --git a/tests/decorators/test_setup_teardown.py
b/tests/decorators/test_setup_teardown.py
index 38793c863b..151ba1d0d8 100644
--- a/tests/decorators/test_setup_teardown.py
+++ b/tests/decorators/test_setup_teardown.py
@@ -92,37 +92,40 @@ class TestSetupTearDownTask:
assert teardown_task._is_teardown
def test_setup_taskgroup_decorator(self, dag_maker):
- @setup
- @task_group
- def mygroup():
- @task
- def mytask():
- print("I am a setup task")
-
- mytask()
-
with dag_maker():
with pytest.raises(
expected_exception=AirflowException,
match="Task groups cannot be marked as setup or teardown.",
):
+
+ @setup
+ @task_group
+ def mygroup():
+ @task
+ def mytask():
+ print("I am a setup task")
+
+ mytask()
+
mygroup()
def test_teardown_taskgroup_decorator(self, dag_maker):
- @teardown
- @task_group
- def mygroup():
- @task
- def mytask():
- print("I am a teardown task")
-
- mytask()
with dag_maker():
with pytest.raises(
expected_exception=AirflowException,
match="Task groups cannot be marked as setup or teardown.",
):
+
+ @teardown
+ @task_group
+ def mygroup():
+ @task
+ def mytask():
+ print("I am a teardown task")
+
+ mytask()
+
mygroup()
@pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
@@ -153,3 +156,58 @@ class TestSetupTearDownTask:
assert teardown_task._is_teardown
assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
assert len(dag.task_group.children) == 1
+
+ def test_setup_task_can_be_overriden(self, dag_maker):
+ @setup
+ def mytask():
+ print("I am a setup task")
+
+ with dag_maker() as dag:
+ mytask.override(task_id="mytask2")()
+ assert len(dag.task_group.children) == 1
+ setup_task = dag.task_group.children["mytask2"]
+ assert setup_task._is_setup
+
+ def test_setup_teardown_mixed_up_in_a_dag(self, dag_maker):
+ @setup
+ def setuptask():
+ print("setup")
+
+ @setup
+ def setuptask2():
+ print("setup")
+
+ @teardown
+ def teardowntask():
+ print("teardown")
+
+ @teardown
+ def teardowntask2():
+ print("teardown")
+
+ @task()
+ def mytask():
+ print("mytask")
+
+ @task()
+ def mytask2():
+ print("mytask")
+
+ with dag_maker() as dag:
+ setuptask()
+ teardowntask()
+ setuptask2()
+ teardowntask2()
+ mytask()
+ mytask2()
+
+ assert len(dag.task_group.children) == 6
+ assert [x for x in dag.tasks if not x.downstream_list] # no deps have
been set
+ assert dag.task_group.children["setuptask"]._is_setup
+ assert dag.task_group.children["teardowntask"]._is_teardown
+ assert dag.task_group.children["setuptask2"]._is_setup
+ assert dag.task_group.children["teardowntask2"]._is_teardown
+ assert dag.task_group.children["mytask"]._is_setup is False
+ assert dag.task_group.children["mytask"]._is_teardown is False
+ assert dag.task_group.children["mytask2"]._is_setup is False
+ assert dag.task_group.children["mytask2"]._is_teardown is False
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 96cf19aa25..e514caf623 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -570,6 +570,7 @@ class TestStringifiedDAGs:
"on_retry_callback",
# Checked separately
"resources",
+ "_on_failure_fail_dagrun",
}
else: # Promised to be mapped by the assert above.
assert isinstance(serialized_task, MappedOperator)
diff --git a/tests/utils/test_setup_teardown.py
b/tests/utils/test_setup_teardown.py
deleted file mode 100644
index 5907994bbc..0000000000
--- a/tests/utils/test_setup_teardown.py
+++ /dev/null
@@ -1,113 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.exceptions import AirflowException
-from airflow.utils.setup_teardown import SetupTeardownContext
-
-
-class TestSetupTearDownContext:
- def test_setup(self):
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False
-
- with SetupTeardownContext.setup():
- assert SetupTeardownContext.is_setup is True
- assert SetupTeardownContext.is_teardown is False
-
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False
-
- def test_teardown(self):
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False
-
- with SetupTeardownContext.setup():
- assert SetupTeardownContext.is_setup is True
- assert SetupTeardownContext.is_teardown is False
-
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False
-
- def test_setup_exception(self):
- """Ensure context is reset even if an exception happens"""
- with pytest.raises(Exception, match="Hello"):
- with SetupTeardownContext.setup():
- raise Exception("Hello")
-
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False
-
- def test_teardown_exception(self):
- """Ensure context is reset even if an exception happens"""
- with pytest.raises(Exception, match="Hello"):
- with SetupTeardownContext.teardown():
- raise Exception("Hello")
-
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False
-
- def test_setup_block_nested(self):
- with SetupTeardownContext.setup():
- with pytest.raises(
- AirflowException,
- match=("You cannot mark a setup or teardown task as setup or
teardown again."),
- ):
- with SetupTeardownContext.setup():
- raise Exception("This should not be reached")
-
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False
-
- def test_teardown_block_nested(self):
- with SetupTeardownContext.teardown():
- with pytest.raises(
- AirflowException,
- match=("You cannot mark a setup or teardown task as setup or
teardown again."),
- ):
- with SetupTeardownContext.teardown():
- raise Exception("This should not be reached")
-
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False
-
- def test_teardown_nested_in_setup_blocked(self):
- with SetupTeardownContext.setup():
- with pytest.raises(
- AirflowException,
- match=("You cannot mark a setup or teardown task as setup or
teardown again."),
- ):
- with SetupTeardownContext.teardown():
- raise Exception("This should not be reached")
-
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False
-
- def test_setup_nested_in_teardown_blocked(self):
- with SetupTeardownContext.teardown():
- with pytest.raises(
- AirflowException,
- match=("You cannot mark a setup or teardown task as setup or
teardown again."),
- ):
- with SetupTeardownContext.setup():
- raise Exception("This should not be reached")
-
- assert SetupTeardownContext.is_setup is False
- assert SetupTeardownContext.is_teardown is False