This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 0f92b1461d9c2c8e464ccd402e5e40aee1d9be22
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Mon Oct 21 18:31:51 2024 +0100

    [skip ci]
---
 .pre-commit-config.yaml                            |   1 +
 airflow/decorators/base.py                         |   2 +-
 airflow/decorators/sensor.py                       |   4 +-
 airflow/models/abstractoperator.py                 |  11 +-
 airflow/models/baseoperator.py                     | 315 +++++++++++----------
 airflow/models/dag.py                              |   6 +-
 airflow/models/mappedoperator.py                   |   4 +-
 airflow/models/skipmixin.py                        |   2 +-
 airflow/models/xcom_arg.py                         |   1 -
 airflow/operators/python.py                        |   4 +-
 airflow/serialization/serialized_objects.py        |   2 +-
 airflow/utils/task_group.py                        |   5 +-
 airflow/utils/types.py                             |   2 +-
 .../pre_commit/base_operator_partial_arguments.py  |  80 ++++--
 task_sdk/pyproject.toml                            |   2 +
 .../airflow/sdk/definitions/abstractoperator.py    |   6 +-
 .../src/airflow/sdk/definitions/baseoperator.py    |   7 +-
 task_sdk/src/airflow/sdk/definitions/dag.py        |  77 +++--
 task_sdk/src/airflow/sdk/definitions/edges.py      |   4 +-
 task_sdk/src/airflow/sdk/definitions/mixins.py     |   6 +-
 task_sdk/src/airflow/sdk/definitions/node.py       |   4 +-
 task_sdk/src/airflow/sdk/types.py                  |   2 +-
 task_sdk/tests/defintions/test_baseoperator.py     |   4 +-
 task_sdk/tests/defintions/test_dag.py              |  68 +++++
 tests/models/test_baseoperator.py                  |   5 +-
 tests/models/test_dag.py                           |  65 +----
 tests/utils/test_task_group.py                     |   1 -
 tests_common/test_utils/mock_operators.py          |   3 +-
 28 files changed, 365 insertions(+), 328 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7e94e4191e8..7d9cd6b2292 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1189,6 +1189,7 @@ repos:
           ^airflow/utils/helpers.py$ |
           ^providers/src/airflow/providers/ |
           ^(providers/)?tests/ |
+          task_sdk/src/airflow/sdk/definitions/dag.py$ |
           ^dev/.*\.py$ |
           ^scripts/.*\.py$ |
           ^docker_tests/.*$ |
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index e7a21e2919b..3290f5342ce 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -187,7 +187,7 @@ class DecoratedOperator(BaseOperator):
 
     # since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects (e.g protobuf).
-    shallow_copy_attrs: Sequence[str] = ("python_callable",)
+    shallow_copy_attrs: ClassVar[Sequence[str]] = ("python_callable",)
 
     def __init__(
         self,
diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py
index c332a78f95c..c37cd08d6b4 100644
--- a/airflow/decorators/sensor.py
+++ b/airflow/decorators/sensor.py
@@ -17,7 +17,7 @@
 
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Callable, Sequence
+from typing import TYPE_CHECKING, Callable, ClassVar, Sequence
 
 from airflow.decorators.base import get_unique_task_id, task_decorator_factory
 from airflow.sensors.python import PythonSensor
@@ -48,7 +48,7 @@ class DecoratedSensorOperator(PythonSensor):
 
     # since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects (e.g protobuf).
-    shallow_copy_attrs: Sequence[str] = ("python_callable",)
+    shallow_copy_attrs: ClassVar[Sequence[str]] = ("python_callable",)
 
     def __init__(
         self,
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index a27d7e26fd1..93c0fd2d93a 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import datetime
 import inspect
 from functools import cached_property
-from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, 
Iterator, Sequence
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence
 
 import methodtools
 from sqlalchemy import select
@@ -28,7 +28,6 @@ from sqlalchemy import select
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.models.expandinput import NotFullyPopulated
-from airflow.models.taskmixin import DependencyMixin
 from airflow.sdk.definitions.abstractoperator import AbstractOperator as 
TaskSDKAbstractOperator
 from airflow.template.templater import Templater
 from airflow.utils.context import Context
@@ -39,7 +38,6 @@ from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.task_group import MappedTaskGroup
 from airflow.utils.trigger_rule import TriggerRule
-from airflow.sdk.types import NOTSET, ArgNotSet
 from airflow.utils.weight_rule import WeightRule
 
 TaskStateChangeCallback = Callable[[Context], None]
@@ -53,10 +51,9 @@ if TYPE_CHECKING:
     from airflow.models.baseoperatorlink import BaseOperatorLink
     from airflow.models.dag import DAG as SchedulerDAG
     from airflow.models.mappedoperator import MappedOperator
-    from airflow.models.operator import Operator
     from airflow.models.taskinstance import TaskInstance
-    from airflow.sdk import BaseOperator, DAG
-    from airflow.sdk.defintions.node import DAGNode
+    from airflow.sdk import DAG, BaseOperator
+    from airflow.sdk.definitions.node import DAGNode
     from airflow.task.priority_strategy import PriorityWeightStrategy
     from airflow.triggers.base import StartTriggerArgs
     from airflow.utils.task_group import TaskGroup
@@ -262,7 +259,7 @@ class AbstractOperator(Templater, TaskSDKAbstractOperator):
         """
         if (group := self.task_group) is None:
             return
-        # TODO: Task-SDK: this type ignore shouldn't be necssary, revisit once 
mapping support is fully in the
+        # TODO: Task-SDK: this type ignore shouldn't be necessary, revisit 
once mapping support is fully in the
         # SDK
         yield from group.iter_mapped_task_groups()  # type: ignore[misc]
 
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index a9628873e5e..f0d1ee6f965 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -71,7 +71,6 @@ from airflow.models.abstractoperator import (
 )
 from airflow.models.base import _sentinel
 from airflow.models.mappedoperator import OperatorPartial, 
validate_mapping_kwargs
-from airflow.models.pool import Pool
 from airflow.models.taskinstance import TaskInstance, clear_task_instances
 from airflow.models.taskmixin import DependencyMixin
 
@@ -83,7 +82,6 @@ from airflow.sdk.definitions.baseoperator import (
     get_merged_defaults,
 )
 from airflow.serialization.enums import DagAttributeTypes
-from airflow.task.priority_strategy import PriorityWeightStrategy
 from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep
 from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
 from airflow.ti_deps.deps.not_previously_skipped_dep import 
NotPreviouslySkippedDep
@@ -106,14 +104,15 @@ if TYPE_CHECKING:
 
     from airflow.models.abstractoperator import TaskStateChangeCallback
     from airflow.models.baseoperatorlink import BaseOperatorLink
-    from airflow.models.dag import DAG, DAG as SchedulerDAG
+    from airflow.models.dag import DAG as SchedulerDAG
     from airflow.models.operator import Operator
+    from airflow.task.priority_strategy import PriorityWeightStrategy
     from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
     from airflow.triggers.base import BaseTrigger, StartTriggerArgs
     from airflow.utils.types import ArgNotSet
 
 
-# Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this 
here pydantic failes to resolve
+# Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this 
here pydantic fails to resolve
 # types for serialization
 from airflow.utils.task_group import TaskGroup  # noqa: TCH001
 
@@ -192,160 +191,148 @@ _PARTIAL_DEFAULTS: dict[str, Any] = {
 
 
 # This is what handles the actual mapping.
-def partial(
-    operator_class: type[BaseOperator],
-    *,
-    task_id: str,
-    dag: DAG | None = None,
-    task_group: TaskGroup | None = None,
-    start_date: datetime | ArgNotSet = NOTSET,
-    end_date: datetime | ArgNotSet = NOTSET,
-    owner: str | ArgNotSet = NOTSET,
-    email: None | str | Iterable[str] | ArgNotSet = NOTSET,
-    params: collections.abc.MutableMapping | None = None,
-    resources: dict[str, Any] | None | ArgNotSet = NOTSET,
-    trigger_rule: str | ArgNotSet = NOTSET,
-    depends_on_past: bool | ArgNotSet = NOTSET,
-    ignore_first_depends_on_past: bool | ArgNotSet = NOTSET,
-    wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET,
-    wait_for_downstream: bool | ArgNotSet = NOTSET,
-    retries: int | None | ArgNotSet = NOTSET,
-    queue: str | ArgNotSet = NOTSET,
-    pool: str | ArgNotSet = NOTSET,
-    pool_slots: int | ArgNotSet = NOTSET,
-    execution_timeout: timedelta | None | ArgNotSet = NOTSET,
-    max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET,
-    retry_delay: timedelta | float | ArgNotSet = NOTSET,
-    retry_exponential_backoff: bool | ArgNotSet = NOTSET,
-    priority_weight: int | ArgNotSet = NOTSET,
-    weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET,
-    sla: timedelta | None | ArgNotSet = NOTSET,
-    map_index_template: str | None | ArgNotSet = NOTSET,
-    max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
-    max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET,
-    on_execute_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
-    on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
-    on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
-    on_retry_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
-    on_skipped_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
-    run_as_user: str | None | ArgNotSet = NOTSET,
-    executor: str | None | ArgNotSet = NOTSET,
-    executor_config: dict | None | ArgNotSet = NOTSET,
-    inlets: Any | None | ArgNotSet = NOTSET,
-    outlets: Any | None | ArgNotSet = NOTSET,
-    doc: str | None | ArgNotSet = NOTSET,
-    doc_md: str | None | ArgNotSet = NOTSET,
-    doc_json: str | None | ArgNotSet = NOTSET,
-    doc_yaml: str | None | ArgNotSet = NOTSET,
-    doc_rst: str | None | ArgNotSet = NOTSET,
-    task_display_name: str | None | ArgNotSet = NOTSET,
-    logger_name: str | None | ArgNotSet = NOTSET,
-    allow_nested_operators: bool = True,
-    **kwargs,
-) -> OperatorPartial:
-    from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
-
-    validate_mapping_kwargs(operator_class, "partial", kwargs)
-
-    dag = dag or DagContext.get_current()
-    if dag:
-        task_group = task_group or TaskGroupContext.get_current(dag)
-    if task_group:
-        task_id = task_group.child_id(task_id)
-
-    # Merge DAG and task group level defaults into user-supplied values.
-    dag_default_args, partial_params = get_merged_defaults(
-        dag=dag,
-        task_group=task_group,
-        task_params=params,
-        task_default_args=kwargs.pop("default_args", None),
-    )
 
-    # Create partial_kwargs from args and kwargs
-    partial_kwargs: dict[str, Any] = {
+if TYPE_CHECKING:
+
+    def partial(
+        operator_class: type[BaseOperator],
+        *,
+        task_id: str,
+        dag: DAG | None = None,
+        task_group: TaskGroup | None = None,
+        start_date: datetime | ArgNotSet = NOTSET,
+        end_date: datetime | ArgNotSet = NOTSET,
+        owner: str | ArgNotSet = NOTSET,
+        email: None | str | Iterable[str] | ArgNotSet = NOTSET,
+        params: collections.abc.MutableMapping | None = None,
+        resources: dict[str, Any] | None | ArgNotSet = NOTSET,
+        trigger_rule: str | ArgNotSet = NOTSET,
+        depends_on_past: bool | ArgNotSet = NOTSET,
+        ignore_first_depends_on_past: bool | ArgNotSet = NOTSET,
+        wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET,
+        wait_for_downstream: bool | ArgNotSet = NOTSET,
+        retries: int | None | ArgNotSet = NOTSET,
+        queue: str | ArgNotSet = NOTSET,
+        pool: str | ArgNotSet = NOTSET,
+        pool_slots: int | ArgNotSet = NOTSET,
+        execution_timeout: timedelta | None | ArgNotSet = NOTSET,
+        max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET,
+        retry_delay: timedelta | float | ArgNotSet = NOTSET,
+        retry_exponential_backoff: bool | ArgNotSet = NOTSET,
+        priority_weight: int | ArgNotSet = NOTSET,
+        weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET,
+        sla: timedelta | None | ArgNotSet = NOTSET,
+        map_index_template: str | None | ArgNotSet = NOTSET,
+        max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
+        max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET,
+        on_execute_callback: None
+        | TaskStateChangeCallback
+        | list[TaskStateChangeCallback]
+        | ArgNotSet = NOTSET,
+        on_failure_callback: None
+        | TaskStateChangeCallback
+        | list[TaskStateChangeCallback]
+        | ArgNotSet = NOTSET,
+        on_success_callback: None
+        | TaskStateChangeCallback
+        | list[TaskStateChangeCallback]
+        | ArgNotSet = NOTSET,
+        on_retry_callback: None
+        | TaskStateChangeCallback
+        | list[TaskStateChangeCallback]
+        | ArgNotSet = NOTSET,
+        on_skipped_callback: None
+        | TaskStateChangeCallback
+        | list[TaskStateChangeCallback]
+        | ArgNotSet = NOTSET,
+        run_as_user: str | None | ArgNotSet = NOTSET,
+        executor: str | None | ArgNotSet = NOTSET,
+        executor_config: dict | None | ArgNotSet = NOTSET,
+        inlets: Any | None | ArgNotSet = NOTSET,
+        outlets: Any | None | ArgNotSet = NOTSET,
+        doc: str | None | ArgNotSet = NOTSET,
+        doc_md: str | None | ArgNotSet = NOTSET,
+        doc_json: str | None | ArgNotSet = NOTSET,
+        doc_yaml: str | None | ArgNotSet = NOTSET,
+        doc_rst: str | None | ArgNotSet = NOTSET,
+        task_display_name: str | None | ArgNotSet = NOTSET,
+        logger_name: str | None | ArgNotSet = NOTSET,
+        allow_nested_operators: bool = True,
+        **kwargs,
+    ) -> OperatorPartial: ...
+else:
+
+    def partial(
+        operator_class: type[BaseOperator],
+        *,
+        task_id: str,
+        dag: DAG | None = None,
+        task_group: TaskGroup | None = None,
+        params: collections.abc.MutableMapping | None = None,
         **kwargs,
-        "dag": dag,
-        "task_group": task_group,
-        "task_id": task_id,
-        "map_index_template": map_index_template,
-        "start_date": start_date,
-        "end_date": end_date,
-        "owner": owner,
-        "email": email,
-        "trigger_rule": trigger_rule,
-        "depends_on_past": depends_on_past,
-        "ignore_first_depends_on_past": ignore_first_depends_on_past,
-        "wait_for_past_depends_before_skipping": 
wait_for_past_depends_before_skipping,
-        "wait_for_downstream": wait_for_downstream,
-        "retries": retries,
-        "queue": queue,
-        "pool": pool,
-        "pool_slots": pool_slots,
-        "execution_timeout": execution_timeout,
-        "max_retry_delay": max_retry_delay,
-        "retry_delay": retry_delay,
-        "retry_exponential_backoff": retry_exponential_backoff,
-        "priority_weight": priority_weight,
-        "weight_rule": weight_rule,
-        "sla": sla,
-        "max_active_tis_per_dag": max_active_tis_per_dag,
-        "max_active_tis_per_dagrun": max_active_tis_per_dagrun,
-        "on_execute_callback": on_execute_callback,
-        "on_failure_callback": on_failure_callback,
-        "on_retry_callback": on_retry_callback,
-        "on_success_callback": on_success_callback,
-        "on_skipped_callback": on_skipped_callback,
-        "run_as_user": run_as_user,
-        "executor": executor,
-        "executor_config": executor_config,
-        "inlets": inlets,
-        "outlets": outlets,
-        "resources": resources,
-        "doc": doc,
-        "doc_json": doc_json,
-        "doc_md": doc_md,
-        "doc_rst": doc_rst,
-        "doc_yaml": doc_yaml,
-        "task_display_name": task_display_name,
-        "logger_name": logger_name,
-        "allow_nested_operators": allow_nested_operators,
-    }
-
-    # Inject DAG-level default args into args provided to this function.
-    partial_kwargs.update((k, v) for k, v in dag_default_args.items() if 
partial_kwargs.get(k) is NOTSET)
-
-    # Fill fields not provided by the user with default values.
-    partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k, 
v in partial_kwargs.items()}
-
-    # Post-process arguments. Should be kept in sync with 
_TaskDecorator.expand().
-    if "task_concurrency" in kwargs:  # Reject deprecated option.
-        raise TypeError("unexpected argument: task_concurrency")
-    if partial_kwargs["wait_for_downstream"]:
-        partial_kwargs["depends_on_past"] = True
-    partial_kwargs["start_date"] = 
timezone.convert_to_utc(partial_kwargs["start_date"])
-    partial_kwargs["end_date"] = 
timezone.convert_to_utc(partial_kwargs["end_date"])
-    if partial_kwargs["pool"] is None:
-        partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
-    if partial_kwargs["pool_slots"] < 1:
-        dag_str = ""
+    ):
+        from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
+
+        validate_mapping_kwargs(operator_class, "partial", kwargs)
+
+        dag = dag or DagContext.get_current()
         if dag:
-            dag_str = f" in dag {dag.dag_id}"
-        raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less 
than 1")
-    partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"])
-    partial_kwargs["retry_delay"] = 
coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay")
-    if partial_kwargs["max_retry_delay"] is not None:
-        partial_kwargs["max_retry_delay"] = coerce_timedelta(
-            partial_kwargs["max_retry_delay"],
-            key="max_retry_delay",
+            task_group = task_group or TaskGroupContext.get_current(dag)
+        if task_group:
+            task_id = task_group.child_id(task_id)
+
+        # Merge DAG and task group level defaults into user-supplied values.
+        dag_default_args, partial_params = get_merged_defaults(
+            dag=dag,
+            task_group=task_group,
+            task_params=params,
+            task_default_args=kwargs.pop("default_args", None),
         )
-    partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {}
-    partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"])
 
-    return OperatorPartial(
-        operator_class=operator_class,
-        kwargs=partial_kwargs,
-        params=partial_params,
-    )
+        # Create partial_kwargs from args and kwargs
+        partial_kwargs: dict[str, Any] = {
+            "task_id": task_id,
+            "dag": dag,
+            "task_group": task_group,
+            **kwargs,
+        }
+
+        # Inject DAG-level default args into args provided to this function.
+        partial_kwargs.update((k, v) for k, v in dag_default_args.items() if 
partial_kwargs.get(k) is NOTSET)
+
+        # Fill fields not provided by the user with default values.
+        for k, v in _PARTIAL_DEFAULTS.items():
+            partial_kwargs.setdefault(k, v)
+
+        # Post-process arguments. Should be kept in sync with 
_TaskDecorator.expand().
+        if "task_concurrency" in kwargs:  # Reject deprecated option.
+            raise TypeError("unexpected argument: task_concurrency")
+        if wait := partial_kwargs.get("wait_for_downstream", False):
+            partial_kwargs["depends_on_past"] = wait
+        if start_date := partial_kwargs.get("start_date", None):
+            partial_kwargs["start_date"] = timezone.convert_to_utc(start_date)
+        if end_date := partial_kwargs.get("end_date", None):
+            partial_kwargs["end_date"] = timezone.convert_to_utc(end_date)
+        if partial_kwargs["pool_slots"] < 1:
+            dag_str = ""
+            if dag:
+                dag_str = f" in dag {dag.dag_id}"
+            raise ValueError(f"pool slots for {task_id}{dag_str} cannot be 
less than 1")
+        if retries := partial_kwargs.get("retries"):
+            partial_kwargs["retries"] = parse_retries(retries)
+        partial_kwargs["retry_delay"] = 
coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay")
+        if partial_kwargs.get("max_retry_delay", None) is not None:
+            partial_kwargs["max_retry_delay"] = coerce_timedelta(
+                partial_kwargs["max_retry_delay"],
+                key="max_retry_delay",
+            )
+        partial_kwargs.setdefault("executor_config", {})
+
+        return OperatorPartial(
+            operator_class=operator_class,
+            kwargs=partial_kwargs,
+            params=partial_params,
+        )
 
 
 class ExecutorSafeguard:
@@ -387,6 +374,8 @@ class ExecutorSafeguard:
 
 # TODO: Task-SDK - temporarily extend the metaclass to add in the 
ExecutorSafeguard.
 class BaseOperatorMeta(TaskSDKBaseOperatorMeta):
+    """:meta private:"""  # noqa: D400
+
     def __new__(cls, name, bases, namespace, **kwargs):
         execute_method = namespace.get("execute")
         if callable(execute_method) and not getattr(execute_method, 
"__isabstractmethod__", False):
@@ -617,7 +606,17 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperator
     _is_setup: bool = False
     _is_teardown: bool = False
 
-    def __init__(self, pre_execute=None, post_execute=None, **kwargs):
+    def __init__(
+        self,
+        pre_execute=None,
+        post_execute=None,
+        on_execute_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
+        on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
+        on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
+        on_retry_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
+        on_skipped_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
+        **kwargs,
+    ):
         if start_date := kwargs.get("start_date", None):
             kwargs["start_date"] = timezone.convert_to_utc(start_date)
 
@@ -626,6 +625,10 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperator
         super().__init__(**kwargs)
         self._pre_execute_hook = pre_execute
         self._post_execute_hook = post_execute
+        self.on_execute_callback = on_execute_callback
+        self.on_failure_callback = on_failure_callback
+        self.on_success_callback = on_success_callback
+        self.on_skipped_callback = on_skipped_callback
 
     # Defines the operator level extra links
     operator_extra_links: Collection[BaseOperatorLink] = ()
@@ -764,7 +767,7 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperator
 
         if TYPE_CHECKING:
             # TODO: Task-SDK: We need to set this to the scheduler DAG until 
we fully separate scheduling and
-            # defintion code
+            # definition code
             assert isinstance(self.dag, SchedulerDAG)
 
         clear_task_instances(results, session, dag=self.dag)
@@ -815,7 +818,7 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperator
             assert self.start_date
 
             # TODO: Task-SDK: We need to set this to the scheduler DAG until 
we fully separate scheduling and
-            # defintion code
+            # definition code
             assert isinstance(self.dag, SchedulerDAG)
 
         start_date = pendulum.instance(start_date or self.start_date)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 642b4c2f7e1..e26a696ee7a 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -83,7 +83,6 @@ from airflow.exceptions import (
     UnknownExecutorException,
 )
 from airflow.executors.executor_loader import ExecutorLoader
-from airflow.models.abstractoperator import TaskStateChangeCallback
 from airflow.models.asset import (
     AssetDagRunQueue,
     AssetModel,
@@ -125,6 +124,7 @@ if TYPE_CHECKING:
     from sqlalchemy.orm.query import Query
     from sqlalchemy.orm.session import Session
 
+    from airflow.models.abstractoperator import TaskStateChangeCallback
     from airflow.models.dagbag import DagBag
     from airflow.models.operator import Operator
     from airflow.serialization.pydantic.dag import DagModelPydantic
@@ -2433,9 +2433,7 @@ if STATICA_HACK:  # pragma: no cover
 
 
 class DagContext(airflow.sdk.definitions.contextmanager.DagContext, 
share_parent_context=True):
-    """
-    :meta private:
-    """
+    """:meta private:"""  # noqa: D400
 
     @classmethod
     def push_context_managed_dag(cls, dag: DAG):
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 8a9e790ea7f..52a08bce027 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -201,8 +201,8 @@ class OperatorPartial:
         task_id = partial_kwargs.pop("task_id")
         dag = partial_kwargs.pop("dag")
         task_group = partial_kwargs.pop("task_group")
-        start_date = partial_kwargs.pop("start_date")
-        end_date = partial_kwargs.pop("end_date")
+        start_date = partial_kwargs.pop("start_date", None)
+        end_date = partial_kwargs.pop("end_date", None)
 
         try:
             operator_name = self.operator_class.custom_operator_name  # type: 
ignore
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 00e81791246..a67c7cf310b 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
 
     from airflow.models.dagrun import DagRun
     from airflow.models.operator import Operator
-    from airflow.sdk.defintions.node import DAGNode
+    from airflow.sdk.definitions.node import DAGNode
     from airflow.serialization.pydantic.dag_run import DagRunPydantic
     from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
 
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 07ca0190285..940a7f1a066 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -43,7 +43,6 @@ if TYPE_CHECKING:
 
     # from airflow.models.dag import DAG
     from airflow.models.operator import Operator
-    from airflow.models.taskmixin import DAGNode
     from airflow.sdk import DAG, BaseOperator
     from airflow.utils.context import Context
     from airflow.utils.edgemodifier import EdgeModifier
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index b032b45ed3e..dc2e772af0e 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -33,7 +33,7 @@ from collections.abc import Container
 from functools import cache
 from pathlib import Path
 from tempfile import TemporaryDirectory
-from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, 
Mapping, NamedTuple, Sequence
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, 
Iterable, Mapping, NamedTuple, Sequence
 
 import lazy_object_proxy
 
@@ -197,7 +197,7 @@ class PythonOperator(BaseOperator):
 
     # since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects(e.g protobuf).
-    shallow_copy_attrs: Sequence[str] = (
+    shallow_copy_attrs: ClassVar[Sequence[str]] = (
         "python_callable",
         "op_kwargs",
     )
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 9aafaf1f54a..e2e21a6686d 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -101,7 +101,7 @@ if TYPE_CHECKING:
     from airflow.models.baseoperatorlink import BaseOperatorLink
     from airflow.models.expandinput import ExpandInput
     from airflow.models.operator import Operator
-    from airflow.sdk.defintions.node import DAGNode
+    from airflow.sdk.definitions.node import DAGNode
     from airflow.serialization.json_schema import Validator
     from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
     from airflow.timetables.base import Timetable
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 1cdbe4b745d..6b760a112af 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -21,9 +21,7 @@ from __future__ import annotations
 
 import functools
 import operator
-from typing import TYPE_CHECKING, Any, Iterator
-
-import methodtools
+from typing import TYPE_CHECKING, Iterator
 
 import airflow.sdk.definitions.contextmanager
 import airflow.sdk.definitions.taskgroup
@@ -32,7 +30,6 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
     from airflow.models.dag import DAG
-    from airflow.models.expandinput import ExpandInput
     from airflow.models.operator import Operator
     from airflow.typing_compat import TypeAlias
 
diff --git a/airflow/utils/types.py b/airflow/utils/types.py
index 6aba5c711f7..7dd1ce02b60 100644
--- a/airflow/utils/types.py
+++ b/airflow/utils/types.py
@@ -20,7 +20,7 @@ import enum
 from typing import TYPE_CHECKING
 
 import airflow.sdk.types
-from airflow.typing_compat import TypedDict, TypeAlias
+from airflow.typing_compat import TypeAlias, TypedDict
 
 if TYPE_CHECKING:
     from datetime import datetime
diff --git a/scripts/ci/pre_commit/base_operator_partial_arguments.py 
b/scripts/ci/pre_commit/base_operator_partial_arguments.py
index 14999e034ed..b5070533170 100755
--- a/scripts/ci/pre_commit/base_operator_partial_arguments.py
+++ b/scripts/ci/pre_commit/base_operator_partial_arguments.py
@@ -27,6 +27,7 @@ import typing
 ROOT_DIR = pathlib.Path(__file__).resolve().parents[3]
 
 BASEOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "baseoperator.py")
+SDK_BASEOPERATOR_PY = ROOT_DIR.joinpath("task_sdk", "src", "airflow", "sdk", 
"definitions", "baseoperator.py")
 MAPPEDOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "mappedoperator.py")
 
 IGNORED = {
@@ -51,12 +52,31 @@ IGNORED = {
     # Only on MappedOperator.
     "expand_input",
     "partial_kwargs",
+    "operator_class",
+    # Task-SDK migration ones.
+    "deps",
+    "downstream_task_ids",
+    "on_execute_callback",
+    "on_failure_callback",
+    "on_retry_callback",
+    "on_skipped_callback",
+    "on_success_callback",
+    "operator_extra_links",
+    "start_from_trigger",
+    "start_trigger_args",
+    "upstream_task_ids",
+    "logger_name",
+    "sla",
 }
 
 
 BO_MOD = ast.parse(BASEOPERATOR_PY.read_text("utf-8"), str(BASEOPERATOR_PY))
+SDK_BO_MOD = ast.parse(SDK_BASEOPERATOR_PY.read_text("utf-8"), 
str(SDK_BASEOPERATOR_PY))
 MO_MOD = ast.parse(MAPPEDOPERATOR_PY.read_text("utf-8"), 
str(MAPPEDOPERATOR_PY))
 
+# TODO: Task-SDK: Look at the BaseOperator init functions in both 
airflow.models.baseoperator and combine
+# them, until we fully remove BaseOperator class from core.
+
 BO_CLS = next(
     node
     for node in ast.iter_child_nodes(BO_MOD)
@@ -67,9 +87,27 @@ BO_INIT = next(
     for node in ast.iter_child_nodes(BO_CLS)
     if isinstance(node, ast.FunctionDef) and node.name == "__init__"
 )
-BO_PARTIAL = next(
+
+SDK_BO_CLS = next(
+    node
+    for node in ast.iter_child_nodes(SDK_BO_MOD)
+    if isinstance(node, ast.ClassDef) and node.name == "BaseOperator"
+)
+SDK_BO_INIT = next(
+    node
+    for node in ast.iter_child_nodes(SDK_BO_CLS)
+    if isinstance(node, ast.FunctionDef) and node.name == "__init__"
+)
+
+# We now define the signature in a type checking block, the runtime impl uses 
**kwargs
+BO_TYPE_CHECKING_BLOCKS = (
     node
     for node in ast.iter_child_nodes(BO_MOD)
+    if isinstance(node, ast.If) and node.test.id == "TYPE_CHECKING"  # type: 
ignore[attr-defined]
+)
+BO_PARTIAL = next(
+    node
+    for node in itertools.chain.from_iterable(map(ast.iter_child_nodes, 
BO_TYPE_CHECKING_BLOCKS))
     if isinstance(node, ast.FunctionDef) and node.name == "partial"
 )
 MO_CLS = next(
@@ -79,23 +117,27 @@ MO_CLS = next(
 )
 
 
-def _compare(a: set[str], b: set[str], *, excludes: set[str]) -> 
tuple[set[str], set[str]]:
-    only_in_a = {n for n in a if n not in b and n not in excludes and n[0] != 
"_"}
-    only_in_b = {n for n in b if n not in a and n not in excludes and n[0] != 
"_"}
+def _compare(a: set[str], b: set[str]) -> tuple[set[str], set[str]]:
+    only_in_a = a - b - IGNORED
+    only_in_b = b - a - IGNORED
     return only_in_a, only_in_b
 
 
-def _iter_arg_names(func: ast.FunctionDef) -> typing.Iterator[str]:
-    func_args = func.args
-    for arg in itertools.chain(func_args.args, getattr(func_args, 
"posonlyargs", ()), func_args.kwonlyargs):
-        yield arg.arg
+def _iter_arg_names(*funcs: ast.FunctionDef) -> typing.Iterator[str]:
+    for func in funcs:
+        func_args = func.args
+        for arg in itertools.chain(
+            func_args.args, getattr(func_args, "posonlyargs", ()), 
func_args.kwonlyargs
+        ):
+            if arg.arg == "self" or arg.arg.startswith("_"):
+                continue
+            yield arg.arg
 
 
 def check_baseoperator_partial_arguments() -> bool:
     only_in_init, only_in_partial = _compare(
-        set(itertools.islice(_iter_arg_names(BO_INIT), 1, None)),
-        set(itertools.islice(_iter_arg_names(BO_PARTIAL), 1, None)),
-        excludes=IGNORED,
+        set(_iter_arg_names(SDK_BO_INIT, BO_INIT)),
+        set(_iter_arg_names(BO_PARTIAL)),
     )
     if only_in_init:
         print("Arguments in BaseOperator missing from partial():", ", 
".join(sorted(only_in_init)))
@@ -109,6 +151,8 @@ def check_baseoperator_partial_arguments() -> bool:
 def _iter_assignment_to_self_attributes(targets: typing.Iterable[ast.expr]) -> 
typing.Iterator[str]:
     for t in targets:
         if isinstance(t, ast.Attribute) and isinstance(t.value, ast.Name) and 
t.value.id == "self":
+            if t.attr.startswith("_"):
+                continue
             yield t.attr  # Something like "self.foo = ...".
         else:
             # Recursively visit nodes in unpacking assignments like "a, b = 
...".
@@ -132,20 +176,24 @@ def _is_property(f: ast.FunctionDef) -> bool:
 
 def _iter_member_names(klass: ast.ClassDef) -> typing.Iterator[str]:
     for node in ast.iter_child_nodes(klass):
+        name = ""
         if isinstance(node, ast.AnnAssign) and isinstance(node.target, 
ast.Name):
-            yield node.target.id
+            name = node.target.id
         elif isinstance(node, ast.FunctionDef) and _is_property(node):
-            yield node.name
+            name = node.name
         elif isinstance(node, ast.Assign):
             if len(node.targets) == 1 and isinstance(target := 
node.targets[0], ast.Name):
-                yield target.id
+                name = target.id
+        else:
+            continue
+        if not name.startswith("_"):
+            yield name
 
 
 def check_operator_member_parity() -> bool:
     only_in_base, only_in_mapped = _compare(
-        set(itertools.chain(_iter_assignment_targets(BO_INIT), 
_iter_member_names(BO_CLS))),
+        set(itertools.chain(_iter_assignment_targets(SDK_BO_INIT), 
_iter_member_names(SDK_BO_CLS))),
         set(_iter_member_names(MO_CLS)),
-        excludes=IGNORED,
     )
     if only_in_base:
         print("Members on BaseOperator missing from MappedOperator:", ", 
".join(sorted(only_in_base)))
diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index 7c8a4e52472..37ea2d300ab 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -44,6 +44,8 @@ namespace-packages = ["src/airflow"]
 # Ignore Doc rules et al for anything outside of tests
 "!src/*" = ["D", "TID253", "S101", "TRY002"]
 
+"src/airflow/sdk/__init__.py" = ["TCH004"]
+
 [tool.uv]
 dev-dependencies = [
     "kgb>=7.1.1",
diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py 
b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
index 6f90ae7f118..bb5ddf88e23 100644
--- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
@@ -22,7 +22,6 @@ from abc import abstractmethod
 from collections.abc import (
     Collection,
     Iterable,
-    Mapping,
 )
 from typing import (
     TYPE_CHECKING,
@@ -30,17 +29,14 @@ from typing import (
     ClassVar,
 )
 
-from airflow.sdk.definitions.node import DAGNode
 from airflow.sdk.definitions.mixins import DependencyMixin
-from airflow.utils.log.secrets_masker import redact
+from airflow.sdk.definitions.node import DAGNode
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
 
 # TaskStateChangeCallback = Callable[[Context], None]
 
 if TYPE_CHECKING:
-    import jinja2  # Slow import.
-
     from airflow.models.baseoperatorlink import BaseOperatorLink
     from airflow.models.operator import Operator
     from airflow.sdk.definitions.baseoperator import BaseOperator
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py 
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index 57a85988987..8349876cb8f 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -719,7 +719,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         if kwargs:
             raise TypeError(
                 f"Invalid arguments were passed to {self.__class__.__name__} 
(task_id: {task_id}). "
-                + f"Invalid arguments were:\n**kwargs: {kwargs}",
+                f"Invalid arguments were:\n**kwargs: {kwargs}",
             )
         validate_key(task_id)
 
@@ -762,6 +762,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         # self.retries = parse_retries(retries)
         self.retries = retries
         self.queue = queue
+        # TODO: Task-SDK: pull this default name from Pool constant?
         self.pool = "default" if pool is None else pool
         self.pool_slots = pool_slots
         if self.pool_slots < 1:
@@ -1070,8 +1071,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     def _set_xcomargs_dependencies(self) -> None:
         from airflow.models.xcom_arg import XComArg
 
-        for field in self.template_fields:
-            arg = getattr(self, field, NOTSET)
+        for f in self.template_fields:
+            arg = getattr(self, f, NOTSET)
             if arg is not NOTSET:
                 XComArg.apply_upstream_relationship(self, arg)
 
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 21579e87356..ba82f19339e 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -46,7 +46,6 @@ from dateutil.relativedelta import relativedelta
 
 from airflow import settings
 from airflow.assets import Asset, AssetAlias, BaseAsset
-from airflow.configuration import conf as airflow_conf
 from airflow.exceptions import (
     DuplicateTaskIdFound,
     FailStopDagInvalidTriggerRule,
@@ -119,29 +118,6 @@ _DAG_HASH_ATTRS = frozenset(
     }
 )
 
-# TODO: The following mapping is used to validate that the arguments passed to 
the DAG are of the correct
-#  type. This is a temporary solution until we find a more sophisticated 
method for argument validation.
-#  One potential method is to use `get_type_hints` from the typing module. 
However, this is not fully
-#  compatible with future annotations for Python versions below 3.10. Once we 
require a minimum Python
-#  version that supports `get_type_hints` effectively or find a better 
approach, we can replace this
-#  manual type-checking method.
-DAG_ARGS_EXPECTED_TYPES = {
-    "dag_id": str,
-    "description": str,
-    "max_active_tasks": int,
-    "max_active_runs": int,
-    "max_consecutive_failed_dag_runs": int,
-    "dagrun_timeout": timedelta,
-    "catchup": bool,
-    "doc_md": str,
-    "is_paused_upon_creation": bool,
-    "render_template_as_native_obj": bool,
-    "tags": Collection,
-    "auto_register": bool,
-    "fail_stop": bool,
-    "dag_display_name": str,
-}
-
 
 def _create_timetable(interval: ScheduleInterval, timezone: Timezone | 
FixedTimezone) -> Timetable:
     """Create a Timetable instance from a plain ``schedule`` value."""
@@ -167,7 +143,7 @@ def _create_timetable(interval: ScheduleInterval, timezone: 
Timezone | FixedTime
 
 def _convert_params(val: abc.MutableMapping | None, self_: DAG) -> ParamsDict:
     """
-    Convert the plain dict into a ParamsDict
+    Convert the plain dict into a ParamsDict.
 
     This will also merge in params from default_args
     """
@@ -336,8 +312,11 @@ class DAG:
 
     # NOTE: When updating arguments here, please also keep arguments in @dag()
     # below in sync. (Search for 'def dag(' in this file.)
-    dag_id: str = attrs.field(kw_only=False)
-    description: str | None = None
+    dag_id: str = attrs.field(kw_only=False, 
validator=attrs.validators.instance_of(str))
+    description: str | None = attrs.field(
+        default=None,
+        validator=attrs.validators.optional(attrs.validators.instance_of(str)),
+    )
     default_args: dict[str, Any] = attrs.field(
         factory=dict, validator=attrs.validators.instance_of(dict), 
converter=dict_copy
     )
@@ -355,12 +334,17 @@ class DAG:
     user_defined_macros: dict | None = None
     user_defined_filters: dict | None = None
     concurrency: int | None = None
-    max_active_tasks: int = 16
-    max_active_runs: int = 16
-    max_consecutive_failed_dag_runs: int = -1
-    dagrun_timeout: timedelta | None = None
+    max_active_tasks: int = attrs.field(default=16, 
validator=attrs.validators.instance_of(int))
+    max_active_runs: int = attrs.field(default=16, 
validator=attrs.validators.instance_of(int))
+    max_consecutive_failed_dag_runs: int = attrs.field(
+        default=-1, validator=attrs.validators.instance_of(int)
+    )
+    dagrun_timeout: timedelta | None = attrs.field(
+        default=None,
+        
validator=attrs.validators.optional(attrs.validators.instance_of(timedelta)),
+    )
     # sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None
-    catchup: bool = attrs.field(default=True)
+    catchup: bool = attrs.field(default=True, converter=bool)
     # on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
     # on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
     doc_md: str | None = None
@@ -372,12 +356,12 @@ class DAG:
     access_control: dict | None = None
     is_paused_upon_creation: bool | None = None
     jinja_environment_kwargs: dict | None = None
-    render_template_as_native_obj: bool = False
+    render_template_as_native_obj: bool = attrs.field(default=False, 
converter=bool)
     tags: MutableSet[str] = attrs.field(factory=set, converter=_convert_tags)
     owner_links: dict[str, str] = attrs.field(factory=dict)
-    auto_register: bool = True
-    fail_stop: bool = False
-    dag_display_name: str = attrs.field()
+    auto_register: bool = attrs.field(default=True, converter=bool)
+    fail_stop: bool = attrs.field(default=True, converter=bool)
+    dag_display_name: str = 
attrs.field(validator=attrs.validators.instance_of(str))
 
     task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False)
 
@@ -453,7 +437,7 @@ class DAG:
 
         from airflow.utils import timezone
 
-        # TODO: Task-SDK: get default dag tz from settins
+        # TODO: Task-SDK: get default dag tz from settings
         tz = timezone.utc
         if self.start_date and (tzinfo := self.start_date.tzinfo):
             tzinfo = None if tzinfo else tz
@@ -484,6 +468,11 @@ class DAG:
         if requires_automatic_backfilling and not ("start_date" in 
self.default_args or self.start_date):
             raise ValueError("start_date is required when catchup=True")
 
+    @tags.validator
+    def _validate_tags(self, _, tags: Collection[str]):
+        if tags and any(len(tag) > TAG_MAX_LEN for tag in tags):
+            raise ValueError(f"tag cannot be longer than {TAG_MAX_LEN} 
characters")
+
     def __repr__(self):
         return f"<DAG: {self.dag_id}>"
 
@@ -521,7 +510,7 @@ class DAG:
         return self
 
     def __exit__(self, _type, _value, _tb):
-        from .contextmanager import DagContext
+        from airflow.sdk.definitions.contextmanager import DagContext
 
         _ = DagContext.pop()
 
@@ -692,7 +681,7 @@ class DAG:
         result.user_defined_macros = self.user_defined_macros
         result.user_defined_filters = self.user_defined_filters
         if hasattr(self, "_log"):
-            result._log = self._log
+            result._log = self._log  # type: ignore[attr-defined]
         return result
 
     def partial_subset(
@@ -1000,14 +989,12 @@ if TYPE_CHECKING:
         user_defined_macros: dict | None = None,
         user_defined_filters: dict | None = None,
         default_args: dict | None = None,
-        max_active_tasks: int = airflow_conf.getint("core", 
"max_active_tasks_per_dag"),
-        max_active_runs: int = airflow_conf.getint("core", 
"max_active_runs_per_dag"),
-        max_consecutive_failed_dag_runs: int = airflow_conf.getint(
-            "core", "max_consecutive_failed_dag_runs_per_dag"
-        ),
+        max_active_tasks: int = ...,
+        max_active_runs: int = ...,
+        max_consecutive_failed_dag_runs: int = ...,
         dagrun_timeout: timedelta | None = None,
         sla_miss_callback: Any = None,
-        catchup: bool = airflow_conf.getboolean("scheduler", 
"catchup_by_default"),
+        catchup: bool = ...,
         on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
         on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
         doc_md: str | None = None,
diff --git a/task_sdk/src/airflow/sdk/definitions/edges.py 
b/task_sdk/src/airflow/sdk/definitions/edges.py
index 7c62c1bda80..7e50431b497 100644
--- a/task_sdk/src/airflow/sdk/definitions/edges.py
+++ b/task_sdk/src/airflow/sdk/definitions/edges.py
@@ -19,10 +19,10 @@ from __future__ import annotations
 from collections.abc import Sequence
 from typing import TYPE_CHECKING
 
-from .mixins import DependencyMixin
+from airflow.sdk.definitions.mixins import DependencyMixin
 
 if TYPE_CHECKING:
-    from .dag import DAG
+    from airflow.sdk.definitions.dag import DAG
 
 
 class EdgeModifier(DependencyMixin):
diff --git a/task_sdk/src/airflow/sdk/definitions/mixins.py 
b/task_sdk/src/airflow/sdk/definitions/mixins.py
index 73d455d6c92..e9d6e162927 100644
--- a/task_sdk/src/airflow/sdk/definitions/mixins.py
+++ b/task_sdk/src/airflow/sdk/definitions/mixins.py
@@ -21,11 +21,11 @@ from abc import abstractmethod
 from collections.abc import Iterable, Sequence
 from typing import TYPE_CHECKING, Any
 
-from ..types import NOTSET, ArgNotSet
+from airflow.sdk.types import NOTSET, ArgNotSet
 
 if TYPE_CHECKING:
-    from .baseoperator import BaseOperator
-    from .edges import EdgeModifier
+    from airflow.sdk.definitions.baseoperator import BaseOperator
+    from airflow.sdk.definitions.edges import EdgeModifier
 
 # TODO: Should this all just live on DAGNode?
 
diff --git a/task_sdk/src/airflow/sdk/definitions/node.py 
b/task_sdk/src/airflow/sdk/definitions/node.py
index ddc560b8bcd..7b877bbaf48 100644
--- a/task_sdk/src/airflow/sdk/definitions/node.py
+++ b/task_sdk/src/airflow/sdk/definitions/node.py
@@ -53,7 +53,7 @@ def validate_key(k: str, max_length: int = 250):
     if not KEY_REGEX.match(k):
         raise ValueError(
             f"The key {k!r} has to be made of alphanumeric characters, dashes, 
"
-            + "dots and underscores exclusively"
+            "dots and underscores exclusively"
         )
 
 
@@ -152,7 +152,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
         else:
             raise ValueError(
                 "Tried to create relationships between tasks that don't have 
DAGs yet. "
-                + f"Set the DAG for at least one task and try again: {[self, 
*task_list]}"
+                f"Set the DAG for at least one task and try again: {[self, 
*task_list]}"
             )
 
         if not self.has_dag():
diff --git a/task_sdk/src/airflow/sdk/types.py 
b/task_sdk/src/airflow/sdk/types.py
index ffde2170b17..232d08e27f9 100644
--- a/task_sdk/src/airflow/sdk/types.py
+++ b/task_sdk/src/airflow/sdk/types.py
@@ -57,7 +57,7 @@ if TYPE_CHECKING:
     Logger = logging.Logger
 else:
 
-    class Logger: ...
+    class Logger: ...  # noqa: D101
 
 
 def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, 
Any]) -> None:
diff --git a/task_sdk/tests/defintions/test_baseoperator.py 
b/task_sdk/tests/defintions/test_baseoperator.py
index 23ede320c87..427d1ee0e3e 100644
--- a/task_sdk/tests/defintions/test_baseoperator.py
+++ b/task_sdk/tests/defintions/test_baseoperator.py
@@ -222,7 +222,7 @@ class TestBaseOperator:
 
         spy_agency.spy_on(XComArg.apply_upstream_relationship, 
call_original=False)
         op_copy.arg1 = "b"
-        assert XComArg.apply_upstream_relationship.called == False
+        assert XComArg.apply_upstream_relationship.called is False
 
     def test_upstream_is_set_when_template_field_is_xcomarg(self):
         with DAG("xcomargs_test", schedule=None):
@@ -273,7 +273,7 @@ class TestBaseOperator:
 
 def test_init_subclass_args():
     class InitSubclassOp(BaseOperator):
-        class_arg: Any
+        class_arg = None
 
         def __init_subclass__(cls, class_arg=None, **kwargs) -> None:
             cls.class_arg = class_arg
diff --git a/task_sdk/tests/defintions/test_dag.py 
b/task_sdk/tests/defintions/test_dag.py
index 2300e97f07e..c3b3bbdce4d 100644
--- a/task_sdk/tests/defintions/test_dag.py
+++ b/task_sdk/tests/defintions/test_dag.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 import weakref
 from datetime import datetime, timedelta, timezone
+from typing import Any
 
 import pytest
 
@@ -248,6 +249,73 @@ class TestDag:
         assert task.task_group.upstream_group_ids is not 
copied_task.task_group.upstream_group_ids
 
 
+# Test some of the arg valiadtion. This is not all the validations we perform, 
just some of them.
[email protected](
+    ["attr", "value"],
+    [
+        pytest.param("max_consecutive_failed_dag_runs", "not_an_int", 
id="max_consecutive_failed_dag_runs"),
+        pytest.param("dagrun_timeout", "not_an_int", id="dagrun_timeout"),
+    ],
+)
+def test_invalid_type_for_args(attr: str, value: Any):
+    with pytest.raises(TypeError):
+        DAG("invalid-default-args", **{attr: value})
+
+
[email protected](
+    "tags, should_pass",
+    [
+        pytest.param([], True, id="empty tags"),
+        pytest.param(["a normal tag"], True, id="one tag"),
+        pytest.param(["a normal tag", "another normal tag"], True, id="two 
tags"),
+        pytest.param(["a" * 100], True, id="a tag that's of just length 100"),
+        pytest.param(["a normal tag", "a" * 101], False, id="two tags and one 
of them is of length > 100"),
+    ],
+)
+def test__tags_length(tags: list[str], should_pass: bool):
+    if should_pass:
+        DAG("test-dag", schedule=None, tags=tags)
+    else:
+        with pytest.raises(ValueError):
+            DAG("test-dag", schedule=None, tags=tags)
+
+
[email protected](
+    "input_tags, expected_result",
+    [
+        pytest.param([], set(), id="empty tags"),
+        pytest.param(
+            ["a normal tag"],
+            {"a normal tag"},
+            id="one tag",
+        ),
+        pytest.param(
+            ["a normal tag", "another normal tag"],
+            {"a normal tag", "another normal tag"},
+            id="two different tags",
+        ),
+        pytest.param(
+            ["a", "a"],
+            {"a"},
+            id="two same tags",
+        ),
+    ],
+)
+def test__tags_duplicates(input_tags: list[str], expected_result: set[str]):
+    result = DAG("test-dag", tags=input_tags)
+    assert result.tags == expected_result
+
+
+def test__tags_mutable():
+    expected_tags = {"6", "7"}
+    test_dag = DAG("test-dag")
+    test_dag.tags.add("6")
+    test_dag.tags.add("7")
+    test_dag.tags.add("8")
+    test_dag.tags.remove("8")
+    assert test_dag.tags == expected_tags
+
+
 class TestDagDecorator:
     DEFAULT_ARGS = {
         "owner": "test",
diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index 0af6f69df01..1d8e4f4d32a 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -29,7 +29,7 @@ import jinja2
 import pytest
 
 from airflow.decorators import task as task_decorator
-from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule
+from airflow.exceptions import AirflowException
 from airflow.lineage.entities import File
 from airflow.models.baseoperator import (
     BaseOperator,
@@ -47,9 +47,8 @@ from airflow.utils.template import literal
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
 
-from tests_common.test_utils.mock_operators import MockOperator
-
 from tests.models import DEFAULT_DATE
+from tests_common.test_utils.mock_operators import MockOperator
 
 
 class ClassWithCustomAttributes:
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index be8dde52331..43029fbab98 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1496,10 +1496,10 @@ class TestDag:
         fail_stop_dag.add_task(task_with_default_trigger_rule)
 
         # a fail stop dag should not allow a non-default trigger rule
+        task_with_non_default_trigger_rule = EmptyOperator(
+            task_id="task_with_non_default_trigger_rule", 
trigger_rule=TriggerRule.ALWAYS
+        )
         with pytest.raises(FailStopDagInvalidTriggerRule):
-            task_with_non_default_trigger_rule = EmptyOperator(
-                task_id="task_with_non_default_trigger_rule", 
trigger_rule=TriggerRule.ALWAYS
-            )
             fail_stop_dag.add_task(task_with_non_default_trigger_rule)
 
     def test_dag_add_task_sets_default_task_group(self):
@@ -3033,60 +3033,6 @@ def test__time_restriction(dag_maker, dag_date, 
tasks_date, restrict):
     assert dag._time_restriction == restrict
 
 
[email protected](
-    "tags, should_pass",
-    [
-        pytest.param([], True, id="empty tags"),
-        pytest.param(["a normal tag"], True, id="one tag"),
-        pytest.param(["a normal tag", "another normal tag"], True, id="two 
tags"),
-        pytest.param(["a" * 100], True, id="a tag that's of just length 100"),
-        pytest.param(["a normal tag", "a" * 101], False, id="two tags and one 
of them is of length > 100"),
-    ],
-)
-def test__tags_length(tags: list[str], should_pass: bool):
-    if should_pass:
-        DAG("test-dag", schedule=None, tags=tags)
-    else:
-        with pytest.raises(AirflowException):
-            DAG("test-dag", schedule=None, tags=tags)
-
-
[email protected](
-    "input_tags, expected_result",
-    [
-        pytest.param([], set(), id="empty tags"),
-        pytest.param(
-            ["a normal tag"],
-            {"a normal tag"},
-            id="one tag",
-        ),
-        pytest.param(
-            ["a normal tag", "another normal tag"],
-            {"a normal tag", "another normal tag"},
-            id="two different tags",
-        ),
-        pytest.param(
-            ["a", "a"],
-            {"a"},
-            id="two same tags",
-        ),
-    ],
-)
-def test__tags_duplicates(input_tags: list[str], expected_result: set[str]):
-    result = DAG("test-dag", tags=input_tags)
-    assert result.tags == expected_result
-
-
-def test__tags_mutable():
-    expected_tags = {"6", "7"}
-    test_dag = DAG("test-dag")
-    test_dag.tags.add("6")
-    test_dag.tags.add("7")
-    test_dag.tags.add("8")
-    test_dag.tags.remove("8")
-    assert test_dag.tags == expected_tags
-
-
 @pytest.mark.need_serialized_dag
 def test_get_asset_triggered_next_run_info(dag_maker, clear_assets):
     asset1 = Asset(uri="ds1")
@@ -3195,11 +3141,6 @@ def 
test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagR
     )
 
 
-def test_invalid_type_for_args():
-    with pytest.raises(TypeError):
-        DAG("invalid-default-args", schedule=None, 
max_consecutive_failed_dag_runs="not_an_int")
-
-
 class TestTaskClearingSetupTeardownBehavior:
     """
     Task clearing behavior is mainly controlled by dag.partial_subset.
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index 43abcf9012e..85c225cc9b6 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -18,7 +18,6 @@
 from __future__ import annotations
 
 from datetime import timedelta
-from unittest import mock
 
 import pendulum
 import pytest
diff --git a/tests_common/test_utils/mock_operators.py 
b/tests_common/test_utils/mock_operators.py
index ef941abb13f..6a88f41c22a 100644
--- a/tests_common/test_utils/mock_operators.py
+++ b/tests_common/test_utils/mock_operators.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from collections.abc import Mapping
 from typing import TYPE_CHECKING, Any, Sequence
 
 import attr
@@ -66,7 +67,7 @@ class MockOperatorWithNestedFields(BaseOperator):
     def _render_nested_template_fields(
         self,
         content: Any,
-        context: Context,
+        context: Mapping[str, Any],
         jinja_env: jinja2.Environment,
         seen_oids: set,
     ) -> None:

Reply via email to