This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/task-sdk-first-code by this 
push:
     new 50081fada27 fixup! Move over more of BaseOperator and DAG, along with 
their tests
50081fada27 is described below

commit 50081fada2795dce76aa0a6fb89ad62b97c9dd07
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Oct 16 12:56:20 2024 +0100

    fixup! Move over more of BaseOperator and DAG, along with their tests
---
 airflow/dag_processing/collection.py               |   5 +-
 airflow/decorators/base.py                         |   3 +-
 airflow/models/abstractoperator.py                 |  67 --
 airflow/models/baseoperator.py                     | 760 ++-------------------
 airflow/models/dag.py                              |  24 +-
 airflow/utils/edgemodifier.py                      |  10 +-
 airflow/utils/task_group.py                        |  11 +-
 pyproject.toml                                     |   1 +
 task_sdk/src/airflow/sdk/__init__.py               |   2 +
 .../airflow/sdk/definitions/abstractoperator.py    |  84 ++-
 .../src/airflow/sdk/definitions/baseoperator.py    |  61 +-
 .../src/airflow/sdk/definitions/contextmanager.py  |  10 +-
 task_sdk/src/airflow/sdk/definitions/dag.py        |  50 +-
 task_sdk/src/airflow/sdk/definitions/node.py       |  25 +-
 task_sdk/src/airflow/sdk/definitions/taskgroup.py  |  22 +-
 task_sdk/tests/defintions/test_baseoperator.py     |  36 +-
 tests/models/test_baseoperator.py                  |  41 +-
 17 files changed, 293 insertions(+), 919 deletions(-)

diff --git a/airflow/dag_processing/collection.py 
b/airflow/dag_processing/collection.py
index c8ce5dc873a..163ddf01c02 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -198,7 +198,10 @@ class DagModelOperation(NamedTuple):
             dm.has_import_errors = False
             dm.last_parsed_time = utcnow()
             dm.default_view = dag.default_view
-            dm._dag_display_property_value = dag._dag_display_property_value
+            if hasattr(dag, "_dag_display_property_value"):
+                dm._dag_display_property_value = 
dag._dag_display_property_value
+            else:
+                dm._dag_display_property_value = dag.dag_display_name
             dm.description = dag.description
             dm.max_active_tasks = dag.max_active_tasks
             dm.max_active_runs = dag.max_active_runs
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index bb9602d50c1..e7a21e2919b 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -59,6 +59,7 @@ from airflow.models.expandinput import (
 from airflow.models.mappedoperator import MappedOperator, 
ensure_xcomarg_return_value
 from airflow.models.pool import Pool
 from airflow.models.xcom_arg import XComArg
+from airflow.sdk.definitions.baseoperator import BaseOperator as 
TaskSDKBaseOperator
 from airflow.typing_compat import ParamSpec, Protocol
 from airflow.utils import timezone
 from airflow.utils.context import KNOWN_CONTEXT_KEYS
@@ -442,7 +443,7 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
             "is_teardown": self.is_teardown,
             "on_failure_fail_dagrun": self.on_failure_fail_dagrun,
         }
-        base_signature = inspect.signature(BaseOperator)
+        base_signature = inspect.signature(TaskSDKBaseOperator)
         ignore = {
             "default_args",  # This is target we are working on now.
             "kwargs",  # A common name for a keyword argument.
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 5e5d13d5dc2..0c962307519 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -33,7 +33,6 @@ from airflow.models.taskmixin import DAGNode, DependencyMixin
 from airflow.template.templater import Templater
 from airflow.utils.context import Context
 from airflow.utils.db import exists_query
-from airflow.utils.log.secrets_masker import redact
 from airflow.utils.setup_teardown import SetupTeardownContext
 from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.state import State, TaskInstanceState
@@ -719,72 +718,6 @@ class AbstractOperator(Templater, DAGNode):
         """
         raise NotImplementedError()
 
-    def _render(self, template, context, dag: DAG | None = None):
-        if dag is None:
-            dag = self.get_dag()
-        return super()._render(template, context, dag=dag)
-
-    def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
-        """Get the template environment for rendering templates."""
-        if dag is None:
-            dag = self.get_dag()
-        return super().get_template_env(dag=dag)
-
-    def _do_render_template_fields(
-        self,
-        parent: Any,
-        template_fields: Iterable[str],
-        context: Context,
-        jinja_env: jinja2.Environment,
-        seen_oids: set[int],
-    ) -> None:
-        """Override the base to use custom error logging."""
-        for attr_name in template_fields:
-            try:
-                value = getattr(parent, attr_name)
-            except AttributeError:
-                raise AttributeError(
-                    f"{attr_name!r} is configured as a template field "
-                    f"but {parent.task_type} does not have this attribute."
-                )
-            try:
-                if not value:
-                    continue
-            except Exception:
-                # This may happen if the templated field points to a class 
which does not support `__bool__`,
-                # such as Pandas DataFrames:
-                # 
https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
-                self.log.info(
-                    "Unable to check if the value of type '%s' is False for 
task '%s', field '%s'.",
-                    type(value).__name__,
-                    self.task_id,
-                    attr_name,
-                )
-                # We may still want to render custom classes which do not 
support __bool__
-                pass
-
-            try:
-                if callable(value):
-                    rendered_content = value(context=context, 
jinja_env=jinja_env)
-                else:
-                    rendered_content = self.render_template(
-                        value,
-                        context,
-                        jinja_env,
-                        seen_oids,
-                    )
-            except Exception:
-                value_masked = redact(name=attr_name, value=value)
-                self.log.exception(
-                    "Exception rendering Jinja template for task '%s', field 
'%s'. Template: %r",
-                    self.task_id,
-                    attr_name,
-                    value_masked,
-                )
-                raise
-            else:
-                setattr(parent, attr_name, rendered_content)
-
     def __enter__(self):
         if not self.is_setup and not self.is_teardown:
             raise AirflowException("Only setup/teardown tasks can be used as 
context managers.")
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 39d56187371..2a9fed1503a 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -23,17 +23,14 @@ Base operator for all operators.
 
 from __future__ import annotations
 
-import abc
 import collections.abc
 import contextlib
 import copy
 import functools
-import inspect
 import logging
 import sys
-import warnings
 from datetime import datetime, timedelta
-from functools import total_ordering, wraps
+from functools import wraps
 from threading import local
 from types import FunctionType
 from typing import (
@@ -46,10 +43,8 @@ from typing import (
     Sequence,
     TypeVar,
     Union,
-    cast,
 )
 
-import attr
 import pendulum
 from sqlalchemy import select
 from sqlalchemy.orm.exc import NoResultFound
@@ -57,7 +52,6 @@ from sqlalchemy.orm.exc import NoResultFound
 from airflow.configuration import conf
 from airflow.exceptions import (
     AirflowException,
-    FailStopDagInvalidTriggerRule,
     TaskDeferralError,
     TaskDeferred,
 )
@@ -83,8 +77,15 @@ from airflow.models.param import ParamsDict
 from airflow.models.pool import Pool
 from airflow.models.taskinstance import TaskInstance, clear_task_instances
 from airflow.models.taskmixin import DependencyMixin
+from airflow.sdk.definitions.baseoperator import BaseOperator as 
TaskSDKBaseOperator
+
+# Keeping this file at all is a temp thing as we migrate the repo to the task 
sdk as the base, but to keep
+# main working and useful for others to develop against we use the TaskSDK 
here but keep this file around
+from airflow.sdk.definitions.dag import DAG
+from airflow.sdk.definitions.edges import EdgeModifier as TaskSDKEdgeModifier
+from airflow.sdk.definitions.mixins import DependencyMixin as 
TaskSDKDependencyMixin
 from airflow.serialization.enums import DagAttributeTypes
-from airflow.task.priority_strategy import PriorityWeightStrategy, 
validate_and_load_priority_weight_strategy
+from airflow.task.priority_strategy import PriorityWeightStrategy
 from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep
 from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
 from airflow.ti_deps.deps.not_previously_skipped_dep import 
NotPreviouslySkippedDep
@@ -92,15 +93,11 @@ from airflow.ti_deps.deps.prev_dagrun_dep import 
PrevDagrunDep
 from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
 from airflow.utils import timezone
 from airflow.utils.context import Context, context_get_outlet_events
-from airflow.utils.decorators import fixup_decorator_warning_stack
 from airflow.utils.edgemodifier import EdgeModifier
-from airflow.utils.helpers import validate_instance_args, validate_key
 from airflow.utils.operator_helpers import ExecutionCallableRunner
 from airflow.utils.operator_resources import Resources
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.setup_teardown import SetupTeardownContext
-from airflow.utils.trigger_rule import TriggerRule
-from airflow.utils.types import NOTSET, AttributeRemoved, DagRunTriggeredByType
+from airflow.utils.types import NOTSET, DagRunTriggeredByType
 from airflow.utils.xcom import XCOM_RETURN_KEY
 
 if TYPE_CHECKING:
@@ -113,7 +110,6 @@ if TYPE_CHECKING:
     from airflow.models.baseoperatorlink import BaseOperatorLink
     from airflow.models.dag import DAG
     from airflow.models.operator import Operator
-    from airflow.models.xcom_arg import XComArg
     from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
     from airflow.triggers.base import BaseTrigger, StartTriggerArgs
     from airflow.utils.task_group import TaskGroup
@@ -419,165 +415,16 @@ class ExecutorSafeguard:
         return wrapper
 
 
-class BaseOperatorMeta(abc.ABCMeta):
-    """Metaclass of BaseOperator."""
-
-    @classmethod
-    def _apply_defaults(cls, func: T) -> T:
-        """
-        Look for an argument named "default_args", and fill the unspecified 
arguments from it.
-
-        Since python2.* isn't clear about which arguments are missing when
-        calling a function, and that this can be quite confusing with 
multi-level
-        inheritance and argument defaults, this decorator also alerts with
-        specific information about the missing arguments.
-        """
-        # Cache inspect.signature for the wrapper closure to avoid calling it
-        # at every decorated invocation. This is separate sig_cache created
-        # per decoration, i.e. each function decorated using apply_defaults 
will
-        # have a different sig_cache.
-        sig_cache = inspect.signature(func)
-        non_variadic_params = {
-            name: param
-            for (name, param) in sig_cache.parameters.items()
-            if param.name != "self" and param.kind not in 
(param.VAR_POSITIONAL, param.VAR_KEYWORD)
-        }
-        non_optional_args = {
-            name
-            for name, param in non_variadic_params.items()
-            if param.default == param.empty and name != "task_id"
-        }
-
-        fixup_decorator_warning_stack(func)
-
-        @wraps(func)
-        def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> 
Any:
-            from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
-
-            if args:
-                raise AirflowException("Use keyword arguments when 
initializing operators")
-
-            instantiated_from_mapped = kwargs.pop(
-                "_airflow_from_mapped",
-                getattr(self, "_BaseOperator__from_mapped", False),
-            )
-
-            dag: DAG | None = kwargs.get("dag") or DagContext.get_current()
-            task_group: TaskGroup | None = kwargs.get("task_group")
-            if dag and not task_group:
-                task_group = TaskGroupContext.get_current(dag)
-
-            default_args, merged_params = get_merged_defaults(
-                dag=dag,
-                task_group=task_group,
-                task_params=kwargs.pop("params", None),
-                task_default_args=kwargs.pop("default_args", None),
-            )
-
-            for arg in sig_cache.parameters:
-                if arg not in kwargs and arg in default_args:
-                    kwargs[arg] = default_args[arg]
-
-            missing_args = non_optional_args.difference(kwargs)
-            if len(missing_args) == 1:
-                raise AirflowException(f"missing keyword argument 
{missing_args.pop()!r}")
-            elif missing_args:
-                display = ", ".join(repr(a) for a in sorted(missing_args))
-                raise AirflowException(f"missing keyword arguments {display}")
-
-            if merged_params:
-                kwargs["params"] = merged_params
-
-            hook = getattr(self, "_hook_apply_defaults", None)
-            if hook:
-                args, kwargs = hook(**kwargs, default_args=default_args)
-                default_args = kwargs.pop("default_args", {})
-
-            if not hasattr(self, "_BaseOperator__init_kwargs"):
-                self._BaseOperator__init_kwargs = {}
-            self._BaseOperator__from_mapped = instantiated_from_mapped
-
-            result = func(self, **kwargs, default_args=default_args)
-
-            # Store the args passed to init -- we need them to support 
task.map serialization!
-            self._BaseOperator__init_kwargs.update(kwargs)  # type: ignore
-
-            # Set upstream task defined by XComArgs passed to template fields 
of the operator.
-            # BUT: only do this _ONCE_, not once for each class in the 
hierarchy
-            if not instantiated_from_mapped and func == 
self.__init__.__wrapped__:  # type: ignore[misc]
-                self.set_xcomargs_dependencies()
-                # Mark instance as instantiated.
-                self._BaseOperator__instantiated = True
-
-            return result
-
-        apply_defaults.__non_optional_args = non_optional_args  # type: ignore
-        apply_defaults.__param_names = set(non_variadic_params)  # type: ignore
-
-        return cast(T, apply_defaults)
-
+# TODO: Task-SDK - temporarily extend the metaclass to add in the 
ExecutorSafeguard.
+class BaseOperatorMeta(type(TaskSDKBaseOperator)):
     def __new__(cls, name, bases, namespace, **kwargs):
         execute_method = namespace.get("execute")
         if callable(execute_method) and not getattr(execute_method, 
"__isabstractmethod__", False):
             namespace["execute"] = 
ExecutorSafeguard().decorator(execute_method)
-        new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
-        with contextlib.suppress(KeyError):
-            # Update the partial descriptor with the class method, so it calls 
the actual function
-            # (but let subclasses override it if they need to)
-            partial_desc = vars(new_cls)["partial"]
-            if isinstance(partial_desc, _PartialDescriptor):
-                partial_desc.class_method = classmethod(partial)
-
-        # We patch `__init__` only if the class defines it.
-        if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__:
-            new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
-
-        return new_cls
-
-
-# TODO: The following mapping is used to validate that the arguments passed to 
the BaseOperator are of the
-#  correct type. This is a temporary solution until we find a more 
sophisticated method for argument
-#  validation. One potential method is to use `get_type_hints` from the typing 
module. However, this is not
-#  fully compatible with future annotations for Python versions below 3.10. 
Once we require a minimum Python
-#  version that supports `get_type_hints` effectively or find a better 
approach, we can replace this
-#  manual type-checking method.
-BASEOPERATOR_ARGS_EXPECTED_TYPES = {
-    "task_id": str,
-    "email": (str, Iterable),
-    "email_on_retry": bool,
-    "email_on_failure": bool,
-    "retries": int,
-    "retry_exponential_backoff": bool,
-    "depends_on_past": bool,
-    "ignore_first_depends_on_past": bool,
-    "wait_for_past_depends_before_skipping": bool,
-    "wait_for_downstream": bool,
-    "priority_weight": int,
-    "queue": str,
-    "pool": str,
-    "pool_slots": int,
-    "trigger_rule": str,
-    "run_as_user": str,
-    "task_concurrency": int,
-    "map_index_template": str,
-    "max_active_tis_per_dag": int,
-    "max_active_tis_per_dagrun": int,
-    "executor": str,
-    "do_xcom_push": bool,
-    "multiple_outputs": bool,
-    "doc": str,
-    "doc_md": str,
-    "doc_json": str,
-    "doc_yaml": str,
-    "doc_rst": str,
-    "task_display_name": str,
-    "logger_name": str,
-    "allow_nested_operators": bool,
-}
+        return super().__new__(cls, name, bases, namespace, **kwargs)
 
 
-@total_ordering
-class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
+class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperatorMeta):
     r"""
     Abstract base class for all operators.
 
@@ -782,17 +629,24 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                 hello_world_task.execute(context)
     """
 
-    # Implementing Operator.
-    template_fields: Sequence[str] = ()
-    template_ext: Sequence[str] = ()
+    start_trigger_args: StartTriggerArgs | None = None
+    start_from_trigger: bool = False
 
-    template_fields_renderers: dict[str, str] = {}
+    on_execute_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
+    on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
+    on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
+    on_retry_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
+    on_skipped_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
 
-    # Defines the color in the UI
-    ui_color: str = "#fff"
-    ui_fgcolor: str = "#000"
+    def __init__(self, pre_execute=None, post_execute=None, **kwargs):
+        if start_date := kwargs.get("start_date", None):
+            kwargs["start_date"] = timezone.convert_to_utc(start_date)
 
-    pool: str = ""
+        if end_date := kwargs.get("end_date", None):
+            kwargs["end_date"] = timezone.convert_to_utc(end_date)
+        super().__init__(**kwargs)
+        self._pre_execute_hook = pre_execute
+        self._post_execute_hook = post_execute
 
     # base list which includes all the attrs that don't need deep copy.
     _base_operator_shallow_copy_attrs: tuple[str, ...] = (
@@ -807,368 +661,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     # Defines the operator level extra links
     operator_extra_links: Collection[BaseOperatorLink] = ()
 
-    # The _serialized_fields are lazily loaded when get_serialized_fields() 
method is called
-    __serialized_fields: frozenset[str] | None = None
-
     partial: Callable[..., OperatorPartial] = _PartialDescriptor()  # type: 
ignore
 
-    _comps = {
-        "task_id",
-        "dag_id",
-        "owner",
-        "email",
-        "email_on_retry",
-        "retry_delay",
-        "retry_exponential_backoff",
-        "max_retry_delay",
-        "start_date",
-        "end_date",
-        "depends_on_past",
-        "wait_for_downstream",
-        "priority_weight",
-        "sla",
-        "execution_timeout",
-        "on_execute_callback",
-        "on_failure_callback",
-        "on_success_callback",
-        "on_retry_callback",
-        "on_skipped_callback",
-        "do_xcom_push",
-        "multiple_outputs",
-        "allow_nested_operators",
-        "executor",
-    }
-
-    # Defines if the operator supports lineage without manual definitions
-    supports_lineage = False
-
-    # If True then the class constructor was called
-    __instantiated = False
-    # List of args as passed to `init()`, after apply_defaults() has been 
updated. Used to "recreate" the task
-    # when mapping
-    __init_kwargs: dict[str, Any]
-
-    # Set to True before calling execute method
-    _lock_for_execution = False
-
-    _dag: DAG | None = None
-    task_group: TaskGroup | None = None
-
-    start_date: pendulum.DateTime | None = None
-    end_date: pendulum.DateTime | None = None
-
-    # Set to True for an operator instantiated by a mapped operator.
-    __from_mapped = False
-
-    start_trigger_args: StartTriggerArgs | None = None
-    start_from_trigger: bool = False
-
-    def __init__(
-        self,
-        task_id: str,
-        owner: str = DEFAULT_OWNER,
-        email: str | Iterable[str] | None = None,
-        email_on_retry: bool = conf.getboolean("email", 
"default_email_on_retry", fallback=True),
-        email_on_failure: bool = conf.getboolean("email", 
"default_email_on_failure", fallback=True),
-        retries: int | None = DEFAULT_RETRIES,
-        retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
-        retry_exponential_backoff: bool = False,
-        max_retry_delay: timedelta | float | None = None,
-        start_date: datetime | None = None,
-        end_date: datetime | None = None,
-        depends_on_past: bool = False,
-        ignore_first_depends_on_past: bool = 
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
-        wait_for_past_depends_before_skipping: bool = 
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
-        wait_for_downstream: bool = False,
-        dag: DAG | None = None,
-        params: collections.abc.MutableMapping | None = None,
-        default_args: dict | None = None,
-        priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
-        weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
-        queue: str = DEFAULT_QUEUE,
-        pool: str | None = None,
-        pool_slots: int = DEFAULT_POOL_SLOTS,
-        sla: timedelta | None = None,
-        execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
-        on_execute_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
-        on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
-        on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
-        on_retry_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
-        on_skipped_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
-        pre_execute: TaskPreExecuteHook | None = None,
-        post_execute: TaskPostExecuteHook | None = None,
-        trigger_rule: str = DEFAULT_TRIGGER_RULE,
-        resources: dict[str, Any] | None = None,
-        run_as_user: str | None = None,
-        map_index_template: str | None = None,
-        max_active_tis_per_dag: int | None = None,
-        max_active_tis_per_dagrun: int | None = None,
-        executor: str | None = None,
-        executor_config: dict | None = None,
-        do_xcom_push: bool = True,
-        multiple_outputs: bool = False,
-        inlets: Any | None = None,
-        outlets: Any | None = None,
-        task_group: TaskGroup | None = None,
-        doc: str | None = None,
-        doc_md: str | None = None,
-        doc_json: str | None = None,
-        doc_yaml: str | None = None,
-        doc_rst: str | None = None,
-        task_display_name: str | None = None,
-        logger_name: str | None = None,
-        allow_nested_operators: bool = True,
-        **kwargs,
-    ):
-        from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
-
-        self.__init_kwargs = {}
-
-        super().__init__()
-
-        kwargs.pop("_airflow_mapped_validation_only", None)
-        if kwargs:
-            raise AirflowException(
-                f"Invalid arguments were passed to {self.__class__.__name__} 
(task_id: {task_id}). "
-                f"Invalid arguments were:\n**kwargs: {kwargs}",
-            )
-        validate_key(task_id)
-
-        dag = dag or DagContext.get_current()
-        task_group = task_group or TaskGroupContext.get_current(dag)
-
-        self.task_id = task_group.child_id(task_id) if task_group else task_id
-        if not self.__from_mapped and task_group:
-            task_group.add(self)
-
-        self.owner = owner
-        self.email = email
-        self.email_on_retry = email_on_retry
-        self.email_on_failure = email_on_failure
-
-        if execution_timeout is not None and not isinstance(execution_timeout, 
timedelta):
-            raise ValueError(
-                f"execution_timeout must be timedelta object but passed as 
type: {type(execution_timeout)}"
-            )
-        self.execution_timeout = execution_timeout
-
-        self.on_execute_callback = on_execute_callback
-        self.on_failure_callback = on_failure_callback
-        self.on_success_callback = on_success_callback
-        self.on_retry_callback = on_retry_callback
-        self.on_skipped_callback = on_skipped_callback
-        self._pre_execute_hook = pre_execute
-        self._post_execute_hook = post_execute
-
-        if start_date and not isinstance(start_date, datetime):
-            self.log.warning("start_date for %s isn't datetime.datetime", self)
-        elif start_date:
-            self.start_date = timezone.convert_to_utc(start_date)
-
-        if end_date:
-            self.end_date = timezone.convert_to_utc(end_date)
-
-        self.executor = executor
-        self.executor_config = executor_config or {}
-        self.run_as_user = run_as_user
-        self.retries = parse_retries(retries)
-        self.queue = queue
-        self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool
-        self.pool_slots = pool_slots
-        if self.pool_slots < 1:
-            dag_str = f" in dag {dag.dag_id}" if dag else ""
-            raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot 
be less than 1")
-
-        if sla:
-            self.log.warning(
-                "The SLA feature is removed in Airflow 3.0, to be replaced 
with a new implementation in 3.1"
-            )
-
-        if not TriggerRule.is_valid(trigger_rule):
-            raise AirflowException(
-                f"The trigger_rule must be one of 
{TriggerRule.all_triggers()},"
-                f"'{dag.dag_id if dag else ''}.{task_id}'; received 
'{trigger_rule}'."
-            )
-
-        self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
-        FailStopDagInvalidTriggerRule.check(dag=dag, 
trigger_rule=self.trigger_rule)
-
-        self.depends_on_past: bool = depends_on_past
-        self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
-        self.wait_for_past_depends_before_skipping: bool = 
wait_for_past_depends_before_skipping
-        self.wait_for_downstream: bool = wait_for_downstream
-        if wait_for_downstream:
-            self.depends_on_past = True
-
-        self.retry_delay = coerce_timedelta(retry_delay, key="retry_delay")
-        self.retry_exponential_backoff = retry_exponential_backoff
-        self.max_retry_delay = (
-            max_retry_delay
-            if max_retry_delay is None
-            else coerce_timedelta(max_retry_delay, key="max_retry_delay")
-        )
-
-        # At execution_time this becomes a normal dict
-        self.params: ParamsDict | dict = ParamsDict(params)
-        if priority_weight is not None and not isinstance(priority_weight, 
int):
-            raise AirflowException(
-                f"`priority_weight` for task '{self.task_id}' only accepts 
integers, "
-                f"received '{type(priority_weight)}'."
-            )
-        self.priority_weight = priority_weight
-        self.weight_rule = 
validate_and_load_priority_weight_strategy(weight_rule)
-        self.resources = coerce_resources(resources)
-        self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
-        self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
-        self.do_xcom_push: bool = do_xcom_push
-        self.map_index_template: str | None = map_index_template
-        self.multiple_outputs: bool = multiple_outputs
-
-        self.doc_md = doc_md
-        self.doc_json = doc_json
-        self.doc_yaml = doc_yaml
-        self.doc_rst = doc_rst
-        self.doc = doc
-        # Populate the display field only if provided and different from task 
id
-        self._task_display_property_value = (
-            task_display_name if task_display_name and task_display_name != 
task_id else None
-        )
-
-        self.upstream_task_ids: set[str] = set()
-        self.downstream_task_ids: set[str] = set()
-
-        if dag:
-            self.dag = dag
-
-        self._log_config_logger_name = "airflow.task.operators"
-        self._logger_name = logger_name
-        self.allow_nested_operators: bool = allow_nested_operators
-
-        # Lineage
-        self.inlets: list = []
-        self.outlets: list = []
-
-        if inlets:
-            self.inlets = (
-                inlets
-                if isinstance(inlets, list)
-                else [
-                    inlets,
-                ]
-            )
-
-        if outlets:
-            self.outlets = (
-                outlets
-                if isinstance(outlets, list)
-                else [
-                    outlets,
-                ]
-            )
-
-        if isinstance(self.template_fields, str):
-            warnings.warn(
-                f"The `template_fields` value for {self.task_type} is a string 
"
-                "but should be a list or tuple of string. Wrapping it in a 
list for execution. "
-                f"Please update {self.task_type} accordingly.",
-                UserWarning,
-                stacklevel=2,
-            )
-            self.template_fields = [self.template_fields]
-
-        self._is_setup = False
-        self._is_teardown = False
-        if SetupTeardownContext.active:
-            SetupTeardownContext.update_context_map(self)
-
-        validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
-
-    def __eq__(self, other):
-        if type(self) is type(other):
-            # Use getattr() instead of __dict__ as __dict__ doesn't return
-            # correct values for properties.
-            return all(getattr(self, c, None) == getattr(other, c, None) for c 
in self._comps)
-        return False
-
-    def __ne__(self, other):
-        return not self == other
-
-    def __hash__(self):
-        hash_components = [type(self)]
-        for component in self._comps:
-            val = getattr(self, component, None)
-            try:
-                hash(val)
-                hash_components.append(val)
-            except TypeError:
-                hash_components.append(repr(val))
-        return hash(tuple(hash_components))
-
-    # including lineage information
-    def __or__(self, other):
-        """
-        Return [This Operator] | [Operator].
-
-        The inlets of other will be set to pick up the outlets from this 
operator.
-        Other will be set as a downstream task of this operator.
-        """
-        if isinstance(other, BaseOperator):
-            if not self.outlets and not self.supports_lineage:
-                raise ValueError("No outlets defined for this operator")
-            other.add_inlets([self.task_id])
-            self.set_downstream(other)
-        else:
-            raise TypeError(f"Right hand side ({other}) is not an Operator")
-
-        return self
-
-    # /Composing Operators ---------------------------------------------
-
-    def __gt__(self, other):
-        """
-        Return [Operator] > [Outlet].
-
-        If other is an attr annotated object it is set as an outlet of this 
Operator.
-        """
-        if not isinstance(other, Iterable):
-            other = [other]
-
-        for obj in other:
-            if not attr.has(obj):
-                raise TypeError(f"Left hand side ({obj}) is not an outlet")
-        self.add_outlets(other)
-
-        return self
-
-    def __lt__(self, other):
-        """
-        Return [Inlet] > [Operator] or [Operator] < [Inlet].
-
-        If other is an attr annotated object it is set as an inlet to this 
operator.
-        """
-        if not isinstance(other, Iterable):
-            other = [other]
-
-        for obj in other:
-            if not attr.has(obj):
-                raise TypeError(f"{obj} cannot be an inlet")
-        self.add_inlets(other)
-
-        return self
-
-    def __setattr__(self, key, value):
-        super().__setattr__(key, value)
-        if self.__from_mapped or self._lock_for_execution:
-            return  # Skip any custom behavior for validation and during 
execute.
-        if key in self.__init_kwargs:
-            self.__init_kwargs[key] = value
-        if self.__instantiated and key in self.template_fields:
-            # Resolve upstreams set by assigning an XComArg after initializing
-            # an operator, example:
-            #   op = BashOperator()
-            #   op.bash_command = "sleep 1"
-            self.set_xcomargs_dependencies()
-
     def add_inlets(self, inlets: Iterable[Any]):
         """Set inlets to this operator."""
         self.inlets.extend(inlets)
@@ -1193,55 +687,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         """
         return self.outlets
 
-    def get_dag(self) -> DAG | None:
-        return self._dag
-
-    @property  # type: ignore[override]
-    def dag(self) -> DAG:  # type: ignore[override]
-        """Returns the Operator's DAG if set, otherwise raises an error."""
-        if self._dag:
-            return self._dag
-        else:
-            raise AirflowException(f"Operator {self} has not been assigned to 
a DAG yet")
-
-    @dag.setter
-    def dag(self, dag: DAG | None):
-        """Operators can be assigned to one DAG, one time. Repeat assignments 
to that same DAG are ok."""
-        if dag is None:
-            self._dag = None
-            return
-
-        # if set to removed, then just set and exit
-        if self._dag.__class__ is AttributeRemoved:
-            self._dag = dag
-            return
-        # if setting to removed, then just set and exit
-        if dag.__class__ is AttributeRemoved:
-            self._dag = AttributeRemoved("_dag")  # type: ignore[assignment]
-            return
-
-        from airflow.models.dag import DAG
-
-        if not isinstance(dag, DAG):
-            raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
-        elif self.has_dag() and self.dag is not dag:
-            raise AirflowException(f"The DAG assigned to {self} can not be 
changed.")
-
-        if self.__from_mapped:
-            pass  # Don't add to DAG -- the mapped task takes the place.
-        elif dag.task_dict.get(self.task_id) is not self:
-            dag.add_task(self)
-
-        self._dag = dag
-
-    @property
-    def task_display_name(self) -> str:
-        return self._task_display_property_value or self.task_id
-
-    def has_dag(self):
-        """Return True if the Operator has been assigned to a DAG."""
-        return self._dag is not None
-
     deps: frozenset[BaseTIDep] = frozenset(
         {
             NotInRetryPeriodDep(),
@@ -1263,33 +708,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         other._lock_for_execution = True
         return other
 
-    def set_xcomargs_dependencies(self) -> None:
-        """
-        Resolve upstream dependencies of a task.
-
-        In this way passing an ``XComArg`` as value for a template field
-        will result in creating upstream relation between two tasks.
-
-        **Example**: ::
-
-            with DAG(...):
-                generate_content = 
GenerateContentOperator(task_id="generate_content")
-                send_email = EmailOperator(..., 
html_content=generate_content.output)
-
-            # This is equivalent to
-            with DAG(...):
-                generate_content = 
GenerateContentOperator(task_id="generate_content")
-                send_email = EmailOperator(..., html_content="{{ 
task_instance.xcom_pull('generate_content') }}")
-                generate_content >> send_email
-
-        """
-        from airflow.models.xcom_arg import XComArg
-
-        for field in self.template_fields:
-            if hasattr(self, field):
-                arg = getattr(self, field)
-                XComArg.apply_upstream_relationship(self, arg)
-
     @prepare_lineage
     def pre_execute(self, context: Any):
         """Execute right before self.execute() is called."""
@@ -1326,14 +744,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             logger=self.log,
         ).run(context, result)
 
-    def on_kill(self) -> None:
-        """
-        Override this method to clean up subprocesses when a task instance 
gets killed.
-
-        Any use of the threading, subprocess or multiprocessing module within 
an
-        operator needs to be cleaned up, or it will leave ghost processes 
behind.
-        """
-
     def __deepcopy__(self, memo):
         # Hack sorting double chained task lists by task_id to avoid hitting
         # max_depth on deepcopy operations.
@@ -1518,43 +928,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         else:
             return self.downstream_list
 
-    def __repr__(self):
-        return f"<Task({self.task_type}): {self.task_id}>"
-
-    @property
-    def operator_class(self) -> type[BaseOperator]:  # type: ignore[override]
-        return self.__class__
-
-    @property
-    def task_type(self) -> str:
-        """@property: type of the task."""
-        return self.__class__.__name__
-
-    @property
-    def operator_name(self) -> str:
-        """@property: use a more friendly display name for the operator, if 
set."""
-        try:
-            return self.custom_operator_name  # type: ignore
-        except AttributeError:
-            return self.task_type
-
-    @property
-    def roots(self) -> list[BaseOperator]:
-        """Required by DAGNode."""
-        return [self]
-
-    @property
-    def leaves(self) -> list[BaseOperator]:
-        """Required by DAGNode."""
-        return [self]
-
-    @property
-    def output(self) -> XComArg:
-        """Returns reference to XCom pushed by current operator."""
-        from airflow.models.xcom_arg import XComArg
-
-        return XComArg(operator=self)
-
     @property
     def is_setup(self) -> bool:
         """
@@ -1655,68 +1028,10 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             session=session,
         )
 
-    @classmethod
-    def get_serialized_fields(cls):
-        """Stringified DAGs and operators contain exactly these fields."""
-        if not cls.__serialized_fields:
-            from airflow.sdk.definitions.contextmanager import DagContext
-
-            # make sure the following dummy task is not added to current active
-            # dag in context, otherwise, it will result in
-            # `RuntimeError: dictionary changed size during iteration`
-            # Exception in SerializedDAG.serialize_dag() call.
-            DagContext.push(None)
-            cls.__serialized_fields = frozenset(
-                vars(BaseOperator(task_id="test")).keys()
-                - {
-                    "upstream_task_ids",
-                    "default_args",
-                    "dag",
-                    "_dag",
-                    "label",
-                    "_BaseOperator__instantiated",
-                    "_BaseOperator__init_kwargs",
-                    "_BaseOperator__from_mapped",
-                    "_is_setup",
-                    "_is_teardown",
-                    "_on_failure_fail_dagrun",
-                }
-                | {  # Class level defaults need to be added to this list
-                    "start_date",
-                    "end_date",
-                    "_task_type",
-                    "_operator_name",
-                    "ui_color",
-                    "ui_fgcolor",
-                    "template_ext",
-                    "template_fields",
-                    "template_fields_renderers",
-                    "params",
-                    "is_setup",
-                    "is_teardown",
-                    "on_failure_fail_dagrun",
-                    "map_index_template",
-                    "start_trigger_args",
-                    "_needs_expansion",
-                    "start_from_trigger",
-                }
-            )
-            DagContext.pop()
-
-        return cls.__serialized_fields
-
     def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
         """Serialize; required by DAGNode."""
         return DagAttributeTypes.OP, self.task_id
 
-    @property
-    def inherits_from_empty_operator(self):
-        """Used to determine if an Operator is inherited from EmptyOperator."""
-        # This looks like `isinstance(self, EmptyOperator) would work, but 
this also
-        # needs to cope when `self` is a Serialized instance of a 
EmptyOperator or one
-        # of its subclasses (which don't inherit from anything but 
BaseOperator).
-        return getattr(self, "_is_empty", False)
-
     def defer(
         self,
         *,
@@ -1786,11 +1101,12 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         return self.start_trigger_args
 
 
-# TODO: Deprecate for Airflow 3.0
-Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
+# TODO: Task-SDK: remove before Airflow 3.0
+# Temp for migration to Task-SDK only
+_HasDependency = Union[DependencyMixin | TaskSDKDependencyMixin]
 
 
-def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
+def chain(*tasks: _HasDependency | Sequence[DependencyMixin]) -> None:
     r"""
     Given a number of tasks, builds a dependency chain.
 
@@ -1899,10 +1215,10 @@ def chain(*tasks: DependencyMixin | 
Sequence[DependencyMixin]) -> None:
     :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or 
TaskGroups to set dependencies
     """
     for up_task, down_task in zip(tasks, tasks[1:]):
-        if isinstance(up_task, DependencyMixin):
+        if isinstance(up_task, (DependencyMixin, TaskSDKDependencyMixin)):
             up_task.set_downstream(down_task)
             continue
-        if isinstance(down_task, DependencyMixin):
+        if isinstance(down_task, (DependencyMixin, TaskSDKDependencyMixin)):
             down_task.set_upstream(up_task)
             continue
         if not isinstance(up_task, Sequence) or not isinstance(down_task, 
Sequence):
@@ -2040,13 +1356,15 @@ def chain_linear(*elements: DependencyMixin | 
Sequence[DependencyMixin]):
     prev_elem = None
     deps_set = False
     for curr_elem in elements:
-        if isinstance(curr_elem, EdgeModifier):
+        if isinstance(curr_elem, (EdgeModifier, TaskSDKEdgeModifier)):
             raise ValueError("Labels are not supported by chain_linear")
         if prev_elem is not None:
             for task in prev_elem:
                 task >> curr_elem
                 if not deps_set:
                     deps_set = True
-        prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else 
curr_elem
+        prev_elem = (
+            [curr_elem] if isinstance(curr_elem, (DependencyMixin, 
TaskSDKDependencyMixin)) else curr_elem
+        )
     if not deps_set:
         raise ValueError("No dependencies were set. Did you forget to expand 
with `*`?")
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 290f42ca387..8cb0e07d137 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -26,7 +26,7 @@ import pickle
 import sys
 import time
 import traceback
-from collections import abc, defaultdict, deque
+from collections import abc, defaultdict
 from contextlib import ExitStack
 from datetime import datetime, timedelta
 from inspect import signature
@@ -44,6 +44,7 @@ from typing import (
     overload,
 )
 
+import attrs
 import jinja2
 import pendulum
 import re2
@@ -117,14 +118,13 @@ from airflow.utils import timezone
 from airflow.utils.dag_cycle_tester import check_cycle
 from airflow.utils.decorators import fixup_decorator_warning_stack
 from airflow.utils.helpers import exactly_one
+from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, 
tuple_in_condition, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 if TYPE_CHECKING:
-    from types import ModuleType
-
     from pendulum.tz.timezone import FixedTimezone, Timezone
     from sqlalchemy.orm.query import Query
     from sqlalchemy.orm.session import Session
@@ -345,7 +345,8 @@ DAG_ARGS_EXPECTED_TYPES = {
 
 
 @functools.total_ordering
-class DAG(TaskSDKDag):
[email protected](hash=False, repr=False)
+class DAG(TaskSDKDag, LoggingMixin):
     """
     A dag (directed acyclic graph) is a collection of tasks with directional 
dependencies.
 
@@ -459,6 +460,13 @@ class DAG(TaskSDKDag):
     :param dag_display_name: The display name of the DAG which appears on the 
UI.
     """
 
+    partial: bool = False
+    last_loaded: datetime | None = None
+    on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
+    on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
+
+    __hash__ = TaskSDKDag.__hash__
+
     def validate_executor_field(self):
         for task in self.tasks:
             if task.executor:
@@ -762,7 +770,7 @@ class DAG(TaskSDKDag):
 
     @property
     def default_view(self) -> str:
-        return self._default_view
+        return "grid"
 
     @property
     def pickle_id(self) -> int | None:
@@ -2566,15 +2574,11 @@ if STATICA_HACK:  # pragma: no cover
     """:sphinx-autoapi-skip:"""
 
 
-class DagContext(airflow.sdk.definitions.contextmanager.DagContext):
+class DagContext(airflow.sdk.definitions.contextmanager.DagContext, 
share_parent_context=True):
     """
     :meta private:
     """
 
-    _context_managed_dags: deque[DAG] = deque()
-    autoregistered_dags: set[tuple[DAG, ModuleType]] = set()
-    current_autoregister_module_name: str | None = None
-
     @classmethod
     def push_context_managed_dag(cls, dag: DAG):
         cls.push(dag)
diff --git a/airflow/utils/edgemodifier.py b/airflow/utils/edgemodifier.py
index a78e6c64999..e4345cb471b 100644
--- a/airflow/utils/edgemodifier.py
+++ b/airflow/utils/edgemodifier.py
@@ -19,6 +19,8 @@ from __future__ import annotations
 from typing import Sequence
 
 from airflow.models.taskmixin import DAGNode, DependencyMixin
+from airflow.sdk.definitions.node import DAGNode as TaskSDKDagNode
+from airflow.sdk.definitions.taskgroup import TaskGroup as TaskSDKTaskGroup
 from airflow.utils.task_group import TaskGroup
 
 
@@ -68,7 +70,7 @@ class EdgeModifier(DependencyMixin):
         from airflow.models.xcom_arg import XComArg
 
         for node in self._make_list(nodes):
-            if isinstance(node, (TaskGroup, XComArg, DAGNode)):
+            if isinstance(node, (XComArg, DAGNode, TaskSDKDagNode)):
                 stream.append(node)
             else:
                 raise TypeError(
@@ -95,10 +97,10 @@ class EdgeModifier(DependencyMixin):
                     group_ids.add("root")
                 else:
                     group_ids.add(node.task_group.group_id)
-            elif isinstance(node, TaskGroup):
+            elif isinstance(node, (TaskGroup, TaskSDKTaskGroup)):
                 group_ids.add(node.group_id)
             elif isinstance(node, XComArg):
-                if isinstance(node.operator, DAGNode) and 
node.operator.task_group:
+                if isinstance(node.operator, (DAGNode, TaskSDKDagNode)) and 
node.operator.task_group:
                     if node.operator.task_group.is_root:
                         group_ids.add("root")
                     else:
@@ -112,7 +114,7 @@ class EdgeModifier(DependencyMixin):
     def _convert_stream_to_task_groups(self, stream: 
Sequence[DependencyMixin]) -> Sequence[DependencyMixin]:
         return [
             node.task_group
-            if isinstance(node, DAGNode) and node.task_group and not 
node.task_group.is_root
+            if isinstance(node, (DAGNode, TaskSDKDagNode)) and node.task_group 
and not node.task_group.is_root
             else node
             for node in stream
         ]
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 79b6329e44b..acbd8183052 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -36,6 +36,7 @@ from airflow.exceptions import (
     TaskAlreadyInTaskGroup,
 )
 from airflow.models.taskmixin import DAGNode
+from airflow.sdk.definitions.node import DAGNode as TaskSDKDagNode
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.utils.helpers import validate_group_key, validate_instance_args
 
@@ -275,7 +276,7 @@ class TaskGroup(DAGNode):
     @property
     def group_id(self) -> str | None:
         """group_id of this TaskGroup."""
-        if self.task_group and self.task_group.prefix_group_id and 
self.task_group._group_id:
+        if self.task_group and self.task_group.prefix_group_id and 
self.task_group.node_id:
             # defer to parent whether it adds a prefix
             return self.task_group.child_id(self._group_id)
 
@@ -311,7 +312,7 @@ class TaskGroup(DAGNode):
         else:
             # Handles setting relationship between a TaskGroup and a task
             for task in other.roots:
-                if not isinstance(task, DAGNode):
+                if not isinstance(task, (DAGNode, TaskSDKDagNode)):
                     raise AirflowException(
                         "Relationships can only be set between TaskGroup "
                         f"or operators; received {task.__class__.__name__}"
@@ -651,13 +652,13 @@ class MappedTaskGroup(TaskGroup):
         super().__exit__(exc_type, exc_val, exc_tb)
 
 
-class 
TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext):
+class 
TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext, 
share_parent_context=True):
     """TaskGroup context is used to keep the current TaskGroup when TaskGroup 
is used as ContextManager."""
 
     @classmethod
     def push_context_managed_task_group(cls, task_group: TaskGroup):
         """Push a TaskGroup into the list of managed TaskGroups."""
-        return cls.pusg(task_group)
+        return cls.push(task_group)
 
     @classmethod
     def pop_context_managed_task_group(cls) -> TaskGroup | None:
@@ -667,7 +668,7 @@ class 
TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext):
     @classmethod
     def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
         """Get the current TaskGroup."""
-        return cls.get_current()
+        return cls.get_current(dag)
 
 
 def task_group_to_dict(task_item_or_group):
diff --git a/pyproject.toml b/pyproject.toml
index 3a3bde8e23a..63b2958dad2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -468,6 +468,7 @@ fixture-parentheses = false
 ## pytest settings ##
 [tool.pytest.ini_options]
 addopts = [
+    "--tb=short",
     "-rasl",
     "--verbosity=2",
     # Disable `flaky` plugin for pytest. This plugin conflicts with 
`rerunfailures` because provide the same marker.
diff --git a/task_sdk/src/airflow/sdk/__init__.py 
b/task_sdk/src/airflow/sdk/__init__.py
index baf7c85baa9..5fea295d981 100644
--- a/task_sdk/src/airflow/sdk/__init__.py
+++ b/task_sdk/src/airflow/sdk/__init__.py
@@ -23,12 +23,14 @@ __all__ = ["DAG", "BaseOperator", "TaskGroup"]
 if TYPE_CHECKING:
     from airflow.sdk.definitions.baseoperator import BaseOperator as 
BaseOperator
     from airflow.sdk.definitions.dag import DAG as DAG
+    from airflow.sdk.definitions.edges import EdgeModifier as EdgeModifier
     from airflow.sdk.definitions.taskgroup import TaskGroup as TaskGroup
 
 __lazy_imports: dict[str, str] = {
     "DAG": ".definitions.dag",
     "BaseOperator": ".definitions.baseoperator",
     "TaskGroup": ".definitions.taskgroup",
+    "EdgeModifier": ".definitions.edges",
 }
 
 
diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py 
b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
index 2dd0d488040..56baa258618 100644
--- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
@@ -21,6 +21,7 @@ import datetime
 from abc import abstractmethod
 from collections.abc import (
     Collection,
+    Iterable,
 )
 from typing import (
     TYPE_CHECKING,
@@ -28,16 +29,22 @@ from typing import (
     ClassVar,
 )
 
-from .node import DAGNode
+from airflow.sdk.definitions.node import DAGNode
+from airflow.utils.log.secrets_masker import redact
 
 # TaskStateChangeCallback = Callable[[Context], None]
 
 if TYPE_CHECKING:
-    from airflow.models.baseoperator import BaseOperator
+    import jinja2  # Slow import.
+
     from airflow.models.baseoperatorlink import BaseOperatorLink
+    from airflow.sdk.definitions.baseoperator import BaseOperator
+    from airflow.sdk.definitions.dag import DAG
     from airflow.task.priority_strategy import PriorityWeightStrategy
 
-    from .dag import DAG
+    # TODO: Task-SDK
+    Context = dict[str, Any]
+
 
 DEFAULT_OWNER: str = "airflow"
 DEFAULT_POOL_SLOTS: int = 1
@@ -47,11 +54,12 @@ DEFAULT_QUEUE: str = "default"
 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False
 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
 DEFAULT_RETRIES: int = 0
-DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(300)
+DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300)
 MAX_RETRY_DELAY: int = 24 * 60 * 60
 
+# TODO: Task-SDK
 # DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
-DEFAULT_TRIGGER_RULE = "ALL_SUCCESS"  # TriggerRule.ALL_SUCCESS
+DEFAULT_TRIGGER_RULE = "all_success"
 DEFAULT_WEIGHT_RULE = "downstream"
 DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None
 
@@ -155,3 +163,69 @@ class AbstractOperator(DAGNode):
             # "task_group_id.task_id" -> "task_id"
             return self.task_id[len(tg.node_id) + 1 :]
         return self.task_id
+
+    def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
+        """Get the template environment for rendering templates."""
+        if dag is None:
+            dag = self.get_dag()
+        return super().get_template_env(dag=dag)
+
+    def _render(self, template, context, dag: DAG | None = None):
+        if dag is None:
+            dag = self.get_dag()
+        return super()._render(template, context, dag=dag)
+
+    def _do_render_template_fields(
+        self,
+        parent: Any,
+        template_fields: Iterable[str],
+        context: Context,
+        jinja_env: jinja2.Environment,
+        seen_oids: set[int],
+    ) -> None:
+        """Override the base to use custom error logging."""
+        for attr_name in template_fields:
+            try:
+                value = getattr(parent, attr_name)
+            except AttributeError:
+                raise AttributeError(
+                    f"{attr_name!r} is configured as a template field "
+                    f"but {parent.task_type} does not have this attribute."
+                )
+            try:
+                if not value:
+                    continue
+            except Exception:
+                # This may happen if the templated field points to a class 
which does not support `__bool__`,
+                # such as Pandas DataFrames:
+                # 
https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
+                self.log.info(
+                    "Unable to check if the value of type '%s' is False for 
task '%s', field '%s'.",
+                    type(value).__name__,
+                    self.task_id,
+                    attr_name,
+                )
+                # We may still want to render custom classes which do not 
support __bool__
+                pass
+
+            try:
+                if callable(value):
+                    rendered_content = value(context=context, 
jinja_env=jinja_env)
+                else:
+                    rendered_content = self.render_template(
+                        value,
+                        context,
+                        jinja_env,
+                        seen_oids,
+                    )
+            except Exception:
+                value_masked = redact(name=attr_name, value=value)
+                self.log.exception(
+                    "Exception rendering Jinja template for task '%s', field 
'%s'. Template: %r",
+                    self.task_id,
+                    attr_name,
+                    value_masked,
+                )
+                raise
+            else:
+                setattr(parent, attr_name, rendered_content)
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py 
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index 6480a897807..3a70f910305 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -22,6 +22,7 @@ import collections.abc
 import contextlib
 import copy
 import inspect
+import warnings
 from collections.abc import Iterable, Sequence
 from dataclasses import dataclass
 from datetime import datetime, timedelta
@@ -29,6 +30,9 @@ from functools import total_ordering, wraps
 from types import FunctionType
 from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast
 
+import attrs
+
+from airflow.exceptions import FailStopDagInvalidTriggerRule
 from airflow.sdk.definitions.abstractoperator import (
     DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
     DEFAULT_OWNER,
@@ -47,6 +51,7 @@ from airflow.sdk.definitions.decorators import 
fixup_decorator_warning_stack
 from airflow.sdk.definitions.node import validate_key
 from airflow.sdk.types import NOTSET, validate_instance_args
 from airflow.task.priority_strategy import PriorityWeightStrategy, 
validate_and_load_priority_weight_strategy
+from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import AttributeRemoved
 
 T = TypeVar("T", bound=FunctionType)
@@ -58,8 +63,9 @@ if TYPE_CHECKING:
     class ParamsDict: ...
 
     from airflow.sdk.definitions.dag import DAG
-    from airflow.sdk.definitions.taskgroup import TaskGroup
+    from airflow.utils.operator_resources import Resources
 
+from airflow.sdk.definitions.taskgroup import TaskGroup
 
 # TODO: Task-SDK
 AirflowException = RuntimeError
@@ -489,6 +495,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     task_id: str
     owner: str = DEFAULT_OWNER
     email: str | Sequence[str] | None = None
+    email_on_retry: bool = True
+    email_on_failure: bool = True
     retries: int | None = DEFAULT_RETRIES
     retry_delay: timedelta | float = DEFAULT_RETRY_DELAY
     retry_exponential_backoff: bool = False
@@ -500,7 +508,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     wait_for_past_depends_before_skipping: bool = 
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
     wait_for_downstream: bool = False
     dag: DAG | None = None
-    params: MutableMapping | None = None
+    params: collections.abc.MutableMapping | None = None
     default_args: dict | None = None
     priority_weight: int = DEFAULT_PRIORITY_WEIGHT
     # TODO:
@@ -618,6 +626,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         task_id: str,
         owner: str = DEFAULT_OWNER,
         email: str | Sequence[str] | None = None,
+        email_on_retry: bool = True,
+        email_on_failure: bool = True,
         retries: int | None = DEFAULT_RETRIES,
         retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
         retry_exponential_backoff: bool = False,
@@ -670,14 +680,16 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     ):
         from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
 
+        dag = dag or DagContext.get_current()
+        task_group = task_group or TaskGroupContext.get_current(dag)
+
         self.task_id = task_group.child_id(task_id) if task_group else task_id
         if not self.__from_mapped and task_group:
             task_group.add(self)
 
-        dag = dag or DagContext.get_current()
-        task_group = task_group or TaskGroupContext.get_current(dag)
-
-        super().__init__(dag=dag, task_group=task_group)
+        super().__init__()
+        self.dag = dag
+        self.task_group = task_group
 
         kwargs.pop("_airflow_mapped_validation_only", None)
         if kwargs:
@@ -689,6 +701,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
         self.owner = owner
         self.email = email
+        self.email_on_retry = email_on_retry
+        self.email_on_failure = email_on_failure
 
         if execution_timeout is not None and not isinstance(execution_timeout, 
timedelta):
             raise ValueError(
@@ -706,10 +720,14 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         # self._post_execute_hook = post_execute
 
         if start_date:
-            self.start_date = timezone.convert_to_utc(start_date)
+            # TODO: Task-SDK
+            # self.start_date = timezone.convert_to_utc(start_date)
+            self.start_date = start_date
 
         if end_date:
-            self.end_date = timezone.convert_to_utc(end_date)
+            # TODO: Task-SDK
+            # self.end_date = timezone.convert_to_utc(end_date)
+            self.end_date = end_date
 
         if executor:
             warnings.warn(
@@ -731,16 +749,14 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot 
be less than 1")
         self.sla = sla
 
-        """
-        # if not TriggerRule.is_valid(trigger_rule):
-        #     raise AirflowException(
-        #         f"The trigger_rule must be one of 
{TriggerRule.all_triggers()},"
-        #         f"'{dag.dag_id if dag else ''}.{task_id}'; received 
'{trigger_rule}'."
-        #     )
-        #
-        # self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
-        # FailStopDagInvalidTriggerRule.check(dag=dag, 
trigger_rule=self.trigger_rule)
-        """
+        if not TriggerRule.is_valid(trigger_rule):
+            raise ValueError(
+                f"The trigger_rule must be one of 
{TriggerRule.all_triggers()},"
+                f"'{dag.dag_id if dag else ''}.{task_id}'; received 
'{trigger_rule}'."
+            )
+
+        self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
+        FailStopDagInvalidTriggerRule.check(dag=dag, 
trigger_rule=self.trigger_rule)
 
         self.depends_on_past: bool = depends_on_past
         self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
@@ -892,7 +908,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             other = [other]
 
         for obj in other:
-            if not attr.has(obj):
+            if not attrs.has(obj):
                 raise TypeError(f"{obj} cannot be an inlet")
         self.add_inlets(other)
 
@@ -947,8 +963,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         return parsed_retries
 
     @staticmethod
-    def _convert_timedelta(value: float | timedelta) -> timedelta:
-        if isinstance(value, timedelta):
+    def _convert_timedelta(value: float | timedelta | None) -> timedelta | 
None:
+        if value is None or isinstance(value, timedelta):
             return value
         return timedelta(seconds=value)
 
@@ -959,6 +975,9 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     def _convert_resources(resources: dict[str, Any] | None) -> Resources | 
None:
         if resources is None:
             return None
+
+        from airflow.utils.operator_resources import Resources
+
         return Resources(**resources)
 
     @property
diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py 
b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
index d97339fc265..47edde0b5fe 100644
--- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py
+++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
@@ -28,17 +28,16 @@ T = TypeVar("T")
 
 
 class ContextStack(Generic[T]):
-    active: bool = False
     _context: deque[T]
 
-    def __init_subclass__(cls) -> None:
-        cls._context = deque()
+    def __init_subclass__(cls, /, **kwargs) -> None:
+        if not kwargs.get("share_parent_context", False):
+            cls._context = deque()
         return super().__init_subclass__()
 
     @classmethod
     def push(cls, obj: T):
         cls._context.appendleft(obj)
-        cls.active = True
 
     @classmethod
     def pop(cls) -> T | None:
@@ -52,7 +51,8 @@ class ContextStack(Generic[T]):
             return None
 
     @classmethod
-    def is_active(cls) -> bool:
+    @property
+    def active(cls) -> bool:
         """The active property says if any object is currently in scope."""
         try:
             cls._context[0]
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 5cf76da4882..6e62b4524f7 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -59,6 +59,7 @@ from airflow.exceptions import (
 from airflow.models.param import DagParam
 from airflow.sdk.definitions.abstractoperator import AbstractOperator
 from airflow.sdk.definitions.baseoperator import BaseOperator
+from airflow.sdk.types import NOTSET
 from airflow.stats import Stats
 from airflow.timetables.base import Timetable
 from airflow.timetables.interval import CronDataIntervalTimetable, 
DeltaDataIntervalTimetable
@@ -71,7 +72,7 @@ from airflow.timetables.simple import (
 from airflow.utils.dag_cycle_tester import check_cycle
 from airflow.utils.decorators import fixup_decorator_warning_stack
 from airflow.utils.trigger_rule import TriggerRule
-from airflow.utils.types import NOTSET, EdgeInfoType
+from airflow.utils.types import EdgeInfoType
 
 if TYPE_CHECKING:
     from airflow.decorators import TaskDecoratorCollection
@@ -130,8 +131,6 @@ DAG_ARGS_EXPECTED_TYPES = {
     "max_active_runs": int,
     "max_consecutive_failed_dag_runs": int,
     "dagrun_timeout": timedelta,
-    "default_view": str,
-    "orientation": str,
     "catchup": bool,
     "doc_md": str,
     "is_paused_upon_creation": bool,
@@ -161,7 +160,18 @@ def _create_timetable(interval: ScheduleInterval, 
timezone: Timezone | FixedTime
     raise ValueError(f"{interval!r} is not a valid schedule.")
 
 
[email protected](kw_only=True)
+def _all_after_dag_id_to_kw_only(cls, fields: list[attrs.Attribute]):
+    i = iter(fields)
+    f = next(i)
+    if f.name != "dag_id":
+        raise RuntimeError("dag_id was not the first field")
+    yield f
+
+    for f in i:
+        yield f.evolve(kw_only=True)
+
+
[email protected](repr=False, field_transformer=_all_after_dag_id_to_kw_only)
 class DAG:
     """
     A dag (directed acyclic graph) is a collection of tasks with directional 
dependencies.
@@ -231,9 +241,6 @@ class DAG:
     :param dagrun_timeout: Specify the duration a DagRun should be allowed to 
run before it times out or
         fails. Task instances that are running when a DagRun is timed out will 
be marked as skipped.
     :param sla_miss_callback: DEPRECATED - The SLA feature is removed in 
Airflow 3.0, to be replaced with a new implementation in 3.1
-    :param default_view: Specify DAG default view (grid, graph, duration,
-                                                   gantt, landing_times), 
default grid
-    :param orientation: Specify DAG orientation in graph view (LR, TB, RL, 
BT), default LR
     :param catchup: Perform scheduler catchup (or only run latest)? Defaults 
to True
     :param on_failure_callback: A function or list of functions to be called 
when a DagRun of this dag fails.
         A context dictionary is passed as a single parameter to this function.
@@ -290,7 +297,7 @@ class DAG:
 
     # NOTE: When updating arguments here, please also keep arguments in @dag()
     # below in sync. (Search for 'def dag(' in this file.)
-    dag_id: str
+    dag_id: str = attrs.field(kw_only=False)
     description: str | None = None
     start_date: datetime | None = None
     end_date: datetime | None = None
@@ -299,7 +306,10 @@ class DAG:
     timetable: Timetable = attrs.field(init=False)
     full_filepath: str | None = None
     template_searchpath: str | Iterable[str] | None = None
-    # template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined
+    # TODO: Task-SDK: Work out how to not import jinj2 until we need it! It's 
expensive
+    template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined
+    user_defined_macros: dict | None = None
+    user_defined_filters: dict | None = None
     default_args: dict | None = attrs.field(factory=dict, converter=copy.copy)
     concurrency: int | None = None
     max_active_tasks: int = 16
@@ -326,6 +336,16 @@ class DAG:
 
     task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.NO_OP)
 
+    fileloc: str = attrs.field(init=False)
+
+    edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, 
factory=dict)
+
+    @fileloc.default
+    def _default_fileloc(self) -> str:
+        # Skip over this frame, and the 'attrs generated init'
+        back = sys._getframe().f_back.f_back
+        return back.f_code.co_filename if back else ""
+
     @dag_display_name.default
     def _default_dag_display_name(self) -> str:
         return self.dag_id
@@ -443,7 +463,6 @@ class DAG:
 
         This is called by the DAG bag before bagging the DAG.
         """
-        self.validate_schedule_and_params()
         self.timetable.validate()
         self.validate_setup_teardown()
 
@@ -850,13 +869,6 @@ class DAG:
         args = parser.parse_args()
         args.func(args, self)
 
-    def get_default_view(self):
-        """Allow backward compatible jinja2 templates."""
-        if self.default_view is None:
-            return airflow_conf.get("webserver", "dag_default_view").lower()
-        else:
-            return self.default_view
-
     @classmethod
     def get_serialized_fields(cls):
         """Stringified DAGs and operators contain exactly these fields."""
@@ -943,8 +955,6 @@ def dag(
     ),
     dagrun_timeout: timedelta | None = None,
     sla_miss_callback: Any = None,
-    default_view: str = airflow_conf.get_mandatory_value("webserver", 
"dag_default_view").lower(),
-    orientation: str = airflow_conf.get_mandatory_value("webserver", 
"dag_orientation"),
     catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"),
     on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
     on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
@@ -995,8 +1005,6 @@ def dag(
                 
max_consecutive_failed_dag_runs=max_consecutive_failed_dag_runs,
                 dagrun_timeout=dagrun_timeout,
                 sla_miss_callback=sla_miss_callback,
-                default_view=default_view,
-                orientation=orientation,
                 catchup=catchup,
                 on_success_callback=on_success_callback,
                 on_failure_callback=on_failure_callback,
diff --git a/task_sdk/src/airflow/sdk/definitions/node.py 
b/task_sdk/src/airflow/sdk/definitions/node.py
index 40236cac8bd..92527531eb9 100644
--- a/task_sdk/src/airflow/sdk/definitions/node.py
+++ b/task_sdk/src/airflow/sdk/definitions/node.py
@@ -24,7 +24,6 @@ from collections.abc import Iterable, Sequence
 from datetime import datetime
 from typing import TYPE_CHECKING
 
-import attrs
 import methodtools
 import re2
 
@@ -55,7 +54,6 @@ def validate_key(k: str, max_length: int = 250):
         )
 
 
[email protected](hash=False, repr=False)
 class DAGNode(DependencyMixin, metaclass=ABCMeta):
     """
     A base class for a node in the graph of a workflow.
@@ -63,13 +61,17 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
     A node may be an Operator or a Task Group, either mapped or unmapped.
     """
 
-    dag: DAG | None = attrs.field(kw_only=True, default=None)
-    task_group: TaskGroup | None = attrs.field(kw_only=True, default=None)
+    dag: DAG | None
+    task_group: TaskGroup | None
     """The task_group that contains this node"""
-    start_date: datetime | None = attrs.field(kw_only=True, default=None)
-    end_date: datetime | None = attrs.field(kw_only=True, default=None)
-    upstream_task_ids: set[str] = attrs.field(init=False, factory=set)
-    downstream_task_ids: set[str] = attrs.field(init=False, factory=set)
+    start_date: datetime | None
+    end_date: datetime | None
+    upstream_task_ids: set[str]
+    downstream_task_ids: set[str]
+
+    def __init__(self):
+        self.upstream_task_ids = set()
+        self.downstream_task_ids = set()
 
     @property
     @abstractmethod
@@ -188,6 +190,13 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
             raise RuntimeError(f"Operator {self} has not been assigned to a 
DAG yet")
         return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
 
+    def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
+        """Get set of the direct relative ids to the current task, upstream or 
downstream."""
+        if upstream:
+            return self.upstream_task_ids
+        else:
+            return self.downstream_task_ids
+
     # def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
     #     """Serialize a task group's content; used by 
TaskGroupSerialization."""
     #     raise NotImplementedError()
diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index 6bcda41d2d2..85139efdf0c 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -153,8 +153,12 @@ class TaskGroup(DAGNode):
         return not self.group_id
 
     @property
-    def parent_group(self) -> TaskGroup | None:
-        return self.task_group
+    def task_group(self) -> TaskGroup | None:
+        return self.parent_group
+
+    @task_group.setter
+    def _set_task_group(self, tg: TaskGroup):
+        self.parent_group = tg
 
     def __iter__(self):
         for child in self.children.values():
@@ -173,9 +177,9 @@ class TaskGroup(DAGNode):
         from airflow.sdk.definitions.contextmanager import TaskGroupContext
 
         if TaskGroupContext.active:
-            if task.task_group and task.task_group != self:
-                task.task_group.children.pop(task.node_id, None)
-                task.task_group = self
+            if task.parent_group and task.parent_group != self:
+                task.parent_group.children.pop(task.node_id, None)
+                task.parent_group = self
         existing_tg = task.task_group
         if isinstance(task, AbstractOperator) and existing_tg is not None and 
existing_tg != self:
             raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, 
self.node_id)
@@ -213,9 +217,9 @@ class TaskGroup(DAGNode):
     @property
     def group_id(self) -> str | None:
         """group_id of this TaskGroup."""
-        if self.task_group and self.task_group.prefix_group_id and 
self.task_group._group_id:
+        if self.parent_group and self.parent_group.prefix_group_id and 
self.parent_group.group_id:
             # defer to parent whether it adds a prefix
-            return self.task_group.child_id(self._group_id)
+            return self.parent_group.child_id(self.group_id)
 
         return self._group_id
 
@@ -471,7 +475,7 @@ class TaskGroup(DAGNode):
                     while tg:
                         if tg.node_id in graph_unsorted:
                             break
-                        tg = tg.task_group
+                        tg = tg.parent_group
 
                     if tg:
                         # We are already going to visit that TG
@@ -499,7 +503,7 @@ class TaskGroup(DAGNode):
         while group is not None:
             if isinstance(group, MappedTaskGroup):
                 yield group
-            group = group.task_group
+            group = group.parent_group
 
     def iter_tasks(self) -> Iterator[AbstractOperator]:
         """Return an iterator of the child tasks."""
diff --git a/task_sdk/tests/defintions/test_baseoperator.py 
b/task_sdk/tests/defintions/test_baseoperator.py
index 0222de3ee32..f90891e1e86 100644
--- a/task_sdk/tests/defintions/test_baseoperator.py
+++ b/task_sdk/tests/defintions/test_baseoperator.py
@@ -124,6 +124,27 @@ class TestBaseOperator:
         ):
             BaseOperator(task_id="test", execution_timeout=1)
 
+    def test_default_resources(self):
+        task = BaseOperator(task_id="default-resources")
+        assert task.resources is None
+
+    def test_custom_resources(self):
+        task = BaseOperator(task_id="custom-resources", resources={"cpus": 1, 
"ram": 1024})
+        assert task.resources.cpus.qty == 1
+        assert task.resources.ram.qty == 1024
+
+    def test_default_email_on_actions(self):
+        test_task = BaseOperator(task_id="test_default_email_on_actions")
+        assert test_task.email_on_retry is True
+        assert test_task.email_on_failure is True
+
+    def test_email_on_actions(self):
+        test_task = BaseOperator(
+            task_id="test_default_email_on_actions", email_on_retry=False, 
email_on_failure=True
+        )
+        assert test_task.email_on_retry is False
+        assert test_task.email_on_failure is True
+
     def test_incorrect_default_args(self):
         default_args = {"test_param": True, "extra_param": True}
         op = FakeOperator(default_args=default_args)
@@ -228,6 +249,13 @@ class TestBaseOperator:
         with pytest.raises(ValueError, match="can not be changed"):
             op1.dag = DAG(dag_id="dag2")
 
+    def test_invalid_trigger_rule(self):
+        with pytest.raises(
+            ValueError,
+            match=(r"The trigger_rule must be one of .*,'\.op1'; received 
'some_rule'\."),
+        ):
+            BaseOperator(task_id="op1", trigger_rule="some_rule")
+
 
 def test_init_subclass_args():
     class InitSubclassOp(BaseOperator):
@@ -291,7 +319,13 @@ def test_operator_retries_conversion(retries, expected):
     assert op.retries == expected
 
 
-def test_dag_level_retry_delay(dag_maker):
+def test_default_retry_delay():
+    task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
+
+    assert task1.retry_delay == timedelta(seconds=300)
+
+
+def test_dag_level_retry_delay():
     with DAG(dag_id="test_dag_level_retry_delay", default_args={"retry_delay": 
timedelta(seconds=100)}):
         task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
 
diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index eea5dae2693..8227a6b7c8f 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -21,7 +21,7 @@ import copy
 import logging
 import uuid
 from collections import defaultdict
-from datetime import date, datetime, timedelta
+from datetime import date, datetime
 from typing import NamedTuple
 from unittest import mock
 
@@ -316,27 +316,6 @@ class TestBaseOperator:
         task.render_template_fields(context={"foo": "whatever", "bar": 
"whatever"})
         assert mock_jinja_env.call_count == 1
 
-    def test_default_resources(self):
-        task = BaseOperator(task_id="default-resources")
-        assert task.resources is None
-
-    def test_custom_resources(self):
-        task = BaseOperator(task_id="custom-resources", resources={"cpus": 1, 
"ram": 1024})
-        assert task.resources.cpus.qty == 1
-        assert task.resources.ram.qty == 1024
-
-    def test_default_email_on_actions(self):
-        test_task = BaseOperator(task_id="test_default_email_on_actions")
-        assert test_task.email_on_retry is True
-        assert test_task.email_on_failure is True
-
-    def test_email_on_actions(self):
-        test_task = BaseOperator(
-            task_id="test_default_email_on_actions", email_on_retry=False, 
email_on_failure=True
-        )
-        assert test_task.email_on_retry is False
-        assert test_task.email_on_failure is True
-
     def test_cross_downstream(self):
         """Test if all dependencies between tasks are all set correctly."""
         dag = DAG(dag_id="test_dag", schedule=None, start_date=datetime.now())
@@ -598,16 +577,6 @@ class TestBaseOperator:
         assert op_no_dag.start_date.tzinfo
         assert op_no_dag.end_date.tzinfo
 
-    def test_invalid_trigger_rule(self):
-        with pytest.raises(
-            AirflowException,
-            match=(
-                f"The trigger_rule must be one of 
{TriggerRule.all_triggers()},"
-                "'.op1'; received 'some_rule'."
-            ),
-        ):
-            BaseOperator(task_id="op1", trigger_rule="some_rule")
-
     # ensure the default logging config is used for this test, no matter what 
ran before
     @pytest.mark.usefixtures("reset_logging_config")
     def test_logging_propogated_by_default(self, caplog):
@@ -619,14 +588,6 @@ class TestBaseOperator:
         assert caplog.messages == ["test"]
 
 
[email protected]_test
-def test_default_retry_delay(dag_maker):
-    with dag_maker(dag_id="test_default_retry_delay"):
-        task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
-
-        assert task1.retry_delay == timedelta(seconds=300)
-
-
 def test_deepcopy():
     # Test bug when copying an operator attached to a DAG
     with DAG("dag0", schedule=None, start_date=DEFAULT_DATE) as dag:

Reply via email to