This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c160ab70a00 Introduce serialized task groups; use them in core (#55169)
c160ab70a00 is described below

commit c160ab70a00faa41ef5996c4ba4a6580a64194ad
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Sep 4 09:42:51 2025 +0800

    Introduce serialized task groups; use them in core (#55169)
---
 .pre-commit-config.yaml                            |   2 +-
 .../api_fastapi/core_api/services/ui/grid.py       |   8 +-
 .../api_fastapi/core_api/services/ui/task_group.py |  25 +-
 .../airflow/example_dags/example_setup_teardown.py |   3 +-
 .../src/airflow/example_dags/example_task_group.py |   3 +-
 airflow-core/src/airflow/models/dagrun.py          |   4 +-
 airflow-core/src/airflow/models/mappedoperator.py  |  31 ++-
 airflow-core/src/airflow/models/taskinstance.py    |  22 +-
 .../airflow/serialization/definitions/__init__.py  |  17 ++
 .../airflow/serialization/definitions/taskgroup.py | 284 +++++++++++++++++++++
 .../airflow/serialization/serialized_objects.py    | 200 ++++++++++++---
 .../ti_deps/deps/mapped_task_upstream_dep.py       |   6 +-
 .../src/airflow/ti_deps/deps/trigger_rule_dep.py   |  21 +-
 airflow-core/src/airflow/utils/dot_renderer.py     |  15 +-
 airflow-core/tests/unit/models/test_dagrun.py      | 161 ++++--------
 .../tests/unit/models/test_taskinstance.py         |  57 ++---
 .../unit/serialization/test_dag_serialization.py   |   4 +-
 .../unit/ti_deps/deps/test_trigger_rule_dep.py     |  17 +-
 airflow-core/tests/unit/utils/test_task_group.py   |  53 ++--
 devel-common/src/tests_common/pytest_plugin.py     |  28 +-
 task-sdk/src/airflow/sdk/definitions/dag.py        |  14 +-
 task-sdk/src/airflow/sdk/definitions/taskgroup.py  |  38 ++-
 22 files changed, 692 insertions(+), 321 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3c454e6fb8c..623ab537ab9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1662,7 +1662,7 @@ repos:
           ^airflow-core/src/airflow/operators/subdag\.py$|
           ^airflow-core/src/airflow/plugins_manager\.py$|
           ^airflow-core/src/airflow/providers_manager\.py$|
-          ^airflow-core/src/airflow/serialization/dag\.py$|
+          ^airflow-core/src/airflow/serialization/definitions/[_a-z]+\.py$|
           ^airflow-core/src/airflow/serialization/enums\.py$|
           ^airflow-core/src/airflow/serialization/helpers\.py$|
           ^airflow-core/src/airflow/serialization/serialized_objects\.py$|
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py 
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
index 124f526cd05..1f64ffcefa8 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
@@ -26,7 +26,7 @@ from airflow.api_fastapi.common.parameters import 
state_priority
 from airflow.api_fastapi.core_api.services.ui.task_group import 
get_task_group_children_getter
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.taskmap import TaskMap
-from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
+from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
 from airflow.serialization.serialized_objects import SerializedBaseOperator
 
 log = structlog.get_logger(logger_name=__name__)
@@ -78,8 +78,8 @@ def _get_aggs_for_node(detail):
 
 
 def _find_aggregates(
-    node: TaskGroup | MappedTaskGroup | SerializedBaseOperator | TaskMap,
-    parent_node: TaskGroup | MappedTaskGroup | SerializedBaseOperator | 
TaskMap | None,
+    node: SerializedTaskGroup | SerializedBaseOperator | TaskMap,
+    parent_node: SerializedTaskGroup | SerializedBaseOperator | TaskMap | None,
     ti_details: dict[str, list],
 ) -> Iterable[dict]:
     """Recursively fill the Task Group Map."""
@@ -98,7 +98,7 @@ def _find_aggregates(
         }
 
         return
-    if isinstance(node, TaskGroup):
+    if isinstance(node, SerializedTaskGroup):
         children = []
         for child in get_task_group_children_getter()(node):
             for child_node in _find_aggregates(node=child, parent_node=node, 
ti_details=ti_details):
diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py 
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
index f88dca353c4..ed9a96718e9 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
@@ -24,8 +24,7 @@ from functools import cache
 from operator import methodcaller
 
 from airflow.configuration import conf
-from airflow.models.mappedoperator import MappedOperator
-from airflow.sdk.definitions.taskgroup import MappedTaskGroup
+from airflow.models.mappedoperator import MappedOperator, is_mapped
 from airflow.serialization.serialized_objects import SerializedBaseOperator
 
 
@@ -51,14 +50,14 @@ def task_group_to_dict(task_item_or_group, 
parent_group_is_mapped=False):
             node_operator["setup_teardown_type"] = "setup"
         elif task.is_teardown:
             node_operator["setup_teardown_type"] = "teardown"
-        if isinstance(task, MappedOperator) or parent_group_is_mapped:
+        if is_mapped(task) or parent_group_is_mapped:
             node_operator["is_mapped"] = True
         return node_operator
 
     task_group = task_item_or_group
-    is_mapped = isinstance(task_group, MappedTaskGroup)
+    mapped = is_mapped(task_group)
     children = [
-        task_group_to_dict(child, 
parent_group_is_mapped=parent_group_is_mapped or is_mapped)
+        task_group_to_dict(child, 
parent_group_is_mapped=parent_group_is_mapped or mapped)
         for child in get_task_group_children_getter()(task_group)
     ]
 
@@ -74,7 +73,7 @@ def task_group_to_dict(task_item_or_group, 
parent_group_is_mapped=False):
         "id": task_group.group_id,
         "label": task_group.label,
         "tooltip": task_group.tooltip,
-        "is_mapped": is_mapped,
+        "is_mapped": mapped,
         "children": children,
         "type": "task",
     }
@@ -83,9 +82,9 @@ def task_group_to_dict(task_item_or_group, 
parent_group_is_mapped=False):
 def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False):
     """Create a nested dict representation of this TaskGroup and its children 
used to construct the Grid."""
     if isinstance(task := task_item_or_group, (MappedOperator, 
SerializedBaseOperator)):
-        is_mapped = None
-        if task.is_mapped or parent_group_is_mapped:
-            is_mapped = True
+        mapped = None
+        if parent_group_is_mapped or is_mapped(task):
+            mapped = True
         setup_teardown_type = None
         if task.is_setup is True:
             setup_teardown_type = "setup"
@@ -94,22 +93,22 @@ def task_group_to_dict_grid(task_item_or_group, 
parent_group_is_mapped=False):
         return {
             "id": task.task_id,
             "label": task.label,
-            "is_mapped": is_mapped,
+            "is_mapped": mapped,
             "children": None,
             "setup_teardown_type": setup_teardown_type,
         }
 
     task_group = task_item_or_group
     task_group_sort = get_task_group_children_getter()
-    is_mapped_group = isinstance(task_group, MappedTaskGroup)
+    mapped = is_mapped(task_group)
     children = [
-        task_group_to_dict_grid(x, 
parent_group_is_mapped=parent_group_is_mapped or is_mapped_group)
+        task_group_to_dict_grid(x, 
parent_group_is_mapped=parent_group_is_mapped or mapped)
         for x in task_group_sort(task_group)
     ]
 
     return {
         "id": task_group.group_id,
         "label": task_group.label,
-        "is_mapped": is_mapped_group or None,
+        "is_mapped": mapped or None,
         "children": children or None,
     }
diff --git a/airflow-core/src/airflow/example_dags/example_setup_teardown.py 
b/airflow-core/src/airflow/example_dags/example_setup_teardown.py
index 052377736ea..cefa3b31463 100644
--- a/airflow-core/src/airflow/example_dags/example_setup_teardown.py
+++ b/airflow-core/src/airflow/example_dags/example_setup_teardown.py
@@ -22,8 +22,7 @@ from __future__ import annotations
 import pendulum
 
 from airflow.providers.standard.operators.bash import BashOperator
-from airflow.sdk import DAG
-from airflow.sdk.definitions.taskgroup import TaskGroup
+from airflow.sdk import DAG, TaskGroup
 
 with DAG(
     dag_id="example_setup_teardown",
diff --git a/airflow-core/src/airflow/example_dags/example_task_group.py 
b/airflow-core/src/airflow/example_dags/example_task_group.py
index c882c269c47..39010441d86 100644
--- a/airflow-core/src/airflow/example_dags/example_task_group.py
+++ b/airflow-core/src/airflow/example_dags/example_task_group.py
@@ -23,8 +23,7 @@ import pendulum
 
 from airflow.providers.standard.operators.bash import BashOperator
 from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import DAG
-from airflow.sdk.definitions.taskgroup import TaskGroup
+from airflow.sdk import DAG, TaskGroup
 
 # [START howto_task_group]
 with DAG(
diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index c8749685824..26d1d42e524 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1676,9 +1676,7 @@ class DagRun(Base, LoggingMixin):
 
         # Create the missing tasks, including mapped tasks
         tis_to_create = self._create_tasks(
-            # TODO (GH-52141): task_dict in scheduler should contain scheduler
-            # types instead, but currently it inherits SDK's DAG.
-            (task for task in cast("Iterable[Operator]", 
dag.task_dict.values()) if task_filter(task)),
+            (task for task in dag.task_dict.values() if task_filter(task)),
             task_creator,
             session=session,
         )
diff --git a/airflow-core/src/airflow/models/mappedoperator.py 
b/airflow-core/src/airflow/models/mappedoperator.py
index 9a4a66a9fe8..310573985da 100644
--- a/airflow-core/src/airflow/models/mappedoperator.py
+++ b/airflow-core/src/airflow/models/mappedoperator.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 
 import functools
 import operator
-from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard
+from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard, overload
 
 import attrs
 import methodtools
@@ -31,7 +31,7 @@ from airflow.exceptions import AirflowException, NotMapped
 from airflow.sdk import BaseOperator as TaskSDKBaseOperator
 from airflow.sdk.definitions._internal.node import DAGNode
 from airflow.sdk.definitions.mappedoperator import MappedOperator as 
TaskSDKMappedOperator
-from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
+from airflow.serialization.definitions.taskgroup import 
SerializedMappedTaskGroup, SerializedTaskGroup
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.serialization.serialized_objects import DEFAULT_OPERATOR_DEPS, 
SerializedBaseOperator
 from airflow.task.priority_strategy import PriorityWeightStrategy, 
validate_and_load_priority_weight_strategy
@@ -57,8 +57,16 @@ if TYPE_CHECKING:
 log = structlog.get_logger(__name__)
 
 
-def is_mapped(task: Operator) -> TypeGuard[MappedOperator]:
-    return task.is_mapped
+@overload
+def is_mapped(obj: Operator) -> TypeGuard[MappedOperator]: ...
+
+
+@overload
+def is_mapped(obj: SerializedTaskGroup) -> 
TypeGuard[SerializedMappedTaskGroup]: ...
+
+
+def is_mapped(obj: Operator | SerializedTaskGroup) -> TypeGuard[MappedOperator 
| SerializedMappedTaskGroup]:
+    return obj.is_mapped
 
 
 @attrs.define(
@@ -100,8 +108,11 @@ class MappedOperator(DAGNode):
     start_from_trigger: bool = False
     _needs_expansion: bool = True
 
-    dag: SerializedDAG = attrs.field(init=False)
-    task_group: TaskGroup = attrs.field(init=False)
+    # TODO (GH-52141): These should contain serialized containers, but 
currently
+    # this class inherits from an SDK one.
+    dag: SerializedDAG = attrs.field(init=False)  # type: ignore[assignment]
+    task_group: SerializedTaskGroup = attrs.field(init=False)  # type: 
ignore[assignment]
+
     start_date: pendulum.DateTime | None = attrs.field(init=False, 
default=None)
     end_date: pendulum.DateTime | None = attrs.field(init=False, default=None)
     upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
@@ -388,7 +399,7 @@ class MappedOperator(DAGNode):
         return getattr(self, self._expand_input_attr)
 
     # TODO (GH-52141): Copied from sdk. Find a better place for this to live 
in.
-    def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
+    def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:
         """
         Return mapped task groups this task belongs to.
 
@@ -401,7 +412,7 @@ class MappedOperator(DAGNode):
         yield from group.iter_mapped_task_groups()
 
     # TODO (GH-52141): Copied from sdk. Find a better place for this to live 
in.
-    def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
+    def get_closest_mapped_task_group(self) -> SerializedMappedTaskGroup | 
None:
         """
         Get the mapped task group "closest" to this task in the DAG.
 
@@ -504,7 +515,7 @@ def _(task: MappedOperator | TaskSDKMappedOperator, run_id: 
str, *, session: Ses
 
 
 @get_mapped_ti_count.register
-def _(group: TaskGroup, run_id: str, *, session: Session) -> int:
+def _(group: SerializedTaskGroup, run_id: str, *, session: Session) -> int:
     """
     Return the number of instances a task in this group should be mapped to at 
run time.
 
@@ -523,7 +534,7 @@ def _(group: TaskGroup, run_id: str, *, session: Session) 
-> int:
 
     def iter_mapped_task_group_lengths(group) -> Iterator[int]:
         while group is not None:
-            if isinstance(group, MappedTaskGroup):
+            if isinstance(group, SerializedMappedTaskGroup):
                 exp_input = group._expand_input
                 # TODO (GH-52141): 'group' here should be scheduler-bound and 
returns scheduler expand input.
                 if not hasattr(exp_input, "get_total_map_length"):
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 682787141d2..05cd9ce357e 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -124,8 +124,8 @@ if TYPE_CHECKING:
     from airflow.sdk import DAG
     from airflow.sdk.api.datamodels._generated import AssetProfile
     from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, 
AssetUriRef
-    from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
     from airflow.sdk.types import RuntimeTaskInstanceProtocol
+    from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
     from airflow.serialization.serialized_objects import SerializedBaseOperator
     from airflow.utils.context import Context
 
@@ -1534,12 +1534,9 @@ class TaskInstance(Base, LoggingMixin):
             assert original_task is not None
             assert original_task.dag is not None
 
-        serialized_task = SerializedDAG.deserialize_dag(
-            SerializedDAG.serialize_dag(original_task.dag)
-        ).task_dict[original_task.task_id]
-        # TODO (GH-52141): task_dict in scheduler should contain scheduler
-        # types instead, but currently it inherits SDK's DAG.
-        self.task = cast("Operator", serialized_task)
+        self.task = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(original_task.dag)).task_dict[
+            original_task.task_id
+        ]
         res = self.check_and_change_state_before_execution(
             verbose=verbose,
             ignore_all_deps=ignore_all_deps,
@@ -2286,7 +2283,7 @@ class TaskInstance(Base, LoggingMixin):
         )
 
 
-def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> 
MappedTaskGroup | None:
+def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> 
SerializedTaskGroup | None:
     """Given two operators, find their innermost common mapped task group."""
     if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
         return None
@@ -2295,16 +2292,15 @@ def _find_common_ancestor_mapped_group(node1: Operator, 
node2: Operator) -> Mapp
     return next(common_groups, None)
 
 
-def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> 
bool:
+def _is_further_mapped_inside(operator: Operator, container: 
SerializedTaskGroup) -> bool:
     """Whether given operator is *further* mapped inside a task group."""
-    from airflow.models.mappedoperator import MappedOperator
-    from airflow.sdk.definitions.taskgroup import MappedTaskGroup
+    from airflow.models.mappedoperator import is_mapped
 
-    if isinstance(operator, MappedOperator):
+    if is_mapped(operator):
         return True
     task_group = operator.task_group
     while task_group is not None and task_group.group_id != container.group_id:
-        if isinstance(task_group, MappedTaskGroup):
+        if is_mapped(task_group):
             return True
         task_group = task_group.parent_group
     return False
diff --git a/airflow-core/src/airflow/serialization/definitions/__init__.py 
b/airflow-core/src/airflow/serialization/definitions/__init__.py
new file mode 100644
index 00000000000..217e5db9607
--- /dev/null
+++ b/airflow-core/src/airflow/serialization/definitions/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py 
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
new file mode 100644
index 00000000000..e26c6cfb4ae
--- /dev/null
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -0,0 +1,284 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import copy
+import functools
+import operator
+import weakref
+from typing import TYPE_CHECKING
+
+import attrs
+import methodtools
+
+from airflow.sdk.definitions._internal.node import DAGNode
+
+if TYPE_CHECKING:
+    from collections.abc import Generator, Iterator
+    from typing import Any, ClassVar
+
+    from airflow.models.expandinput import SchedulerExpandInput
+    from airflow.serialization.serialized_objects import SerializedDAG, 
SerializedOperator
+
+
[email protected](kw_only=True, repr=False)
+class SerializedTaskGroup(DAGNode):
+    """Serialized representation of a TaskGroup used in protected processes."""
+
+    _group_id: str | None = attrs.field(alias="group_id")
+    group_display_name: str | None = attrs.field()
+    prefix_group_id: bool = attrs.field()
+    parent_group: SerializedTaskGroup | None = attrs.field()
+    dag: SerializedDAG = attrs.field()
+    tooltip: str = attrs.field()
+    default_args: dict[str, Any] = attrs.field(factory=dict)
+
+    # TODO: Are these actually useful?
+    ui_color: str = attrs.field(default="CornflowerBlue")
+    ui_fgcolor: str = attrs.field(default="#000")
+
+    children: dict[str, DAGNode] = attrs.field(factory=dict, init=False)
+    upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False)
+    downstream_group_ids: set[str | None] = attrs.field(factory=set, 
init=False)
+    upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
+    downstream_task_ids: set[str] = attrs.field(factory=set, init=False)
+
+    is_mapped: ClassVar[bool] = False
+
+    @staticmethod
+    def _iter_child(child):
+        """Iterate over the children of this TaskGroup."""
+        if isinstance(child, SerializedTaskGroup):
+            yield from child
+        else:
+            yield child
+
+    def __iter__(self):
+        for child in self.children.values():
+            yield from self._iter_child(child)
+
+    @property
+    def group_id(self) -> str | None:
+        if (
+            self._group_id
+            and self.parent_group
+            and self.parent_group.prefix_group_id
+            and self.parent_group._group_id
+        ):
+            return self.parent_group.child_id(self._group_id)
+        return self._group_id
+
+    @property
+    def label(self) -> str:
+        """group_id excluding parent's group_id used as the node label in 
UI."""
+        return self.group_display_name or self._group_id or ""
+
+    @property
+    def node_id(self) -> str:
+        return self.group_id or ""
+
+    @property
+    def is_root(self) -> bool:
+        return not self._group_id
+
+    # TODO (GH-52141): This shouldn't need to be writable after serialization,
+    # but DAGNode defines the property as writable.
+    @property
+    def task_group(self) -> SerializedTaskGroup | None:  # type: 
ignore[override]
+        return self.parent_group
+
+    def child_id(self, label: str) -> str:
+        if self.prefix_group_id and (group_id := self.group_id):
+            return f"{group_id}.{label}"
+        return label
+
+    @property
+    def upstream_join_id(self) -> str:
+        return f"{self.group_id}.upstream_join_id"
+
+    @property
+    def downstream_join_id(self) -> str:
+        return f"{self.group_id}.downstream_join_id"
+
+    @property
+    def roots(self) -> list[DAGNode]:
+        return list(self.get_roots())
+
+    @property
+    def leaves(self) -> list[DAGNode]:
+        return list(self.get_leaves())
+
+    def get_roots(self) -> Generator[SerializedOperator, None, None]:
+        """Return a generator of tasks with no upstream dependencies within 
the TaskGroup."""
+        tasks = list(self)
+        ids = {x.task_id for x in tasks}
+        for task in tasks:
+            if task.upstream_task_ids.isdisjoint(ids):
+                yield task
+
+    def get_leaves(self) -> Generator[SerializedOperator, None, None]:
+        """Return a generator of tasks with no downstream dependencies within 
the TaskGroup."""
+        tasks = list(self)
+        ids = {x.task_id for x in tasks}
+
+        def has_non_teardown_downstream(task, exclude: str):
+            for down_task in task.downstream_list:
+                if down_task.task_id == exclude:
+                    continue
+                if down_task.task_id not in ids:
+                    continue
+                if not down_task.is_teardown:
+                    return True
+            return False
+
+        def recurse_for_first_non_teardown(task):
+            for upstream_task in task.upstream_list:
+                if upstream_task.task_id not in ids:
+                    # upstream task is not in task group
+                    continue
+                elif upstream_task.is_teardown:
+                    yield from recurse_for_first_non_teardown(upstream_task)
+                elif task.is_teardown and upstream_task.is_setup:
+                    # don't go through the teardown-to-setup path
+                    continue
+                # return unless upstream task already has non-teardown 
downstream in group
+                elif not has_non_teardown_downstream(upstream_task, 
exclude=task.task_id):
+                    yield upstream_task
+
+        for task in tasks:
+            if task.downstream_task_ids.isdisjoint(ids):
+                if not task.is_teardown:
+                    yield task
+                else:
+                    yield from recurse_for_first_non_teardown(task)
+
+    def get_task_group_dict(self) -> dict[str | None, SerializedTaskGroup]:
+        """Create a flat dict of group_id: TaskGroup."""
+
+        def build_map(node: DAGNode) -> Generator[tuple[str | None, 
SerializedTaskGroup]]:
+            if not isinstance(node, SerializedTaskGroup):
+                return
+            yield node.group_id, node
+            for child in node.children.values():
+                yield from build_map(child)
+
+        return dict(build_map(self))
+
+    def iter_tasks(self) -> Iterator[SerializedOperator]:
+        """Return an iterator of the child tasks."""
+        from airflow.models.mappedoperator import MappedOperator
+        from airflow.serialization.serialized_objects import 
SerializedBaseOperator
+
+        groups_to_visit = [self]
+        while groups_to_visit:
+            for child in groups_to_visit.pop(0).children.values():
+                if isinstance(child, (MappedOperator, SerializedBaseOperator)):
+                    yield child
+                elif isinstance(child, SerializedTaskGroup):
+                    groups_to_visit.append(child)
+                else:
+                    raise ValueError(
+                        f"Encountered a DAGNode that is not a task or task "
+                        f"group: {type(child).__module__}.{type(child)}"
+                    )
+
+    def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:
+        """
+        Find mapped task groups in the hierarchy.
+
+        Groups are returned from the closest to the outmost. If *self* is a
+        mapped task group, it is returned first.
+        """
+        group: SerializedTaskGroup | None = self
+        while group is not None:
+            if isinstance(group, SerializedMappedTaskGroup):
+                yield group
+            group = group.parent_group
+
+    def topological_sort(self) -> list[DAGNode]:
+        """
+        Sorts children in topographical order.
+
+        A task in the result would come after any of its upstream dependencies.
+        """
+        # This uses a modified version of Kahn's Topological Sort algorithm to
+        # not have to pre-compute the "in-degree" of the nodes.
+        graph_unsorted = copy.copy(self.children)
+        graph_sorted: list[DAGNode] = []
+        if not self.children:
+            return graph_sorted
+        while graph_unsorted:
+            for node in list(graph_unsorted.values()):
+                for edge in node.upstream_list:
+                    if edge.node_id in graph_unsorted:
+                        break
+                    # Check for task's group is a child (or grand child) of 
this TG,
+                    tg = edge.task_group
+                    while tg:
+                        if tg.node_id in graph_unsorted:
+                            break
+                        tg = tg.parent_group
+                else:
+                    del graph_unsorted[node.node_id]
+                    graph_sorted.append(node)
+        return graph_sorted
+
+    def add(self, node: DAGNode) -> DAGNode:
+        # Set the TG first, as setting it might change the return value of 
node_id!
+        node.task_group = weakref.proxy(self)
+        if isinstance(node, SerializedTaskGroup):
+            if self.dag:
+                node.dag = self.dag
+        self.children[node.node_id] = node
+        return node
+
+
[email protected](kw_only=True, repr=False)
+class SerializedMappedTaskGroup(SerializedTaskGroup):
+    """Serialized representation of a MappedTaskGroup used in protected 
processes."""
+
+    _expand_input: SchedulerExpandInput = attrs.field(alias="expand_input")
+
+    is_mapped: ClassVar[bool] = True
+
+    @methodtools.lru_cache(maxsize=None)
+    def get_parse_time_mapped_ti_count(self) -> int:
+        """
+        Return the number of instances a task in this group should be mapped 
to.
+
+        This only considers literal mapped arguments, and would return *None*
+        when any non-literal values are used for mapping.
+
+        If this group is inside mapped task groups, all the nested counts are
+        multiplied and accounted.
+
+        :raise NotFullyPopulated: If any non-literal mapped arguments are 
encountered.
+        :return: The total number of mapped instances each task should have.
+        """
+        return functools.reduce(
+            operator.mul,
+            (g._expand_input.get_parse_time_mapped_ti_count() for g in 
self.iter_mapped_task_groups()),
+        )
+
+    def iter_mapped_dependencies(self) -> Iterator[SerializedOperator]:
+        """Upstream dependencies that provide XComs used by this mapped task 
group."""
+        from airflow.models.xcom_arg import SchedulerXComArg
+
+        for op, _ in SchedulerXComArg.iter_xcom_references(self._expand_input):
+            yield op
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 24ab2e106f1..090f5c21432 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -33,7 +33,18 @@ from collections.abc import Collection, Iterable, Iterator, 
Mapping, Sequence
 from functools import cached_property, lru_cache
 from inspect import signature
 from textwrap import dedent
-from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, 
TypeAlias, TypeVar, cast, overload
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    ClassVar,
+    Literal,
+    NamedTuple,
+    TypeAlias,
+    TypeGuard,
+    TypeVar,
+    cast,
+    overload,
+)
 
 import attrs
 import lazy_object_proxy
@@ -76,6 +87,7 @@ from airflow.sdk.definitions.taskgroup import 
MappedTaskGroup, TaskGroup
 from airflow.sdk.definitions.xcom_arg import serialize_xcom_arg
 from airflow.sdk.execution_time.context import OutletEventAccessor, 
OutletEventAccessors
 from airflow.serialization.dag_dependency import DagDependency
+from airflow.serialization.definitions.taskgroup import 
SerializedMappedTaskGroup, SerializedTaskGroup
 from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
 from airflow.serialization.helpers import serialize_template_field
 from airflow.serialization.json_schema import load_dag_schema
@@ -1235,8 +1247,10 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
     _task_display_name: str | None
     _weight_rule: str | PriorityWeightStrategy = "downstream"
 
-    dag: SerializedDAG | None = None
-    task_group: TaskGroup | None = None
+    # TODO (GH-52141): These should contain serialized containers, but 
currently
+    # this class inherits from an SDK one.
+    dag: SerializedDAG | None = None  # type: ignore[assignment]
+    task_group: SerializedTaskGroup | None = None  # type: ignore[assignment]
 
     allow_nested_operators: bool = True
     depends_on_past: bool = False
@@ -1664,7 +1678,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
         setattr(op, "start_from_trigger", 
bool(encoded_op.get("start_from_trigger", False)))
 
     @staticmethod
-    def set_task_dag_references(task: SerializedOperator, dag: SerializedDAG) 
-> None:
+    def set_task_dag_references(task: SerializedOperator | MappedOperator, 
dag: SerializedDAG) -> None:
         """
         Handle DAG references on an operator.
 
@@ -2147,7 +2161,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
 
         return result
 
-    def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | 
MappedTaskGroup]:
+    def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | 
SerializedMappedTaskGroup]:
         """
         Return mapped nodes that are direct dependencies of the current task.
 
@@ -2164,7 +2178,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
         :meth:`iter_mapped_dependants` instead.
         """
 
-        def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
+        def _walk_group(group: SerializedTaskGroup) -> Iterable[tuple[str, 
DAGNode]]:
             """
             Recursively walk children in a task group.
 
@@ -2173,7 +2187,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
             """
             for key, child in group.children.items():
                 yield key, child
-                if isinstance(child, TaskGroup):
+                if isinstance(child, SerializedTaskGroup):
                     yield from _walk_group(child)
 
         if not (dag := self.dag):
@@ -2181,12 +2195,12 @@ class SerializedBaseOperator(DAGNode, 
BaseSerialization):
         for key, child in _walk_group(dag.task_group):
             if key == self.node_id:
                 continue
-            if not isinstance(child, MappedOperator | MappedTaskGroup):
+            if not isinstance(child, MappedOperator | 
SerializedMappedTaskGroup):
                 continue
             if self.node_id in child.upstream_task_ids:
                 yield child
 
-    def iter_mapped_dependants(self) -> Iterator[MappedOperator | 
MappedTaskGroup]:
+    def iter_mapped_dependants(self) -> Iterator[MappedOperator | 
SerializedMappedTaskGroup]:
         """
         Return mapped nodes that depend on the current task the expansion.
 
@@ -2202,7 +2216,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
         )
 
     # TODO (GH-52141): Copied from sdk. Find a better place for this to live 
in.
-    def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
+    def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:
         """
         Return mapped task groups this task belongs to.
 
@@ -2215,7 +2229,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
         yield from group.iter_mapped_task_groups()
 
     # TODO (GH-52141): Copied from sdk. Find a better place for this to live 
in.
-    def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
+    def get_closest_mapped_task_group(self) -> SerializedMappedTaskGroup | 
None:
         """
         Get the mapped task group "closest" to this task in the DAG.
 
@@ -2310,7 +2324,6 @@ def _create_orm_dagrun(
     return run
 
 
[email protected](hash=False, repr=False, eq=False, slots=False)
 class SerializedDAG(DAG, BaseSerialization):
     """
     A JSON serializable representation of DAG.
@@ -2322,10 +2335,15 @@ class SerializedDAG(DAG, BaseSerialization):
 
     _decorated_fields: ClassVar[set[str]] = {"default_args", "access_control"}
 
-    last_loaded: datetime.datetime | None = attrs.field(init=False, 
factory=utcnow)
+    # TODO (GH-52141): These should contain serialized containers, but 
currently
+    # this class inherits from an SDK one.
+    task_group: SerializedTaskGroup  # type: ignore[assignment]
+    task_dict: dict[str, SerializedBaseOperator | SerializedMappedOperator]  # 
type: ignore[assignment]
+
+    last_loaded: datetime.datetime
     # this will only be set at serialization time
     # it's only use is for determining the relative fileloc based only on the 
serialize dag
-    _processor_dags_folder: str = attrs.field(init=False)
+    _processor_dags_folder: str
 
     @staticmethod
     def __get_constructor_defaults():
@@ -2404,6 +2422,7 @@ class SerializedDAG(DAG, BaseSerialization):
     ) -> SerializedDAG:
         """Handle the main Dag deserialization logic."""
         dag = SerializedDAG(dag_id=encoded_dag["dag_id"], schedule=None)
+        dag.last_loaded = utcnow()
 
         # Note: Context is passed explicitly through method parameters, no 
class attributes needed
 
@@ -2449,18 +2468,24 @@ class SerializedDAG(DAG, BaseSerialization):
             tg = TaskGroupSerialization.deserialize_task_group(
                 encoded_dag["task_group"],
                 None,
-                # TODO (GH-52141): SerializedDAG's task_dict should contain
-                # scheduler types instead, but currently it inherits SDK's DAG.
-                cast("dict[str, SerializedOperator]", dag.task_dict),
+                dag.task_dict,
                 dag,
             )
             object.__setattr__(dag, "task_group", tg)
         else:
-            # This must be old data that had no task_group. Create a root 
TaskGroup and add
-            # all tasks to it.
-            object.__setattr__(dag, "task_group", TaskGroup.create_root(dag))
+            # This must be old data that had no task_group. Create a root
+            # task group and add all tasks to it.
+            tg = SerializedTaskGroup(
+                group_id=None,
+                group_display_name=None,
+                prefix_group_id=True,
+                parent_group=None,
+                dag=dag,
+                tooltip="",
+            )
+            object.__setattr__(dag, "task_group", tg)
             for task in dag.tasks:
-                dag.task_group.add(task)
+                tg.add(task)
 
         # Set has_on_*_callbacks to True if they exist in Serialized blob as 
False is the default
         if "has_on_success_callback" in encoded_dag:
@@ -2475,10 +2500,8 @@ class SerializedDAG(DAG, BaseSerialization):
         for k in keys_to_set_none:
             setattr(dag, k, None)
 
-        # TODO (GH-52141): SerializedDAG's task_dict should contain scheduler
-        # types instead, but currently it inherits SDK's DAG.
-        for task in dag.task_dict.values():
-            
SerializedBaseOperator.set_task_dag_references(cast("SerializedOperator", 
task), dag)
+        for t in dag.task_dict.values():
+            SerializedBaseOperator.set_task_dag_references(t, dag)
 
         return dag
 
@@ -2705,6 +2728,125 @@ class SerializedDAG(DAG, BaseSerialization):
         dag_op.update_dag_asset_expression(orm_dags=orm_dags, 
orm_assets=orm_assets)
         session.flush()
 
+    # TODO (GH-52141): This needs to take scheduler types, but currently it 
inherits SDK's DAG.
+    # TODO (GH-52141): This shouldn't need to be writable, but SDK's DAG 
defines it as such.
+    @property  # type: ignore[misc]
+    def tasks(self) -> Sequence[SerializedOperator]:  # type: ignore[override]
+        return list(self.task_dict.values())
+
+    def partial_subset(
+        self,
+        task_ids: str | Iterable[str],
+        include_downstream: bool = False,
+        include_upstream: bool = True,
+        include_direct_upstream: bool = False,
+    ):
+        from airflow.models.mappedoperator import MappedOperator as 
SerializedMappedOperator
+
+        def is_task(obj) -> TypeGuard[SerializedOperator]:
+            return isinstance(obj, (SerializedMappedOperator, 
SerializedBaseOperator))
+
+        # deep-copying self.task_dict and self.task_group takes a long time, 
and we don't want all
+        # the tasks anyway, so we copy the tasks manually later
+        memo = {id(self.task_dict): None, id(self.task_group): None}
+        dag = copy.deepcopy(self, memo)
+
+        if isinstance(task_ids, str):
+            matched_tasks = [t for t in self.tasks if task_ids in t.task_id]
+        else:
+            matched_tasks = [t for t in self.tasks if t.task_id in task_ids]
+
+        also_include_ids: set[str] = set()
+        for t in matched_tasks:
+            if include_downstream:
+                for rel in t.get_flat_relatives(upstream=False):
+                    also_include_ids.add(rel.task_id)
+                    if rel not in matched_tasks:  # if it's in there, we're 
already processing it
+                        # need to include setups and teardowns for tasks that 
are in multiple
+                        # non-collinear setup/teardown paths
+                        if not rel.is_setup and not rel.is_teardown:
+                            also_include_ids.update(
+                                x.task_id for x in 
rel.get_upstreams_only_setups_and_teardowns()
+                            )
+            if include_upstream:
+                also_include_ids.update(x.task_id for x in 
t.get_upstreams_follow_setups())
+            else:
+                if not t.is_setup and not t.is_teardown:
+                    also_include_ids.update(x.task_id for x in 
t.get_upstreams_only_setups_and_teardowns())
+            if t.is_setup and not include_downstream:
+                also_include_ids.update(x.task_id for x in t.downstream_list 
if x.is_teardown)
+
+        also_include: list[SerializedOperator] = [self.task_dict[x] for x in 
also_include_ids]
+        direct_upstreams: list[SerializedOperator] = []
+        if include_direct_upstream:
+            for t in itertools.chain(matched_tasks, also_include):
+                upstream = (u for u in t.upstream_list if is_task(u))
+                direct_upstreams.extend(upstream)
+
+        # Make sure to not recursively deepcopy the dag or task_group while 
copying the task.
+        # task_group is reset later
+        def _deepcopy_task(t) -> SerializedOperator:
+            memo.setdefault(id(t.task_group), None)
+            return copy.deepcopy(t, memo)
+
+        # Compiling the unique list of tasks that made the cut
+        dag.task_dict = {
+            t.task_id: _deepcopy_task(t)
+            for t in itertools.chain(matched_tasks, also_include, 
direct_upstreams)
+        }
+
+        def filter_task_group(group, parent_group):
+            """Exclude tasks not included in the partial dag from the given 
TaskGroup."""
+            # We want to deepcopy _most but not all_ attributes of the task 
group, so we create a shallow copy
+            # and then manually deep copy the instances. (memo argument to 
deepcopy only works for instances
+            # of classes, not "native" properties of an instance)
+            copied = copy.copy(group)
+
+            memo[id(group.children)] = {}
+            if parent_group:
+                memo[id(group.parent_group)] = parent_group
+            for attr in type(group).__slots__:
+                value = getattr(group, attr)
+                value = copy.deepcopy(value, memo)
+                object.__setattr__(copied, attr, value)
+
+            proxy = weakref.proxy(copied)
+
+            for child in group.children.values():
+                if is_task(child):
+                    if child.task_id in dag.task_dict:
+                        task = copied.children[child.task_id] = 
dag.task_dict[child.task_id]
+                        task.task_group = proxy
+                else:
+                    filtered_child = filter_task_group(child, proxy)
+
+                    # Only include this child TaskGroup if it is non-empty.
+                    if filtered_child.children:
+                        copied.children[child.group_id] = filtered_child
+
+            return copied
+
+        object.__setattr__(dag, "task_group", 
filter_task_group(self.task_group, None))
+
+        # Removing upstream/downstream references to tasks and TaskGroups that 
did not make
+        # the cut.
+        groups = dag.task_group.get_task_group_dict()
+        for g in groups.values():
+            g.upstream_group_ids.intersection_update(groups)
+            g.downstream_group_ids.intersection_update(groups)
+            g.upstream_task_ids.intersection_update(dag.task_dict)
+            g.downstream_task_ids.intersection_update(dag.task_dict)
+
+        for t in dag.tasks:
+            # Removing upstream/downstream references to tasks that did not
+            # make the cut
+            t.upstream_task_ids.intersection_update(dag.task_dict)
+            t.downstream_task_ids.intersection_update(dag.task_dict)
+
+        dag.partial = len(dag.tasks) < len(self.tasks)
+
+        return dag
+
     @cached_property
     def _time_restriction(self) -> TimeRestriction:
         start_dates = [t.start_date for t in self.tasks if t.start_date]
@@ -3416,10 +3558,10 @@ class TaskGroupSerialization(BaseSerialization):
     def deserialize_task_group(
         cls,
         encoded_group: dict[str, Any],
-        parent_group: TaskGroup | None,
+        parent_group: SerializedTaskGroup | None,
         task_dict: dict[str, SerializedOperator],
         dag: SerializedDAG,
-    ) -> TaskGroup:
+    ) -> SerializedTaskGroup:
         """Deserializes a TaskGroup from a JSON object."""
         group_id = cls.deserialize(encoded_group["_group_id"])
         kwargs = {
@@ -3429,10 +3571,10 @@ class TaskGroupSerialization(BaseSerialization):
         kwargs["group_display_name"] = 
cls.deserialize(encoded_group.get("group_display_name", ""))
 
         if not encoded_group.get("is_mapped"):
-            group = TaskGroup(group_id=group_id, parent_group=parent_group, 
dag=dag, **kwargs)
+            group = SerializedTaskGroup(group_id=group_id, 
parent_group=parent_group, dag=dag, **kwargs)
         else:
             xi = encoded_group["expand_input"]
-            group = MappedTaskGroup(
+            group = SerializedMappedTaskGroup(
                 group_id=group_id,
                 parent_group=parent_group,
                 dag=dag,
diff --git a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py 
b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
index d5922074ef5..5e00d4b7b1b 100644
--- a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 from collections.abc import Iterator
-from typing import TYPE_CHECKING, TypeAlias, cast
+from typing import TYPE_CHECKING, TypeAlias
 
 from sqlalchemy import select
 
@@ -63,9 +63,7 @@ class MappedTaskUpstreamDep(BaseTIDep):
         elif is_mapped(ti.task):
             mapped_dependencies = ti.task.iter_mapped_dependencies()
         elif (task_group := ti.task.get_closest_mapped_task_group()) is not 
None:
-            # TODO (GH-52141): Task group in scheduler needs to return 
scheduler
-            # types instead, but currently the scheduler uses SDK's TaskGroup.
-            mapped_dependencies = cast("Iterator[Operator]", 
task_group.iter_mapped_dependencies())
+            mapped_dependencies = task_group.iter_mapped_dependencies()
         else:
             return
 
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 1298518b96e..971b156a067 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -21,12 +21,11 @@ import collections.abc
 import functools
 from collections import Counter
 from collections.abc import Iterator, KeysView
-from typing import TYPE_CHECKING, NamedTuple, cast
+from typing import TYPE_CHECKING, NamedTuple
 
 from sqlalchemy import and_, func, or_, select
 
 from airflow.models.taskinstance import PAST_DEPENDS_MET
-from airflow.sdk.definitions.taskgroup import MappedTaskGroup
 from airflow.task.trigger_rule import TriggerRule as TR
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.state import TaskInstanceState
@@ -35,10 +34,8 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
     from sqlalchemy.sql.expression import ColumnOperators
 
-    from airflow import DAG
-    from airflow.models.mappedoperator import MappedOperator
     from airflow.models.taskinstance import TaskInstance
-    from airflow.serialization.serialized_objects import SerializedBaseOperator
+    from airflow.serialization.definitions.taskgroup import 
SerializedMappedTaskGroup
     from airflow.ti_deps.dep_context import DepContext
     from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
 
@@ -131,6 +128,7 @@ class TriggerRuleDep(BaseTIDep):
         """
         from airflow.exceptions import NotMapped
         from airflow.models.expandinput import NotFullyPopulated
+        from airflow.models.mappedoperator import is_mapped
         from airflow.models.taskinstance import TaskInstance
 
         @functools.lru_cache
@@ -148,9 +146,7 @@ class TriggerRuleDep(BaseTIDep):
 
             return get_mapped_ti_count(ti.task, ti.run_id, session=session)
 
-        def _iter_expansion_dependencies(task_group: MappedTaskGroup) -> 
Iterator[str]:
-            from airflow.models.mappedoperator import is_mapped
-
+        def _iter_expansion_dependencies(task_group: 
SerializedMappedTaskGroup) -> Iterator[str]:
             if (task := ti.task) is not None and is_mapped(task):
                 for op in task.iter_mapped_dependencies():
                     yield op.task_id
@@ -172,9 +168,10 @@ class TriggerRuleDep(BaseTIDep):
             """
             if TYPE_CHECKING:
                 assert ti.task
-                assert isinstance(ti.task.dag, DAG)
+                assert ti.task.dag
+                assert ti.task.task_group
 
-            if isinstance(ti.task.task_group, MappedTaskGroup):
+            if is_mapped(ti.task.task_group):
                 is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS, 
TR.ONE_FAILED, TR.ONE_DONE)
                 if is_fast_triggered and upstream_id not in set(
                     _iter_expansion_dependencies(task_group=ti.task.task_group)
@@ -186,9 +183,7 @@ class TriggerRuleDep(BaseTIDep):
             except (NotFullyPopulated, NotMapped):
                 return None
             return ti.get_relevant_upstream_map_indexes(
-                # TODO (GH-52141): task_dict in scheduler should contain
-                # scheduler types instead, but currently it inherits SDK's DAG.
-                upstream=cast("MappedOperator | SerializedBaseOperator", 
ti.task.dag.task_dict[upstream_id]),
+                upstream=ti.task.dag.task_dict[upstream_id],
                 ti_count=expanded_ti_count,
                 session=session,
             )
diff --git a/airflow-core/src/airflow/utils/dot_renderer.py 
b/airflow-core/src/airflow/utils/dot_renderer.py
index 66b83492269..259b4ced252 100644
--- a/airflow-core/src/airflow/utils/dot_renderer.py
+++ b/airflow-core/src/airflow/utils/dot_renderer.py
@@ -24,9 +24,10 @@ import warnings
 from typing import TYPE_CHECKING, Any
 
 from airflow.exceptions import AirflowException
-from airflow.sdk import DAG, BaseOperator
+from airflow.models.mappedoperator import MappedOperator as 
SerializedMappedOperator
+from airflow.sdk import DAG, BaseOperator, TaskGroup
 from airflow.sdk.definitions.mappedoperator import MappedOperator
-from airflow.sdk.definitions.taskgroup import TaskGroup
+from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
 from airflow.serialization.serialized_objects import SerializedBaseOperator
 from airflow.utils.dag_edges import dag_edges
 from airflow.utils.state import State
@@ -69,7 +70,7 @@ def _refine_color(color: str):
 
 
 def _draw_task(
-    task: BaseOperator | MappedOperator | SerializedBaseOperator,
+    task: BaseOperator | MappedOperator | SerializedBaseOperator | 
SerializedMappedOperator,
     parent_graph: graphviz.Digraph,
     states_by_task_id: dict[Any, Any] | None,
 ) -> None:
@@ -95,7 +96,9 @@ def _draw_task(
 
 
 def _draw_task_group(
-    task_group: TaskGroup, parent_graph: graphviz.Digraph, states_by_task_id: 
dict[str, str] | None
+    task_group: TaskGroup | SerializedTaskGroup,
+    parent_graph: graphviz.Digraph,
+    states_by_task_id: dict[str, str] | None,
 ) -> None:
     """Draw the given task_group and its children on the given parent_graph."""
     # Draw joins
@@ -136,10 +139,10 @@ def _draw_nodes(
     node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id: 
dict[str, str] | None
 ) -> None:
     """Draw the node and its children on the given parent_graph recursively."""
-    if isinstance(node, (BaseOperator, MappedOperator, 
SerializedBaseOperator)):
+    if isinstance(node, (BaseOperator, MappedOperator, SerializedBaseOperator, 
SerializedMappedOperator)):
         _draw_task(node, parent_graph, states_by_task_id)
     else:
-        if not isinstance(node, TaskGroup):
+        if not isinstance(node, (SerializedTaskGroup, TaskGroup)):
             raise AirflowException(f"The node {node} should be TaskGroup and 
is not")
         # Draw TaskGroup
         if node.is_root:
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index 8d92eab5aa0..369f530077b 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -240,14 +240,11 @@ class TestDagRun:
             schedule=datetime.timedelta(days=1),
             start_date=timezone.datetime(2017, 1, 1),
         ) as dag:
-            ...
-        dag_task1 = ShortCircuitOperator(
-            task_id="test_short_circuit_false", dag=dag, 
python_callable=lambda: False
-        )
-        dag_task2 = EmptyOperator(task_id="test_state_skipped1", dag=dag)
-        dag_task3 = EmptyOperator(task_id="test_state_skipped2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
-        dag_task2.set_downstream(dag_task3)
+            dag_task1 = 
ShortCircuitOperator(task_id="test_short_circuit_false", python_callable=bool)
+            dag_task2 = EmptyOperator(task_id="test_state_skipped1")
+            dag_task3 = EmptyOperator(task_id="test_state_skipped2")
+            dag_task1.set_downstream(dag_task2)
+            dag_task2.set_downstream(dag_task3)
 
         initial_task_states = {
             "test_short_circuit_false": TaskInstanceState.SUCCESS,
@@ -268,14 +265,11 @@ class TestDagRun:
             schedule=datetime.timedelta(days=1),
             start_date=timezone.datetime(2017, 1, 1),
         ) as dag:
-            ...
-        dag_task1 = ShortCircuitOperator(
-            task_id="test_short_circuit_false", dag=dag, 
python_callable=lambda: False
-        )
-        dag_task2 = EmptyOperator(task_id="test_state_skipped1", dag=dag)
-        dag_task3 = EmptyOperator(task_id="test_state_skipped2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
-        dag_task2.set_downstream(dag_task3)
+            dag_task1 = 
ShortCircuitOperator(task_id="test_short_circuit_false", python_callable=bool)
+            dag_task2 = EmptyOperator(task_id="test_state_skipped1")
+            dag_task3 = EmptyOperator(task_id="test_state_skipped2")
+            dag_task1.set_downstream(dag_task2)
+            dag_task2.set_downstream(dag_task3)
 
         initial_task_states = {
             "test_short_circuit_false": TaskInstanceState.REMOVED,
@@ -397,19 +391,15 @@ class TestDagRun:
             start_date=datetime.datetime(2017, 1, 1),
             on_success_callback=on_success_callable,
         ) as dag:
-            ...
-        dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
-        dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
+            dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+            dag_task2 = EmptyOperator(task_id="test_state_succeeded2")
+            dag_task1.set_downstream(dag_task2)
 
         initial_task_states = {
             "test_state_succeeded1": TaskInstanceState.SUCCESS,
             "test_state_succeeded2": TaskInstanceState.SUCCESS,
         }
 
-        # Scheduler uses Serialized DAG -- so use that instead of the Actual 
DAG
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
         dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
         _, callback = dag_run.update_state()
         assert dag_run.state == DagRunState.SUCCESS
@@ -426,9 +416,8 @@ class TestDagRun:
             start_date=datetime.datetime(2017, 1, 1),
             on_failure_callback=on_failure_callable,
         ) as dag:
-            ...
-        dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
-        dag_task2 = EmptyOperator(task_id="test_state_failed2", dag=dag)
+            dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+            dag_task2 = EmptyOperator(task_id="test_state_failed2")
 
         initial_task_states = {
             "test_state_succeeded1": TaskInstanceState.SUCCESS,
@@ -436,9 +425,6 @@ class TestDagRun:
         }
         dag_task1.set_downstream(dag_task2)
 
-        # Scheduler uses Serialized DAG -- so use that instead of the Actual 
DAG
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
         dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
         _, callback = dag_run.update_state()
         assert dag_run.state == DagRunState.FAILED
@@ -481,27 +467,21 @@ class TestDagRun:
         assert dag_run.state == DagRunState.SUCCESS
         mock_on_success.assert_called_once()
 
-    def test_start_dr_spans_if_needed_new_span(self, testing_dag_bundle, 
dag_maker, session):
+    def test_start_dr_spans_if_needed_new_span(self, dag_maker, session):
         with dag_maker(
             dag_id="test_start_dr_spans_if_needed_new_span",
             schedule=datetime.timedelta(days=1),
             start_date=datetime.datetime(2017, 1, 1),
         ) as dag:
-            ...
-        SerializedDAG.bulk_write_to_db("testing", None, dags=[dag], 
session=session)
-
-        dag_task1 = EmptyOperator(task_id="test_task1", dag=dag)
-        dag_task2 = EmptyOperator(task_id="test_task2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
+            dag_task1 = EmptyOperator(task_id="test_task1")
+            dag_task2 = EmptyOperator(task_id="test_task2")
+            dag_task1.set_downstream(dag_task2)
 
         initial_task_states = {
             "test_task1": TaskInstanceState.QUEUED,
             "test_task2": TaskInstanceState.QUEUED,
         }
 
-        # Scheduler uses Serialized DAG -- so use that instead of the Actual 
DAG
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
         dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
 
         active_spans = ThreadSafeDict()
@@ -518,27 +498,21 @@ class TestDagRun:
         assert dag_run.span_status == SpanStatus.ACTIVE
         assert dag_run.active_spans.get("dr:" + str(dag_run.id)) is not None
 
-    def test_start_dr_spans_if_needed_span_with_continuance(self, 
testing_dag_bundle, dag_maker, session):
+    def test_start_dr_spans_if_needed_span_with_continuance(self, dag_maker, 
session):
         with dag_maker(
             dag_id="test_start_dr_spans_if_needed_span_with_continuance",
             schedule=datetime.timedelta(days=1),
             start_date=datetime.datetime(2017, 1, 1),
         ) as dag:
-            ...
-        SerializedDAG.bulk_write_to_db("testing", None, dags=[dag], 
session=session)
-
-        dag_task1 = EmptyOperator(task_id="test_task1", dag=dag)
-        dag_task2 = EmptyOperator(task_id="test_task2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
+            dag_task1 = EmptyOperator(task_id="test_task1")
+            dag_task2 = EmptyOperator(task_id="test_task2")
+            dag_task1.set_downstream(dag_task2)
 
         initial_task_states = {
             "test_task1": TaskInstanceState.RUNNING,
             "test_task2": TaskInstanceState.QUEUED,
         }
 
-        # Scheduler uses Serialized DAG -- so use that instead of the Actual 
DAG
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
         dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
 
         active_spans = ThreadSafeDict()
@@ -570,21 +544,15 @@ class TestDagRun:
             schedule=datetime.timedelta(days=1),
             start_date=datetime.datetime(2017, 1, 1),
         ) as dag:
-            ...
-        SerializedDAG.bulk_write_to_db("testing", None, dags=[dag], 
session=session)
-
-        dag_task1 = EmptyOperator(task_id="test_task1", dag=dag)
-        dag_task2 = EmptyOperator(task_id="test_task2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
+            dag_task1 = EmptyOperator(task_id="test_task1")
+            dag_task2 = EmptyOperator(task_id="test_task2")
+            dag_task1.set_downstream(dag_task2)
 
         initial_task_states = {
             "test_task1": TaskInstanceState.SUCCESS,
             "test_task2": TaskInstanceState.SUCCESS,
         }
 
-        # Scheduler uses Serialized DAG -- so use that instead of the Actual 
DAG
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
         dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
 
         active_spans = ThreadSafeDict()
@@ -612,21 +580,15 @@ class TestDagRun:
             schedule=datetime.timedelta(days=1),
             start_date=datetime.datetime(2017, 1, 1),
         ) as dag:
-            ...
-        SerializedDAG.bulk_write_to_db("testing", None, dags=[dag], 
session=session)
-
-        dag_task1 = EmptyOperator(task_id="test_task1", dag=dag)
-        dag_task2 = EmptyOperator(task_id="test_task2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
+            dag_task1 = EmptyOperator(task_id="test_task1")
+            dag_task2 = EmptyOperator(task_id="test_task2")
+            dag_task1.set_downstream(dag_task2)
 
         initial_task_states = {
             "test_task1": TaskInstanceState.SUCCESS,
             "test_task2": TaskInstanceState.SUCCESS,
         }
 
-        # Scheduler uses Serialized DAG -- so use that instead of the Actual 
DAG
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
         dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
 
         active_spans = ThreadSafeDict()
@@ -652,23 +614,18 @@ class TestDagRun:
             start_date=datetime.datetime(2017, 1, 1),
             on_success_callback=on_success_callable,
         ) as dag:
-            ...
+            dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+            dag_task2 = EmptyOperator(task_id="test_state_succeeded2")
+            dag_task1.set_downstream(dag_task2)
         dm = DagModel.get_dagmodel(dag.dag_id, session=session)
         dm.relative_fileloc = relative_fileloc
         session.merge(dm)
         session.commit()
 
-        dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
-        dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
-
         initial_task_states = {
             "test_state_succeeded1": TaskInstanceState.SUCCESS,
             "test_state_succeeded2": TaskInstanceState.SUCCESS,
         }
-
-        # Scheduler uses Serialized DAG -- so use that instead of the Actual 
DAG
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
         dag.relative_fileloc = relative_fileloc
         SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name="dag_maker")
         session.commit()
@@ -704,23 +661,18 @@ class TestDagRun:
             start_date=datetime.datetime(2017, 1, 1),
             on_failure_callback=on_failure_callable,
         ) as dag:
-            ...
+            dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+            dag_task2 = EmptyOperator(task_id="test_state_failed2")
+            dag_task1.set_downstream(dag_task2)
         dm = DagModel.get_dagmodel(dag.dag_id, session=session)
         dm.relative_fileloc = relative_fileloc
         session.merge(dm)
         session.commit()
 
-        dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
-        dag_task2 = EmptyOperator(task_id="test_state_failed2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
-
         initial_task_states = {
             "test_state_succeeded1": TaskInstanceState.SUCCESS,
             "test_state_failed2": TaskInstanceState.FAILED,
         }
-
-        # Scheduler uses Serialized DAG -- so use that instead of the Actual 
DAG
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
         dag.relative_fileloc = relative_fileloc
         SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name="dag_maker")
         session.commit()
@@ -873,34 +825,32 @@ class TestDagRun:
                 assert dagrun.logical_date == timezone.datetime(2015, 1, 2)
 
     def test_removed_task_instances_can_be_restored(self, dag_maker, session):
-        def with_all_tasks_removed(dag):
-            with dag_maker(
-                dag_id=dag.dag_id,
+        def create_dag():
+            return dag_maker(
+                dag_id="test_task_restoration",
                 schedule=datetime.timedelta(days=1),
-                start_date=dag.start_date,
-            ) as dag:
-                pass
-            return dag
+                start_date=DEFAULT_DATE,
+            )
 
-        with dag_maker(
-            "test_task_restoration",
-            schedule=datetime.timedelta(days=1),
-            start_date=DEFAULT_DATE,
-        ) as ori_dag:
+        with create_dag() as dag:
             EmptyOperator(task_id="flaky_task", owner="test")
 
-        dagrun = self.create_dag_run(ori_dag, session=session)
+        dagrun = self.create_dag_run(dag, session=session)
         flaky_ti = dagrun.get_task_instances()[0]
         assert flaky_ti.task_id == "flaky_task"
         assert flaky_ti.state is None
 
-        dagrun.dag = with_all_tasks_removed(ori_dag)
-        dag_version_id = DagVersion.get_latest_version(ori_dag.dag_id, 
session=session).id
+        with create_dag() as dag:
+            pass
+
+        dagrun.dag = dag
+        dag_version_id = DagVersion.get_latest_version(dag.dag_id, 
session=session).id
         dagrun.verify_integrity(dag_version_id=dag_version_id)
         flaky_ti.refresh_from_db()
         assert flaky_ti.state is None
 
-        dagrun.dag.add_task(ori_dag.task_dict["flaky_task"])
+        with create_dag() as dag:
+            EmptyOperator(task_id="flaky_task", owner="test")
 
         dagrun.verify_integrity(dag_version_id=dag_version_id)
         flaky_ti.refresh_from_db()
@@ -1211,9 +1161,8 @@ class TestDagRun:
         with dag_maker(
             dag_id="test_dagrun_states", schedule=datetime.timedelta(days=1), 
start_date=DEFAULT_DATE
         ) as dag:
-            ...
-        dag_task_success = EmptyOperator(task_id="dummy", dag=dag)
-        dag_task_failed = EmptyOperator(task_id="dummy2", dag=dag)
+            dag_task_success = EmptyOperator(task_id="dummy")
+            dag_task_failed = EmptyOperator(task_id="dummy2")
 
         initial_task_states = {
             dag_task_success.task_id: TaskInstanceState.SUCCESS,
@@ -1319,10 +1268,9 @@ class TestDagRun:
                 callback=AsyncCallback(empty_callback_for_deadline),
             ),
         ) as dag:
-            ...
-        dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
-        dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag)
-        dag_task1.set_downstream(dag_task2)
+            dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+            dag_task2 = EmptyOperator(task_id="test_state_succeeded2")
+            dag_task1.set_downstream(dag_task2)
 
         initial_task_states = {
             "test_state_succeeded1": TaskInstanceState.SUCCESS,
@@ -1330,7 +1278,6 @@ class TestDagRun:
         }
 
         # Scheduler uses Serialized DAG -- so use that instead of the Actual 
DAG.
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
         dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
         dag_run = session.merge(dag_run)
         dag_run.dag = dag
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 094b8f362e8..0df775ca41e 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -151,6 +151,7 @@ class TestTaskInstance:
     def teardown_method(self):
         self.clean_db()
 
+    @pytest.mark.need_serialized_dag(False)
     def test_set_task_dates(self, dag_maker):
         """
         Test that tasks properly take start/end dates from DAGs
@@ -159,7 +160,6 @@ class TestTaskInstance:
             pass
 
         op1 = EmptyOperator(task_id="op_1")
-
         assert op1.start_date is None
         assert op1.end_date is None
 
@@ -190,6 +190,7 @@ class TestTaskInstance:
         assert op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1)
         assert op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9)
 
+    @pytest.mark.need_serialized_dag(False)
     def test_set_dag(self, dag_maker):
         """
         Test assigning Operators to Dags, including deferred assignment
@@ -2417,23 +2418,25 @@ class TestTaskInstance:
 
     def test_handle_failure_fail_fast(self, dag_maker, session):
         start_date = timezone.datetime(2016, 6, 1)
-        clear_db_runs()
 
         class CustomOp(BaseOperator):
             def execute(self, context): ...
 
+        reg_states = [State.RUNNING, State.FAILED, State.QUEUED, 
State.SCHEDULED, State.DEFERRED]
+
         with dag_maker(
             dag_id="test_handle_failure_fail_fast",
             start_date=start_date,
             schedule=None,
             fail_fast=True,
-        ) as dag:
-            task1 = CustomOp(task_id="task1", trigger_rule="all_success")
-
-        dag_maker.create_dagrun(run_type=DagRunType.MANUAL, 
start_date=start_date)
+        ):
+            CustomOp(task_id="task1", trigger_rule="all_success")
+            for i, _ in enumerate(reg_states):
+                CustomOp(task_id=f"reg_Task{i}")
+            CustomOp(task_id="fail_Task")
 
         logical_date = timezone.utcnow()
-        dr = dag.create_dagrun(
+        dr = dag_maker.create_dagrun(
             run_id="test_ff",
             run_type=DagRunType.MANUAL,
             logical_date=logical_date,
@@ -2445,31 +2448,23 @@ class TestTaskInstance:
         )
         dr.set_state(DagRunState.SUCCESS)
 
-        ti1 = dr.get_task_instance(task1.task_id, session=session)
-        ti1.task = task1
-        ti1.state = State.SUCCESS
-
-        states = [State.RUNNING, State.FAILED, State.QUEUED, State.SCHEDULED, 
State.DEFERRED]
-        tasks = []
-        for i, state in enumerate(states):
-            op = CustomOp(task_id=f"reg_Task{i}", dag=dag)
-            ti = TI(task=op, run_id=dr.run_id, 
dag_version_id=ti1.dag_version_id)
-            ti.state = state
-            session.add(ti)
-            tasks.append(ti)
-
-        fail_task = CustomOp(task_id="fail_Task", dag=dag)
-        ti_ff = TI(task=fail_task, run_id=dr.run_id, 
dag_version_id=ti1.dag_version_id)
-        ti_ff.state = State.FAILED
-        session.add(ti_ff)
-        session.commit()
-        ti_ff.handle_failure("test retry handling")
+        tis = {ti.task_id: ti for ti in dr.task_instances}
+        tis["task1"].state = State.SUCCESS
+        for i, state in enumerate(reg_states):
+            tis[f"reg_Task{i}"].state = state
+        tis["fail_Task"].state = State.FAILED
+        session.flush()
 
-        assert ti1.state == State.SUCCESS
-        assert ti_ff.state == State.FAILED
-        exp_states = [State.FAILED, State.FAILED, State.SKIPPED, 
State.SKIPPED, State.SKIPPED]
-        for i in range(len(states)):
-            assert tasks[i].state == exp_states[i]
+        tis["fail_Task"].handle_failure("test retry handling")
+        assert {task_id: ti.state for task_id, ti in tis.items()} == {
+            "task1": State.SUCCESS,
+            "fail_Task": State.FAILED,
+            "reg_Task0": State.FAILED,
+            "reg_Task1": State.FAILED,
+            "reg_Task2": State.SKIPPED,
+            "reg_Task3": State.SKIPPED,
+            "reg_Task4": State.SKIPPED,
+        }
 
     def test_does_not_retry_on_airflow_fail_exception(self, dag_maker):
         def fail():
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py 
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 6d71851dc87..96b26f69335 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -2914,7 +2914,7 @@ def test_taskflow_expand_kwargs_serde(strict):
 def test_mapped_task_group_serde():
     from airflow.models.expandinput import SchedulerDictOfListsExpandInput
     from airflow.sdk.definitions.decorators.task_group import task_group
-    from airflow.sdk.definitions.taskgroup import MappedTaskGroup
+    from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
 
     with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as 
dag:
 
@@ -2955,7 +2955,7 @@ def test_mapped_task_group_serde():
 
     serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR])
     serde_tg = serde_dag.task_group.children["tg"]
-    assert isinstance(serde_tg, MappedTaskGroup)
+    assert isinstance(serde_tg, SerializedTaskGroup)
     assert serde_tg._expand_input == SchedulerDictOfListsExpandInput({"a": 
[".", ".."]})
 
 
diff --git a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py 
b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
index 75c6888509a..5d2eda6dfac 100644
--- a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
+++ b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
@@ -1419,12 +1419,13 @@ def 
test_upstream_in_mapped_group_when_mapped_tasks_list_is_empty(dag_maker, ses
 
 
 @pytest.mark.parametrize("flag_upstream_failed", [True, False])
[email protected]_serialized_dag
 def test_mapped_task_check_before_expand(dag_maker, session, 
flag_upstream_failed):
     """
     t3 depends on t2, which depends on t1 for expansion. Since t1 has not yet 
run, t2 has not expanded yet,
     and we need to guarantee this lack of expansion does not fail the 
dependency-checking logic.
     """
-    with dag_maker(session=session):
+    with dag_maker(session=session) as dag:
 
         @task
         def t(x):
@@ -1439,9 +1440,11 @@ def test_mapped_task_check_before_expand(dag_maker, 
session, flag_upstream_faile
         tg.expand(a=t([1, 2, 3]))
 
     dr: DagRun = dag_maker.create_dagrun()
+    ti = next(ti for ti in dr.task_instances if ti.task_id == "tg.t3" and 
ti.map_index == -1)
+    ti.refresh_from_task(dag.get_task(ti.task_id))
 
     _test_trigger_rule(
-        ti=next(ti for ti in dr.task_instances if ti.task_id == "tg.t3" and 
ti.map_index == -1),
+        ti=ti,
         session=session,
         flag_upstream_failed=flag_upstream_failed,
         expected_reason="requires all upstream tasks to have succeeded, but 
found 1",
@@ -1449,6 +1452,7 @@ def test_mapped_task_check_before_expand(dag_maker, 
session, flag_upstream_faile
 
 
 @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, 
SKIPPED), (False, None)])
[email protected]_serialized_dag
 def test_mapped_task_group_finished_upstream_before_expand(
     dag_maker, session, flag_upstream_failed, expected_ti_state
 ):
@@ -1456,7 +1460,7 @@ def 
test_mapped_task_group_finished_upstream_before_expand(
     t3 depends on t2, which was skipped before it was expanded. We need to 
guarantee this lack of expansion
     does not fail the dependency-checking logic.
     """
-    with dag_maker(session=session):
+    with dag_maker(session=session) as dag:
 
         @task
         def t(x):
@@ -1472,6 +1476,8 @@ def 
test_mapped_task_group_finished_upstream_before_expand(
     tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)}
     tis["t2"].set_state(SKIPPED, session=session)
     session.flush()
+
+    tis["tg.t3"].refresh_from_task(dag.get_task("tg.t3"))
     _test_trigger_rule(
         ti=tis["tg.t3"],
         session=session,
@@ -1734,6 +1740,7 @@ def 
test_setup_constraint_wait_for_past_depends_before_skipping(
 
 
 @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, 
SKIPPED), (False, None)])
[email protected]_serialized_dag
 def test_setup_mapped_task_group_finished_upstream_before_expand(
     dag_maker, session, flag_upstream_failed, expected_ti_state
 ):
@@ -1741,7 +1748,7 @@ def 
test_setup_mapped_task_group_finished_upstream_before_expand(
     t3 indirectly depends on t1, which was skipped before it was expanded. We 
need to guarantee this lack of
     expansion does not fail the dependency-checking logic.
     """
-    with dag_maker(session=session):
+    with dag_maker(session=session) as dag:
 
         @task(trigger_rule=TriggerRule.ALL_DONE)
         def t(x):
@@ -1760,6 +1767,8 @@ def 
test_setup_mapped_task_group_finished_upstream_before_expand(
     tis["t1"].set_state(SKIPPED, session=session)
     tis["t2"].set_state(SUCCESS, session=session)
     session.flush()
+
+    tis["tg.t3"].refresh_from_task(dag.get_task("tg.t3"))
     _test_trigger_rule(
         ti=tis["tg.t3"],
         session=session,
diff --git a/airflow-core/tests/unit/utils/test_task_group.py 
b/airflow-core/tests/unit/utils/test_task_group.py
index 447578329aa..52524ecda2f 100644
--- a/airflow-core/tests/unit/utils/test_task_group.py
+++ b/airflow-core/tests/unit/utils/test_task_group.py
@@ -21,20 +21,21 @@ import pendulum
 import pytest
 
 from airflow.api_fastapi.core_api.services.ui.task_group import 
task_group_to_dict
-from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import DAG
+from airflow.providers.standard.operators.bash import BashOperator
 from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.providers.standard.operators.python import PythonOperator
 from airflow.sdk import (
+    DAG,
+    BaseOperator,
+    TaskGroup,
     setup,
     task as task_decorator,
     task_group as task_group_decorator,
     teardown,
 )
-from airflow.sdk.definitions.taskgroup import TaskGroup
 from airflow.serialization.serialized_objects import SerializedDAG
 from airflow.utils.dag_edges import dag_edges
 
-from tests_common.test_utils.compat import BashOperator, PythonOperator
 from unit.models import DEFAULT_DATE
 
 pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag]
@@ -157,6 +158,7 @@ EXPECTED_JSON_LEGACY = {
 EXPECTED_JSON = {
     "children": [
         {"id": "task1", "label": "task1", "operator": "EmptyOperator", "type": 
"task"},
+        {"id": "task5", "label": "task5", "operator": "EmptyOperator", "type": 
"task"},
         {
             "children": [
                 {
@@ -195,11 +197,10 @@ EXPECTED_JSON = {
             "tooltip": "",
             "type": "task",
         },
-        {"id": "task5", "label": "task5", "operator": "EmptyOperator", "type": 
"task"},
     ],
     "id": None,
     "is_mapped": False,
-    "label": None,
+    "label": "",
     "tooltip": "",
     "type": "task",
 }
@@ -276,7 +277,10 @@ def test_task_group_to_dict_with_prefix(dag_maker):
     expected_node_id = {
         "children": [
             {"id": "task1", "label": "task1"},
+            {"id": "task5", "label": "task5"},
             {
+                "id": "group234",
+                "label": "group234",
                 "children": [
                     {
                         "children": [
@@ -294,13 +298,10 @@ def test_task_group_to_dict_with_prefix(dag_maker):
                     {"id": "task2", "label": "task2"},
                     {"id": "group234.upstream_join_id", "label": ""},
                 ],
-                "id": "group234",
-                "label": "group234",
             },
-            {"id": "task5", "label": "task5"},
         ],
         "id": None,
-        "label": None,
+        "label": "",
     }
 
     assert extract_node_id(task_group_to_dict(dag.task_group), 
include_label=True) == expected_node_id
@@ -346,6 +347,7 @@ def test_task_group_to_dict_with_task_decorator(dag_maker):
         "id": None,
         "children": [
             {"id": "task_1"},
+            {"id": "task_5"},
             {
                 "id": "group234",
                 "children": [
@@ -356,7 +358,6 @@ def test_task_group_to_dict_with_task_decorator(dag_maker):
                     {"id": "group234.downstream_join_id"},
                 ],
             },
-            {"id": "task_5"},
         ],
     }
 
@@ -402,6 +403,7 @@ def test_task_group_to_dict_sub_dag(dag_maker):
         "id": None,
         "children": [
             {"id": "task1"},
+            {"id": "task5"},
             {
                 "id": "group234",
                 "children": [
@@ -416,7 +418,6 @@ def test_task_group_to_dict_sub_dag(dag_maker):
                     {"id": "group234.upstream_join_id"},
                 ],
             },
-            {"id": "task5"},
         ],
     }
 
@@ -477,6 +478,16 @@ def test_task_group_to_dict_and_dag_edges(dag_maker):
     expected_node_id = {
         "id": None,
         "children": [
+            {
+                "id": "group_c",
+                "children": [
+                    {"id": "group_c.task6"},
+                    {"id": "group_c.task7"},
+                    {"id": "group_c.task8"},
+                    {"id": "group_c.upstream_join_id"},
+                    {"id": "group_c.downstream_join_id"},
+                ],
+            },
             {
                 "id": "group_d",
                 "children": [
@@ -486,6 +497,8 @@ def test_task_group_to_dict_and_dag_edges(dag_maker):
                 ],
             },
             {"id": "task1"},
+            {"id": "task10"},
+            {"id": "task9"},
             {
                 "id": "group_a",
                 "children": [
@@ -503,18 +516,6 @@ def test_task_group_to_dict_and_dag_edges(dag_maker):
                     {"id": "group_a.downstream_join_id"},
                 ],
             },
-            {
-                "id": "group_c",
-                "children": [
-                    {"id": "group_c.task6"},
-                    {"id": "group_c.task7"},
-                    {"id": "group_c.task8"},
-                    {"id": "group_c.upstream_join_id"},
-                    {"id": "group_c.downstream_join_id"},
-                ],
-            },
-            {"id": "task10"},
-            {"id": "task9"},
         ],
     }
 
@@ -783,6 +784,7 @@ def test_task_group_context_mix(dag_maker):
     node_ids = {
         "id": None,
         "children": [
+            {"id": "task_end"},
             {"id": "task_start"},
             {
                 "id": "section_1",
@@ -802,7 +804,6 @@ def test_task_group_context_mix(dag_maker):
                     {"id": "section_1.downstream_join_id"},
                 ],
             },
-            {"id": "task_end"},
         ],
     }
 
@@ -1184,7 +1185,7 @@ def test_task_group_display_name_used_as_label(dag_maker):
     assert tg.label == "my_custom_name"
     expected_node_id = {
         "id": None,
-        "label": None,
+        "label": "",
         "children": [
             {
                 "id": "tg",
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index af87248da46..ce588352d4e 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -43,7 +43,6 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
     from airflow.models.dagrun import DagRun, DagRunType
-    from airflow.models.mappedoperator import MappedOperator
     from airflow.models.taskinstance import TaskInstance
     from airflow.providers.standard.operators.empty import EmptyOperator
     from airflow.sdk import DAG, BaseOperator, Context, TriggerRule
@@ -51,8 +50,8 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.dag import ScheduleArg
     from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
     from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
-    from airflow.sdk.types import DagRunProtocol
-    from airflow.serialization.serialized_objects import 
SerializedBaseOperator, SerializedDAG
+    from airflow.sdk.types import DagRunProtocol, Operator
+    from airflow.serialization.serialized_objects import SerializedDAG
     from airflow.timetables.base import DataInterval
     from airflow.typing_compat import Self
     from airflow.utils.state import DagRunState, TaskInstanceState
@@ -892,10 +891,17 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
 
             self.dag.__enter__()
             if self.want_serialized:
+                factory = self
 
                 class DAGProxy(lazy_object_proxy.Proxy):
-                    # Make `@dag.task` decorator work when need_serialized_dag 
marker is set
-                    task = self.dag.task
+                    """Wrapper to make test patterns work with serialized 
dag."""
+
+                    task = factory.dag.task  # Expose the @dag.task decorator.
+
+                    # When adding a task to the dag, automatically 
re-serialize.
+                    def add_task(self, task):
+                        factory.dag.add_task(task)
+                        factory._make_serdag(factory.dag)
 
                 return DAGProxy(self._serialized_dag)
             return self.dag
@@ -2310,13 +2316,12 @@ def create_runtime_ti(mocked_parse):
     from airflow.sdk import DAG
     from airflow.sdk.api.datamodels._generated import TaskInstance
     from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
-    from airflow.serialization.serialized_objects import SerializedDAG
     from airflow.timetables.base import TimeRestriction
 
     timezone = _import_timezone()
 
     def _create_task_instance(
-        task: MappedOperator | SerializedBaseOperator,
+        task: Operator,
         dag_id: str = "test_dag",
         run_id: str = "test_run",
         logical_date: str | datetime | None = "2024-12-01T01:00:00Z",
@@ -2353,14 +2358,7 @@ def create_runtime_ti(mocked_parse):
             ti_id = uuid7()
 
         if not task.has_dag():
-            dag = SerializedDAG.deserialize_dag(
-                SerializedDAG.serialize_dag(DAG(dag_id=dag_id, 
start_date=timezone.datetime(2024, 12, 3)))
-            )
-            # Fixture only helps in regular base operator tasks, so mypy is 
wrong here
-            task.dag = dag
-            # TODO (GH-52141): Scheduler DAG should contain scheduler tasks, 
but
-            # currently this inherits from SDK DAG.
-            task = dag.task_dict[task.task_id]  # type: ignore[assignment]
+            task.dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 
12, 3))
 
         if TYPE_CHECKING:
             assert task.dag is not None
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py 
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 03b694f57d1..849de411cdf 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -29,13 +29,7 @@ from collections import abc, defaultdict, deque
 from collections.abc import Callable, Collection, Iterable, MutableSet
 from datetime import datetime, timedelta
 from inspect import signature
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    ClassVar,
-    cast,
-    overload,
-)
+from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast, overload
 from urllib.parse import urlsplit
 
 import attrs
@@ -829,15 +823,9 @@ class DAG:
         :param include_direct_upstream: Include all tasks directly upstream of 
matched
             and downstream (if include_downstream = True) tasks
         """
-        from typing import TypeGuard
-
-        from airflow.models.mappedoperator import MappedOperator as 
DbMappedOperator
         from airflow.sdk.definitions.mappedoperator import MappedOperator
-        from airflow.serialization.serialized_objects import 
SerializedBaseOperator
 
         def is_task(obj) -> TypeGuard[Operator]:
-            if isinstance(obj, (DbMappedOperator, SerializedBaseOperator)):
-                return True  # TODO (GH-52141): Split DAG implementation to 
straight this up.
             return isinstance(obj, (BaseOperator, MappedOperator))
 
         # deep-copying self.task_dict and self.task_group takes a long time, 
and we don't want all
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 45de93b7da8..b2f9aa6b909 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -40,8 +40,6 @@ from airflow.sdk.definitions._internal.node import DAGNode, 
validate_group_key
 from airflow.sdk.exceptions import AirflowDagCycleException
 
 if TYPE_CHECKING:
-    from airflow.models.expandinput import SchedulerExpandInput
-    from airflow.models.mappedoperator import MappedOperator
     from airflow.sdk.bases.operator import BaseOperator
     from airflow.sdk.definitions._internal.abstractoperator import 
AbstractOperator
     from airflow.sdk.definitions._internal.expandinput import 
DictOfListsExpandInput, ListOfDictsExpandInput
@@ -50,7 +48,6 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.edges import EdgeModifier
     from airflow.sdk.types import Operator
     from airflow.serialization.enums import DagAttributeTypes
-    from airflow.serialization.serialized_objects import SerializedBaseOperator
 
 
 def _default_parent_group() -> TaskGroup | None:
@@ -274,10 +271,14 @@ class TaskGroup(DAGNode):
     @property
     def group_id(self) -> str | None:
         """group_id of this TaskGroup."""
-        if self.parent_group and self.parent_group.prefix_group_id and 
self.parent_group._group_id:
+        if (
+            self._group_id
+            and self.parent_group
+            and self.parent_group.prefix_group_id
+            and self.parent_group._group_id
+        ):
             # defer to parent whether it adds a prefix
             return self.parent_group.child_id(self._group_id)
-
         return self._group_id
 
     @property
@@ -585,12 +586,9 @@ class TaskGroup(DAGNode):
                 yield group
             group = group.parent_group
 
-    # TODO (GH-52141): This should only return SDK operators. Have a db 
representation for db operators.
-    def iter_tasks(self) -> Iterator[AbstractOperator | MappedOperator | 
SerializedBaseOperator]:
+    def iter_tasks(self) -> Iterator[AbstractOperator]:
         """Return an iterator of the child tasks."""
-        from airflow.models.mappedoperator import MappedOperator
         from airflow.sdk.definitions._internal.abstractoperator import 
AbstractOperator
-        from airflow.serialization.serialized_objects import 
SerializedBaseOperator
 
         groups_to_visit = [self]
 
@@ -598,16 +596,18 @@ class TaskGroup(DAGNode):
             visiting = groups_to_visit.pop(0)
 
             for child in visiting.children.values():
-                if isinstance(child, (AbstractOperator, MappedOperator, 
SerializedBaseOperator)):
+                if isinstance(child, AbstractOperator):
                     yield child
                 elif isinstance(child, TaskGroup):
                     groups_to_visit.append(child)
                 else:
                     raise ValueError(
-                        f"Encountered a DAGNode that is not a TaskGroup or an 
AbstractOperator: {type(child).__module__}.{type(child)}"
+                        f"Encountered a DAGNode that is not a TaskGroup or an "
+                        f"AbstractOperator: 
{type(child).__module__}.{type(child)}"
                     )
 
 
[email protected](kw_only=True, repr=False)
 class MappedTaskGroup(TaskGroup):
     """
     A mapped task group.
@@ -619,22 +619,14 @@ class MappedTaskGroup(TaskGroup):
     a ``@task_group`` function instead.
     """
 
-    def __init__(
-        self,
-        *,
-        expand_input: SchedulerExpandInput | DictOfListsExpandInput | 
ListOfDictsExpandInput,
-        **kwargs: Any,
-    ) -> None:
-        super().__init__(**kwargs)
-        self._expand_input = expand_input
+    _expand_input: DictOfListsExpandInput | ListOfDictsExpandInput = 
attrs.field(alias="expand_input")
 
     def __iter__(self):
-        from airflow.sdk.definitions._internal.abstractoperator import 
AbstractOperator
-
         for child in self.children.values():
-            if isinstance(child, AbstractOperator) and child.trigger_rule == 
TriggerRule.ALWAYS:
+            if getattr(child, "trigger_rule", None) == TriggerRule.ALWAYS:
                 raise ValueError(
-                    "Task-generated mapping within a mapped task group is not 
allowed with trigger rule 'always'"
+                    "Task-generated mapping within a mapped task group is not "
+                    "allowed with trigger rule 'always'"
                 )
             yield from self._iter_child(child)
 


Reply via email to