This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 762f0212c10350c6445ba6f354af8dd357325b24
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Sun Oct 13 15:53:28 2024 +0100

    Start replacing DAG class with the task-sdk DAG class instead
    
    I'm not sure if this approach is right, but it is _a_ path forward that lets
    us keep main working.
    
    So far this has been tested just with `uv run pytest 
tests/models/test_dag.py -k 'test_dag_as_context_manager'` and no further
---
 airflow/models/baseoperator.py                     |  27 +-
 airflow/models/dag.py                              | 826 +--------------------
 airflow/utils/task_group.py                        |  30 +-
 task_sdk/pyproject.toml                            |   1 -
 task_sdk/src/airflow/sdk/__init__.py               |  28 +-
 .../src/airflow/sdk/definitions/baseoperator.py    |   2 +-
 .../src/airflow/sdk/definitions/contextmanager.py  |   2 +-
 task_sdk/src/airflow/sdk/definitions/dag.py        |  39 +-
 task_sdk/src/airflow/sdk/definitions/node.py       |   9 +-
 .../definitions/{task_group.py => taskgroup.py}    |   0
 task_sdk/src/airflow/sdk/types.py                  |   4 +-
 task_sdk/tests/test_hello.py                       |  23 -
 uv.lock                                            |  11 -
 13 files changed, 85 insertions(+), 917 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 0c3d119be19..39d56187371 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -277,14 +277,13 @@ def partial(
     allow_nested_operators: bool = True,
     **kwargs,
 ) -> OperatorPartial:
-    from airflow.models.dag import DagContext
-    from airflow.utils.task_group import TaskGroupContext
+    from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
 
     validate_mapping_kwargs(operator_class, "partial", kwargs)
 
-    dag = dag or DagContext.get_current_dag()
+    dag = dag or DagContext.get_current()
     if dag:
-        task_group = task_group or TaskGroupContext.get_current_task_group(dag)
+        task_group = task_group or TaskGroupContext.get_current(dag)
     if task_group:
         task_id = task_group.child_id(task_id)
 
@@ -453,8 +452,7 @@ class BaseOperatorMeta(abc.ABCMeta):
 
         @wraps(func)
         def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> 
Any:
-            from airflow.models.dag import DagContext
-            from airflow.utils.task_group import TaskGroupContext
+            from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
 
             if args:
                 raise AirflowException("Use keyword arguments when 
initializing operators")
@@ -464,10 +462,10 @@ class BaseOperatorMeta(abc.ABCMeta):
                 getattr(self, "_BaseOperator__from_mapped", False),
             )
 
-            dag: DAG | None = kwargs.get("dag") or DagContext.get_current_dag()
+            dag: DAG | None = kwargs.get("dag") or DagContext.get_current()
             task_group: TaskGroup | None = kwargs.get("task_group")
             if dag and not task_group:
-                task_group = TaskGroupContext.get_current_task_group(dag)
+                task_group = TaskGroupContext.get_current(dag)
 
             default_args, merged_params = get_merged_defaults(
                 dag=dag,
@@ -922,8 +920,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         allow_nested_operators: bool = True,
         **kwargs,
     ):
-        from airflow.models.dag import DagContext
-        from airflow.utils.task_group import TaskGroupContext
+        from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
 
         self.__init_kwargs = {}
 
@@ -937,8 +934,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             )
         validate_key(task_id)
 
-        dag = dag or DagContext.get_current_dag()
-        task_group = task_group or TaskGroupContext.get_current_task_group(dag)
+        dag = dag or DagContext.get_current()
+        task_group = task_group or TaskGroupContext.get_current(dag)
 
         self.task_id = task_group.child_id(task_id) if task_group else task_id
         if not self.__from_mapped and task_group:
@@ -1662,13 +1659,13 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     def get_serialized_fields(cls):
         """Stringified DAGs and operators contain exactly these fields."""
         if not cls.__serialized_fields:
-            from airflow.models.dag import DagContext
+            from airflow.sdk.definitions.contextmanager import DagContext
 
             # make sure the following dummy task is not added to current active
             # dag in context, otherwise, it will result in
             # `RuntimeError: dictionary changed size during iteration`
             # Exception in SerializedDAG.serialize_dag() call.
-            DagContext.push_context_managed_dag(None)
+            DagContext.push(None)
             cls.__serialized_fields = frozenset(
                 vars(BaseOperator(task_id="test")).keys()
                 - {
@@ -1704,7 +1701,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                     "start_from_trigger",
                 }
             )
-            DagContext.pop_context_managed_dag()
+            DagContext.pop()
 
         return cls.__serialized_fields
 
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index f5def92ea92..290f42ca387 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -20,15 +20,12 @@ from __future__ import annotations
 import asyncio
 import copy
 import functools
-import itertools
 import logging
-import os
 import pathlib
 import pickle
 import sys
 import time
 import traceback
-import weakref
 from collections import abc, defaultdict, deque
 from contextlib import ExitStack
 from datetime import datetime, timedelta
@@ -40,15 +37,12 @@ from typing import (
     Collection,
     Container,
     Iterable,
-    Iterator,
-    MutableSet,
     Pattern,
     Sequence,
     Union,
     cast,
     overload,
 )
-from urllib.parse import urlsplit
 
 import jinja2
 import pendulum
@@ -77,22 +71,19 @@ from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import backref, relationship
 from sqlalchemy.sql import Select, expression
 
+import airflow.sdk.definitions.contextmanager
 import airflow.templates
 from airflow import settings, utils
 from airflow.api_internal.internal_api_call import internal_api_call
-from airflow.assets import Asset, AssetAlias, AssetAll, BaseAsset
+from airflow.assets import Asset, AssetAlias, BaseAsset
 from airflow.configuration import conf as airflow_conf, secrets_backend_list
 from airflow.exceptions import (
     AirflowException,
-    DuplicateTaskIdFound,
-    FailStopDagInvalidTriggerRule,
-    ParamValidationError,
     TaskDeferred,
-    TaskNotFound,
     UnknownExecutorException,
 )
 from airflow.executors.executor_loader import ExecutorLoader
-from airflow.models.abstractoperator import AbstractOperator, 
TaskStateChangeCallback
+from airflow.models.abstractoperator import TaskStateChangeCallback
 from airflow.models.asset import (
     AssetDagRunQueue,
     AssetModel,
@@ -102,7 +93,6 @@ from airflow.models.baseoperator import BaseOperator
 from airflow.models.dagcode import DagCode
 from airflow.models.dagpickle import DagPickle
 from airflow.models.dagrun import RUN_ID_REGEX, DagRun
-from airflow.models.param import DagParam, ParamsDict
 from airflow.models.taskinstance import (
     Context,
     TaskInstance,
@@ -110,10 +100,10 @@ from airflow.models.taskinstance import (
     clear_task_instances,
 )
 from airflow.models.tasklog import LogTemplate
+from airflow.sdk import DAG as TaskSDKDag
 from airflow.secrets.local_filesystem import LocalFilesystemBackend
 from airflow.security import permissions
 from airflow.settings import json
-from airflow.stats import Stats
 from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, 
Timetable
 from airflow.timetables.interval import CronDataIntervalTimetable, 
DeltaDataIntervalTimetable
 from airflow.timetables.simple import (
@@ -126,13 +116,11 @@ from airflow.timetables.trigger import 
CronTriggerTimetable
 from airflow.utils import timezone
 from airflow.utils.dag_cycle_tester import check_cycle
 from airflow.utils.decorators import fixup_decorator_warning_stack
-from airflow.utils.helpers import exactly_one, validate_instance_args, 
validate_key
-from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.helpers import exactly_one
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, 
tuple_in_condition, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
-from airflow.utils.trigger_rule import TriggerRule
-from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType, 
EdgeInfoType
+from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 if TYPE_CHECKING:
     from types import ModuleType
@@ -141,13 +129,11 @@ if TYPE_CHECKING:
     from sqlalchemy.orm.query import Query
     from sqlalchemy.orm.session import Session
 
-    from airflow.decorators import TaskDecoratorCollection
     from airflow.models.dagbag import DagBag
     from airflow.models.operator import Operator
     from airflow.serialization.pydantic.dag import DagModelPydantic
     from airflow.serialization.pydantic.dag_run import DagRunPydantic
     from airflow.typing_compat import Literal
-    from airflow.utils.task_group import TaskGroup
 
 log = logging.getLogger(__name__)
 
@@ -359,7 +345,7 @@ DAG_ARGS_EXPECTED_TYPES = {
 
 
 @functools.total_ordering
-class DAG(LoggingMixin):
+class DAG(TaskSDKDag):
     """
     A dag (directed acyclic graph) is a collection of tasks with directional 
dependencies.
 
@@ -473,251 +459,6 @@ class DAG(LoggingMixin):
     :param dag_display_name: The display name of the DAG which appears on the 
UI.
     """
 
-    _comps = {
-        "dag_id",
-        "task_ids",
-        "start_date",
-        "end_date",
-        "fileloc",
-        "template_searchpath",
-        "last_loaded",
-    }
-
-    __serialized_fields: frozenset[str] | None = None
-
-    fileloc: str
-    """
-    File path that needs to be imported to load this DAG.
-
-    This may not be an actual file on disk in the case when this DAG is loaded
-    from a ZIP file or other DAG distribution format.
-    """
-
-    # NOTE: When updating arguments here, please also keep arguments in @dag()
-    # below in sync. (Search for 'def dag(' in this file.)
-    def __init__(
-        self,
-        dag_id: str,
-        description: str | None = None,
-        schedule: ScheduleArg = None,
-        start_date: datetime | None = None,
-        end_date: datetime | None = None,
-        template_searchpath: str | Iterable[str] | None = None,
-        template_undefined: type[jinja2.StrictUndefined] = 
jinja2.StrictUndefined,
-        user_defined_macros: dict | None = None,
-        user_defined_filters: dict | None = None,
-        default_args: dict | None = None,
-        max_active_tasks: int = airflow_conf.getint("core", 
"max_active_tasks_per_dag"),
-        max_active_runs: int = airflow_conf.getint("core", 
"max_active_runs_per_dag"),
-        max_consecutive_failed_dag_runs: int = airflow_conf.getint(
-            "core", "max_consecutive_failed_dag_runs_per_dag"
-        ),
-        dagrun_timeout: timedelta | None = None,
-        sla_miss_callback: Any = None,
-        default_view: str = airflow_conf.get_mandatory_value("webserver", 
"dag_default_view").lower(),
-        orientation: str = airflow_conf.get_mandatory_value("webserver", 
"dag_orientation"),
-        catchup: bool = airflow_conf.getboolean("scheduler", 
"catchup_by_default"),
-        on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
-        on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
-        doc_md: str | None = None,
-        params: abc.MutableMapping | None = None,
-        access_control: dict[str, dict[str, Collection[str]]] | dict[str, 
Collection[str]] | None = None,
-        is_paused_upon_creation: bool | None = None,
-        jinja_environment_kwargs: dict | None = None,
-        render_template_as_native_obj: bool = False,
-        tags: Collection[str] | None = None,
-        owner_links: dict[str, str] | None = None,
-        auto_register: bool = True,
-        fail_stop: bool = False,
-        dag_display_name: str | None = None,
-    ):
-        from airflow.utils.task_group import TaskGroup
-
-        if tags and any(len(tag) > TAG_MAX_LEN for tag in tags):
-            raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} 
characters")
-
-        self.owner_links = owner_links or {}
-        self.user_defined_macros = user_defined_macros
-        self.user_defined_filters = user_defined_filters
-        if default_args and not isinstance(default_args, dict):
-            raise TypeError("default_args must be a dict")
-        self.default_args = copy.deepcopy(default_args or {})
-        params = params or {}
-
-        # merging potentially conflicting default_args['params'] into params
-        if "params" in self.default_args:
-            params.update(self.default_args["params"])
-            del self.default_args["params"]
-
-        # check self.params and convert them into ParamsDict
-        self.params = ParamsDict(params)
-
-        validate_key(dag_id)
-
-        self._dag_id = dag_id
-        self._dag_display_property_value = dag_display_name
-
-        self._max_active_tasks = max_active_tasks
-        self._pickle_id: int | None = None
-
-        self._description = description
-        # set file location to caller source path
-        back = sys._getframe().f_back
-        self.fileloc = back.f_code.co_filename if back else ""
-        self.task_dict: dict[str, Operator] = {}
-
-        # set timezone from start_date
-        tz = None
-        if start_date and start_date.tzinfo:
-            tzinfo = None if start_date.tzinfo else settings.TIMEZONE
-            tz = pendulum.instance(start_date, tz=tzinfo).timezone
-        elif date := self.default_args.get("start_date"):
-            if not isinstance(date, datetime):
-                date = timezone.parse(date)
-                self.default_args["start_date"] = date
-                start_date = date
-
-            tzinfo = None if date.tzinfo else settings.TIMEZONE
-            tz = pendulum.instance(date, tz=tzinfo).timezone
-        self.timezone: Timezone | FixedTimezone = tz or settings.TIMEZONE
-
-        # Apply the timezone we settled on to end_date if it wasn't supplied
-        if isinstance(_end_date := self.default_args.get("end_date"), str):
-            self.default_args["end_date"] = timezone.parse(_end_date, 
timezone=self.timezone)
-
-        self.start_date = timezone.convert_to_utc(start_date)
-        self.end_date = timezone.convert_to_utc(end_date)
-
-        # also convert tasks
-        if "start_date" in self.default_args:
-            self.default_args["start_date"] = 
timezone.convert_to_utc(self.default_args["start_date"])
-        if "end_date" in self.default_args:
-            self.default_args["end_date"] = 
timezone.convert_to_utc(self.default_args["end_date"])
-
-        if isinstance(schedule, Timetable):
-            self.timetable = schedule
-        elif isinstance(schedule, BaseAsset):
-            self.timetable = AssetTriggeredTimetable(schedule)
-        elif isinstance(schedule, Collection) and not isinstance(schedule, 
str):
-            if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule):
-                raise ValueError("All elements in 'schedule' should be assets 
or asset aliases")
-            self.timetable = AssetTriggeredTimetable(AssetAll(*schedule))
-        else:
-            self.timetable = create_timetable(schedule, self.timezone)
-
-        requires_automatic_backfilling = self.timetable.can_be_scheduled and 
catchup
-        if requires_automatic_backfilling and not ("start_date" in 
self.default_args or self.start_date):
-            raise ValueError("start_date is required when catchup=True")
-
-        if isinstance(template_searchpath, str):
-            template_searchpath = [template_searchpath]
-        self.template_searchpath = template_searchpath
-        self.template_undefined = template_undefined
-        self.last_loaded: datetime = timezone.utcnow()
-        self.safe_dag_id = dag_id.replace(".", "__dot__")
-        self.max_active_runs = max_active_runs
-        self.max_consecutive_failed_dag_runs = max_consecutive_failed_dag_runs
-        if self.max_consecutive_failed_dag_runs == 0:
-            self.max_consecutive_failed_dag_runs = airflow_conf.getint(
-                "core", "max_consecutive_failed_dag_runs_per_dag"
-            )
-        if self.max_consecutive_failed_dag_runs < 0:
-            raise AirflowException(
-                f"Invalid max_consecutive_failed_dag_runs: 
{self.max_consecutive_failed_dag_runs}."
-                f"Requires max_consecutive_failed_dag_runs >= 0"
-            )
-        if self.timetable.active_runs_limit is not None:
-            if self.timetable.active_runs_limit < self.max_active_runs:
-                raise AirflowException(
-                    f"Invalid max_active_runs: {type(self.timetable)} "
-                    f"requires max_active_runs <= 
{self.timetable.active_runs_limit}"
-                )
-        self.dagrun_timeout = dagrun_timeout
-        if sla_miss_callback:
-            log.warning(
-                "The SLA feature is removed in Airflow 3.0, to be replaced 
with a new implementation in 3.1"
-            )
-        if default_view in DEFAULT_VIEW_PRESETS:
-            self._default_view: str = default_view
-        else:
-            raise AirflowException(
-                f"Invalid values of dag.default_view: only support "
-                f"{DEFAULT_VIEW_PRESETS}, but get {default_view}"
-            )
-        if orientation in ORIENTATION_PRESETS:
-            self.orientation = orientation
-        else:
-            raise AirflowException(
-                f"Invalid values of dag.orientation: only support "
-                f"{ORIENTATION_PRESETS}, but get {orientation}"
-            )
-        self.catchup: bool = catchup
-
-        self.partial: bool = False
-        self.on_success_callback = on_success_callback
-        self.on_failure_callback = on_failure_callback
-
-        # Keeps track of any extra edge metadata (sparse; will not contain all
-        # edges, so do not iterate over it for that). Outer key is upstream
-        # task ID, inner key is downstream task ID.
-        self.edge_info: dict[str, dict[str, EdgeInfoType]] = {}
-
-        # To keep it in parity with Serialized DAGs
-        # and identify if DAG has on_*_callback without actually storing them 
in Serialized JSON
-        self.has_on_success_callback: bool = self.on_success_callback is not 
None
-        self.has_on_failure_callback: bool = self.on_failure_callback is not 
None
-
-        self._access_control = 
DAG._upgrade_outdated_dag_access_control(access_control)
-        self.is_paused_upon_creation = is_paused_upon_creation
-        self.auto_register = auto_register
-
-        self.fail_stop: bool = fail_stop
-
-        self.jinja_environment_kwargs = jinja_environment_kwargs
-        self.render_template_as_native_obj = render_template_as_native_obj
-
-        self.doc_md = self.get_doc_md(doc_md)
-
-        self.tags: MutableSet[str] = set(tags or [])
-        self._task_group = TaskGroup.create_root(self)
-        self.validate_schedule_and_params()
-        wrong_links = dict(self.iter_invalid_owner_links())
-        if wrong_links:
-            raise AirflowException(
-                "Wrong link format was used for the owner. Use a valid link \n"
-                f"Bad formatted links are: {wrong_links}"
-            )
-
-        # this will only be set at serialization time
-        # it's only use is for determining the relative
-        # fileloc based only on the serialize dag
-        self._processor_dags_folder = None
-
-        validate_instance_args(self, DAG_ARGS_EXPECTED_TYPES)
-
-    def get_doc_md(self, doc_md: str | None) -> str | None:
-        if doc_md is None:
-            return doc_md
-
-        if doc_md.endswith(".md"):
-            try:
-                return open(doc_md).read()
-            except FileNotFoundError:
-                return doc_md
-
-        return doc_md
-
-    def validate(self):
-        """
-        Validate the DAG has a coherent setup.
-
-        This is called by the DAG bag before bagging the DAG.
-        """
-        self.validate_executor_field()
-        self.validate_schedule_and_params()
-        self.timetable.validate()
-        self.validate_setup_teardown()
-
     def validate_executor_field(self):
         for task in self.tasks:
             if task.executor:
@@ -730,63 +471,6 @@ class DAG(LoggingMixin):
                         "update the executor configuration for this task."
                     )
 
-    def validate_setup_teardown(self):
-        """
-        Validate that setup and teardown tasks are configured properly.
-
-        :meta private:
-        """
-        for task in self.tasks:
-            if task.is_setup:
-                for down_task in task.downstream_list:
-                    if not down_task.is_teardown and down_task.trigger_rule != 
TriggerRule.ALL_SUCCESS:
-                        # todo: we can relax this to allow out-of-scope tasks 
to have other trigger rules
-                        # this is required to ensure consistent behavior of dag
-                        # when clearing an indirect setup
-                        raise ValueError("Setup tasks must be followed with 
trigger rule ALL_SUCCESS.")
-            FailStopDagInvalidTriggerRule.check(dag=self, 
trigger_rule=task.trigger_rule)
-
-    def __repr__(self):
-        return f"<DAG: {self.dag_id}>"
-
-    def __eq__(self, other):
-        if type(self) is type(other):
-            # Use getattr() instead of __dict__ as __dict__ doesn't return
-            # correct values for properties.
-            return all(getattr(self, c, None) == getattr(other, c, None) for c 
in self._comps)
-        return False
-
-    def __ne__(self, other):
-        return not self == other
-
-    def __lt__(self, other):
-        return self.dag_id < other.dag_id
-
-    def __hash__(self):
-        hash_components = [type(self)]
-        for c in self._comps:
-            # task_ids returns a list and lists can't be hashed
-            if c == "task_ids":
-                val = tuple(self.task_dict)
-            else:
-                val = getattr(self, c, None)
-            try:
-                hash(val)
-                hash_components.append(val)
-            except TypeError:
-                hash_components.append(repr(val))
-        return hash(tuple(hash_components))
-
-    # Context Manager -----------------------------------------------
-    def __enter__(self):
-        DagContext.push_context_managed_dag(self)
-        return self
-
-    def __exit__(self, _type, _value, _tb):
-        DagContext.pop_context_managed_dag()
-
-    # /Context Manager ----------------------------------------------
-
     @staticmethod
     def _upgrade_outdated_dag_access_control(access_control=None):
         """Look for outdated dag level actions in DAG access_controls and 
replace them with updated actions."""
@@ -1076,14 +760,6 @@ class DAG(LoggingMixin):
     def access_control(self, value):
         self._access_control = DAG._upgrade_outdated_dag_access_control(value)
 
-    @property
-    def dag_display_name(self) -> str:
-        return self._dag_display_property_value or self._dag_id
-
-    @property
-    def description(self) -> str | None:
-        return self._description
-
     @property
     def default_view(self) -> str:
         return self._default_view
@@ -1096,41 +772,6 @@ class DAG(LoggingMixin):
     def pickle_id(self, value: int) -> None:
         self._pickle_id = value
 
-    def param(self, name: str, default: Any = NOTSET) -> DagParam:
-        """
-        Return a DagParam object for current dag.
-
-        :param name: dag parameter name.
-        :param default: fallback value for dag parameter.
-        :return: DagParam instance for specified name and current dag.
-        """
-        return DagParam(current_dag=self, name=name, default=default)
-
-    @property
-    def tasks(self) -> list[Operator]:
-        return list(self.task_dict.values())
-
-    @tasks.setter
-    def tasks(self, val):
-        raise AttributeError("DAG.tasks can not be modified. Use 
dag.add_task() instead.")
-
-    @property
-    def task_ids(self) -> list[str]:
-        return list(self.task_dict)
-
-    @property
-    def teardowns(self) -> list[Operator]:
-        return [task for task in self.tasks if getattr(task, "is_teardown", 
None)]
-
-    @property
-    def tasks_upstream_of_teardowns(self) -> list[Operator]:
-        upstream_tasks = [t.upstream_list for t in self.teardowns]
-        return [val for sublist in upstream_tasks for val in sublist if not 
getattr(val, "is_teardown", None)]
-
-    @property
-    def task_group(self) -> TaskGroup:
-        return self._task_group
-
     @property
     def relative_fileloc(self) -> pathlib.Path:
         """File location of the importable dag 'file' relative to the 
configured DAGs folder."""
@@ -1145,24 +786,6 @@ class DAG(LoggingMixin):
             # Not relative to DAGS_FOLDER.
             return path
 
-    @property
-    def folder(self) -> str:
-        """Folder location of where the DAG object is instantiated."""
-        return os.path.dirname(self.fileloc)
-
-    @property
-    def owner(self) -> str:
-        """
-        Return list of all owners found in DAG tasks.
-
-        :return: Comma separated list of owners in DAG tasks
-        """
-        return ", ".join({t.owner for t in self.tasks})
-
-    @property
-    def allow_future_exec_dates(self) -> bool:
-        return settings.ALLOW_FUTURE_EXEC_DATES and not 
self.timetable.can_be_scheduled
-
     @provide_session
     def get_concurrency_reached(self, session=NEW_SESSION) -> bool:
         """Return a boolean indicating whether the max_active_tasks limit for 
this DAG has been reached."""
@@ -1251,24 +874,6 @@ class DAG(LoggingMixin):
 
         DAG.execute_callback(callbacks, context, self.dag_id)
 
-    @classmethod
-    def execute_callback(cls, callbacks: list[Callable] | None, context: 
Context | None, dag_id: str):
-        """
-        Triggers the callbacks with the given context.
-
-        :param callbacks: List of callbacks to call
-        :param context: Context to pass to all callbacks
-        :param dag_id: The dag_id of the DAG to find.
-        """
-        if callbacks and context:
-            for callback in callbacks:
-                cls.logger().info("Executing dag callback function: %s", 
callback)
-                try:
-                    callback(context)
-                except Exception:
-                    cls.logger().exception("failed to invoke dag state update 
callback")
-                    Stats.incr("dag.callback_exceptions", tags={"dag_id": 
dag_id})
-
     def get_active_runs(self):
         """
         Return a list of dag run execution dates currently running.
@@ -1368,45 +973,6 @@ class DAG(LoggingMixin):
         """Return the latest date for which at least one dag run exists."""
         return 
session.scalar(select(func.max(DagRun.execution_date)).where(DagRun.dag_id == 
self.dag_id))
 
-    def resolve_template_files(self):
-        for t in self.tasks:
-            t.resolve_template_files()
-
-    def get_template_env(self, *, force_sandboxed: bool = False) -> 
jinja2.Environment:
-        """Build a Jinja2 environment."""
-        # Collect directories to search for template files
-        searchpath = [self.folder]
-        if self.template_searchpath:
-            searchpath += self.template_searchpath
-
-        # Default values (for backward compatibility)
-        jinja_env_options = {
-            "loader": jinja2.FileSystemLoader(searchpath),
-            "undefined": self.template_undefined,
-            "extensions": ["jinja2.ext.do"],
-            "cache_size": 0,
-        }
-        if self.jinja_environment_kwargs:
-            jinja_env_options.update(self.jinja_environment_kwargs)
-        env: jinja2.Environment
-        if self.render_template_as_native_obj and not force_sandboxed:
-            env = airflow.templates.NativeEnvironment(**jinja_env_options)
-        else:
-            env = airflow.templates.SandboxedEnvironment(**jinja_env_options)
-
-        # Add any user defined items. Safe to edit globals as long as no 
templates are rendered yet.
-        # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals
-        if self.user_defined_macros:
-            env.globals.update(self.user_defined_macros)
-        if self.user_defined_filters:
-            env.filters.update(self.user_defined_filters)
-
-        return env
-
-    def set_dependency(self, upstream_task_id, downstream_task_id):
-        """Set dependency between two tasks that already have been added to 
the DAG using add_task()."""
-        
self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id))
-
     @provide_session
     def get_task_instances_before(
         self,
@@ -1871,33 +1437,6 @@ class DAG(LoggingMixin):
 
         return altered
 
-    @property
-    def roots(self) -> list[Operator]:
-        """Return nodes with no parents. These are first to execute and are 
called roots or root nodes."""
-        return [task for task in self.tasks if not task.upstream_list]
-
-    @property
-    def leaves(self) -> list[Operator]:
-        """Return nodes with no children. These are last to execute and are 
called leaves or leaf nodes."""
-        return [task for task in self.tasks if not task.downstream_list]
-
-    def topological_sort(self):
-        """
-        Sorts tasks in topographical order, such that a task comes after any 
of its upstream dependencies.
-
-        Deprecated in place of ``task_group.topological_sort``
-        """
-        from airflow.utils.task_group import TaskGroup
-
-        def nested_topo(group):
-            for node in group.topological_sort():
-                if isinstance(node, TaskGroup):
-                    yield from nested_topo(node)
-                else:
-                    yield node
-
-        return tuple(nested_topo(self.task_group))
-
     @provide_session
     def clear(
         self,
@@ -2031,169 +1570,6 @@ class DAG(LoggingMixin):
             print("Cancelled, nothing was cleared.")
         return count
 
-    def __deepcopy__(self, memo):
-        # Switcharoo to go around deepcopying objects coming through the
-        # backdoor
-        cls = self.__class__
-        result = cls.__new__(cls)
-        memo[id(self)] = result
-        for k, v in self.__dict__.items():
-            if k not in ("user_defined_macros", "user_defined_filters", 
"_log"):
-                setattr(result, k, copy.deepcopy(v, memo))
-
-        result.user_defined_macros = self.user_defined_macros
-        result.user_defined_filters = self.user_defined_filters
-        if hasattr(self, "_log"):
-            result._log = self._log
-        return result
-
-    def partial_subset(
-        self,
-        task_ids_or_regex: str | Pattern | Iterable[str],
-        include_downstream=False,
-        include_upstream=True,
-        include_direct_upstream=False,
-    ):
-        """
-        Return a subset of the current dag based on regex matching one or more 
tasks.
-
-        Returns a subset of the current dag as a deep copy of the current dag
-        based on a regex that should match one or many tasks, and includes
-        upstream and downstream neighbours based on the flag passed.
-
-        :param task_ids_or_regex: Either a list of task_ids, or a regex to
-            match against task ids (as a string, or compiled regex pattern).
-        :param include_downstream: Include all downstream tasks of matched
-            tasks, in addition to matched tasks.
-        :param include_upstream: Include all upstream tasks of matched tasks,
-            in addition to matched tasks.
-        :param include_direct_upstream: Include all tasks directly upstream of 
matched
-            and downstream (if include_downstream = True) tasks
-        """
-        from airflow.models.baseoperator import BaseOperator
-        from airflow.models.mappedoperator import MappedOperator
-
-        # deep-copying self.task_dict and self._task_group takes a long time, 
and we don't want all
-        # the tasks anyway, so we copy the tasks manually later
-        memo = {id(self.task_dict): None, id(self._task_group): None}
-        dag = copy.deepcopy(self, memo)  # type: ignore
-
-        if isinstance(task_ids_or_regex, (str, Pattern)):
-            matched_tasks = [t for t in self.tasks if 
re2.findall(task_ids_or_regex, t.task_id)]
-        else:
-            matched_tasks = [t for t in self.tasks if t.task_id in 
task_ids_or_regex]
-
-        also_include_ids: set[str] = set()
-        for t in matched_tasks:
-            if include_downstream:
-                for rel in t.get_flat_relatives(upstream=False):
-                    also_include_ids.add(rel.task_id)
-                    if rel not in matched_tasks:  # if it's in there, we're 
already processing it
-                        # need to include setups and teardowns for tasks that 
are in multiple
-                        # non-collinear setup/teardown paths
-                        if not rel.is_setup and not rel.is_teardown:
-                            also_include_ids.update(
-                                x.task_id for x in 
rel.get_upstreams_only_setups_and_teardowns()
-                            )
-            if include_upstream:
-                also_include_ids.update(x.task_id for x in 
t.get_upstreams_follow_setups())
-            else:
-                if not t.is_setup and not t.is_teardown:
-                    also_include_ids.update(x.task_id for x in 
t.get_upstreams_only_setups_and_teardowns())
-            if t.is_setup and not include_downstream:
-                also_include_ids.update(x.task_id for x in t.downstream_list 
if x.is_teardown)
-
-        also_include: list[Operator] = [self.task_dict[x] for x in 
also_include_ids]
-        direct_upstreams: list[Operator] = []
-        if include_direct_upstream:
-            for t in itertools.chain(matched_tasks, also_include):
-                upstream = (u for u in t.upstream_list if isinstance(u, 
(BaseOperator, MappedOperator)))
-                direct_upstreams.extend(upstream)
-
-        # Compiling the unique list of tasks that made the cut
-        # Make sure to not recursively deepcopy the dag or task_group while 
copying the task.
-        # task_group is reset later
-        def _deepcopy_task(t) -> Operator:
-            memo.setdefault(id(t.task_group), None)
-            return copy.deepcopy(t, memo)
-
-        dag.task_dict = {
-            t.task_id: _deepcopy_task(t)
-            for t in itertools.chain(matched_tasks, also_include, 
direct_upstreams)
-        }
-
-        def filter_task_group(group, parent_group):
-            """Exclude tasks not included in the subdag from the given 
TaskGroup."""
-            # We want to deepcopy _most but not all_ attributes of the task 
group, so we create a shallow copy
-            # and then manually deep copy the instances. (memo argument to 
deepcopy only works for instances
-            # of classes, not "native" properties of an instance)
-            copied = copy.copy(group)
-
-            memo[id(group.children)] = {}
-            if parent_group:
-                memo[id(group.parent_group)] = parent_group
-            for attr, value in copied.__dict__.items():
-                if id(value) in memo:
-                    value = memo[id(value)]
-                else:
-                    value = copy.deepcopy(value, memo)
-                copied.__dict__[attr] = value
-
-            proxy = weakref.proxy(copied)
-
-            for child in group.children.values():
-                if isinstance(child, AbstractOperator):
-                    if child.task_id in dag.task_dict:
-                        task = copied.children[child.task_id] = 
dag.task_dict[child.task_id]
-                        task.task_group = proxy
-                    else:
-                        copied.used_group_ids.discard(child.task_id)
-                else:
-                    filtered_child = filter_task_group(child, proxy)
-
-                    # Only include this child TaskGroup if it is non-empty.
-                    if filtered_child.children:
-                        copied.children[child.group_id] = filtered_child
-
-            return copied
-
-        dag._task_group = filter_task_group(self.task_group, None)
-
-        # Removing upstream/downstream references to tasks and TaskGroups that 
did not make
-        # the cut.
-        subdag_task_groups = dag.task_group.get_task_group_dict()
-        for group in subdag_task_groups.values():
-            group.upstream_group_ids.intersection_update(subdag_task_groups)
-            group.downstream_group_ids.intersection_update(subdag_task_groups)
-            group.upstream_task_ids.intersection_update(dag.task_dict)
-            group.downstream_task_ids.intersection_update(dag.task_dict)
-
-        for t in dag.tasks:
-            # Removing upstream/downstream references to tasks that did not
-            # make the cut
-            t.upstream_task_ids.intersection_update(dag.task_dict)
-            t.downstream_task_ids.intersection_update(dag.task_dict)
-
-        if len(dag.tasks) < len(self.tasks):
-            dag.partial = True
-
-        return dag
-
-    def has_task(self, task_id: str):
-        return task_id in self.task_dict
-
-    def has_task_group(self, task_group_id: str) -> bool:
-        return task_group_id in self.task_group_dict
-
-    @functools.cached_property
-    def task_group_dict(self):
-        return {k: v for k, v in 
self._task_group.get_task_group_dict().items() if k is not None}
-
-    def get_task(self, task_id: str) -> Operator:
-        if task_id in self.task_dict:
-            return self.task_dict[task_id]
-        raise TaskNotFound(f"Task {task_id} not found")
-
     def pickle_info(self):
         d = {}
         d["is_picklable"] = True
@@ -2223,76 +1599,6 @@ class DAG(LoggingMixin):
 
         return dp
 
-    @property
-    def task(self) -> TaskDecoratorCollection:
-        from airflow.decorators import task
-
-        return cast("TaskDecoratorCollection", functools.partial(task, 
dag=self))
-
-    def add_task(self, task: Operator) -> None:
-        """
-        Add a task to the DAG.
-
-        :param task: the task you want to add
-        """
-        FailStopDagInvalidTriggerRule.check(dag=self, 
trigger_rule=task.trigger_rule)
-
-        from airflow.utils.task_group import TaskGroupContext
-
-        # if the task has no start date, assign it the same as the DAG
-        if not task.start_date:
-            task.start_date = self.start_date
-        # otherwise, the task will start on the later of its own start date and
-        # the DAG's start date
-        elif self.start_date:
-            task.start_date = max(task.start_date, self.start_date)
-
-        # if the task has no end date, assign it the same as the dag
-        if not task.end_date:
-            task.end_date = self.end_date
-        # otherwise, the task will end on the earlier of its own end date and
-        # the DAG's end date
-        elif task.end_date and self.end_date:
-            task.end_date = min(task.end_date, self.end_date)
-
-        task_id = task.task_id
-        if not task.task_group:
-            task_group = TaskGroupContext.get_current_task_group(self)
-            if task_group:
-                task_id = task_group.child_id(task_id)
-                task_group.add(task)
-
-        if (
-            task_id in self.task_dict and self.task_dict[task_id] is not task
-        ) or task_id in self._task_group.used_group_ids:
-            raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been 
added to the DAG")
-        else:
-            self.task_dict[task_id] = task
-            task.dag = self
-            # Add task_id to used_group_ids to prevent group_id and task_id 
collisions.
-            self._task_group.used_group_ids.add(task_id)
-
-        self.task_count = len(self.task_dict)
-
-    def add_tasks(self, tasks: Iterable[Operator]) -> None:
-        """
-        Add a list of tasks to the DAG.
-
-        :param tasks: a lit of tasks you want to add
-        """
-        for task in tasks:
-            self.add_task(task)
-
-    def _remove_task(self, task_id: str) -> None:
-        # This is "private" as removing could leave a hole in dependencies if 
done incorrectly, and this
-        # doesn't guard against that
-        task = self.task_dict.pop(task_id)
-        tg = getattr(task, "task_group", None)
-        if tg:
-            tg._remove(task)
-
-        self.task_count = len(self.task_dict)
-
     def cli(self):
         """Exposes a CLI specific to this DAG."""
         check_cycle(self)
@@ -2692,88 +1998,6 @@ class DAG(LoggingMixin):
                 qry = qry.where(TaskInstance.state.in_(states))
         return session.scalar(qry)
 
-    @classmethod
-    def get_serialized_fields(cls):
-        """Stringified DAGs and operators contain exactly these fields."""
-        if not cls.__serialized_fields:
-            exclusion_list = {
-                "schedule_dataset_references",
-                "schedule_dataset_alias_references",
-                "task_outlet_dataset_references",
-                "_old_context_manager_dags",
-                "safe_dag_id",
-                "last_loaded",
-                "user_defined_filters",
-                "user_defined_macros",
-                "partial",
-                "params",
-                "_pickle_id",
-                "_log",
-                "task_dict",
-                "template_searchpath",
-                "sla_miss_callback",
-                "on_success_callback",
-                "on_failure_callback",
-                "template_undefined",
-                "jinja_environment_kwargs",
-                # has_on_*_callback are only stored if the value is True, as 
the default is False
-                "has_on_success_callback",
-                "has_on_failure_callback",
-                "auto_register",
-                "fail_stop",
-            }
-            cls.__serialized_fields = frozenset(vars(DAG(dag_id="test", 
schedule=None))) - exclusion_list
-        return cls.__serialized_fields
-
-    def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> 
EdgeInfoType:
-        """Return edge information for the given pair of tasks or an empty 
edge if there is no information."""
-        # Note - older serialized DAGs may not have edge_info being a dict at 
all
-        empty = cast(EdgeInfoType, {})
-        if self.edge_info:
-            return self.edge_info.get(upstream_task_id, 
{}).get(downstream_task_id, empty)
-        else:
-            return empty
-
-    def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, 
info: EdgeInfoType):
-        """
-        Set the given edge information on the DAG.
-
-        Note that this will overwrite, rather than merge with, existing info.
-        """
-        self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = 
info
-
-    def validate_schedule_and_params(self):
-        """
-        Validate Param values when the DAG has schedule defined.
-
-        Raise exception if there are any Params which can not be resolved by 
their schema definition.
-        """
-        if not self.timetable.can_be_scheduled:
-            return
-
-        try:
-            self.params.validate()
-        except ParamValidationError as pverr:
-            raise AirflowException(
-                "DAG is not allowed to define a Schedule, "
-                "if there are any required params without default values or 
default values are not valid."
-            ) from pverr
-
-    def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]:
-        """
-        Parse a given link, and verifies if it's a valid URL, or a 'mailto' 
link.
-
-        Returns an iterator of invalid (owner, link) pairs.
-        """
-        for owner, link in self.owner_links.items():
-            result = urlsplit(link)
-            if result.scheme == "mailto":
-                # netloc is not existing for 'mailto' link, so we are checking 
that the path is parsed
-                if not result.path:
-                    yield result.path, link
-            elif not result.scheme or not result.netloc:
-                yield owner, link
-
 
 class DagTag(Base):
     """A tag name per dag, to allow quick filtering in the DAG view."""
@@ -3342,25 +2566,9 @@ if STATICA_HACK:  # pragma: no cover
     """:sphinx-autoapi-skip:"""
 
 
-class DagContext:
+class DagContext(airflow.sdk.definitions.contextmanager.DagContext):
     """
-    DAG context is used to keep the current DAG when DAG is used as 
ContextManager.
-
-    You can use DAG as context:
-
-    .. code-block:: python
-
-        with DAG(
-            dag_id="example_dag",
-            default_args=default_args,
-            schedule="0 0 * * *",
-            dagrun_timeout=timedelta(minutes=60),
-        ) as dag:
-            ...
-
-    If you do this the context stores the DAG and whenever new task is 
created, it will use
-    such stored DAG as the parent DAG.
-
+    :meta private:
     """
 
     _context_managed_dags: deque[DAG] = deque()
@@ -3369,25 +2577,15 @@ class DagContext:
 
     @classmethod
     def push_context_managed_dag(cls, dag: DAG):
-        cls._context_managed_dags.appendleft(dag)
+        cls.push(dag)
 
     @classmethod
     def pop_context_managed_dag(cls) -> DAG | None:
-        dag = cls._context_managed_dags.popleft()
-
-        # In a few cases around serialization we explicitly push None in to 
the stack
-        if cls.current_autoregister_module_name is not None and dag and 
dag.auto_register:
-            mod = sys.modules[cls.current_autoregister_module_name]
-            cls.autoregistered_dags.add((dag, mod))
-
-        return dag
+        return cls.pop()
 
     @classmethod
     def get_current_dag(cls) -> DAG | None:
-        try:
-            return cls._context_managed_dags[0]
-        except IndexError:
-            return None
+        return cls.get_current()
 
 
 def _run_inline_trigger(trigger):
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 69a5d015bd4..79b6329e44b 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -28,6 +28,7 @@ from typing import TYPE_CHECKING, Any, Generator, Iterator, 
Sequence
 import methodtools
 import re2
 
+import airflow.sdk.definitions.contextmanager
 from airflow.exceptions import (
     AirflowDagCycleException,
     AirflowException,
@@ -650,44 +651,23 @@ class MappedTaskGroup(TaskGroup):
         super().__exit__(exc_type, exc_val, exc_tb)
 
 
-class TaskGroupContext:
+class 
TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext):
     """TaskGroup context is used to keep the current TaskGroup when TaskGroup 
is used as ContextManager."""
 
-    active: bool = False
-    _context_managed_task_group: TaskGroup | None = None
-    _previous_context_managed_task_groups: list[TaskGroup] = []
-
     @classmethod
     def push_context_managed_task_group(cls, task_group: TaskGroup):
         """Push a TaskGroup into the list of managed TaskGroups."""
-        if cls._context_managed_task_group:
-            
cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
-        cls._context_managed_task_group = task_group
-        cls.active = True
+        return cls.pusg(task_group)
 
     @classmethod
     def pop_context_managed_task_group(cls) -> TaskGroup | None:
         """Pops the last TaskGroup from the list of managed TaskGroups and 
update the current TaskGroup."""
-        old_task_group = cls._context_managed_task_group
-        if cls._previous_context_managed_task_groups:
-            cls._context_managed_task_group = 
cls._previous_context_managed_task_groups.pop()
-        else:
-            cls._context_managed_task_group = None
-        cls.active = False
-        return old_task_group
+        return cls.pop()
 
     @classmethod
     def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
         """Get the current TaskGroup."""
-        from airflow.models.dag import DagContext
-
-        if not cls._context_managed_task_group:
-            dag = dag or DagContext.get_current_dag()
-            if dag:
-                # If there's currently a DAG but no TaskGroup, return the root 
TaskGroup of the dag.
-                return dag.task_group
-
-        return cls._context_managed_task_group
+        return cls.get_current()
 
 
 def task_group_to_dict(task_item_or_group):
diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index b0a912d9f04..f83a4b7ec25 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -25,7 +25,6 @@ dependencies = [
     "attrs>=24.2.0",
     "google-re2>=1.1.20240702",
     "methodtools>=0.4.7",
-    "structlog>=24.4.0",
 ]
 
 [build-system]
diff --git a/task_sdk/src/airflow/sdk/__init__.py 
b/task_sdk/src/airflow/sdk/__init__.py
index 2a3e01b64bc..baf7c85baa9 100644
--- a/task_sdk/src/airflow/sdk/__init__.py
+++ b/task_sdk/src/airflow/sdk/__init__.py
@@ -16,6 +16,30 @@
 # under the License.
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
 
-def hello() -> str:
-    return "Hello from task-sdk!"
+__all__ = ["DAG", "BaseOperator", "TaskGroup"]
+
+if TYPE_CHECKING:
+    from airflow.sdk.definitions.baseoperator import BaseOperator as 
BaseOperator
+    from airflow.sdk.definitions.dag import DAG as DAG
+    from airflow.sdk.definitions.taskgroup import TaskGroup as TaskGroup
+
+__lazy_imports: dict[str, str] = {
+    "DAG": ".definitions.dag",
+    "BaseOperator": ".definitions.baseoperator",
+    "TaskGroup": ".definitions.taskgroup",
+}
+
+
+def __getattr__(name: str):
+    if module_path := __lazy_imports.get(name):
+        import importlib
+
+        mod = importlib.import_module(module_path, __name__)
+        val = getattr(mod, name)
+
+        # Store for next time
+        globals()[name] = val
+        return val
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py 
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index c466f158cc9..d52802db21f 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
     class ParamsDict: ...
 
     from airflow.sdk.definitions.dag import DAG
-    from airflow.sdk.definitions.task_group import TaskGroup
+    from airflow.sdk.definitions.taskgroup import TaskGroup
 
 
 # TODO: Task-SDK
diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py 
b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
index fb160e06b35..d97339fc265 100644
--- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py
+++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
@@ -22,7 +22,7 @@ from types import ModuleType
 from typing import Generic, TypeVar
 
 from airflow.sdk.definitions.dag import DAG
-from airflow.sdk.definitions.task_group import TaskGroup
+from airflow.sdk.definitions.taskgroup import TaskGroup
 
 T = TypeVar("T")
 
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 8a2799d0f51..f80d8f7f71b 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -58,9 +58,6 @@ from airflow.exceptions import (
     TaskNotFound,
 )
 from airflow.models.param import DagParam
-from airflow.models.taskinstance import (
-    Context,
-)
 from airflow.sdk.definitions.abstractoperator import AbstractOperator
 from airflow.sdk.definitions.baseoperator import BaseOperator
 from airflow.stats import Stats
@@ -73,7 +70,7 @@ from airflow.utils.types import NOTSET, EdgeInfoType
 if TYPE_CHECKING:
     from airflow.decorators import TaskDecoratorCollection
     from airflow.models.operator import Operator
-    from airflow.utils.task_group import TaskGroup
+    from airflow.utils.taskgroup import TaskGroup
 
 log = logging.getLogger(__name__)
 
@@ -82,6 +79,11 @@ ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"]
 
 TAG_MAX_LEN = 100
 
+
+# TODO: Task-SDK
+class Context: ...
+
+
 DagStateChangeCallback = Callable[[Context], None]
 ScheduleInterval = Union[None, str, timedelta, relativedelta]
 
@@ -93,13 +95,6 @@ ScheduleArg = Union[
 ]
 
 
-# Defined as a function so that we can avoid import cycles
-def _create_root_group(dag: DAG) -> TaskGroup:
-    from airflow.sdk.definitions.task_group import TaskGroup
-
-    return TaskGroup(group_id=None, dag=dag)
-
-
 _DAG_HASH_ATTRS = frozenset(
     {
         "dag_id",
@@ -299,19 +294,27 @@ class DAG:
     owner_links: dict[str, str] | None = None
     auto_register: bool = True
     fail_stop: bool = False
-    dag_display_name: str | None = None
+    dag_display_name: str = attrs.field()
 
     task_dict: dict[str, DAGNode] = attrs.field(factory=dict, init=False)
 
-    task_group: TaskGroup = attrs.field(
-        default=attrs.Factory(_create_root_group, takes_self=True), 
on_setattr=attrs.setters.NO_OP
-    )
+    task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.NO_OP)
+
+    @dag_display_name.default
+    def _default_dag_display_name(self) -> str:
+        return self.dag_id
+
+    @task_group.default
+    def _default_task_group(self) -> TaskGroup:
+        from airflow.sdk.definitions.taskgroup import TaskGroup
+
+        return TaskGroup.create_root(dag=self)
 
     def __repr__(self):
         return f"<DAG: {self.dag_id}>"
 
     def __eq__(self, other: Self | Any):
-        if type(self) != type(other):
+        if not isinstance(other, type(self)):
             return NotImplemented
         return all(getattr(self, c, None) == getattr(other, c, None) for c in 
_DAG_HASH_ATTRS)
 
@@ -337,7 +340,7 @@ class DAG:
         return hash(tuple(hash_components))
 
     def __enter__(self) -> Self:
-        from .contextmanager import DagContext
+        from airflow.sdk.definitions.contextmanager import DagContext
 
         DagContext.push(self)
         return self
@@ -521,7 +524,7 @@ class DAG:
 
         Deprecated in place of ``task_group.topological_sort``
         """
-        from airflow.utils.task_group import TaskGroup
+        from airflow.utils.taskgroup import TaskGroup
 
         def nested_topo(group):
             for node in group.topological_sort():
diff --git a/task_sdk/src/airflow/sdk/definitions/node.py 
b/task_sdk/src/airflow/sdk/definitions/node.py
index 0abb0a83c16..1d980284677 100644
--- a/task_sdk/src/airflow/sdk/definitions/node.py
+++ b/task_sdk/src/airflow/sdk/definitions/node.py
@@ -17,6 +17,8 @@
 
 from __future__ import annotations
 
+import logging
+import re
 from abc import ABCMeta, abstractmethod
 from collections.abc import Iterable, Sequence
 from datetime import datetime
@@ -25,20 +27,19 @@ from typing import TYPE_CHECKING
 import attrs
 import methodtools
 import re2
-import structlog
 
 from airflow.sdk.definitions.mixins import DependencyMixin
 
 if TYPE_CHECKING:
     from airflow.sdk.definitions.dag import DAG
     from airflow.sdk.definitions.edges import EdgeModifier
-    from airflow.sdk.definitions.task_group import TaskGroup
+    from airflow.sdk.definitions.taskgroup import TaskGroup
     from airflow.sdk.types import Logger
 
 
 KEY_REGEX = re2.compile(r"^[\w.-]+$")
 GROUP_KEY_REGEX = re2.compile(r"^[\w-]+$")
-CAMELCASE_TO_SNAKE_CASE_REGEX = re2.compile(r"(?!^)([A-Z]+)")
+CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)")
 
 
 def validate_key(k: str, max_length: int = 250):
@@ -98,7 +99,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
     def log(self) -> Logger:
         typ = type(self)
         name = f"{typ.__module__}.{typ.__qualname__}"
-        return structlog.get_logger(logger_name=name)
+        return logging.getLogger(name)
 
     def _set_relatives(
         self,
diff --git a/task_sdk/src/airflow/sdk/definitions/task_group.py 
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
similarity index 100%
rename from task_sdk/src/airflow/sdk/definitions/task_group.py
rename to task_sdk/src/airflow/sdk/definitions/taskgroup.py
diff --git a/task_sdk/src/airflow/sdk/types.py 
b/task_sdk/src/airflow/sdk/types.py
index 172ff733a14..505ee4cb191 100644
--- a/task_sdk/src/airflow/sdk/types.py
+++ b/task_sdk/src/airflow/sdk/types.py
@@ -41,9 +41,9 @@ NOTSET = ArgNotSet()
 
 
 if TYPE_CHECKING:
-    import structlog
+    import logging
 
-    Logger = structlog.typing.FilteringBoundLogger
+    Logger = logging.Logger
 else:
 
     class Logger: ...
diff --git a/task_sdk/tests/test_hello.py b/task_sdk/tests/test_hello.py
deleted file mode 100644
index 62cfdc069ca..00000000000
--- a/task_sdk/tests/test_hello.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from airflow.sdk import hello
-
-
-def test_hello():
-    assert hello() == "Hello from task-sdk!"
diff --git a/uv.lock b/uv.lock
index d3508fac220..3a3c2db04b7 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1369,7 +1369,6 @@ dependencies = [
     { name = "attrs" },
     { name = "google-re2" },
     { name = "methodtools" },
-    { name = "structlog" },
 ]
 
 [package.dev-dependencies]
@@ -1384,7 +1383,6 @@ requires-dist = [
     { name = "attrs", specifier = ">=24.2.0" },
     { name = "google-re2", specifier = ">=1.1.20240702" },
     { name = "methodtools", specifier = ">=0.4.7" },
-    { name = "structlog", specifier = ">=24.4.0" },
 ]
 
 [package.metadata.requires-dev]
@@ -3688,15 +3686,6 @@ wheels = [
     { url = 
"https://files.pythonhosted.org/packages/b7/9c/93f7bc03ff03199074e81974cc148908ead60dcf189f68ba1761a0ee35cf/starlette-0.38.6-py3-none-any.whl";,
 hash = 
"sha256:4517a1409e2e73ee4951214ba012052b9e16f60e90d73cfb06192c19203bbb05", size 
= 71451 },
 ]
 
-[[package]]
-name = "structlog"
-version = "24.4.0"
-source = { registry = "https://pypi.org/simple"; }
-sdist = { url = 
"https://files.pythonhosted.org/packages/78/a3/e811a94ac3853826805253c906faa99219b79951c7d58605e89c79e65768/structlog-24.4.0.tar.gz";,
 hash = 
"sha256:b27bfecede327a6d2da5fbc96bd859f114ecc398a6389d664f62085ee7ae6fc4", size 
= 1348634 }
-wheels = [
-    { url = 
"https://files.pythonhosted.org/packages/bf/65/813fc133609ebcb1299be6a42e5aea99d6344afb35ccb43f67e7daaa3b92/structlog-24.4.0-py3-none-any.whl";,
 hash = 
"sha256:597f61e80a91cc0749a9fd2a098ed76715a1c8a01f73e336b746504d1aad7610", size 
= 67180 },
-]
-
 [[package]]
 name = "tabulate"
 version = "0.9.0"

Reply via email to