This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d1d891a178e8d6b9993190c12d416962e616255a
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Sat Oct 19 12:46:54 2024 +0100

    [skip ci]
---
 airflow/models/abstractoperator.py                 | 118 -------------
 airflow/models/baseoperator.py                     |  41 -----
 airflow/models/dag.py                              |  10 +-
 airflow/serialization/schema.json                  |  10 +-
 airflow/serialization/serialized_objects.py        |  22 +--
 airflow/utils/decorators.py                        |   9 +-
 .../airflow/sdk/definitions/abstractoperator.py    | 106 +++++++++++
 .../src/airflow/sdk/definitions/baseoperator.py    |  18 ++
 task_sdk/src/airflow/sdk/definitions/dag.py        | 196 +++++++++------------
 task_sdk/tests/defintions/test_dag.py              |  72 +++++++-
 tests/models/test_dag.py                           |  67 -------
 11 files changed, 314 insertions(+), 355 deletions(-)

diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index a29ef09f270..a27d7e26fd1 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -106,22 +106,6 @@ class AbstractOperator(Templater, TaskSDKAbstractOperator):
     trigger_rule: TriggerRule
     weight_rule: PriorityWeightStrategy
 
-    @property
-    def is_setup(self) -> bool:
-        raise NotImplementedError()
-
-    @is_setup.setter
-    def is_setup(self, value: bool) -> None:
-        raise NotImplementedError()
-
-    @property
-    def is_teardown(self) -> bool:
-        raise NotImplementedError()
-
-    @is_teardown.setter
-    def is_teardown(self, value: bool) -> None:
-        raise NotImplementedError()
-
     @property
     def on_failure_fail_dagrun(self):
         """
@@ -211,108 +195,6 @@ class AbstractOperator(Templater, 
TaskSDKAbstractOperator):
             else:
                 setattr(parent, attr_name, rendered_content)
 
-    def as_setup(self):
-        self.is_setup = True
-        return self
-
-    def as_teardown(
-        self,
-        *,
-        setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
-        on_failure_fail_dagrun=NOTSET,
-    ):
-        self.is_teardown = True
-        self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
-        if on_failure_fail_dagrun is not NOTSET:
-            self.on_failure_fail_dagrun = on_failure_fail_dagrun
-        if not isinstance(setups, ArgNotSet):
-            setups = [setups] if isinstance(setups, DependencyMixin) else 
setups
-            for s in setups:
-                s.is_setup = True
-                s >> self
-        return self
-
-    def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]:
-        """
-        Get a flat set of relative IDs, upstream or downstream.
-
-        Will recurse each relative found in the direction specified.
-
-        :param upstream: Whether to look for upstream or downstream relatives.
-        """
-        dag = self.get_dag()
-        if not dag:
-            return set()
-
-        relatives: set[str] = set()
-
-        # This is intentionally implemented as a loop, instead of calling
-        # get_direct_relative_ids() recursively, since Python has significant
-        # limitation on stack level, and a recursive implementation can blow up
-        # if a DAG contains very long routes.
-        task_ids_to_trace = self.get_direct_relative_ids(upstream)
-        while task_ids_to_trace:
-            task_ids_to_trace_next: set[str] = set()
-            for task_id in task_ids_to_trace:
-                if task_id in relatives:
-                    continue
-                
task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
-                relatives.add(task_id)
-            task_ids_to_trace = task_ids_to_trace_next
-
-        return relatives
-
-    def get_flat_relatives(self, upstream: bool = False) -> 
Collection[Operator]:
-        """Get a flat list of relatives, either upstream or downstream."""
-        dag = self.get_dag()
-        if not dag:
-            return set()
-        return [dag.task_dict[task_id] for task_id in 
self.get_flat_relative_ids(upstream=upstream)]
-
-    def get_upstreams_follow_setups(self) -> Iterable[Operator]:
-        """All upstreams and, for each upstream setup, its respective 
teardowns."""
-        for task in self.get_flat_relatives(upstream=True):
-            yield task
-            if task.is_setup:
-                for t in task.downstream_list:
-                    if t.is_teardown and t != self:
-                        yield t
-
-    def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]:
-        """
-        Only *relevant* upstream setups and their teardowns.
-
-        This method is meant to be used when we are clearing the task 
(non-upstream) and we need
-        to add in the *relevant* setups and their teardowns.
-
-        Relevant in this case means, the setup has a teardown that is 
downstream of ``self``,
-        or the setup has no teardowns.
-        """
-        downstream_teardown_ids = {
-            x.task_id for x in self.get_flat_relatives(upstream=False) if 
x.is_teardown
-        }
-        for task in self.get_flat_relatives(upstream=True):
-            if not task.is_setup:
-                continue
-            has_no_teardowns = not any(True for x in task.downstream_list if 
x.is_teardown)
-            # if task has no teardowns or has teardowns downstream of self
-            if has_no_teardowns or 
task.downstream_task_ids.intersection(downstream_teardown_ids):
-                yield task
-                for t in task.downstream_list:
-                    if t.is_teardown and t != self:
-                        yield t
-
-    def get_upstreams_only_setups(self) -> Iterable[Operator]:
-        """
-        Return relevant upstream setups.
-
-        This method is meant to be used when we are checking task dependencies 
where we need
-        to wait for all the upstream setups to complete before we can run the 
task.
-        """
-        for task in self.get_upstreams_only_setups_and_teardowns():
-            if task.is_setup:
-                yield task
-
     def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | 
MappedTaskGroup]:
         """
         Return mapped nodes that are direct dependencies of the current task.
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 0dc533bef7c..a9628873e5e 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -28,7 +28,6 @@ import contextlib
 import copy
 import functools
 import logging
-import sys
 from datetime import datetime, timedelta
 from functools import wraps
 from threading import local
@@ -879,46 +878,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperator
         else:
             return self.downstream_list
 
-    @property
-    def is_setup(self) -> bool:
-        """
-        Whether the operator is a setup task.
-
-        :meta private:
-        """
-        return self._is_setup
-
-    @is_setup.setter
-    def is_setup(self, value: bool) -> None:
-        """
-        Setter for is_setup property.
-
-        :meta private:
-        """
-        if self.is_teardown and value:
-            raise ValueError(f"Cannot mark task '{self.task_id}' as setup; 
task is already a teardown.")
-        self._is_setup = value
-
-    @property
-    def is_teardown(self) -> bool:
-        """
-        Whether the operator is a teardown task.
-
-        :meta private:
-        """
-        return self._is_teardown
-
-    @is_teardown.setter
-    def is_teardown(self, value: bool) -> None:
-        """
-        Setter for is_teardown property.
-
-        :meta private:
-        """
-        if self.is_setup and value:
-            raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; 
task is already a setup.")
-        self._is_teardown = value
-
     @staticmethod
     def xcom_push(
         context: Any,
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 691a64ed290..642b4c2f7e1 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -100,7 +100,7 @@ from airflow.models.taskinstance import (
     clear_task_instances,
 )
 from airflow.models.tasklog import LogTemplate
-from airflow.sdk import DAG as TaskSDKDag, dag as dag
+from airflow.sdk import DAG as TaskSDKDag, dag as task_sdk_dag_decorator
 from airflow.secrets.local_filesystem import LocalFilesystemBackend
 from airflow.security import permissions
 from airflow.settings import json
@@ -296,6 +296,14 @@ def _create_orm_dagrun(
     return run
 
 
+if TYPE_CHECKING:
+    dag = task_sdk_dag_decorator
+else:
+
+    def dag(dag_id: str = "", **kwargs):
+        return task_sdk_dag_decorator(dag_id, __DAG_class=DAG, 
__warnings_stacklevel_delta=3)
+
+
 @functools.total_ordering
 @attrs.define(hash=False, repr=False, eq=False)
 class DAG(TaskSDKDag, LoggingMixin):
diff --git a/airflow/serialization/schema.json 
b/airflow/serialization/schema.json
index fe1e63c4903..e313e2c7af7 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -137,7 +137,7 @@
       "type": "object",
       "properties": {
         "params": { "$ref": "#/definitions/params" },
-        "_dag_id": { "type": "string" },
+        "dag_id": { "type": "string" },
         "tasks": {  "$ref": "#/definitions/tasks" },
         "timezone": { "$ref": "#/definitions/timezone" },
         "owner_links": { "type": "object" },
@@ -157,10 +157,10 @@
             ]
         },
         "orientation": { "type" : "string"},
-        "_dag_display_property_value": { "type" : "string"},
+        "dag_display_name": { "type" : "string"},
         "_description": { "type" : "string"},
         "_concurrency": { "type" : "number"},
-        "_max_active_tasks": { "type" : "number"},
+        "max_active_tasks": { "type" : "number"},
         "max_active_runs": { "type" : "number"},
         "max_consecutive_failed_dag_runs": { "type" : "number"},
         "default_args": { "$ref": "#/definitions/dict" },
@@ -175,7 +175,7 @@
         "has_on_failure_callback":  { "type": "boolean" },
         "render_template_as_native_obj":  { "type": "boolean" },
         "tags": { "type": "array" },
-        "_task_group": {"anyOf": [
+        "task_group": {"anyOf": [
           { "type": "null" },
           { "$ref": "#/definitions/task_group" }
         ]},
@@ -183,7 +183,7 @@
         "dag_dependencies": { "$ref": "#/definitions/dag_dependencies" }
       },
       "required": [
-        "_dag_id",
+        "dag_id",
         "fileloc",
         "tasks"
       ],
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 8b674b2aa0f..9aafaf1f54a 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -920,10 +920,11 @@ class BaseSerialization:
         to account for the case where the default value of the field is None 
but has the
         ``field = field or {}`` set.
         """
-        if attrname in cls._CONSTRUCTOR_PARAMS and (
-            cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])
-        ):
-            return True
+        if attrname in cls._CONSTRUCTOR_PARAMS:
+            if cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, 
[]]):
+                return True
+            if cls._CONSTRUCTOR_PARAMS[attrname] is attrs.NOTHING and value is 
None:
+                return True
         return False
 
     @classmethod
@@ -1613,7 +1614,7 @@ class SerializedDAG(DAG, BaseSerialization):
             ]
             dag_deps.extend(DependencyDetector.detect_dag_dependencies(dag))
             serialized_dag["dag_dependencies"] = [x.__dict__ for x in 
sorted(dag_deps)]
-            serialized_dag["_task_group"] = 
TaskGroupSerialization.serialize_task_group(dag.task_group)
+            serialized_dag["task_group"] = 
TaskGroupSerialization.serialize_task_group(dag.task_group)
 
             # Edge info in the JSON exactly matches our internal structure
             serialized_dag["edge_info"] = dag.edge_info
@@ -1633,7 +1634,7 @@ class SerializedDAG(DAG, BaseSerialization):
     @classmethod
     def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
         """Deserializes a DAG from a JSON object."""
-        dag = SerializedDAG(dag_id=encoded_dag["_dag_id"], schedule=None)
+        dag = SerializedDAG(dag_id=encoded_dag["dag_id"], schedule=None)
 
         for k, v in encoded_dag.items():
             if k == "_downstream_task_ids":
@@ -1668,16 +1669,17 @@ class SerializedDAG(DAG, BaseSerialization):
                 v = set(v)
             # else use v as it is
 
-            setattr(dag, k, v)
+            object.__setattr__(dag, k, v)
 
         # Set _task_group
-        if "_task_group" in encoded_dag:
-            dag.task_group = TaskGroupSerialization.deserialize_task_group(
-                encoded_dag["_task_group"],
+        if "task_group" in encoded_dag:
+            tg = TaskGroupSerialization.deserialize_task_group(
+                encoded_dag["task_group"],
                 None,
                 dag.task_dict,
                 dag,
             )
+            object.__setattr__(dag, "task_group", tg)
         else:
             # This must be old data that had no task_group. Create a root 
TaskGroup and add
             # all tasks to it.
diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py
index e299999423e..78044e4e357 100644
--- a/airflow/utils/decorators.py
+++ b/airflow/utils/decorators.py
@@ -69,8 +69,9 @@ def _balance_parens(after_decorator):
 
 
 class _autostacklevel_warn:
-    def __init__(self):
+    def __init__(self, delta):
         self.warnings = __import__("warnings")
+        self.delta = delta
 
     def __getattr__(self, name):
         return getattr(self.warnings, name)
@@ -79,11 +80,11 @@ class _autostacklevel_warn:
         return dir(self.warnings)
 
     def warn(self, message, category=None, stacklevel=1, source=None):
-        self.warnings.warn(message, category, stacklevel + 2, source)
+        self.warnings.warn(message, category, stacklevel + self.delta, source)
 
 
-def fixup_decorator_warning_stack(func):
+def fixup_decorator_warning_stack(func, delta: int = 2):
     if func.__globals__.get("warnings") is sys.modules["warnings"]:
         # Yes, this is more than slightly hacky, but it _automatically_ sets 
the right stacklevel parameter to
         # `warnings.warn` to ignore the decorator.
-        func.__globals__["warnings"] = _autostacklevel_warn()
+        func.__globals__["warnings"] = _autostacklevel_warn(delta)
diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py 
b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
index 54b1e30ab81..6f90ae7f118 100644
--- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
@@ -31,6 +31,7 @@ from typing import (
 )
 
 from airflow.sdk.definitions.node import DAGNode
+from airflow.sdk.definitions.mixins import DependencyMixin
 from airflow.utils.log.secrets_masker import redact
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
@@ -41,6 +42,7 @@ if TYPE_CHECKING:
     import jinja2  # Slow import.
 
     from airflow.models.baseoperatorlink import BaseOperatorLink
+    from airflow.models.operator import Operator
     from airflow.sdk.definitions.baseoperator import BaseOperator
     from airflow.sdk.definitions.dag import DAG
 
@@ -99,6 +101,8 @@ class AbstractOperator(DAGNode):
     trigger_rule: TriggerRule
     _needs_expansion: bool | None = None
     _on_failure_fail_dagrun = False
+    is_setup: bool = False
+    is_teardown: bool = False
 
     HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
         (
@@ -163,3 +167,105 @@ class AbstractOperator(DAGNode):
             # "task_group_id.task_id" -> "task_id"
             return self.task_id[len(tg.node_id) + 1 :]
         return self.task_id
+
+    def as_setup(self):
+        self.is_setup = True
+        return self
+
+    def as_teardown(
+        self,
+        *,
+        setups: BaseOperator | Iterable[BaseOperator] | None = None,
+        on_failure_fail_dagrun: bool | None = None,
+    ):
+        self.is_teardown = True
+        self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
+        if on_failure_fail_dagrun is not None:
+            self.on_failure_fail_dagrun = on_failure_fail_dagrun
+        if setups is not None:
+            setups = [setups] if isinstance(setups, DependencyMixin) else 
setups
+            for s in setups:
+                s.is_setup = True
+                s >> self
+        return self
+
+    def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]:
+        """
+        Get a flat set of relative IDs, upstream or downstream.
+
+        Will recurse each relative found in the direction specified.
+
+        :param upstream: Whether to look for upstream or downstream relatives.
+        """
+        dag = self.get_dag()
+        if not dag:
+            return set()
+
+        relatives: set[str] = set()
+
+        # This is intentionally implemented as a loop, instead of calling
+        # get_direct_relative_ids() recursively, since Python has significant
+        # limitation on stack level, and a recursive implementation can blow up
+        # if a DAG contains very long routes.
+        task_ids_to_trace = self.get_direct_relative_ids(upstream)
+        while task_ids_to_trace:
+            task_ids_to_trace_next: set[str] = set()
+            for task_id in task_ids_to_trace:
+                if task_id in relatives:
+                    continue
+                
task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
+                relatives.add(task_id)
+            task_ids_to_trace = task_ids_to_trace_next
+
+        return relatives
+
+    def get_flat_relatives(self, upstream: bool = False) -> 
Collection[Operator]:
+        """Get a flat list of relatives, either upstream or downstream."""
+        dag = self.get_dag()
+        if not dag:
+            return set()
+        return [dag.task_dict[task_id] for task_id in 
self.get_flat_relative_ids(upstream=upstream)]
+
+    def get_upstreams_follow_setups(self) -> Iterable[Operator]:
+        """All upstreams and, for each upstream setup, its respective 
teardowns."""
+        for task in self.get_flat_relatives(upstream=True):
+            yield task
+            if task.is_setup:
+                for t in task.downstream_list:
+                    if t.is_teardown and t != self:
+                        yield t
+
+    def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]:
+        """
+        Only *relevant* upstream setups and their teardowns.
+
+        This method is meant to be used when we are clearing the task 
(non-upstream) and we need
+        to add in the *relevant* setups and their teardowns.
+
+        Relevant in this case means, the setup has a teardown that is 
downstream of ``self``,
+        or the setup has no teardowns.
+        """
+        downstream_teardown_ids = {
+            x.task_id for x in self.get_flat_relatives(upstream=False) if 
x.is_teardown
+        }
+        for task in self.get_flat_relatives(upstream=True):
+            if not task.is_setup:
+                continue
+            has_no_teardowns = not any(True for x in task.downstream_list if 
x.is_teardown)
+            # if task has no teardowns or has teardowns downstream of self
+            if has_no_teardowns or 
task.downstream_task_ids.intersection(downstream_teardown_ids):
+                yield task
+                for t in task.downstream_list:
+                    if t.is_teardown and t != self:
+                        yield t
+
+    def get_upstreams_only_setups(self) -> Iterable[Operator]:
+        """
+        Return relevant upstream setups.
+
+        This method is meant to be used when we are checking task dependencies 
where we need
+        to wait for all the upstream setups to complete before we can run the 
task.
+        """
+        for task in self.get_upstreams_only_setups_and_teardowns():
+            if task.is_setup:
+                yield task
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py 
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index eb7e57907eb..57a85988987 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -557,6 +557,9 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     logger_name: str | None = None
     allow_nested_operators: bool = True
 
+    is_setup: bool = False
+    is_teardown: bool = False
+
     template_fields: Collection[str] = ()
     template_ext: Sequence[str] = ()
 
@@ -1041,6 +1044,21 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
         return Resources(**resources)
 
+    def _convert_is_setup(self, value: bool) -> bool:
+        """
+        Setter for is_setup property.
+
+        :meta private:
+        """
+        if self.is_teardown and value:
+            raise ValueError(f"Cannot mark task '{self.task_id}' as setup; 
task is already a teardown.")
+        return value
+
+    def _convert_is_teardown(self, value: bool) -> bool:
+        if self.is_setup and value:
+            raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; 
task is already a setup.")
+        return value
+
     @property
     def task_display_name(self) -> str:
         return self._task_display_name or self.task_id
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 0ad82e52455..21579e87356 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -947,6 +947,7 @@ class DAG:
                 "has_on_failure_callback",
                 "auto_register",
                 "fail_stop",
+                "schedule",
             }
             cls.__serialized_fields = frozenset(vars(DAG(dag_id="test", 
schedule=None))) - exclusion_list
         return cls.__serialized_fields
@@ -984,114 +985,93 @@ class DAG:
                 yield owner, link
 
 
-# NOTE: Please keep the list of arguments in sync with DAG.__init__.
-# Only exception: dag_id here should have a default value, but not in DAG.
-def dag(
-    dag_id: str = "",
-    description: str | None = None,
-    schedule: ScheduleArg = None,
-    start_date: datetime | None = None,
-    end_date: datetime | None = None,
-    template_searchpath: str | Iterable[str] | None = None,
-    template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined,
-    user_defined_macros: dict | None = None,
-    user_defined_filters: dict | None = None,
-    default_args: dict | None = None,
-    max_active_tasks: int = airflow_conf.getint("core", 
"max_active_tasks_per_dag"),
-    max_active_runs: int = airflow_conf.getint("core", 
"max_active_runs_per_dag"),
-    max_consecutive_failed_dag_runs: int = airflow_conf.getint(
-        "core", "max_consecutive_failed_dag_runs_per_dag"
-    ),
-    dagrun_timeout: timedelta | None = None,
-    sla_miss_callback: Any = None,
-    catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"),
-    on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
-    on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
-    doc_md: str | None = None,
-    params: abc.MutableMapping | None = None,
-    access_control: dict[str, dict[str, Collection[str]]] | dict[str, 
Collection[str]] | None = None,
-    is_paused_upon_creation: bool | None = None,
-    jinja_environment_kwargs: dict | None = None,
-    render_template_as_native_obj: bool = False,
-    tags: Collection[str] | None = None,
-    owner_links: dict[str, str] | None = None,
-    auto_register: bool = True,
-    fail_stop: bool = False,
-    dag_display_name: str | None = None,
-) -> Callable[[Callable], Callable[..., DAG]]:
-    """
-    Python dag decorator which wraps a function into an Airflow DAG.
+if TYPE_CHECKING:
+    # NOTE: Please keep the list of arguments in sync with DAG.__init__.
+    # Only exception: dag_id here should have a default value, but not in DAG.
+    def dag(
+        dag_id: str = "",
+        *,
+        description: str | None = None,
+        schedule: ScheduleArg = None,
+        start_date: datetime | None = None,
+        end_date: datetime | None = None,
+        template_searchpath: str | Iterable[str] | None = None,
+        template_undefined: type[jinja2.StrictUndefined] = 
jinja2.StrictUndefined,
+        user_defined_macros: dict | None = None,
+        user_defined_filters: dict | None = None,
+        default_args: dict | None = None,
+        max_active_tasks: int = airflow_conf.getint("core", 
"max_active_tasks_per_dag"),
+        max_active_runs: int = airflow_conf.getint("core", 
"max_active_runs_per_dag"),
+        max_consecutive_failed_dag_runs: int = airflow_conf.getint(
+            "core", "max_consecutive_failed_dag_runs_per_dag"
+        ),
+        dagrun_timeout: timedelta | None = None,
+        sla_miss_callback: Any = None,
+        catchup: bool = airflow_conf.getboolean("scheduler", 
"catchup_by_default"),
+        on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
+        on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
+        doc_md: str | None = None,
+        params: abc.MutableMapping | None = None,
+        access_control: dict[str, dict[str, Collection[str]]] | dict[str, 
Collection[str]] | None = None,
+        is_paused_upon_creation: bool | None = None,
+        jinja_environment_kwargs: dict | None = None,
+        render_template_as_native_obj: bool = False,
+        tags: Collection[str] | None = None,
+        owner_links: dict[str, str] | None = None,
+        auto_register: bool = True,
+        fail_stop: bool = False,
+        dag_display_name: str | None = None,
+    ) -> Callable[[Callable], Callable[..., DAG]]:
+        """
+        Python dag decorator which wraps a function into an Airflow DAG.
 
-    Accepts kwargs for operator kwarg. Can be used to parameterize DAGs.
+        Accepts kwargs for operator kwarg. Can be used to parameterize DAGs.
 
-    :param dag_args: Arguments for DAG object
-    :param dag_kwargs: Kwargs for DAG object.
-    """
+        :param dag_args: Arguments for DAG object
+        :param dag_kwargs: Kwargs for DAG object.
+        """
+else:
 
-    def wrapper(f: Callable) -> Callable[..., DAG]:
-        @functools.wraps(f)
-        def factory(*args, **kwargs):
-            # Generate signature for decorated function and bind the arguments 
when called
-            # we do this to extract parameters, so we can annotate them on the 
DAG object.
-            # In addition, this fails if we are missing any args/kwargs with 
TypeError as expected.
-            f_sig = signature(f).bind(*args, **kwargs)
-            # Apply defaults to capture default values if set.
-            f_sig.apply_defaults()
-
-            # Initialize DAG with bound arguments
-            with DAG(
-                dag_id or f.__name__,
-                description=description,
-                start_date=start_date,
-                end_date=end_date,
-                template_searchpath=template_searchpath,
-                template_undefined=template_undefined,
-                user_defined_macros=user_defined_macros,
-                user_defined_filters=user_defined_filters,
-                default_args=default_args,
-                max_active_tasks=max_active_tasks,
-                max_active_runs=max_active_runs,
-                
max_consecutive_failed_dag_runs=max_consecutive_failed_dag_runs,
-                dagrun_timeout=dagrun_timeout,
-                sla_miss_callback=sla_miss_callback,
-                catchup=catchup,
-                on_success_callback=on_success_callback,
-                on_failure_callback=on_failure_callback,
-                doc_md=doc_md,
-                params=params,
-                access_control=access_control,
-                is_paused_upon_creation=is_paused_upon_creation,
-                jinja_environment_kwargs=jinja_environment_kwargs,
-                render_template_as_native_obj=render_template_as_native_obj,
-                tags=tags,
-                schedule=schedule,
-                owner_links=owner_links,
-                auto_register=auto_register,
-                fail_stop=fail_stop,
-                dag_display_name=dag_display_name,
-            ) as dag_obj:
-                # Set DAG documentation from function documentation if it 
exists and doc_md is not set.
-                if f.__doc__ and not dag_obj.doc_md:
-                    dag_obj.doc_md = f.__doc__
-
-                # Generate DAGParam for each function arg/kwarg and replace it 
for calling the function.
-                # All args/kwargs for function will be DAGParam object and 
replaced on execution time.
-                f_kwargs = {}
-                for name, value in f_sig.arguments.items():
-                    f_kwargs[name] = dag_obj.param(name, value)
-
-                # set file location to caller source path
-                back = sys._getframe().f_back
-                dag_obj.fileloc = back.f_code.co_filename if back else ""
-
-                # Invoke function to create operators in the DAG scope.
-                f(**f_kwargs)
-
-            # Return dag object such that it's accessible in Globals.
-            return dag_obj
-
-        # Ensure that warnings from inside DAG() are emitted from the caller, 
not here
-        fixup_decorator_warning_stack(factory)
-        return factory
-
-    return wrapper
+    def dag(dag_id="", __DAG_class=DAG, __warnings_stacklevel_delta=2, 
**decorator_kwargs):
+        # TODO: Task-SDK: remove __DAG_class
+        # __DAG_class is a temporary hack to allow the dag decorator in 
airflow.models.dag to continue to
+        # return SchedulerDag objects
+        DAG = __DAG_class
+
+        def wrapper(f: Callable) -> Callable[..., DAG]:
+            @functools.wraps(f)
+            def factory(*args, **kwargs):
+                # Generate signature for decorated function and bind the 
arguments when called
+                # we do this to extract parameters, so we can annotate them on 
the DAG object.
+                # In addition, this fails if we are missing any args/kwargs 
with TypeError as expected.
+                f_sig = signature(f).bind(*args, **kwargs)
+                # Apply defaults to capture default values if set.
+                f_sig.apply_defaults()
+
+                # Initialize DAG with bound arguments
+                with DAG(dag_id or f.__name__, **decorator_kwargs) as dag_obj:
+                    # Set DAG documentation from function documentation if it 
exists and doc_md is not set.
+                    if f.__doc__ and not dag_obj.doc_md:
+                        dag_obj.doc_md = f.__doc__
+
+                    # Generate DAGParam for each function arg/kwarg and 
replace it for calling the function.
+                    # All args/kwargs for function will be DAGParam object and 
replaced on execution time.
+                    f_kwargs = {}
+                    for name, value in f_sig.arguments.items():
+                        f_kwargs[name] = dag_obj.param(name, value)
+
+                    # set file location to caller source path
+                    back = sys._getframe().f_back
+                    dag_obj.fileloc = back.f_code.co_filename if back else ""
+
+                    # Invoke function to create operators in the DAG scope.
+                    f(**f_kwargs)
+
+                # Return dag object such that it's accessible in Globals.
+                return dag_obj
+
+            # Ensure that warnings from inside DAG() are emitted from the 
caller, not here
+            fixup_decorator_warning_stack(factory)
+            return factory
+
+        return wrapper
diff --git a/task_sdk/tests/defintions/test_dag.py 
b/task_sdk/tests/defintions/test_dag.py
index dd4ded2f4fe..2300e97f07e 100644
--- a/task_sdk/tests/defintions/test_dag.py
+++ b/task_sdk/tests/defintions/test_dag.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import weakref
 from datetime import datetime, timedelta, timezone
 
 import pytest
@@ -23,7 +24,7 @@ import pytest
 from airflow.exceptions import DuplicateTaskIdFound
 from airflow.models.param import Param, ParamsDict
 from airflow.sdk.definitions.baseoperator import BaseOperator
-from airflow.sdk.definitions.dag import DAG
+from airflow.sdk.definitions.dag import DAG, dag as dag_decorator
 
 DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
 
@@ -245,3 +246,72 @@ class TestDag:
 
         # Make sure we don't affect the original!
         assert task.task_group.upstream_group_ids is not 
copied_task.task_group.upstream_group_ids
+
+
+class TestDagDecorator:
+    DEFAULT_ARGS = {
+        "owner": "test",
+        "depends_on_past": True,
+        "start_date": datetime.now(tz=timezone.utc),
+        "retries": 1,
+        "retry_delay": timedelta(minutes=1),
+    }
+    VALUE = 42
+
+    def test_fileloc(self):
+        @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
+        def noop_pipeline(): ...
+
+        dag = noop_pipeline()
+        assert isinstance(dag, DAG)
+        assert dag.dag_id == "noop_pipeline"
+        assert dag.fileloc == __file__
+
+    def test_set_dag_id(self):
+        """Test that checks you can set dag_id from decorator."""
+
+        @dag_decorator("test", schedule=None, default_args=self.DEFAULT_ARGS)
+        def noop_pipeline(): ...
+
+        dag = noop_pipeline()
+        assert isinstance(dag, DAG)
+        assert dag.dag_id == "test"
+
+    def test_default_dag_id(self):
+        """Test that @dag uses function name as default dag id."""
+
+        @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
+        def noop_pipeline(): ...
+
+        dag = noop_pipeline()
+        assert isinstance(dag, DAG)
+        assert dag.dag_id == "noop_pipeline"
+
+    @pytest.mark.parametrize(
+        argnames=["dag_doc_md", "expected_doc_md"],
+        argvalues=[
+            pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"),
+            pytest.param(None, "Regular DAG documentation", 
id="use_dag_docstring"),
+        ],
+    )
+    def test_documentation_added(self, dag_doc_md, expected_doc_md):
+        """Test that @dag uses function docs as doc_md for DAG object if 
doc_md is not explicitly set."""
+
+        @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS, 
doc_md=dag_doc_md)
+        def noop_pipeline():
+            """Regular DAG documentation"""
+
+        dag = noop_pipeline()
+        assert isinstance(dag, DAG)
+        assert dag.dag_id == "noop_pipeline"
+        assert dag.doc_md == expected_doc_md
+
+    def test_fails_if_arg_not_set(self):
+        """Test that @dag decorated function fails if positional argument is 
not set"""
+
+        @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
+        def noop_pipeline(value): ...
+
+        # Test that if arg is not passed it raises a type error as expected.
+        with pytest.raises(TypeError):
+            noop_pipeline()
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 0b0b3c340f7..be8dde52331 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -23,7 +23,6 @@ import logging
 import os
 import pickle
 import re
-import weakref
 from datetime import timedelta
 from pathlib import Path
 from typing import TYPE_CHECKING
@@ -92,7 +91,6 @@ from airflow.utils.weight_rule import WeightRule
 from tests.models import DEFAULT_DATE
 from tests.plugins.priority_weight_strategy import (
     FactorPriorityWeightStrategy,
-    NotRegisteredPriorityWeightStrategy,
     StaticTestPriorityWeightStrategy,
     TestPriorityWeightStrategyPlugin,
 )
@@ -2517,54 +2515,6 @@ class TestDagDecorator:
     def teardown_method(self):
         clear_db_runs()
 
-    def test_fileloc(self):
-        @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
-        def noop_pipeline(): ...
-
-        dag = noop_pipeline()
-        assert isinstance(dag, DAG)
-        assert dag.dag_id == "noop_pipeline"
-        assert dag.fileloc == __file__
-
-    def test_set_dag_id(self):
-        """Test that checks you can set dag_id from decorator."""
-
-        @dag_decorator("test", schedule=None, default_args=self.DEFAULT_ARGS)
-        def noop_pipeline(): ...
-
-        dag = noop_pipeline()
-        assert isinstance(dag, DAG)
-        assert dag.dag_id == "test"
-
-    def test_default_dag_id(self):
-        """Test that @dag uses function name as default dag id."""
-
-        @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
-        def noop_pipeline(): ...
-
-        dag = noop_pipeline()
-        assert isinstance(dag, DAG)
-        assert dag.dag_id == "noop_pipeline"
-
-    @pytest.mark.parametrize(
-        argnames=["dag_doc_md", "expected_doc_md"],
-        argvalues=[
-            pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"),
-            pytest.param(None, "Regular DAG documentation", 
id="use_dag_docstring"),
-        ],
-    )
-    def test_documentation_added(self, dag_doc_md, expected_doc_md):
-        """Test that @dag uses function docs as doc_md for DAG object if 
doc_md is not explicitly set."""
-
-        @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS, 
doc_md=dag_doc_md)
-        def noop_pipeline():
-            """Regular DAG documentation"""
-
-        dag = noop_pipeline()
-        assert isinstance(dag, DAG)
-        assert dag.dag_id == "noop_pipeline"
-        assert dag.doc_md == expected_doc_md
-
     def test_documentation_template_rendered(self):
         """Test that @dag uses function docs as doc_md for DAG object"""
 
@@ -2577,7 +2527,6 @@ class TestDagDecorator:
             """
 
         dag = noop_pipeline()
-        assert isinstance(dag, DAG)
         assert dag.dag_id == "noop_pipeline"
         assert "Regular DAG documentation" in dag.doc_md
 
@@ -2597,25 +2546,9 @@ class TestDagDecorator:
         def markdown_docs(): ...
 
         dag = markdown_docs()
-        assert isinstance(dag, DAG)
         assert dag.dag_id == "test-dag"
         assert dag.doc_md == raw_content
 
-    def test_fails_if_arg_not_set(self):
-        """Test that @dag decorated function fails if positional argument is 
not set"""
-
-        @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
-        def noop_pipeline(value):
-            @task_decorator
-            def return_num(num):
-                return num
-
-            return_num(value)
-
-        # Test that if arg is not passed it raises a type error as expected.
-        with pytest.raises(TypeError):
-            noop_pipeline()
-
     def test_dag_param_resolves(self):
         """Test that dag param is correctly resolved by operator"""
 


Reply via email to