This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 444016cc59b635cfa937238b3e5ba1e86d6f0f28
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Oct 24 16:25:41 2024 +0100

    fix more stest [skip-ci]
---
 .pre-commit-config.yaml                            |  1 +
 airflow/models/dag.py                              | 25 ++-----
 airflow/models/taskmixin.py                        | 24 +-----
 .../src/airflow/sdk/definitions/baseoperator.py    | 15 ++--
 task_sdk/src/airflow/sdk/definitions/dag.py        | 86 +++++++++++++++-------
 task_sdk/src/airflow/sdk/definitions/node.py       |  7 ++
 task_sdk/tests/defintions/test_dag.py              | 34 +++++++++
 tests/models/test_dag.py                           | 31 +++-----
 8 files changed, 128 insertions(+), 95 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7d9cd6b2292..ad4b2529b86 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1190,6 +1190,7 @@ repos:
           ^providers/src/airflow/providers/ |
           ^(providers/)?tests/ |
           task_sdk/src/airflow/sdk/definitions/dag.py$ |
+          task_sdk/src/airflow/sdk/definitions/node.py$ |
           ^dev/.*\.py$ |
           ^scripts/.*\.py$ |
           ^docker_tests/.*$ |
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index fc0928ddcc1..d133eb43e68 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -298,7 +298,7 @@ if TYPE_CHECKING:
 else:
 
     def dag(dag_id: str = "", **kwargs):
-        return task_sdk_dag_decorator(dag_id, __DAG_class=DAG, 
__warnings_stacklevel_delta=3)
+        return task_sdk_dag_decorator(dag_id, __DAG_class=DAG, 
__warnings_stacklevel_delta=3, **kwargs)
 
 
 @functools.total_ordering
@@ -419,15 +419,14 @@ class DAG(TaskSDKDag, LoggingMixin):
 
     partial: bool = False
     last_loaded: datetime | None = attrs.field(factory=timezone.utcnow)
-    on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
-    on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
-
-    has_on_success_callback: bool = attrs.field(init=False)
-    has_on_failure_callback: bool = attrs.field(init=False)
 
     default_view: str = airflow_conf.get_mandatory_value("webserver", 
"dag_default_view").lower()
     orientation: str = airflow_conf.get_mandatory_value("webserver", 
"dag_orientation")
 
+    # this will only be set at serialization time
+    # it's only use is for determining the relative fileloc based only on the 
serialize dag
+    _processor_dags_folder: str | None = attrs.field(init=False, default=None)
+
     # Override the default from parent class to use config
     max_consecutive_failed_dag_runs: int = attrs.field()
 
@@ -435,17 +434,9 @@ class DAG(TaskSDKDag, LoggingMixin):
     def _max_consecutive_failed_dag_runs_default(self):
         return airflow_conf.getint("core", 
"max_consecutive_failed_dag_runs_per_dag")
 
-    # this will only be set at serialization time
-    # it's only use is for determining the relative fileloc based only on the 
serialize dag
-    _processor_dags_folder: str | None = attrs.field(init=False, default=None)
-
-    @has_on_success_callback.default
-    def _has_on_success_callback(self) -> bool:
-        return self.on_success_callback is not None
-
-    @has_on_failure_callback.default
-    def _has_on_failure_callback(self) -> bool:
-        return self.on_failure_callback is not None
+    def validate(self):
+        super().validate()
+        self.validate_executor_field()
 
     def validate_executor_field(self):
         for task in self.tasks:
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index c3ab50e2e0b..fa76a3815cb 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -16,33 +16,13 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Any, Iterable
+from typing import TYPE_CHECKING
 
 if TYPE_CHECKING:
-    from airflow.models.operator import Operator
-    from airflow.serialization.enums import DagAttributeTypes
     from airflow.typing_compat import TypeAlias
 
 import airflow.sdk.definitions.mixins
 import airflow.sdk.definitions.node
 
 DependencyMixin: TypeAlias = airflow.sdk.definitions.mixins.DependencyMixin
-
-
-class DAGNode(airflow.sdk.definitions.node.DAGNode):
-    """
-    A base class for a node in the graph of a workflow.
-
-    A node may be an Operator or a Task Group, either mapped or unmapped.
-    """
-
-    def get_direct_relatives(self, upstream: bool = False) -> 
Iterable[Operator]:
-        """Get list of the direct relatives to the current task, upstream or 
downstream."""
-        if upstream:
-            return self.upstream_list
-        else:
-            return self.downstream_list
-
-    def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
-        """Serialize a task group's content; used by TaskGroupSerialization."""
-        raise NotImplementedError()
+DAGNode: TypeAlias = airflow.sdk.definitions.node.DAGNode
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py 
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index d8e07c44b71..5796a6bbc1f 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -57,6 +57,7 @@ from airflow.task.priority_strategy import (
     validate_and_load_priority_weight_strategy,
 )
 from airflow.utils import timezone
+from airflow.utils.setup_teardown import SetupTeardownContext
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import AttributeRemoved
 
@@ -854,11 +855,11 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             )
             self.template_fields = [self.template_fields]
 
-        self._is_setup = False
-        self._is_teardown = False
-        # TODO: Task-SDK
-        # if SetupTeardownContext.active:
-        #     SetupTeardownContext.update_context_map(self)
+        self.is_setup = False
+        self.is_teardown = False
+
+        if SetupTeardownContext.active:
+            SetupTeardownContext.update_context_map(self)
 
         validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
 
@@ -944,7 +945,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         result = cls.__new__(cls)
         memo[id(self)] = result
 
-        shallow_copy = cls.shallow_copy_attrs + 
cls._base_operator_shallow_copy_attrs
+        shallow_copy = tuple(cls.shallow_copy_attrs) + 
cls._base_operator_shallow_copy_attrs
 
         for k, v in self.__dict__.items():
             if k not in shallow_copy:
@@ -1173,8 +1174,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                     "_BaseOperator__instantiated",
                     "_BaseOperator__init_kwargs",
                     "_BaseOperator__from_mapped",
-                    "_is_setup",
-                    "_is_teardown",
                     "_on_failure_fail_dagrun",
                 }
                 | {  # Class level defaults need to be added to this list
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index e784639f8c9..111f51ce855 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -25,7 +25,7 @@ import os
 import sys
 import weakref
 from collections import abc
-from collections.abc import Collection, Iterable, Iterator, MutableSet
+from collections.abc import Collection, Iterable, MutableSet
 from datetime import datetime, timedelta
 from inspect import signature
 from re import Pattern
@@ -177,6 +177,20 @@ def _convert_access_control(value, self_: DAG):
         return value
 
 
+def _convert_doc_md(doc_md: str | None) -> str | None:
+    if doc_md is None:
+        return doc_md
+
+    if doc_md.endswith(".md"):
+        try:
+            with open(doc_md) as fh:
+                return fh.read()
+        except FileNotFoundError:
+            return doc_md
+
+    return doc_md
+
+
 def _all_after_dag_id_to_kw_only(cls, fields: list[attrs.Attribute]):
     i = iter(fields)
     f = next(i)
@@ -351,16 +365,17 @@ class DAG:
     )
     # sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None
     catchup: bool = attrs.field(default=True, converter=bool)
-    # on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
-    # on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
-    doc_md: str | None = None
+    on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
+    on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
+    doc_md: str | None = attrs.field(default=None, converter=_convert_doc_md)
     params: ParamsDict = attrs.field(
         # mypy doesn't really like passing the Converter object
         default=None,
         converter=attrs.Converter(_convert_params, takes_self=True),  # type: 
ignore[misc, call-overload]
     )
     access_control: dict[str, dict[str, Collection[str]]] | dict[str, 
Collection[str]] | None = attrs.field(
-        default=None, converter=attrs.Converter(_convert_access_control, 
takes_self=True)
+        default=None,
+        converter=attrs.Converter(_convert_access_control, takes_self=True),  
# type: ignore[misc, call-overload]
     )
     is_paused_upon_creation: bool | None = None
     jinja_environment_kwargs: dict | None = None
@@ -368,7 +383,7 @@ class DAG:
     tags: MutableSet[str] = attrs.field(factory=set, converter=_convert_tags)
     owner_links: dict[str, str] = attrs.field(factory=dict)
     auto_register: bool = attrs.field(default=True, converter=bool)
-    fail_stop: bool = attrs.field(default=True, converter=bool)
+    fail_stop: bool = attrs.field(default=False, converter=bool)
     dag_display_name: str = 
attrs.field(validator=attrs.validators.instance_of(str))
 
     task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False)
@@ -380,6 +395,9 @@ class DAG:
 
     edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, 
factory=dict)
 
+    has_on_success_callback: bool = attrs.field(init=False)
+    has_on_failure_callback: bool = attrs.field(init=False)
+
     def __attrs_post_init__(self):
         from airflow.utils import timezone
 
@@ -481,6 +499,23 @@ class DAG:
         if tags and any(len(tag) > TAG_MAX_LEN for tag in tags):
             raise ValueError(f"tag cannot be longer than {TAG_MAX_LEN} 
characters")
 
+    @max_active_runs.validator
+    def _validate_max_active_runs(self, _, max_active_runs):
+        if self.timetable.active_runs_limit is not None:
+            if self.timetable.active_runs_limit < self.max_active_runs:
+                raise ValueError(
+                    f"Invalid max_active_runs: {type(self.timetable).__name__} 
"
+                    f"requires max_active_runs <= 
{self.timetable.active_runs_limit}"
+                )
+
+    @has_on_success_callback.default
+    def _has_on_success_callback(self) -> bool:
+        return self.on_success_callback is not None
+
+    @has_on_failure_callback.default
+    def _has_on_failure_callback(self) -> bool:
+        return self.on_failure_callback is not None
+
     def __repr__(self):
         return f"<DAG: {self.dag_id}>"
 
@@ -522,18 +557,6 @@ class DAG:
 
         _ = DagContext.pop()
 
-    def get_doc_md(self, doc_md: str | None) -> str | None:
-        if doc_md is None:
-            return doc_md
-
-        if doc_md.endswith(".md"):
-            try:
-                return open(doc_md).read()
-            except FileNotFoundError:
-                return doc_md
-
-        return doc_md
-
     def validate(self):
         """
         Validate the DAG has a coherent setup.
@@ -543,6 +566,10 @@ class DAG:
         self.timetable.validate()
         self.validate_setup_teardown()
 
+        # We validate owner links on set, but since it's a dict it could be 
mutated without calling the
+        # setter. Validate again here
+        self._validate_owner_links(None, self.owner_links)
+
     def validate_setup_teardown(self):
         """
         Validate that setup and teardown tasks are configured properly.
@@ -966,20 +993,23 @@ class DAG:
         """
         self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = 
info
 
-    def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]:
-        """
-        Parse a given link, and verifies if it's a valid URL, or a 'mailto' 
link.
+    @owner_links.validator
+    def _validate_owner_links(self, _, owner_links):
+        wrong_links = {}
 
-        Returns an iterator of invalid (owner, link) pairs.
-        """
-        for owner, link in self.owner_links.items():
+        for owner, link in owner_links.items():
             result = urlsplit(link)
             if result.scheme == "mailto":
                 # netloc is not existing for 'mailto' link, so we are checking 
that the path is parsed
                 if not result.path:
-                    yield result.path, link
+                    wrong_links[result.path] = link
             elif not result.scheme or not result.netloc:
-                yield owner, link
+                wrong_links[owner] = link
+        if wrong_links:
+            raise ValueError(
+                "Wrong link format was used for the owner. Use a valid link \n"
+                f"Bad formatted links are: {wrong_links}"
+            )
 
 
 if TYPE_CHECKING:
@@ -1003,8 +1033,8 @@ if TYPE_CHECKING:
         dagrun_timeout: timedelta | None = None,
         # sla_miss_callback: Any = None,
         catchup: bool = ...,
-        # on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
-        # on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
+        on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
+        on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
         doc_md: str | None = None,
         params: ParamsDict | None = None,
         access_control: dict[str, dict[str, Collection[str]]] | dict[str, 
Collection[str]] | None = None,
diff --git a/task_sdk/src/airflow/sdk/definitions/node.py 
b/task_sdk/src/airflow/sdk/definitions/node.py
index 7b877bbaf48..b3b519a07f0 100644
--- a/task_sdk/src/airflow/sdk/definitions/node.py
+++ b/task_sdk/src/airflow/sdk/definitions/node.py
@@ -211,6 +211,13 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
         else:
             return self.downstream_task_ids
 
+    def get_direct_relatives(self, upstream: bool = False) -> 
Iterable[Operator]:
+        """Get list of the direct relatives to the current task, upstream or 
downstream."""
+        if upstream:
+            return self.upstream_list
+        else:
+            return self.downstream_list
+
     def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
         """Serialize a task group's content; used by TaskGroupSerialization."""
         raise NotImplementedError()
diff --git a/task_sdk/tests/defintions/test_dag.py 
b/task_sdk/tests/defintions/test_dag.py
index c3b3bbdce4d..b2481a49b6a 100644
--- a/task_sdk/tests/defintions/test_dag.py
+++ b/task_sdk/tests/defintions/test_dag.py
@@ -248,6 +248,39 @@ class TestDag:
         # Make sure we don't affect the original!
         assert task.task_group.upstream_group_ids is not 
copied_task.task_group.upstream_group_ids
 
+    def test_dag_owner_links(self):
+        dag = DAG(
+            "dag",
+            schedule=None,
+            start_date=DEFAULT_DATE,
+            owner_links={"owner1": "https://mylink.com";, "owner2": 
"mailto:[email protected]"},
+        )
+
+        assert dag.owner_links == {"owner1": "https://mylink.com";, "owner2": 
"mailto:[email protected]"}
+
+        # Check wrong formatted owner link
+        with pytest.raises(ValueError, match="Wrong link format"):
+            DAG("dag", schedule=None, start_date=DEFAULT_DATE, 
owner_links={"owner1": "my-bad-link"})
+
+        dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE)
+        dag.owner_links["owner1"] = "my-bad-link"
+        with pytest.raises(ValueError, match="Wrong link format"):
+            dag.validate()
+
+    def test_continuous_schedule_linmits_max_active_runs(self):
+        from airflow.timetables.simple import ContinuousTimetable
+
+        dag = DAG("continuous", start_date=DEFAULT_DATE, 
schedule="@continuous", max_active_runs=1)
+        assert isinstance(dag.timetable, ContinuousTimetable)
+        assert dag.max_active_runs == 1
+
+        dag = DAG("continuous", start_date=DEFAULT_DATE, 
schedule="@continuous", max_active_runs=0)
+        assert isinstance(dag.timetable, ContinuousTimetable)
+        assert dag.max_active_runs == 0
+
+        with pytest.raises(ValueError, match="ContinuousTimetable requires 
max_active_runs <= 1"):
+            dag = DAG("continuous", start_date=DEFAULT_DATE, 
schedule="@continuous", max_active_runs=25)
+
 
 # Test some of the arg valiadtion. This is not all the validations we perform, 
just some of them.
 @pytest.mark.parametrize(
@@ -255,6 +288,7 @@ class TestDag:
     [
         pytest.param("max_consecutive_failed_dag_runs", "not_an_int", 
id="max_consecutive_failed_dag_runs"),
         pytest.param("dagrun_timeout", "not_an_int", id="dagrun_timeout"),
+        pytest.param("max_active_runs", "not_an_int", id="max_active_runs"),
     ],
 )
 def test_invalid_type_for_args(attr: str, value: Any):
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index c126e2d3039..4c1d8a67960 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -76,7 +76,6 @@ from airflow.templates import NativeEnvironment, 
SandboxedEnvironment
 from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, 
Timetable
 from airflow.timetables.simple import (
     AssetTriggeredTimetable,
-    ContinuousTimetable,
     NullTimetable,
     OnceTimetable,
 )
@@ -1683,7 +1682,15 @@ class TestDag:
         with dag:
             check_task_2(check_task())
 
-        dag.test()
+        dr = dag.test()
+
+        ti1 = dr.get_task_instance("check_task")
+        ti2 = dr.get_task_instance("check_task_2")
+
+        assert ti1
+        assert ti2
+        assert ti1.state == TaskInstanceState.FAILED
+        assert ti2.state == TaskInstanceState.UPSTREAM_FAILED
 
         mock_handle_object_1.assert_called_with("task check_task failed...")
         mock_handle_object_2.assert_called_with("dag 
test_local_testing_conn_file run failed...")
@@ -2121,22 +2128,6 @@ my_postgres_conn:
         orm_dag_owners = session.query(DagOwnerAttributes).all()
         assert not orm_dag_owners
 
-        # Check wrong formatted owner link
-        with pytest.raises(AirflowException):
-            DAG("dag", schedule=None, start_date=DEFAULT_DATE, 
owner_links={"owner1": "my-bad-link"})
-
-    def test_continuous_schedule_linmits_max_active_runs(self):
-        dag = DAG("continuous", start_date=DEFAULT_DATE, 
schedule="@continuous", max_active_runs=1)
-        assert isinstance(dag.timetable, ContinuousTimetable)
-        assert dag.max_active_runs == 1
-
-        dag = DAG("continuous", start_date=DEFAULT_DATE, 
schedule="@continuous", max_active_runs=0)
-        assert isinstance(dag.timetable, ContinuousTimetable)
-        assert dag.max_active_runs == 0
-
-        with pytest.raises(AirflowException):
-            dag = DAG("continuous", start_date=DEFAULT_DATE, 
schedule="@continuous", max_active_runs=25)
-
 
 class TestDagModel:
     def _clean(self):
@@ -2854,7 +2845,7 @@ def test_set_task_group_state(run_id, execution_date, 
session, dag_maker):
     }
 
 
-def test_dag_teardowns_property_lists_all_teardown_tasks(dag_maker):
+def test_dag_teardowns_property_lists_all_teardown_tasks():
     @setup
     def setup_task():
         return 1
@@ -2875,7 +2866,7 @@ def 
test_dag_teardowns_property_lists_all_teardown_tasks(dag_maker):
     def mytask():
         return 1
 
-    with dag_maker() as dag:
+    with DAG("dag") as dag:
         t1 = setup_task()
         t2 = teardown_task()
         t3 = teardown_task2()

Reply via email to