This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 444016cc59b635cfa937238b3e5ba1e86d6f0f28 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Thu Oct 24 16:25:41 2024 +0100 fix more stest [skip-ci] --- .pre-commit-config.yaml | 1 + airflow/models/dag.py | 25 ++----- airflow/models/taskmixin.py | 24 +----- .../src/airflow/sdk/definitions/baseoperator.py | 15 ++-- task_sdk/src/airflow/sdk/definitions/dag.py | 86 +++++++++++++++------- task_sdk/src/airflow/sdk/definitions/node.py | 7 ++ task_sdk/tests/defintions/test_dag.py | 34 +++++++++ tests/models/test_dag.py | 31 +++----- 8 files changed, 128 insertions(+), 95 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7d9cd6b2292..ad4b2529b86 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1190,6 +1190,7 @@ repos: ^providers/src/airflow/providers/ | ^(providers/)?tests/ | task_sdk/src/airflow/sdk/definitions/dag.py$ | + task_sdk/src/airflow/sdk/definitions/node.py$ | ^dev/.*\.py$ | ^scripts/.*\.py$ | ^docker_tests/.*$ | diff --git a/airflow/models/dag.py b/airflow/models/dag.py index fc0928ddcc1..d133eb43e68 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -298,7 +298,7 @@ if TYPE_CHECKING: else: def dag(dag_id: str = "", **kwargs): - return task_sdk_dag_decorator(dag_id, __DAG_class=DAG, __warnings_stacklevel_delta=3) + return task_sdk_dag_decorator(dag_id, __DAG_class=DAG, __warnings_stacklevel_delta=3, **kwargs) @functools.total_ordering @@ -419,15 +419,14 @@ class DAG(TaskSDKDag, LoggingMixin): partial: bool = False last_loaded: datetime | None = attrs.field(factory=timezone.utcnow) - on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None - on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None - - has_on_success_callback: bool = attrs.field(init=False) - has_on_failure_callback: bool = attrs.field(init=False) default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower() orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation") + # this will only be set at serialization time + # it's only use is for determining the relative fileloc based only on the serialize dag + _processor_dags_folder: str | None = attrs.field(init=False, default=None) + # Override the default from parent class to use config max_consecutive_failed_dag_runs: int = attrs.field() @@ -435,17 +434,9 @@ class DAG(TaskSDKDag, LoggingMixin): def _max_consecutive_failed_dag_runs_default(self): return airflow_conf.getint("core", "max_consecutive_failed_dag_runs_per_dag") - # this will only be set at serialization time - # it's only use is for determining the relative fileloc based only on the serialize dag - _processor_dags_folder: str | None = attrs.field(init=False, default=None) - - @has_on_success_callback.default - def _has_on_success_callback(self) -> bool: - return self.on_success_callback is not None - - @has_on_failure_callback.default - def _has_on_failure_callback(self) -> bool: - return self.on_failure_callback is not None + def validate(self): + super().validate() + self.validate_executor_field() def validate_executor_field(self): for task in self.tasks: diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index c3ab50e2e0b..fa76a3815cb 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -16,33 +16,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING if TYPE_CHECKING: - from airflow.models.operator import Operator - from airflow.serialization.enums import DagAttributeTypes from airflow.typing_compat import TypeAlias import airflow.sdk.definitions.mixins import airflow.sdk.definitions.node DependencyMixin: TypeAlias = airflow.sdk.definitions.mixins.DependencyMixin - - -class DAGNode(airflow.sdk.definitions.node.DAGNode): - """ - A base class for a node in the graph of a workflow. - - A node may be an Operator or a Task Group, either mapped or unmapped. - """ - - def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]: - """Get list of the direct relatives to the current task, upstream or downstream.""" - if upstream: - return self.upstream_list - else: - return self.downstream_list - - def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: - """Serialize a task group's content; used by TaskGroupSerialization.""" - raise NotImplementedError() +DAGNode: TypeAlias = airflow.sdk.definitions.node.DAGNode diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index d8e07c44b71..5796a6bbc1f 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -57,6 +57,7 @@ from airflow.task.priority_strategy import ( validate_and_load_priority_weight_strategy, ) from airflow.utils import timezone +from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import AttributeRemoved @@ -854,11 +855,11 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): ) self.template_fields = [self.template_fields] - self._is_setup = False - self._is_teardown = False - # TODO: Task-SDK - # if SetupTeardownContext.active: - # SetupTeardownContext.update_context_map(self) + self.is_setup = False + self.is_teardown = False + + if SetupTeardownContext.active: + SetupTeardownContext.update_context_map(self) validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) @@ -944,7 +945,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): result = cls.__new__(cls) memo[id(self)] = result - shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs + shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs for k, v in self.__dict__.items(): if k not in shallow_copy: @@ -1173,8 +1174,6 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): "_BaseOperator__instantiated", "_BaseOperator__init_kwargs", "_BaseOperator__from_mapped", - "_is_setup", - "_is_teardown", "_on_failure_fail_dagrun", } | { # Class level defaults need to be added to this list diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index e784639f8c9..111f51ce855 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -25,7 +25,7 @@ import os import sys import weakref from collections import abc -from collections.abc import Collection, Iterable, Iterator, MutableSet +from collections.abc import Collection, Iterable, MutableSet from datetime import datetime, timedelta from inspect import signature from re import Pattern @@ -177,6 +177,20 @@ def _convert_access_control(value, self_: DAG): return value +def _convert_doc_md(doc_md: str | None) -> str | None: + if doc_md is None: + return doc_md + + if doc_md.endswith(".md"): + try: + with open(doc_md) as fh: + return fh.read() + except FileNotFoundError: + return doc_md + + return doc_md + + def _all_after_dag_id_to_kw_only(cls, fields: list[attrs.Attribute]): i = iter(fields) f = next(i) @@ -351,16 +365,17 @@ class DAG: ) # sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None catchup: bool = attrs.field(default=True, converter=bool) - # on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None - # on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None - doc_md: str | None = None + on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None + on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None + doc_md: str | None = attrs.field(default=None, converter=_convert_doc_md) params: ParamsDict = attrs.field( # mypy doesn't really like passing the Converter object default=None, converter=attrs.Converter(_convert_params, takes_self=True), # type: ignore[misc, call-overload] ) access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = attrs.field( - default=None, converter=attrs.Converter(_convert_access_control, takes_self=True) + default=None, + converter=attrs.Converter(_convert_access_control, takes_self=True), # type: ignore[misc, call-overload] ) is_paused_upon_creation: bool | None = None jinja_environment_kwargs: dict | None = None @@ -368,7 +383,7 @@ class DAG: tags: MutableSet[str] = attrs.field(factory=set, converter=_convert_tags) owner_links: dict[str, str] = attrs.field(factory=dict) auto_register: bool = attrs.field(default=True, converter=bool) - fail_stop: bool = attrs.field(default=True, converter=bool) + fail_stop: bool = attrs.field(default=False, converter=bool) dag_display_name: str = attrs.field(validator=attrs.validators.instance_of(str)) task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False) @@ -380,6 +395,9 @@ class DAG: edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, factory=dict) + has_on_success_callback: bool = attrs.field(init=False) + has_on_failure_callback: bool = attrs.field(init=False) + def __attrs_post_init__(self): from airflow.utils import timezone @@ -481,6 +499,23 @@ class DAG: if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): raise ValueError(f"tag cannot be longer than {TAG_MAX_LEN} characters") + @max_active_runs.validator + def _validate_max_active_runs(self, _, max_active_runs): + if self.timetable.active_runs_limit is not None: + if self.timetable.active_runs_limit < self.max_active_runs: + raise ValueError( + f"Invalid max_active_runs: {type(self.timetable).__name__} " + f"requires max_active_runs <= {self.timetable.active_runs_limit}" + ) + + @has_on_success_callback.default + def _has_on_success_callback(self) -> bool: + return self.on_success_callback is not None + + @has_on_failure_callback.default + def _has_on_failure_callback(self) -> bool: + return self.on_failure_callback is not None + def __repr__(self): return f"<DAG: {self.dag_id}>" @@ -522,18 +557,6 @@ class DAG: _ = DagContext.pop() - def get_doc_md(self, doc_md: str | None) -> str | None: - if doc_md is None: - return doc_md - - if doc_md.endswith(".md"): - try: - return open(doc_md).read() - except FileNotFoundError: - return doc_md - - return doc_md - def validate(self): """ Validate the DAG has a coherent setup. @@ -543,6 +566,10 @@ class DAG: self.timetable.validate() self.validate_setup_teardown() + # We validate owner links on set, but since it's a dict it could be mutated without calling the + # setter. Validate again here + self._validate_owner_links(None, self.owner_links) + def validate_setup_teardown(self): """ Validate that setup and teardown tasks are configured properly. @@ -966,20 +993,23 @@ class DAG: """ self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info - def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]: - """ - Parse a given link, and verifies if it's a valid URL, or a 'mailto' link. + @owner_links.validator + def _validate_owner_links(self, _, owner_links): + wrong_links = {} - Returns an iterator of invalid (owner, link) pairs. - """ - for owner, link in self.owner_links.items(): + for owner, link in owner_links.items(): result = urlsplit(link) if result.scheme == "mailto": # netloc is not existing for 'mailto' link, so we are checking that the path is parsed if not result.path: - yield result.path, link + wrong_links[result.path] = link elif not result.scheme or not result.netloc: - yield owner, link + wrong_links[owner] = link + if wrong_links: + raise ValueError( + "Wrong link format was used for the owner. Use a valid link \n" + f"Bad formatted links are: {wrong_links}" + ) if TYPE_CHECKING: @@ -1003,8 +1033,8 @@ if TYPE_CHECKING: dagrun_timeout: timedelta | None = None, # sla_miss_callback: Any = None, catchup: bool = ..., - # on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - # on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, + on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, + on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, params: ParamsDict | None = None, access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, diff --git a/task_sdk/src/airflow/sdk/definitions/node.py b/task_sdk/src/airflow/sdk/definitions/node.py index 7b877bbaf48..b3b519a07f0 100644 --- a/task_sdk/src/airflow/sdk/definitions/node.py +++ b/task_sdk/src/airflow/sdk/definitions/node.py @@ -211,6 +211,13 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta): else: return self.downstream_task_ids + def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]: + """Get list of the direct relatives to the current task, upstream or downstream.""" + if upstream: + return self.upstream_list + else: + return self.downstream_list + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Serialize a task group's content; used by TaskGroupSerialization.""" raise NotImplementedError() diff --git a/task_sdk/tests/defintions/test_dag.py b/task_sdk/tests/defintions/test_dag.py index c3b3bbdce4d..b2481a49b6a 100644 --- a/task_sdk/tests/defintions/test_dag.py +++ b/task_sdk/tests/defintions/test_dag.py @@ -248,6 +248,39 @@ class TestDag: # Make sure we don't affect the original! assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids + def test_dag_owner_links(self): + dag = DAG( + "dag", + schedule=None, + start_date=DEFAULT_DATE, + owner_links={"owner1": "https://mylink.com", "owner2": "mailto:[email protected]"}, + ) + + assert dag.owner_links == {"owner1": "https://mylink.com", "owner2": "mailto:[email protected]"} + + # Check wrong formatted owner link + with pytest.raises(ValueError, match="Wrong link format"): + DAG("dag", schedule=None, start_date=DEFAULT_DATE, owner_links={"owner1": "my-bad-link"}) + + dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE) + dag.owner_links["owner1"] = "my-bad-link" + with pytest.raises(ValueError, match="Wrong link format"): + dag.validate() + + def test_continuous_schedule_linmits_max_active_runs(self): + from airflow.timetables.simple import ContinuousTimetable + + dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=1) + assert isinstance(dag.timetable, ContinuousTimetable) + assert dag.max_active_runs == 1 + + dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=0) + assert isinstance(dag.timetable, ContinuousTimetable) + assert dag.max_active_runs == 0 + + with pytest.raises(ValueError, match="ContinuousTimetable requires max_active_runs <= 1"): + dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=25) + # Test some of the arg valiadtion. This is not all the validations we perform, just some of them. @pytest.mark.parametrize( @@ -255,6 +288,7 @@ class TestDag: [ pytest.param("max_consecutive_failed_dag_runs", "not_an_int", id="max_consecutive_failed_dag_runs"), pytest.param("dagrun_timeout", "not_an_int", id="dagrun_timeout"), + pytest.param("max_active_runs", "not_an_int", id="max_active_runs"), ], ) def test_invalid_type_for_args(attr: str, value: Any): diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index c126e2d3039..4c1d8a67960 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -76,7 +76,6 @@ from airflow.templates import NativeEnvironment, SandboxedEnvironment from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( AssetTriggeredTimetable, - ContinuousTimetable, NullTimetable, OnceTimetable, ) @@ -1683,7 +1682,15 @@ class TestDag: with dag: check_task_2(check_task()) - dag.test() + dr = dag.test() + + ti1 = dr.get_task_instance("check_task") + ti2 = dr.get_task_instance("check_task_2") + + assert ti1 + assert ti2 + assert ti1.state == TaskInstanceState.FAILED + assert ti2.state == TaskInstanceState.UPSTREAM_FAILED mock_handle_object_1.assert_called_with("task check_task failed...") mock_handle_object_2.assert_called_with("dag test_local_testing_conn_file run failed...") @@ -2121,22 +2128,6 @@ my_postgres_conn: orm_dag_owners = session.query(DagOwnerAttributes).all() assert not orm_dag_owners - # Check wrong formatted owner link - with pytest.raises(AirflowException): - DAG("dag", schedule=None, start_date=DEFAULT_DATE, owner_links={"owner1": "my-bad-link"}) - - def test_continuous_schedule_linmits_max_active_runs(self): - dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=1) - assert isinstance(dag.timetable, ContinuousTimetable) - assert dag.max_active_runs == 1 - - dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=0) - assert isinstance(dag.timetable, ContinuousTimetable) - assert dag.max_active_runs == 0 - - with pytest.raises(AirflowException): - dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=25) - class TestDagModel: def _clean(self): @@ -2854,7 +2845,7 @@ def test_set_task_group_state(run_id, execution_date, session, dag_maker): } -def test_dag_teardowns_property_lists_all_teardown_tasks(dag_maker): +def test_dag_teardowns_property_lists_all_teardown_tasks(): @setup def setup_task(): return 1 @@ -2875,7 +2866,7 @@ def test_dag_teardowns_property_lists_all_teardown_tasks(dag_maker): def mytask(): return 1 - with dag_maker() as dag: + with DAG("dag") as dag: t1 = setup_task() t2 = teardown_task() t3 = teardown_task2()
