This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d374f71bc8 Remove TaskInstance and TaskLogReader unused methods 
(#59922)
2d374f71bc8 is described below

commit 2d374f71bc81202204ac0208df07b07c280668fa
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Jan 6 01:18:57 2026 +0800

    Remove TaskInstance and TaskLogReader unused methods (#59922)
---
 airflow-core/newsfragments/59835.significant.rst   |   7 +-
 .../src/airflow/cli/commands/task_command.py       |  56 +++++---
 airflow-core/src/airflow/models/taskinstance.py    | 160 +--------------------
 airflow-core/src/airflow/serialization/enums.py    |   1 -
 .../airflow/serialization/serialized_objects.py    |  24 +---
 airflow-core/src/airflow/utils/context.py          |  76 +++-------
 airflow-core/src/airflow/utils/helpers.py          |  36 +----
 airflow-core/src/airflow/utils/log/log_reader.py   |  26 ----
 airflow-core/tests/unit/core/test_core.py          |  10 +-
 .../tests/unit/models/test_taskinstance.py         |  37 +++--
 .../tests/unit/utils/log/test_log_reader.py        |  57 --------
 airflow-core/tests/unit/utils/test_helpers.py      |  19 ---
 devel-common/src/tests_common/pytest_plugin.py     |  42 +++++-
 .../src/tests_common/test_utils/taskinstance.py    |  64 +++++++--
 .../aws/executors/aws_lambda/lambda_executor.py    |   8 +-
 .../amazon/aws/executors/batch/batch_executor.py   |   7 +-
 .../amazon/aws/executors/ecs/ecs_executor.py       |   7 +-
 .../providers/amazon/aws/hooks/sagemaker.py        |   7 +-
 .../executors/aws_lambda/test_lambda_executor.py   |   6 +-
 .../tests/unit/amazon/aws/hooks/test_appflow.py    |   5 +-
 .../unit/amazon/aws/operators/test_appflow.py      |   5 +-
 .../tests/unit/amazon/aws/operators/test_athena.py |  10 +-
 .../amazon/aws/operators/test_cloud_formation.py   |   6 +-
 .../unit/amazon/aws/operators/test_datasync.py     |  28 ++--
 .../tests/unit/amazon/aws/operators/test_dms.py    |  18 +--
 .../amazon/aws/operators/test_emr_add_steps.py     |   8 +-
 .../aws/operators/test_emr_create_job_flow.py      |  10 +-
 .../aws/operators/test_emr_modify_cluster.py       |   6 +-
 .../tests/unit/amazon/aws/operators/test_rds.py    |   6 +-
 .../tests/unit/amazon/aws/operators/test_s3.py     |   2 +-
 .../amazon/aws/operators/test_sagemaker_base.py    |   8 +-
 .../tests/unit/amazon/aws/sensors/test_ecs.py      |   5 +-
 .../tests/unit/amazon/aws/sensors/test_rds.py      |   6 +-
 .../tests/unit/amazon/aws/sensors/test_s3.py       |  10 +-
 .../unit/amazon/aws/transfers/test_mongo_to_s3.py  |   2 +-
 .../unit/apache/druid/operators/test_druid.py      |  13 +-
 .../decorators/test_kubernetes_commons.py          |  10 +-
 .../unit/cncf/kubernetes/operators/test_pod.py     |   6 +-
 .../io/tests/unit/common/io/xcom/test_backend.py   |   6 +-
 .../cloud/tests/unit/dbt/cloud/hooks/test_dbt.py   |   5 +-
 .../tests/unit/docker/decorators/test_docker.py    |  30 +---
 .../tests/unit/github/operators/test_github.py     |   5 +-
 .../tests/unit/github/sensors/test_github.py       |   5 +-
 .../tests/unit/google/cloud/hooks/test_gcs.py      |   6 +-
 .../mysql/tests/unit/mysql/hooks/test_mysql.py     |  13 +-
 .../unit/slack/transfers/test_sql_to_slack.py      |   5 +-
 .../slack/transfers/test_sql_to_slack_webhook.py   |   5 +-
 .../tests/unit/standard/decorators/test_bash.py    |  30 +++-
 .../tests/unit/standard/decorators/test_python.py  |  74 +++++-----
 .../tests/unit/standard/operators/test_python.py   |  43 ++----
 .../tests/unit/standard/sensors/test_time.py       |   5 +-
 .../check_template_context_variable_in_sync.py     |  34 +----
 task-sdk/src/airflow/sdk/definitions/dag.py        |  12 +-
 .../src/airflow/sdk/definitions/mappedoperator.py  |  10 +-
 54 files changed, 335 insertions(+), 767 deletions(-)

diff --git a/airflow-core/newsfragments/59835.significant.rst 
b/airflow-core/newsfragments/59835.significant.rst
index c55ff910611..02b9d709a49 100644
--- a/airflow-core/newsfragments/59835.significant.rst
+++ b/airflow-core/newsfragments/59835.significant.rst
@@ -1,5 +1,6 @@
 Methods removed from TaskInstance
 
-On class ``TaskInstance``, functions ``run()``, ``render_templates()``, and
-private members related to them have been removed. The class has been
-considered internal since 3.0, and should not be relied on in user code.
+On class ``TaskInstance``, functions ``run()``, ``render_templates()``,
+``get_template_context()``, and private members related to them have been
+removed. The class has been considered internal since 3.0, and should not be
+relied on in user code.
diff --git a/airflow-core/src/airflow/cli/commands/task_command.py 
b/airflow-core/src/airflow/cli/commands/task_command.py
index 5e87912b5bc..3cdd5c9f619 100644
--- a/airflow-core/src/airflow/cli/commands/task_command.py
+++ b/airflow-core/src/airflow/cli/commands/task_command.py
@@ -62,6 +62,8 @@ if TYPE_CHECKING:
 
     from sqlalchemy.orm.session import Session
 
+    from airflow.sdk import Context
+    from airflow.sdk.types import Operator as SdkOperator
     from airflow.serialization.definitions.mappedoperator import Operator
 
     CreateIfNecessary = Literal[False, "db", "memory"]
@@ -224,6 +226,24 @@ def _get_ti(
     return ti, dr_created
 
 
+def _get_template_context(ti: TaskInstance, task: SdkOperator) -> Context:
+    from airflow.api_fastapi.execution_api.datamodels.taskinstance import 
DagRun, TaskInstance, TIRunContext
+    from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+    runtime_ti = RuntimeTaskInstance.model_construct(
+        **TaskInstance.model_validate(ti, 
from_attributes=True).model_dump(exclude_unset=True),
+        task=task,
+        _ti_context_from_server=TIRunContext(
+            dag_run=DagRun.model_validate(ti.dag_run, from_attributes=True),
+            max_tries=ti.max_tries,
+            variables=[],
+            connections=[],
+            xcom_keys_to_clear=[],
+        ),
+    )
+    return runtime_ti.get_template_context()
+
+
 class TaskCommandMarker:
     """Marker for listener hooks, to properly detect from which component they 
are called."""
 
@@ -441,27 +461,21 @@ def task_render(args, dag: DAG | None = None) -> None:
         create_if_necessary="memory",
     )
 
-    with create_session() as session:
-        context = ti.get_template_context(session=session)
-        task = sdk_dag.get_task(args.task_id)
-        # TODO (GH-52141): After sdk separation, ti.get_template_context() 
would
-        # contain serialized operators, but we need the real operators for
-        # rendering. This does not make sense and eventually we should rewrite
-        # this entire function so "ti" is a RuntimeTaskInstance instead, but 
for
-        # now we'll just manually fix it to contain the right objects.
-        context["task"] = context["ti"].task = task
-        task.render_template_fields(context)
-        for attr in context["task"].template_fields:
-            print(
-                textwrap.dedent(
-                    f"""\
-                    # 
----------------------------------------------------------
-                    # property: {attr}
-                    # 
----------------------------------------------------------
-                    """
-                )
-                + str(getattr(context["task"], attr))  # This shouldn't be 
dedented.
-            )
+    task = sdk_dag.get_task(args.task_id)
+    context = _get_template_context(ti, task)
+    task.render_template_fields(context)
+    for attr in task.template_fields:
+        print(
+            textwrap.dedent(
+                f"""\
+                # ----------------------------------------------------------
+                # property: {attr}
+                # ----------------------------------------------------------
+                """
+            ),
+            getattr(context["task"], attr),  # This shouldn't be dedented.
+            sep="",
+        )
 
 
 @cli_utils.action_cli(check_db=False)
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 61c6d6f3691..d41849f3e20 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -27,13 +27,11 @@ import uuid
 from collections import defaultdict
 from collections.abc import Collection, Iterable
 from datetime import datetime, timedelta
-from functools import cache
 from typing import TYPE_CHECKING, Any
 from urllib.parse import quote
 
 import attrs
 import dill
-import lazy_object_proxy
 import uuid6
 from sqlalchemy import (
     JSON,
@@ -72,7 +70,7 @@ from airflow._shared.timezones import timezone
 from airflow.assets.manager import asset_manager
 from airflow.configuration import conf
 from airflow.listeners.listener import get_listener_manager
-from airflow.models.asset import AssetEvent, AssetModel
+from airflow.models.asset import AssetModel
 from airflow.models.base import Base, StringID, TaskInstanceDependencies
 from airflow.models.dag_version import DagVersion
 
@@ -106,7 +104,6 @@ if TYPE_CHECKING:
     from datetime import datetime
     from typing import Literal
 
-    import pendulum
     from sqlalchemy.engine import Connection as SAConnection, Engine
     from sqlalchemy.orm.session import Session
     from sqlalchemy.sql import Update
@@ -115,7 +112,6 @@ if TYPE_CHECKING:
     from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile
     from airflow.models.dag import DagModel
     from airflow.models.dagrun import DagRun
-    from airflow.sdk import Context
     from airflow.serialization.definitions.dag import SerializedDAG
     from airflow.serialization.definitions.mappedoperator import Operator
     from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
@@ -1567,160 +1563,6 @@ class TaskInstance(Base, LoggingMixin):
 
         return bool(self.task.retries and self.try_number <= self.max_tries)
 
-    # TODO (GH-52141): We should remove this entire function (only makes sense 
at runtime).
-    def get_template_context(
-        self,
-        session: Session | None = None,
-        ignore_param_exceptions: bool = True,
-    ) -> Context:
-        """
-        Return TI Context.
-
-        :param session: SQLAlchemy ORM Session
-        :param ignore_param_exceptions: flag to suppress value exceptions 
while initializing the ParamsDict
-        """
-        # Do not use provide_session here -- it expunges everything on exit!
-        if not session:
-            session = settings.get_session()()
-
-        from airflow.exceptions import NotMapped
-        from airflow.sdk.api.datamodels._generated import (
-            DagRun as DagRunSDK,
-            PrevSuccessfulDagRunResponse,
-            TIRunContext,
-        )
-        from airflow.sdk.definitions.param import process_params
-        from airflow.sdk.execution_time.context import InletEventsAccessors
-        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
-        from airflow.serialization.definitions.mappedoperator import 
get_mapped_ti_count
-        from airflow.utils.context import (
-            ConnectionAccessor,
-            OutletEventAccessors,
-            VariableAccessor,
-        )
-
-        if TYPE_CHECKING:
-            assert session
-
-        def _get_dagrun(session: Session) -> DagRun:
-            dag_run = self.get_dagrun(session)
-            if dag_run in session:
-                return dag_run
-            # The dag_run may not be attached to the session anymore since the
-            # code base is over-zealous with use of session.expunge_all().
-            # Re-attach it if the relation is not loaded so we can load it 
when needed.
-            info: Any = inspect(dag_run)
-            if info.attrs.consumed_asset_events.loaded_value is not NO_VALUE:
-                return dag_run
-            # If dag_run is not flushed to db at all (e.g. CLI commands using
-            # in-memory objects for ad-hoc operations), just set the value 
manually.
-            if not info.has_identity:
-                dag_run.consumed_asset_events = []
-                return dag_run
-            return session.merge(dag_run, load=False)
-
-        task: Any = self.task
-        dag = task.dag
-        dag_run = _get_dagrun(session)
-
-        validated_params = process_params(dag, task, dag_run.conf, 
suppress_exception=ignore_param_exceptions)
-        runtime_ti = RuntimeTaskInstance.model_construct(
-            id=self.id,
-            task_id=self.task_id,
-            dag_id=self.dag_id,
-            run_id=self.run_id,
-            try_numer=self.try_number,
-            map_index=self.map_index,
-            task=self.task,
-            max_tries=self.max_tries,
-            hostname=self.hostname,
-            _ti_context_from_server=TIRunContext(
-                dag_run=DagRunSDK.model_validate(dag_run, 
from_attributes=True),
-                max_tries=self.max_tries,
-                should_retry=self.is_eligible_to_retry(),
-            ),
-            start_date=self.start_date,
-            dag_version_id=self.dag_version_id,
-        )
-
-        context: Context = runtime_ti.get_template_context()
-
-        @cache  # Prevent multiple database access.
-        def _get_previous_dagrun_success() -> PrevSuccessfulDagRunResponse:
-            dr_from_db = self.get_previous_dagrun(state=DagRunState.SUCCESS, 
session=session)
-            if dr_from_db:
-                return PrevSuccessfulDagRunResponse.model_validate(dr_from_db, 
from_attributes=True)
-            return PrevSuccessfulDagRunResponse()
-
-        def get_prev_data_interval_start_success() -> pendulum.DateTime | None:
-            return 
timezone.coerce_datetime(_get_previous_dagrun_success().data_interval_start)
-
-        def get_prev_data_interval_end_success() -> pendulum.DateTime | None:
-            return 
timezone.coerce_datetime(_get_previous_dagrun_success().data_interval_end)
-
-        def get_prev_start_date_success() -> pendulum.DateTime | None:
-            return 
timezone.coerce_datetime(_get_previous_dagrun_success().start_date)
-
-        def get_prev_end_date_success() -> pendulum.DateTime | None:
-            return 
timezone.coerce_datetime(_get_previous_dagrun_success().end_date)
-
-        def get_triggering_events() -> dict[str, list[AssetEvent]]:
-            asset_events = dag_run.consumed_asset_events
-            triggering_events: dict[str, list[AssetEvent]] = defaultdict(list)
-            for event in asset_events:
-                if event.asset:
-                    triggering_events[event.asset.uri].append(event)
-
-            return triggering_events
-
-        # NOTE: If you add to this dict, make sure to also update the 
following:
-        # * Context in task-sdk/src/airflow/sdk/definitions/context.py
-        # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
-        # * Table in docs/apache-airflow/templates-ref.rst
-
-        context.update(
-            {
-                "outlet_events": OutletEventAccessors(),
-                "inlet_events": InletEventsAccessors(task.inlets),
-                "params": validated_params,
-                "prev_data_interval_start_success": 
get_prev_data_interval_start_success(),
-                "prev_data_interval_end_success": 
get_prev_data_interval_end_success(),
-                "prev_start_date_success": get_prev_start_date_success(),
-                "prev_end_date_success": get_prev_end_date_success(),
-                "test_mode": self.test_mode,
-                # ti/task_instance are added here for ti.xcom_{push,pull}
-                "task_instance": self,
-                "ti": self,
-                "triggering_asset_events": 
lazy_object_proxy.Proxy(get_triggering_events),
-                "var": {
-                    "json": VariableAccessor(deserialize_json=True),
-                    "value": VariableAccessor(deserialize_json=False),
-                },
-                "conn": ConnectionAccessor(),
-            }
-        )
-
-        try:
-            expanded_ti_count: int | None = get_mapped_ti_count(task, 
self.run_id, session=session)
-            context["expanded_ti_count"] = expanded_ti_count
-            if expanded_ti_count:
-                setattr(
-                    self,
-                    "_upstream_map_indexes",
-                    {
-                        upstream.task_id: 
self.get_relevant_upstream_map_indexes(
-                            upstream,
-                            expanded_ti_count,
-                            session=session,
-                        )
-                        for upstream in task.upstream_list
-                    },
-                )
-        except NotMapped:
-            pass
-
-        return context
-
     def set_duration(self) -> None:
         """Set task instance duration."""
         if self.end_date and self.start_date:
diff --git a/airflow-core/src/airflow/serialization/enums.py 
b/airflow-core/src/airflow/serialization/enums.py
index 484b6ed5ad9..5fdd4d66987 100644
--- a/airflow-core/src/airflow/serialization/enums.py
+++ b/airflow-core/src/airflow/serialization/enums.py
@@ -63,7 +63,6 @@ class DagAttributeTypes(str, Enum):
     ASSET_UNIQUE_KEY = "asset_unique_key"
     ASSET_ALIAS_UNIQUE_KEY = "asset_alias_unique_key"
     CONNECTION = "connection"
-    TASK_CONTEXT = "task_context"
     ARG_NOT_SET = "arg_not_set"
     TASK_CALLBACK_REQUEST = "task_callback_request"
     DAG_CALLBACK_REQUEST = "dag_callback_request"
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 53b6377b96e..f3e5c74cfe4 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -40,7 +40,6 @@ import pydantic
 from dateutil import relativedelta
 from pendulum.tz.timezone import FixedTimezone, Timezone
 
-from airflow import macros
 from airflow._shared.module_loading import import_string, qualname
 from airflow._shared.timezones.timezone import from_timestamp, parse_timezone, 
utcnow
 from airflow.callbacks.callback_requests import DagCallbackRequest, 
TaskCallbackRequest
@@ -100,7 +99,6 @@ from airflow.task.priority_strategy import (
 from airflow.timetables.base import DagRunInfo, Timetable
 from airflow.triggers.base import BaseTrigger, StartTriggerArgs
 from airflow.utils.code_utils import get_python_source
-from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor
 from airflow.utils.db import LazySelectSequence
 
 if TYPE_CHECKING:
@@ -654,12 +652,6 @@ class BaseSerialization:
         elif isinstance(var, MappedArgument):
             data = {"input": encode_expand_input(var._input), "key": var._key}
             return cls._encode(data, type_=DAT.MAPPED_ARGUMENT)
-        elif var.__class__ == Context:
-            d = {}
-            for k, v in var.items():
-                obj = cls.serialize(v, strict=strict)
-                d[str(k)] = obj
-            return cls._encode(d, type_=DAT.TASK_CONTEXT)
         else:
             return cls.default_serialization(strict, var)
 
@@ -686,21 +678,7 @@ class BaseSerialization:
             raise ValueError(f"The encoded_var should be dict and is 
{type(encoded_var)}")
         var = encoded_var[Encoding.VAR]
         type_ = encoded_var[Encoding.TYPE]
-        if type_ == DAT.TASK_CONTEXT:
-            d = {}
-            for k, v in var.items():
-                if k == "task":  # todo: add `_encode` of Operator so we don't 
need this
-                    continue
-                d[k] = cls.deserialize(v)
-            d["task"] = d["task_instance"].task  # todo: add `_encode` of 
Operator so we don't need this
-            d["macros"] = macros
-            d["var"] = {
-                "json": VariableAccessor(deserialize_json=True),
-                "value": VariableAccessor(deserialize_json=False),
-            }
-            d["conn"] = ConnectionAccessor()
-            return Context(**d)
-        elif type_ == DAT.DICT:
+        if type_ == DAT.DICT:
             return {k: cls.deserialize(v) for k, v in var.items()}
         elif type_ == DAT.ASSET_EVENT_ACCESSORS:
             return decode_outlet_event_accessors(var)
diff --git a/airflow-core/src/airflow/utils/context.py 
b/airflow-core/src/airflow/utils/context.py
index 02e324b083b..4783becb562 100644
--- a/airflow-core/src/airflow/utils/context.py
+++ b/airflow-core/src/airflow/utils/context.py
@@ -32,51 +32,16 @@ from airflow.sdk.execution_time.context import (
     VariableAccessor as VariableAccessorSDK,
 )
 from airflow.serialization.definitions.notset import NOTSET, is_arg_set
-from airflow.utils.deprecation_tools import DeprecatedImportWarning
+from airflow.utils.deprecation_tools import DeprecatedImportWarning, 
add_deprecated_classes
 from airflow.utils.session import create_session
 
-# NOTE: Please keep this in sync with the following:
-# * Context in task-sdk/src/airflow/sdk/definitions/context.py
-# * Table in docs/apache-airflow/templates-ref.rst
-KNOWN_CONTEXT_KEYS: set[str] = {
-    "conn",
-    "dag",
-    "dag_run",
-    "data_interval_end",
-    "data_interval_start",
-    "ds",
-    "ds_nodash",
-    "expanded_ti_count",
-    "exception",
-    "inlets",
-    "inlet_events",
-    "logical_date",
-    "macros",
-    "map_index_template",
-    "outlets",
-    "outlet_events",
-    "params",
-    "prev_data_interval_start_success",
-    "prev_data_interval_end_success",
-    "prev_start_date_success",
-    "prev_end_date_success",
-    "reason",
-    "run_id",
-    "start_date",
-    "task",
-    "task_reschedule_count",
-    "task_instance",
-    "task_instance_key_str",
-    "test_mode",
-    "templates_dict",
-    "ti",
-    "triggering_asset_events",
-    "ts",
-    "ts_nodash",
-    "ts_nodash_with_tz",
-    "try_number",
-    "var",
-}
+warnings.warn(
+    "Module airflow.utils.context is deprecated and will be removed in the "
+    "future. Use airflow.sdk.execution_time.context if you are using the "
+    "classes inside an Airflow task.",
+    DeprecatedImportWarning,
+    stacklevel=2,
+)
 
 
 class VariableAccessor(VariableAccessorSDK):
@@ -140,17 +105,14 @@ class OutletEventAccessors(OutletEventAccessorsSDK):
         return Asset(name=asset.name, uri=asset.uri, group=asset.group, 
extra=asset.extra)
 
 
-def __getattr__(name: str):
-    if name in ("Context", "context_copy_partial", "context_merge"):
-        warnings.warn(
-            "Importing Context from airflow.utils.context is deprecated and 
will "
-            "be removed in the future. Please import it from airflow.sdk 
instead.",
-            DeprecatedImportWarning,
-            stacklevel=2,
-        )
-
-        import airflow.sdk.definitions.context as sdk
-
-        return getattr(sdk, name)
-
-    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+add_deprecated_classes(
+    {
+        __name__: {
+            "KNOWN_CONTEXT_KEYS": "airflow.sdk.definitions.context",
+            "Context": "airflow.sdk.definitions.context",
+            "context_copy_partial": "airflow.sdk.definitions.context",
+            "context_merge": "airflow.sdk.definitions.context",
+        },
+    },
+    package=__name__,
+)
diff --git a/airflow-core/src/airflow/utils/helpers.py 
b/airflow-core/src/airflow/utils/helpers.py
index 5b977ff75d1..a48780022a6 100644
--- a/airflow-core/src/airflow/utils/helpers.py
+++ b/airflow-core/src/airflow/utils/helpers.py
@@ -23,7 +23,7 @@ import re
 import signal
 from collections.abc import Callable, Generator, Iterable, MutableMapping
 from functools import cache
-from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
+from typing import TYPE_CHECKING, Any, TypeVar, overload
 from urllib.parse import urljoin
 
 from lazy_object_proxy import Proxy
@@ -39,7 +39,6 @@ if TYPE_CHECKING:
     import jinja2
 
     from airflow.models.taskinstance import TaskInstance
-    from airflow.sdk.definitions.context import Context
 
     CT = TypeVar("CT", str, datetime)
 
@@ -160,39 +159,6 @@ def log_filename_template_renderer() -> Callable[..., str]:
     return f_str_format
 
 
-def _render_template_to_string(template: jinja2.Template, context: Context) -> 
str:
-    """
-    Render a Jinja template to string using the provided context.
-
-    This is a private utility function specifically for log filename rendering.
-    It ensures templates are rendered as strings rather than native Python 
objects.
-    """
-    return render_template(template, cast("MutableMapping[str, Any]", 
context), native=False)
-
-
-def render_log_filename(ti: TaskInstance, try_number, filename_template) -> 
str:
-    """
-    Given task instance, try_number, filename_template, return the rendered 
log filename.
-
-    :param ti: task instance
-    :param try_number: try_number of the task
-    :param filename_template: filename template, which can be jinja template or
-        python string template
-    """
-    filename_template, filename_jinja_template = 
parse_template_string(filename_template)
-    if filename_jinja_template:
-        jinja_context = ti.get_template_context()
-        jinja_context["try_number"] = try_number
-        return _render_template_to_string(filename_jinja_template, 
jinja_context)
-
-    return filename_template.format(
-        dag_id=ti.dag_id,
-        task_id=ti.task_id,
-        logical_date=ti.logical_date.isoformat(),
-        try_number=try_number,
-    )
-
-
 def convert_camel_to_snake(camel_str: str) -> str:
     """Convert CamelCase to snake_case."""
     return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower()
diff --git a/airflow-core/src/airflow/utils/log/log_reader.py 
b/airflow-core/src/airflow/utils/log/log_reader.py
index 9ebc3a9050c..99576559e96 100644
--- a/airflow-core/src/airflow/utils/log/log_reader.py
+++ b/airflow-core/src/airflow/utils/log/log_reader.py
@@ -25,17 +25,13 @@ from functools import cached_property
 from typing import TYPE_CHECKING
 
 from airflow.configuration import conf
-from airflow.utils.helpers import render_log_filename
 from airflow.utils.log.file_task_handler import FileTaskHandler, 
StructuredLogMessage
 from airflow.utils.log.logging_mixin import ExternalLoggingMixin
-from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from typing import TypeAlias
 
-    from sqlalchemy.orm.session import Session
-
     from airflow.models.taskinstance import TaskInstance
     from airflow.models.taskinstancehistory import TaskInstanceHistory
     from airflow.utils.log.file_task_handler import LogHandlerOutputStream, 
LogMetadata
@@ -190,25 +186,3 @@ class TaskLogReader:
             return False
 
         return self.log_handler.supports_external_link
-
-    @provide_session
-    def render_log_filename(
-        self,
-        ti: TaskInstance | TaskInstanceHistory,
-        try_number: int | None = None,
-        *,
-        session: Session = NEW_SESSION,
-    ) -> str:
-        """
-        Render the log attachment filename.
-
-        :param ti: The task instance
-        :param try_number: The task try number
-        """
-        dagrun = ti.get_dagrun(session=session)
-        attachment_filename = render_log_filename(
-            ti=ti,
-            try_number="all" if try_number is None else try_number,
-            
filename_template=dagrun.get_log_template(session=session).filename,
-        )
-        return attachment_filename
diff --git a/airflow-core/tests/unit/core/test_core.py 
b/airflow-core/tests/unit/core/test_core.py
index aa341560553..66b70ddcfa2 100644
--- a/airflow-core/tests/unit/core/test_core.py
+++ b/airflow-core/tests/unit/core/test_core.py
@@ -27,6 +27,7 @@ from airflow.sdk import BaseOperator
 from airflow.utils.types import DagRunType
 
 from tests_common.test_utils.db import clear_db_dags, clear_db_runs
+from tests_common.test_utils.taskinstance import get_template_context
 
 pytestmark = pytest.mark.db_test
 
@@ -87,14 +88,13 @@ class TestCore:
                 params={"key_2": "value_2_new", "key_3": "value_3"},
             )
             task2 = EmptyOperator(task_id="task2")
-        dr = dag_maker.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-        )
+        dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
         ti1 = dag_maker.run_ti(task1.task_id, dr)
         ti2 = dag_maker.run_ti(task2.task_id, dr)
+        ti1.dag_run = ti2.dag_run = dr
 
-        context1 = ti1.get_template_context()
-        context2 = ti2.get_template_context()
+        context1 = get_template_context(ti1, task1)
+        context2 = get_template_context(ti2, task2)
 
         assert context1["params"] == {"key_1": "value_1", "key_2": 
"value_2_new", "key_3": "value_3"}
         assert context2["params"] == {"key_1": "value_1", "key_2": 
"value_2_old"}
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 5de8ca8db3c..e607efd1a44 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -71,9 +71,8 @@ from airflow.providers.standard.operators.hitl import (
 )
 from airflow.providers.standard.operators.python import PythonOperator
 from airflow.providers.standard.sensors.python import PythonSensor
-from airflow.sdk import DAG, BaseOperator, BaseSensorOperator, Metadata, task, 
task_group
+from airflow.sdk import DAG, Asset, AssetAlias, BaseOperator, 
BaseSensorOperator, Metadata, task, task_group
 from airflow.sdk.api.datamodels._generated import AssetEventResponse, 
AssetResponse
-from airflow.sdk.definitions.asset import Asset, AssetAlias
 from airflow.sdk.definitions.param import process_params
 from airflow.sdk.definitions.taskgroup import TaskGroup
 from airflow.sdk.execution_time.comms import AssetEventsResult
@@ -444,6 +443,7 @@ class TestTaskInstance:
             start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
             dag=dag_maker.dag,
         )
+        dag_maker.sync_dag_to_db()
         ti2 = _create_task_instance(task=task2, run_id=ti.run_id, 
dag_version_id=ti.dag_version_id)
         session.add(ti2)
         session.flush()
@@ -552,18 +552,19 @@ class TestTaskInstance:
                 retry_delay=datetime.timedelta(seconds=0),
             )
 
-        def run_with_error(ti):
+        def run_with_error():
             with contextlib.suppress(AirflowException):
-                dag_maker.run_ti(ti.task_id, ti.dag_run)
-            ti.refresh_from_db(session)
+                dag_maker.run_ti(ti.task_id, dag_run)
+            return session.get(TaskInstance, ti.id)
 
-        ti = 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
+        dag_run = dag_maker.create_dagrun(logical_date=timezone.utcnow())
+        ti = dag_run.task_instances[0]
         assert ti.try_number == 0
         session.get(TaskInstance, ti.id).try_number += 1
         session.commit()
 
         # first run -- up for retry
-        run_with_error(ti)
+        ti = run_with_error()
         assert ti.state == State.UP_FOR_RETRY
         assert ti.try_number == 1
 
@@ -571,7 +572,7 @@ class TestTaskInstance:
         session.commit()
 
         # second run -- fail
-        run_with_error(ti)
+        ti = run_with_error()
         assert ti.state == State.FAILED
         assert ti.try_number == 2
 
@@ -579,14 +580,13 @@ class TestTaskInstance:
         # clearing it first
         dag.clear()
 
-        ti.refresh_from_db(session)
+        ti.refresh_from_db()
         ti.try_number += 1
         session.add(ti)
         session.commit()
 
         # third run -- up for retry
-        run_with_error(ti)
-        ti.refresh_from_db()
+        ti = run_with_error()
         assert ti.state == State.UP_FOR_RETRY
         assert ti.try_number == 3
 
@@ -594,8 +594,7 @@ class TestTaskInstance:
         session.commit()
 
         # fourth run -- fail
-        run_with_error(ti)
-        ti.refresh_from_db()
+        ti = run_with_error()
         assert ti.state == State.FAILED
         assert ti.try_number == 4
         assert RenderedTaskInstanceFields.get_templated_fields(ti) == 
expected_rendered_ti_fields
@@ -2094,9 +2093,7 @@ class TestTaskInstance:
         Test that when a task that produces asset has ran, that changing the 
consumer
         dag asset will not cause primary key blank-out
         """
-        from airflow.sdk.definitions.asset import Asset
-
-        with dag_maker(schedule=None, serialized=True) as dag1:
+        with dag_maker(schedule=None, serialized=False) as dag1:
 
             @task(outlets=Asset("test/1"))
             def test_task1():
@@ -2107,7 +2104,7 @@ class TestTaskInstance:
         dr1 = dag_maker.create_dagrun()
         test_task1 = dag1.get_task("test_task1")
 
-        with dag_maker(dag_id="testdag", schedule=[Asset("test/1")], 
serialized=True):
+        with dag_maker(dag_id="testdag", schedule=[Asset("test/1")]):
 
             @task
             def test_task2():
@@ -2119,7 +2116,7 @@ class TestTaskInstance:
         run_task_instance(ti, dag1.get_task(ti.task_id))
 
         # Change the asset.
-        with dag_maker(dag_id="testdag", schedule=[Asset("test2/1")], 
serialized=True):
+        with dag_maker(dag_id="testdag", schedule=[Asset("test2/1")]):
 
             @task
             def test_task2():
@@ -2644,7 +2641,7 @@ class TestTaskInstanceRecordTaskMapXComPush:
     @pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2}, 
"abc"])
     def test_not_recorded_if_leaf(self, dag_maker, xcom_value):
         """Return value should not be recorded if there are no downstreams."""
-        with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
+        with dag_maker(dag_id="test_not_recorded_for_unused", 
serialized=False) as dag:
 
             @dag.task()
             def push_something():
@@ -2660,7 +2657,7 @@ class TestTaskInstanceRecordTaskMapXComPush:
     @pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2}, 
"abc"])
     def test_not_recorded_if_not_used(self, dag_maker, xcom_value):
         """Return value should not be recorded if no downstreams are mapped."""
-        with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
+        with dag_maker(dag_id="test_not_recorded_for_unused", 
serialized=False) as dag:
 
             @dag.task()
             def push_something():
diff --git a/airflow-core/tests/unit/utils/log/test_log_reader.py 
b/airflow-core/tests/unit/utils/log/test_log_reader.py
index d1279f71323..addcb5ebdee 100644
--- a/airflow-core/tests/unit/utils/log/test_log_reader.py
+++ b/airflow-core/tests/unit/utils/log/test_log_reader.py
@@ -17,23 +17,18 @@
 from __future__ import annotations
 
 import copy
-import datetime
 import os
 import sys
 import tempfile
 import types
-from typing import TYPE_CHECKING
 from unittest import mock
 
-import pendulum
 import pytest
 
 from airflow import settings
 from airflow._shared.timezones import timezone
 from airflow.config_templates.airflow_local_settings import 
DEFAULT_LOGGING_CONFIG
 from airflow.models.tasklog import LogTemplate
-from airflow.providers.standard.operators.python import PythonOperator
-from airflow.timetables.base import DataInterval
 from airflow.utils.log.log_reader import TaskLogReader
 from airflow.utils.log.logging_mixin import ExternalLoggingMixin
 from airflow.utils.state import TaskInstanceState
@@ -46,10 +41,6 @@ from tests_common.test_utils.file_task_handler import 
convert_list_to_stream
 pytestmark = pytest.mark.db_test
 
 
-if TYPE_CHECKING:
-    from airflow.models import DagRun
-
-
 class TestLogView:
     DAG_ID = "dag_log_reader"
     TASK_ID = "task_log_reader"
@@ -292,54 +283,6 @@ class TestLogView:
         mock_prop.return_value = True
         assert task_log_reader.supports_external_link
 
-    def test_task_log_filename_unique(self, dag_maker):
-        """
-        Ensure the default log_filename_template produces a unique filename.
-
-        See discussion in apache/airflow#19058 [1]_ for how uniqueness may
-        change in a future Airflow release. For now, the logical date is used
-        to distinguish DAG runs. This test should be modified when the logical
-        date is no longer used to ensure uniqueness.
-
-        [1]: https://github.com/apache/airflow/issues/19058
-        """
-        dag_id = "test_task_log_filename_ts_corresponds_to_logical_date"
-        task_id = "echo_run_type"
-
-        def echo_run_type(dag_run: DagRun, **kwargs):
-            print(dag_run.run_type)
-
-        with dag_maker(dag_id, start_date=self.DEFAULT_DATE, 
schedule="@daily"):
-            PythonOperator(task_id=task_id, python_callable=echo_run_type)
-
-        start = pendulum.datetime(2021, 1, 1)
-        end = start + datetime.timedelta(days=1)
-        trigger_time = end + datetime.timedelta(hours=4, minutes=29)  # 
Arbitrary.
-
-        # Create two DAG runs that have the same data interval, but not the 
same
-        # logical date, to check if they correctly use different log files.
-        scheduled_dagrun: DagRun = dag_maker.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            logical_date=start,
-            data_interval=DataInterval(start, end),
-        )
-        manual_dagrun: DagRun = dag_maker.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            logical_date=trigger_time,
-            data_interval=DataInterval(start, end),
-        )
-
-        scheduled_ti = scheduled_dagrun.get_task_instance(task_id)
-        manual_ti = manual_dagrun.get_task_instance(task_id)
-        assert scheduled_ti is not None
-        assert manual_ti is not None
-
-        
scheduled_ti.refresh_from_task(dag_maker.serialized_dag.get_task(task_id))
-        manual_ti.refresh_from_task(dag_maker.serialized_dag.get_task(task_id))
-
-        reader = TaskLogReader()
-        assert reader.render_log_filename(scheduled_ti, 1) != 
reader.render_log_filename(manual_ti, 1)
-
     @pytest.mark.parametrize(
         ("state", "try_number", "expected_event", "use_self_ti"),
         [
diff --git a/airflow-core/tests/unit/utils/test_helpers.py 
b/airflow-core/tests/unit/utils/test_helpers.py
index 6179acadfb9..8e16d118698 100644
--- a/airflow-core/tests/unit/utils/test_helpers.py
+++ b/airflow-core/tests/unit/utils/test_helpers.py
@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING
 
 import pytest
 
-from airflow._shared.timezones import timezone
 from airflow.exceptions import AirflowException
 from airflow.jobs.base_job_runner import BaseJobRunner
 from airflow.serialization.definitions.notset import NOTSET
@@ -55,24 +54,6 @@ def clear_db():
 
 
 class TestHelpers:
-    @pytest.mark.db_test
-    @pytest.mark.usefixtures("clear_db")
-    def test_render_log_filename(self, create_task_instance):
-        try_number = 1
-        dag_id = "test_render_log_filename_dag"
-        task_id = "test_render_log_filename_task"
-        logical_date = timezone.datetime(2016, 1, 1)
-
-        ti = create_task_instance(dag_id=dag_id, task_id=task_id, 
logical_date=logical_date)
-        filename_template = "{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ 
try_number }}.log"
-
-        ts = ti.get_template_context()["ts"]
-        expected_filename = f"{dag_id}/{task_id}/{ts}/{try_number}.log"
-
-        rendered_filename = helpers.render_log_filename(ti, try_number, 
filename_template)
-
-        assert rendered_filename == expected_filename
-
     def test_chunks(self):
         with pytest.raises(ValueError, match=CHUNK_SIZE_POSITIVE_INT):
             list(helpers.chunks([1, 2, 3], 0))
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index 3dc2cec736e..b7e79451a85 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -807,6 +807,8 @@ class DagMaker(Generic[Dag], Protocol):
 
     def get_serialized_data(self) -> dict[str, Any]: ...
 
+    def sync_dag_to_db(self) -> None: ...
+
     def create_dagrun(
         self,
         run_id: str = ...,
@@ -818,6 +820,15 @@ class DagMaker(Generic[Dag], Protocol):
 
     def create_dagrun_after(self, dagrun: DagRun, **kwargs) -> DagRun: ...
 
+    def create_ti(
+        self,
+        task_id: str,
+        dag_run: DagRun | None = ...,
+        dag_run_kwargs: dict | None = ...,
+        map_index: int = ...,
+        **kwargs,
+    ) -> TaskInstance: ...
+
     def run_ti(
         self,
         task_id: str,
@@ -1047,6 +1058,9 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
             self._bag_dag_compat(dag)
             self.session.flush()
 
+        def sync_dag_to_db(self):
+            self._make_serdag(self.dag)
+
         def create_dagrun(self, *, logical_date=NOTSET, **kwargs):
             from airflow.utils.state import DagRunState
             from airflow.utils.types import DagRunType
@@ -1155,15 +1169,15 @@ def dag_maker(request) -> Generator[DagMaker, None, 
None]:
                 **kwargs,
             )
 
-        def run_ti(self, task_id, dag_run=None, dag_run_kwargs=None, 
map_index=-1, **kwargs):
+        def create_ti(self, task_id, dag_run=None, dag_run_kwargs=None, 
map_index=-1):
             """
-            Create a dagrun and run a specific task instance with proper task 
refresh.
+            Create a specific task instance with proper task refresh.
+
+            This is a convenience method for creating a single task instance:
 
-            This is a convenience method for running a single task instance:
             1. Create a dagrun if it does not exist
             2. Get the specific task instance by task_id
             3. Refresh the task instance from the DAG task
-            4. Run the task instance
 
             Returns the created TaskInstance.
             """
@@ -1178,7 +1192,23 @@ def dag_maker(request) -> Generator[DagMaker, None, 
None]:
                     f"Task instance with task_id '{task_id}' not found in dag 
run. "
                     f"Available task_ids: {available_task_ids}"
                 )
+            if AIRFLOW_V_3_2_PLUS:
+                ti.refresh_from_task(self.serialized_dag.get_task(task_id))
+            else:
+                ti.refresh_from_task(self.dag.get_task(task_id))
+            return ti
+
+        def run_ti(self, task_id, dag_run=None, dag_run_kwargs=None, 
map_index=-1, **kwargs):
+            """
+            Run a specific task instance with proper task refresh.
 
+            This is a convenience method for running a single task instance:
+
+            1. Call ``create_ti()`` to obtain the task instance
+            2. Run the task instance
+
+            Returns the created and run TaskInstance.
+            """
             task = self.dag.get_task(task_id)
             if not AIRFLOW_V_3_1_PLUS:
                 # Airflow <3.1 has a bug for DecoratedOperator has an unused 
signature for
@@ -1192,13 +1222,13 @@ def dag_maker(request) -> Generator[DagMaker, None, 
None]:
                 #                                                              
                        ^^^^^^^^^^^^^^
                 # E   AttributeError: '_PythonDecoratedOperator' object has no 
attribute 'xcom_push'
                 task.xcom_push = lambda *args, **kwargs: None
+
+            ti = self.create_ti(task_id, dag_run=dag_run, 
dag_run_kwargs=dag_run_kwargs, map_index=map_index)
             if AIRFLOW_V_3_2_PLUS:
                 from tests_common.test_utils.taskinstance import 
run_task_instance
 
-                ti.refresh_from_task(self.serialized_dag.get_task(task_id))
                 run_task_instance(ti, task, **kwargs)
             else:
-                ti.refresh_from_task(task)
                 ti.run(**kwargs)
             return ti
 
diff --git a/devel-common/src/tests_common/test_utils/taskinstance.py 
b/devel-common/src/tests_common/test_utils/taskinstance.py
index 7f79819aee3..b82ad5a7fc7 100644
--- a/devel-common/src/tests_common/test_utils/taskinstance.py
+++ b/devel-common/src/tests_common/test_utils/taskinstance.py
@@ -22,6 +22,7 @@ import copy
 from typing import TYPE_CHECKING
 
 from airflow.models.taskinstance import TaskInstance
+from airflow.utils.session import NEW_SESSION
 
 from tests_common.test_utils.compat import SerializedBaseOperator, 
SerializedMappedOperator
 from tests_common.test_utils.dag import create_scheduler_dag
@@ -36,9 +37,10 @@ if TYPE_CHECKING:
     from uuid import UUID
 
     from jinja2 import Environment
+    from sqlalchemy.orm import Session
 
     from airflow.sdk import Context
-    from airflow.sdk.types import Operator as SdkOperator
+    from airflow.sdk.types import Operator as SdkOperator, 
RuntimeTaskInstanceProtocol
     from airflow.serialization.definitions.mappedoperator import Operator as 
SerializedOperator
 
 __all__ = ["TaskInstanceWrapper", "create_task_instance", 
"render_template_fields", "run_task_instance"]
@@ -62,16 +64,15 @@ class TaskInstanceWrapper:
     def __copy__(self):
         return TaskInstanceWrapper(copy.copy(self.__dict__["__ti"]), 
copy.copy(self.__dict__["__task"]))
 
-    def run(self, **kwargs) -> None:
-        from tests_common.test_utils.taskinstance import run_task_instance
-
-        run_task_instance(self.__dict__["__ti"], self.__dict__["__task"], 
**kwargs)
+    def run(self, **kwargs) -> RuntimeTaskInstanceProtocol:
+        return run_task_instance(self.__dict__["__ti"], 
self.__dict__["__task"], **kwargs)
 
     def render_templates(self, **kwargs) -> SdkOperator:
-        from tests_common.test_utils.taskinstance import render_template_fields
-
         return render_template_fields(self.__dict__["__ti"], 
self.__dict__["__task"], **kwargs)
 
+    def get_template_context(self) -> Context:
+        return get_template_context(self.__dict__["__ti"], 
self.__dict__["__task"])
+
 
 def create_task_instance(
     task: SdkOperator | SerializedOperator,
@@ -113,42 +114,75 @@ def run_task_instance(
     ignore_ti_state: bool = False,
     mark_success: bool = False,
     session=None,
-):
+) -> RuntimeTaskInstanceProtocol:
+    session_kwargs = {"session": session} if session else {}
     if not AIRFLOW_V_3_2_PLUS:
         ti.refresh_from_task(task)  # type: ignore[arg-type]
-        ti.run()
+        ti.run(**session_kwargs)
         return ti
 
-    kwargs = {"session": session} if session else {}
     if not ti.check_and_change_state_before_execution(
         ignore_depends_on_past=ignore_depends_on_past,
         ignore_task_deps=ignore_task_deps,
         ignore_ti_state=ignore_ti_state,
         mark_success=mark_success,
-        **kwargs,
+        **session_kwargs,
     ):
         return ti
 
     from airflow.sdk.definitions.dag import _run_task
 
-    taskrun_result = _run_task(ti=ti, task=task)
+    # Session handling is a mess in tests; use a fresh ti to run the task.
+    new_ti = TaskInstance.get_task_instance(
+        dag_id=ti.dag_id,
+        run_id=ti.run_id,
+        task_id=ti.task_id,
+        map_index=ti.map_index,
+        **session_kwargs,
+    )
+    # Some tests don't even save the ti at all, in which case new_ti is None.
+    taskrun_result = _run_task(ti=new_ti or ti, task=task)
+    ti.refresh_from_db(**session_kwargs)  # Some tests expect side effects.
     if not taskrun_result:
-        return None
+        raise RuntimeError("task failed to finish with a result")
     if error := taskrun_result.error:
         raise error
     return taskrun_result.ti
 
 
+def get_template_context(ti: TaskInstance, task: SdkOperator, *, session: 
Session = NEW_SESSION) -> Context:
+    if not AIRFLOW_V_3_2_PLUS:
+        ti.refresh_from_task(task)  # type: ignore[arg-type]
+        return ti.get_template_context(session=session)
+
+    from airflow.cli.commands.task_command import _get_template_context
+    from airflow.utils.context import ConnectionAccessor, VariableAccessor
+
+    # TODO: Move these to test_utils too.
+    context = _get_template_context(ti, task)
+    context["ti"].__dict__.update(xcom_push=ti.xcom_push, 
xcom_pull=ti.xcom_pull)  # Avoid execution API.
+    context.update(  # type: ignore[call-arg]  # 
https://github.com/python/mypy/issues/17750
+        conn=ConnectionAccessor(),
+        test_mode=ti.test_mode,
+        var={
+            "json": VariableAccessor(deserialize_json=True),
+            "value": VariableAccessor(deserialize_json=False),
+        },
+    )
+    return context
+
+
 def render_template_fields(
     ti: TaskInstance,
     task: SdkOperator,
     *,
     context: Context | None = None,
     jinja_env: Environment | None = None,
+    session: Session = NEW_SESSION,
 ) -> SdkOperator:
     if AIRFLOW_V_3_2_PLUS:
-        task.render_template_fields(context or ti.get_template_context(), 
jinja_env)
+        task.render_template_fields(context or get_template_context(ti, task), 
jinja_env)
         return task
     ti.refresh_from_task(task)  # type: ignore[arg-type]
-    ti.render_templates(context, jinja_env)
+    ti.render_templates(context or ti.get_template_context(session=session), 
jinja_env)
     return ti.task  # type: ignore[return-value]
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
index f65113ad2a6..67b23e0f0aa 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
@@ -40,14 +40,8 @@ from 
airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
 )
 from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
 from airflow.providers.amazon.aws.hooks.sqs import SqsHook
-from airflow.providers.common.compat.sdk import AirflowException, Stats, conf
-
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
 from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.common.compat.sdk import AirflowException, Stats, conf, 
timezone
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
index 8657bb07872..0b018a2ea0e 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
@@ -35,12 +35,7 @@ from 
airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
 )
 from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
 from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.providers.common.compat.sdk import AirflowException, Stats, conf
-
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from airflow.providers.common.compat.sdk import AirflowException, Stats, conf, 
timezone
 from airflow.utils.helpers import merge_dicts
 
 if TYPE_CHECKING:
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index efd54def2ff..0887dcd47a8 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -48,12 +48,7 @@ from 
airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
 )
 from airflow.providers.amazon.aws.hooks.ecs import EcsHook
 from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.providers.common.compat.sdk import AirflowException, Stats
-
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from airflow.providers.common.compat.sdk import AirflowException, Stats, 
timezone
 from airflow.utils.helpers import merge_dicts
 from airflow.utils.state import State
 
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
index 3929694ae37..f071ff844c9 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -35,12 +35,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import 
AwsBaseHook
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.amazon.aws.utils.tags import format_tags
-from airflow.providers.common.compat.sdk import AirflowException
-
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from airflow.providers.common.compat.sdk import AirflowException, timezone
 
 
 class LogState:
diff --git 
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
 
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
index 4bce677a3c1..d1265b33a0b 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
@@ -31,14 +31,10 @@ from airflow.providers.amazon.aws.executors.aws_lambda 
import lambda_executor
 from airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor import 
AwsLambdaExecutor
 from airflow.providers.amazon.aws.executors.aws_lambda.utils import 
CONFIG_GROUP_NAME, AllLambdaConfigKeys
 from airflow.providers.common.compat.sdk import AirflowException
-
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
 from airflow.utils.state import TaskInstanceState
 from airflow.version import version as airflow_version_str
 
+from tests_common.test_utils.compat import timezone
 from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_appflow.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_appflow.py
index b8ba2d927cd..8f31e9caf30 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_appflow.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_appflow.py
@@ -25,10 +25,7 @@ import pytest
 
 from airflow.providers.amazon.aws.hooks.appflow import AppflowHook
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
 
 FLOW_NAME = "flow0"
 EXECUTION_ID = "ex_id"
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_appflow.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_appflow.py
index 9e21004385d..86f765808d5 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_appflow.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_appflow.py
@@ -32,10 +32,7 @@ from airflow.providers.amazon.aws.operators.appflow import (
     AppflowRunOperator,
 )
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
 
 CONN_ID = "aws_default"
 DAG_ID = "dag_id"
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py
index 67d945a19ea..4b0b91f0f3f 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py
@@ -41,16 +41,12 @@ from airflow.providers.openlineage.extractors import 
OperatorLineage
 from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunType
 
+from tests_common.test_utils.compat import timezone
 from tests_common.test_utils.dag import sync_dag_to_db
-from tests_common.test_utils.taskinstance import create_task_instance
+from tests_common.test_utils.taskinstance import create_task_instance, 
get_template_context
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 from unit.amazon.aws.utils.test_template_fields import validate_template_fields
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
 TEST_DAG_ID = "unit_tests"
 DEFAULT_DATE = timezone.datetime(2018, 1, 1)
 ATHENA_QUERY_ID = "eac29bf8-daa1-4ffc-b19a-0db31dc3b784"
@@ -272,7 +268,7 @@ class TestAthenaOperator:
         ti.dag_run = dag_run
         session.add(ti)
         session.commit()
-        assert self.athena.execute(ti.get_template_context()) == 
ATHENA_QUERY_ID
+        assert self.athena.execute(get_template_context(ti, self.athena)) == 
ATHENA_QUERY_ID
 
     @mock.patch.object(AthenaHook, "check_query_status", 
side_effect=("SUCCEEDED",))
     @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
index 5643eb9c48b..f718220df7e 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
@@ -28,11 +28,7 @@ from airflow.providers.amazon.aws.operators.cloud_formation 
import (
     CloudFormationDeleteStackOperator,
 )
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
+from tests_common.test_utils.compat import timezone
 from unit.amazon.aws.utils.test_template_fields import validate_template_fields
 
 DEFAULT_DATE = timezone.datetime(2019, 1, 1)
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py
index c47f948482e..2addb0ce539 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py
@@ -30,16 +30,12 @@ from airflow.providers.common.compat.sdk import 
AirflowException
 from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunType
 
+from tests_common.test_utils.compat import timezone
 from tests_common.test_utils.dag import sync_dag_to_db
-from tests_common.test_utils.taskinstance import create_task_instance
+from tests_common.test_utils.taskinstance import create_task_instance, 
get_template_context
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 from unit.amazon.aws.utils.test_template_fields import validate_template_fields
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
 TEST_DAG_ID = "unit_tests"
 DEFAULT_DATE = timezone.datetime(2018, 1, 1)
 
@@ -373,7 +369,7 @@ class TestDataSyncOperatorCreate(DataSyncTestCaseBase):
                 run_type=DagRunType.MANUAL,
                 state=DagRunState.RUNNING,
             )
-            ti = create_task_instance(task=self.datasync, 
dag_version_id=dag_version.id)
+            ti = create_task_instance(task=self.datasync, run_id="test", 
dag_version_id=dag_version.id)
         else:
             dag_run = DagRun(
                 dag_id=self.dag.dag_id,
@@ -386,7 +382,7 @@ class TestDataSyncOperatorCreate(DataSyncTestCaseBase):
         ti.dag_run = dag_run
         session.add(ti)
         session.commit()
-        assert self.datasync.execute(ti.get_template_context()) is not None
+        assert self.datasync.execute(get_template_context(ti, self.datasync)) 
is not None
         # ### Check mocks:
         mock_get_conn.assert_called()
 
@@ -587,7 +583,7 @@ class TestDataSyncOperatorGetTasks(DataSyncTestCaseBase):
 
             sync_dag_to_db(self.dag)
             dag_version = DagVersion.get_latest_version(self.dag.dag_id)
-            ti = create_task_instance(task=self.datasync, 
dag_version_id=dag_version.id)
+            ti = create_task_instance(task=self.datasync, run_id="test", 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=self.dag.dag_id,
                 logical_date=timezone.utcnow(),
@@ -607,7 +603,7 @@ class TestDataSyncOperatorGetTasks(DataSyncTestCaseBase):
         ti.dag_run = dag_run
         session.add(ti)
         session.commit()
-        result = self.datasync.execute(ti.get_template_context())
+        result = self.datasync.execute(get_template_context(ti, self.datasync))
         assert result["TaskArn"] == self.task_arn
         # ### Check mocks:
         mock_get_conn.assert_called()
@@ -710,7 +706,7 @@ class TestDataSyncOperatorUpdate(DataSyncTestCaseBase):
 
             sync_dag_to_db(self.dag)
             dag_version = DagVersion.get_latest_version(self.dag.dag_id)
-            ti = create_task_instance(task=self.datasync, 
dag_version_id=dag_version.id)
+            ti = create_task_instance(task=self.datasync, run_id="test", 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=self.dag.dag_id,
                 logical_date=timezone.utcnow(),
@@ -730,7 +726,7 @@ class TestDataSyncOperatorUpdate(DataSyncTestCaseBase):
         ti.dag_run = dag_run
         session.add(ti)
         session.commit()
-        result = self.datasync.execute(ti.get_template_context())
+        result = self.datasync.execute(get_template_context(ti, self.datasync))
         assert result["TaskArn"] == self.task_arn
         # ### Check mocks:
         mock_get_conn.assert_called()
@@ -926,7 +922,7 @@ class TestDataSyncOperator(DataSyncTestCaseBase):
 
             sync_dag_to_db(self.dag)
             dag_version = DagVersion.get_latest_version(self.dag.dag_id)
-            ti = create_task_instance(task=self.datasync, 
dag_version_id=dag_version.id)
+            ti = create_task_instance(task=self.datasync, run_id="test", 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=self.dag.dag_id,
                 logical_date=timezone.utcnow(),
@@ -946,7 +942,7 @@ class TestDataSyncOperator(DataSyncTestCaseBase):
         ti.dag_run = dag_run
         session.add(ti)
         session.commit()
-        assert self.datasync.execute(ti.get_template_context()) is not None
+        assert self.datasync.execute(get_template_context(ti, self.datasync)) 
is not None
         # ### Check mocks:
         mock_get_conn.assert_called()
 
@@ -1045,7 +1041,7 @@ class TestDataSyncOperatorDelete(DataSyncTestCaseBase):
 
             sync_dag_to_db(self.dag)
             dag_version = DagVersion.get_latest_version(self.dag.dag_id)
-            ti = create_task_instance(task=self.datasync, 
dag_version_id=dag_version.id)
+            ti = create_task_instance(task=self.datasync, run_id="test", 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=self.dag.dag_id,
                 logical_date=timezone.utcnow(),
@@ -1065,7 +1061,7 @@ class TestDataSyncOperatorDelete(DataSyncTestCaseBase):
         ti.dag_run = dag_run
         session.add(ti)
         session.commit()
-        result = self.datasync.execute(ti.get_template_context())
+        result = self.datasync.execute(get_template_context(ti, self.datasync))
         assert result["TaskArn"] == self.task_arn
         # ### Check mocks:
         mock_get_conn.assert_called()
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
index 0a591cce72d..178a2043f98 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
@@ -47,16 +47,16 @@ from airflow.providers.common.compat.sdk import 
AirflowException, TaskDeferred
 from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunType
 
+from tests_common.test_utils.compat import timezone
 from tests_common.test_utils.dag import sync_dag_to_db
-from tests_common.test_utils.taskinstance import create_task_instance, 
render_template_fields
+from tests_common.test_utils.taskinstance import (
+    create_task_instance,
+    get_template_context,
+    render_template_fields,
+)
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 from unit.amazon.aws.utils.test_template_fields import validate_template_fields
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
 if AIRFLOW_V_3_0_PLUS:
     from airflow.models.dag_version import DagVersion
 
@@ -333,7 +333,7 @@ class TestDmsDescribeTasksOperator:
         if AIRFLOW_V_3_0_PLUS:
             sync_dag_to_db(self.dag)
             dag_version = DagVersion.get_latest_version(self.dag.dag_id)
-            ti = create_task_instance(task=describe_task, 
dag_version_id=dag_version.id)
+            ti = create_task_instance(task=describe_task, run_id="test", 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=self.dag.dag_id,
                 logical_date=timezone.utcnow(),
@@ -353,7 +353,7 @@ class TestDmsDescribeTasksOperator:
         ti.dag_run = dag_run
         session.add(ti)
         session.commit()
-        marker, response = describe_task.execute(ti.get_template_context())
+        marker, response = describe_task.execute(get_template_context(ti, 
describe_task))
 
         assert marker is None
         assert response == self.MOCK_RESPONSE
@@ -535,7 +535,7 @@ class TestDmsDescribeReplicationConfigsOperator:
         if AIRFLOW_V_3_0_PLUS:
             sync_dag_to_db(dag)
             dag_version = DagVersion.get_latest_version(dag.dag_id)
-            ti = create_task_instance(task=op, dag_version_id=dag_version.id)
+            ti = create_task_instance(task=op, run_id="test", 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=dag.dag_id,
                 run_id="test",
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py
index d1a32ee8d6f..e52eb36e7bb 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py
@@ -107,7 +107,6 @@ class TestEmrAddStepsOperator:
 
             sync_dag_to_db(self.operator.dag)
             dag_version = 
DagVersion.get_latest_version(self.operator.dag.dag_id)
-            ti = create_task_instance(task=self.operator, 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=self.operator.dag.dag_id,
                 logical_date=DEFAULT_DATE,
@@ -116,6 +115,11 @@ class TestEmrAddStepsOperator:
                 state=DagRunState.RUNNING,
                 run_after=timezone.utcnow(),
             )
+            ti = create_task_instance(
+                task=self.operator,
+                run_id=dag_run.run_id,
+                dag_version_id=dag_version.id,
+            )
         else:
             dag_run = DagRun(
                 dag_id=self.operator.dag.dag_id,
@@ -180,7 +184,6 @@ class TestEmrAddStepsOperator:
 
             sync_dag_to_db(dag)
             dag_version = DagVersion.get_latest_version(dag.dag_id)
-            ti = create_task_instance(task=test_task, 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=dag.dag_id,
                 logical_date=timezone.utcnow(),
@@ -189,6 +192,7 @@ class TestEmrAddStepsOperator:
                 state=DagRunState.RUNNING,
                 run_after=timezone.utcnow(),
             )
+            ti = create_task_instance(task=test_task, run_id=dag_run.run_id, 
dag_version_id=dag_version.id)
         else:
             dag_run = DagRun(
                 dag_id=dag.dag_id,
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
index 28b5952471f..e49f7ba7758 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
@@ -35,17 +35,13 @@ from airflow.providers.common.compat.sdk import TaskDeferred
 from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunType
 
+from tests_common.test_utils.compat import timezone
 from tests_common.test_utils.dag import sync_dag_to_db
 from tests_common.test_utils.taskinstance import create_task_instance, 
render_template_fields
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 from unit.amazon.aws.utils.test_template_fields import validate_template_fields
 from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
 TASK_ID = "test_task"
 
 TEST_DAG_ID = "test_dag_id"
@@ -111,7 +107,6 @@ class TestEmrCreateJobFlowOperator:
 
             sync_dag_to_db(self.operator.dag)
             dag_version = 
DagVersion.get_latest_version(self.operator.dag.dag_id)
-            ti = create_task_instance(task=self.operator, 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=self.operator.dag_id,
                 logical_date=DEFAULT_DATE,
@@ -120,6 +115,7 @@ class TestEmrCreateJobFlowOperator:
                 state=DagRunState.RUNNING,
                 run_after=timezone.utcnow(),
             )
+            ti = create_task_instance(task=self.operator, run_id="test", 
dag_version_id=dag_version.id)
         else:
             dag_run = DagRun(
                 dag_id=self.operator.dag_id,
@@ -165,7 +161,6 @@ class TestEmrCreateJobFlowOperator:
 
             sync_dag_to_db(self.operator.dag)
             dag_version = 
DagVersion.get_latest_version(self.operator.dag.dag_id)
-            ti = create_task_instance(task=self.operator, 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=self.operator.dag_id,
                 logical_date=DEFAULT_DATE,
@@ -174,6 +169,7 @@ class TestEmrCreateJobFlowOperator:
                 state=DagRunState.RUNNING,
                 run_after=timezone.utcnow(),
             )
+            ti = create_task_instance(task=self.operator, run_id="test", 
dag_version_id=dag_version.id)
         else:
             dag_run = DagRun(
                 dag_id=self.operator.dag_id,
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_modify_cluster.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_modify_cluster.py
index a08ca1c5c99..fc012cad306 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_modify_cluster.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_modify_cluster.py
@@ -24,11 +24,7 @@ from airflow.models.dag import DAG
 from airflow.providers.amazon.aws.operators.emr import EmrModifyClusterOperator
 from airflow.providers.common.compat.sdk import AirflowException
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
+from tests_common.test_utils.compat import timezone
 from unit.amazon.aws.utils.test_template_fields import validate_template_fields
 
 DEFAULT_DATE = timezone.datetime(2017, 1, 1)
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
index c98b057c54e..940137f1058 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
@@ -44,11 +44,7 @@ from airflow.providers.amazon.aws.operators.rds import (
 from airflow.providers.amazon.aws.triggers.rds import RdsDbAvailableTrigger, 
RdsDbStoppedTrigger
 from airflow.providers.common.compat.sdk import TaskDeferred
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
+from tests_common.test_utils.compat import timezone
 from unit.amazon.aws.utils.test_template_fields import validate_template_fields
 
 if TYPE_CHECKING:
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
index 73b7cfdc679..ffd43a59c0c 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
@@ -680,7 +680,7 @@ class TestS3DeleteObjectsOperator:
 
             sync_dag_to_db(dag)
             dag_version = DagVersion.get_latest_version(dag.dag_id)
-            ti = create_task_instance(task=op, dag_version_id=dag_version.id)
+            ti = create_task_instance(task=op, run_id="test", 
dag_version_id=dag_version.id)
         else:
             ti = TaskInstance(task=op)
         ti.dag_run = dag_run
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py
index 6ae2ad25b9b..e0269ae603f 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py
@@ -33,16 +33,12 @@ from airflow.providers.common.compat.sdk import 
AirflowException
 from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunType
 
+from tests_common.test_utils.compat import timezone
 from tests_common.test_utils.dag import sync_dag_to_db
 from tests_common.test_utils.taskinstance import create_task_instance, 
render_template_fields
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 from unit.amazon.aws.utils.test_template_fields import validate_template_fields
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
 CONFIG: dict = {
     "key1": "1",
     "key2": {"key3": "3", "key4": "4"},
@@ -221,7 +217,6 @@ class TestSageMakerExperimentOperator:
 
             sync_dag_to_db(dag)
             dag_version = DagVersion.get_latest_version(dag.dag_id)
-            ti = create_task_instance(task=op, dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=dag.dag_id,
                 logical_date=logical_date,
@@ -230,6 +225,7 @@ class TestSageMakerExperimentOperator:
                 state=DagRunState.RUNNING,
                 run_after=timezone.utcnow(),
             )
+            ti = create_task_instance(task=op, run_id="test", 
dag_version_id=dag_version.id)
         else:
             dag_run = DagRun(
                 dag_id=dag.dag_id,
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py 
b/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
index 903171121c5..72891b91bc4 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
@@ -37,10 +37,7 @@ from airflow.providers.amazon.aws.sensors.ecs import (
 from airflow.providers.amazon.version_compat import NOTSET
 from airflow.providers.common.compat.sdk import AirflowException
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
 
 _Operator = TypeVar("_Operator")
 TEST_CLUSTER_NAME = "fake-cluster"
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py 
b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
index a097c5ed3d4..74f724ac398 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
@@ -32,11 +32,7 @@ from airflow.providers.amazon.aws.sensors.rds import (
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.providers.common.compat.sdk import AirflowException
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
+from tests_common.test_utils.compat import timezone
 from unit.amazon.aws.utils.test_template_fields import validate_template_fields
 
 DEFAULT_DATE = timezone.datetime(2019, 1, 1)
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py 
b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
index f789f0cdb57..d5a697cbdb8 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
@@ -32,15 +32,11 @@ from airflow.providers.common.compat.sdk import 
AirflowException
 from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunType
 
+from tests_common.test_utils.compat import timezone
 from tests_common.test_utils.dag import sync_dag_to_db
 from tests_common.test_utils.taskinstance import create_task_instance, 
render_template_fields
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
 DEFAULT_DATE = datetime(2015, 1, 1)
 
 
@@ -149,7 +145,7 @@ class TestS3KeySensor:
                 state=DagRunState.RUNNING,
                 run_after=timezone.utcnow(),
             )
-            ti = create_task_instance(task=op, dag_version_id=dag_version.id)
+            ti = create_task_instance(task=op, run_id="test", 
dag_version_id=dag_version.id)
         else:
             dag_run = DagRun(
                 dag_id=dag.dag_id,
@@ -206,7 +202,7 @@ class TestS3KeySensor:
                 state=DagRunState.RUNNING,
                 run_after=timezone.utcnow(),
             )
-            ti = create_task_instance(task=op, dag_version_id=dag_version.id)
+            ti = create_task_instance(task=op, run_id="test", 
dag_version_id=dag_version.id)
         ti.dag_run = dag_run
         rendered = render_template_fields(ti, op)
         rendered.poke(None)
diff --git 
a/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py 
b/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py
index f589c66fdbf..b2138d88769 100644
--- a/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py
@@ -90,7 +90,6 @@ class TestMongoToS3Operator:
 
             sync_dag_to_db(self.dag)
             dag_version = 
DagVersion.get_latest_version(self.mock_operator.dag_id)
-            ti = create_task_instance(self.mock_operator, 
dag_version_id=dag_version.id)
             dag_run = DagRun(
                 dag_id=self.mock_operator.dag_id,
                 logical_date=DEFAULT_DATE,
@@ -99,6 +98,7 @@ class TestMongoToS3Operator:
                 state=DagRunState.RUNNING,
                 run_after=timezone.utcnow(),
             )
+            ti = create_task_instance(self.mock_operator, run_id="test", 
dag_version_id=dag_version.id)
         else:
             dag_run = DagRun(
                 dag_id=self.mock_operator.dag_id,
diff --git 
a/providers/apache/druid/tests/unit/apache/druid/operators/test_druid.py 
b/providers/apache/druid/tests/unit/apache/druid/operators/test_druid.py
index 41eb0cf06b2..bc2b2c52a07 100644
--- a/providers/apache/druid/tests/unit/apache/druid/operators/test_druid.py
+++ b/providers/apache/druid/tests/unit/apache/druid/operators/test_druid.py
@@ -24,8 +24,9 @@ import pytest
 
 from airflow.providers.apache.druid.hooks.druid import IngestionType
 from airflow.providers.apache.druid.operators.druid import DruidOperator
-from airflow.utils import timezone
-from airflow.utils.types import DagRunType
+
+from tests_common.test_utils.compat import timezone
+from tests_common.test_utils.taskinstance import get_template_context
 
 DEFAULT_DATE = timezone.datetime(2017, 1, 1)
 
@@ -65,8 +66,8 @@ def test_render_template(dag_maker):
             params={"index_type": "index_hadoop", "datasource": 
"datasource_prd"},
         )
 
-    ti = 
dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED).task_instances[0]
-    operator.render_template_fields(ti.get_template_context())
+    ti = dag_maker.create_ti(operator.task_id)
+    operator.render_template_fields(get_template_context(ti, operator))
     assert json.loads(operator.json_index_file) == RENDERED_INDEX
 
 
@@ -89,8 +90,8 @@ def test_render_template_from_file(tmp_path, dag_maker):
             params={"index_type": "index_hadoop", "datasource": 
"datasource_prd"},
         )
 
-    ti = 
dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED).task_instances[0]
-    operator.render_template_fields(ti.get_template_context())
+    ti = dag_maker.create_ti(operator.task_id)
+    operator.render_template_fields(get_template_context(ti, operator))
     assert json.loads(operator.json_index_file) == RENDERED_INDEX
 
 
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
index 82c022918d3..b062ab3a310 100644
--- 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
+++ 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
@@ -21,6 +21,9 @@ from unittest import mock
 
 import pytest
 
+from tests_common.test_utils.compat import timezone
+from tests_common.test_utils.db import clear_db_dags, clear_db_runs, 
clear_rendered_ti_fields
+from tests_common.test_utils.taskinstance import get_template_context
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 
 if AIRFLOW_V_3_0_PLUS:
@@ -28,10 +31,6 @@ if AIRFLOW_V_3_0_PLUS:
 else:
     from airflow.decorators import setup, task, teardown  # type: 
ignore[attr-defined,no-redef]
 
-from airflow.utils import timezone
-
-from tests_common.test_utils.db import clear_db_dags, clear_db_runs, 
clear_rendered_ti_fields
-
 TASK_FUNCTION_NAME_ID = "task_function_name"
 DEFAULT_DATE = timezone.datetime(2023, 1, 1)
 DAG_ID = "k8s_deco_test_dag"
@@ -124,8 +123,7 @@ class TestKubernetesDecoratorsBase:
         session = self.dag_maker.session
         dag_run = 
self.dag_maker.create_dagrun(run_id=f"k8s_decorator_test_{DEFAULT_DATE.date()}")
         ti = dag_run.get_task_instance(task.operator.task_id, session=session)
-        return_val = 
task.operator.execute(context=ti.get_template_context(session=session))
-
+        return_val = task.operator.execute(context=get_template_context(ti, 
task.operator, session=session))
         return ti, return_val
 
 
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py
index 5209fabec05..5b0b247c938 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py
@@ -51,7 +51,7 @@ from airflow.utils.types import DagRunType
 
 from tests_common.test_utils import db
 from tests_common.test_utils.dag import sync_dag_to_db
-from tests_common.test_utils.taskinstance import create_task_instance
+from tests_common.test_utils.taskinstance import create_task_instance, 
get_template_context
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS
 
 if AIRFLOW_V_3_0_PLUS or AIRFLOW_V_3_1_PLUS:
@@ -259,7 +259,7 @@ class TestKubernetesPodOperator:
         (ti,) = dr.task_instances
         ti.map_index = map_index
         self.dag_run = dr
-        context = ti.get_template_context(session=self.dag_maker.session)
+        context = get_template_context(ti, operator, 
session=self.dag_maker.session)
         self.dag_maker.session.commit()  # So 'execute' can read dr and ti.
 
         remote_pod_mock = MagicMock()
@@ -2371,7 +2371,7 @@ class TestKubernetesPodOperatorAsync:
         (ti,) = dr.task_instances
         ti.map_index = map_index
         self.dag_run = dr
-        context = ti.get_template_context(session=self.dag_maker.session)
+        context = get_template_context(ti, operator, 
session=self.dag_maker.session)
         self.dag_maker.session.commit()  # So 'execute' can read dr and ti.
 
         remote_pod_mock = MagicMock()
diff --git a/providers/common/io/tests/unit/common/io/xcom/test_backend.py 
b/providers/common/io/tests/unit/common/io/xcom/test_backend.py
index 8cd2cef1b38..3ad796b7e63 100644
--- a/providers/common/io/tests/unit/common/io/xcom/test_backend.py
+++ b/providers/common/io/tests/unit/common/io/xcom/test_backend.py
@@ -25,12 +25,8 @@ import airflow.models.xcom
 from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend
 from airflow.providers.standard.operators.empty import EmptyOperator
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
 from tests_common.test_utils import db
+from tests_common.test_utils.compat import timezone
 from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS, XCOM_RETURN_KEY
 
diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py 
b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
index e201a73b6c6..45f940236b7 100644
--- a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
+++ b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
@@ -39,10 +39,7 @@ from airflow.providers.dbt.cloud.hooks.dbt import (
     fallback_to_default_account,
 )
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
 
 ACCOUNT_ID_CONN = "account_id_conn"
 NO_ACCOUNT_ID_CONN = "no_account_id_conn"
diff --git a/providers/docker/tests/unit/docker/decorators/test_docker.py 
b/providers/docker/tests/unit/docker/decorators/test_docker.py
index 8d69f5de7bc..51ddc6fe876 100644
--- a/providers/docker/tests/unit/docker/decorators/test_docker.py
+++ b/providers/docker/tests/unit/docker/decorators/test_docker.py
@@ -20,24 +20,12 @@ from importlib.util import find_spec
 
 import pytest
 
-from airflow.models import TaskInstance
-from airflow.providers.common.compat.sdk import AirflowException
+from airflow.providers.common.compat.sdk import DAG, AirflowException, setup, 
task, teardown
 from airflow.utils.state import TaskInstanceState
 
+from tests_common.test_utils.compat import timezone
 from tests_common.test_utils.markers import 
skip_if_force_lowest_dependencies_marker
-from tests_common.test_utils.taskinstance import create_task_instance, 
render_template_fields
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS
-
-if AIRFLOW_V_3_0_PLUS:
-    from airflow.sdk import DAG, setup, task, teardown
-else:
-    from airflow.decorators import setup, task, teardown  # type: 
ignore[attr-defined,no-redef]
-    from airflow.models import DAG  # type: ignore[attr-defined,no-redef]
-
-if AIRFLOW_V_3_1_PLUS:
-    from airflow.sdk import timezone
-else:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.taskinstance import render_template_fields
 
 DEFAULT_DATE = timezone.datetime(2021, 9, 1)
 DILL_INSTALLED = find_spec("dill") is not None
@@ -99,17 +87,9 @@ class TestDockerDecorator:
         with dag_maker():
             ret = f()
 
-        dr = dag_maker.create_dagrun()
-        if AIRFLOW_V_3_0_PLUS:
-            ti = create_task_instance(
-                task=ret.operator,
-                run_id=dr.run_id,
-                dag_version_id=dr.created_dag_version_id,
-            )
-        else:
-            ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
+        ti = dag_maker.create_ti("f")
         rendered = render_template_fields(ti, ret.operator)
-        assert rendered.container_name == f"python_{dr.dag_id}"
+        assert rendered.container_name == f"python_{ti.dag_id}"
         assert rendered.mounts[0]["Target"] == f"/{ti.run_id}"
 
     @pytest.mark.db_test
diff --git a/providers/github/tests/unit/github/operators/test_github.py 
b/providers/github/tests/unit/github/operators/test_github.py
index 25176e222a7..9b9844faeae 100644
--- a/providers/github/tests/unit/github/operators/test_github.py
+++ b/providers/github/tests/unit/github/operators/test_github.py
@@ -25,10 +25,7 @@ from airflow.models import Connection
 from airflow.models.dag import DAG
 from airflow.providers.github.operators.github import GithubOperator
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
 
 DEFAULT_DATE = timezone.datetime(2017, 1, 1)
 github_client_mock = Mock(name="github_client_for_test")
diff --git a/providers/github/tests/unit/github/sensors/test_github.py 
b/providers/github/tests/unit/github/sensors/test_github.py
index a9a414bae37..8001df4edce 100644
--- a/providers/github/tests/unit/github/sensors/test_github.py
+++ b/providers/github/tests/unit/github/sensors/test_github.py
@@ -25,10 +25,7 @@ from airflow.models import Connection
 from airflow.models.dag import DAG
 from airflow.providers.github.sensors.github import GithubTagSensor
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
 
 DEFAULT_DATE = timezone.datetime(2017, 1, 1)
 github_client_mock = Mock(name="github_client_for_test")
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_gcs.py 
b/providers/google/tests/unit/google/cloud/hooks/test_gcs.py
index bdaee90cea2..6040115f220 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_gcs.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_gcs.py
@@ -39,13 +39,9 @@ from airflow.providers.common.compat.sdk import 
AirflowException
 from airflow.providers.google.cloud.hooks import gcs
 from airflow.providers.google.cloud.hooks.gcs import 
_fallback_object_url_to_object_name_and_bucket_name
 from airflow.providers.google.common.consts import CLIENT_INFO
-
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
 from airflow.version import version
 
+from tests_common.test_utils.compat import timezone
 from unit.google.cloud.utils.base_gcp_mock import 
mock_base_gcp_hook_default_project_id
 
 BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
diff --git a/providers/mysql/tests/unit/mysql/hooks/test_mysql.py 
b/providers/mysql/tests/unit/mysql/hooks/test_mysql.py
index 401df681ad6..3e5a731d79d 100644
--- a/providers/mysql/tests/unit/mysql/hooks/test_mysql.py
+++ b/providers/mysql/tests/unit/mysql/hooks/test_mysql.py
@@ -27,6 +27,10 @@ import sqlalchemy
 
 from airflow.models.dag import DAG
 from airflow.providers.common.compat.sdk import Connection
+from airflow.providers.mysql.hooks.mysql import MySqlHook
+
+from tests_common.test_utils.asserts import assert_equal_ignore_multiple_spaces
+from tests_common.test_utils.compat import timezone
 
 try:
     import MySQLdb.cursors
@@ -35,15 +39,6 @@ try:
 except ImportError:
     MYSQL_AVAILABLE = False
 
-from airflow.providers.mysql.hooks.mysql import MySqlHook
-
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-
-from tests_common.test_utils.asserts import assert_equal_ignore_multiple_spaces
-
 SSL_DICT = {"cert": "/tmp/client-cert.pem", "ca": "/tmp/server-ca.pem", "key": 
"/tmp/client-key.pem"}
 INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type, 
description, host, `schema`, login, password, port, is_encrypted, 
is_extra_encrypted, extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)"
 
diff --git a/providers/slack/tests/unit/slack/transfers/test_sql_to_slack.py 
b/providers/slack/tests/unit/slack/transfers/test_sql_to_slack.py
index 4ccec5b79cd..b33f9577784 100644
--- a/providers/slack/tests/unit/slack/transfers/test_sql_to_slack.py
+++ b/providers/slack/tests/unit/slack/transfers/test_sql_to_slack.py
@@ -23,10 +23,7 @@ import pytest
 from airflow.providers.common.compat.sdk import AirflowSkipException
 from airflow.providers.slack.transfers.sql_to_slack import 
SqlToSlackApiFileOperator
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
 
 TEST_DAG_ID = "sql_to_slack_unit_test"
 TEST_TASK_ID = "sql_to_slack_unit_test_task"
diff --git 
a/providers/slack/tests/unit/slack/transfers/test_sql_to_slack_webhook.py 
b/providers/slack/tests/unit/slack/transfers/test_sql_to_slack_webhook.py
index 54cffd37e9a..beb91bf255d 100644
--- a/providers/slack/tests/unit/slack/transfers/test_sql_to_slack_webhook.py
+++ b/providers/slack/tests/unit/slack/transfers/test_sql_to_slack_webhook.py
@@ -24,10 +24,7 @@ import pytest
 from airflow.models import Connection
 from airflow.providers.slack.transfers.sql_to_slack_webhook import 
SqlToSlackWebhookOperator
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
 
 TEST_DAG_ID = "sql_to_slack_unit_test"
 TEST_TASK_ID = "sql_to_slack_unit_test_task"
diff --git a/providers/standard/tests/unit/standard/decorators/test_bash.py 
b/providers/standard/tests/unit/standard/decorators/test_bash.py
index ad899c6bcfc..b4ae3847f26 100644
--- a/providers/standard/tests/unit/standard/decorators/test_bash.py
+++ b/providers/standard/tests/unit/standard/decorators/test_bash.py
@@ -29,8 +29,12 @@ from airflow.models.renderedtifields import 
RenderedTaskInstanceFields
 from airflow.providers.common.compat.sdk import AirflowException, 
AirflowSkipException
 
 from tests_common.test_utils.db import clear_db_dags, clear_db_runs, 
clear_rendered_ti_fields
-from tests_common.test_utils.taskinstance import render_template_fields, 
run_task_instance
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.taskinstance import (
+    get_template_context,
+    render_template_fields,
+    run_task_instance,
+)
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS
 
 if TYPE_CHECKING:
     from airflow.models import TaskInstance
@@ -359,7 +363,10 @@ class TestBashDecorator:
         ti = dr.task_instances[0]
         with pytest.raises(AirflowException, match=f"Can not find the cwd: 
{cwd_path}"):
             run_task_instance(ti, bash_task.operator)
-        assert ti.task.bash_command == "echo"
+        if AIRFLOW_V_3_2_PLUS:
+            assert ti.task.bash_command == "DYNAMIC (set during execution)"
+        else:
+            assert ti.task.bash_command == "echo"
 
     def test_cwd_is_file(self, tmp_path):
         """Verify task failure for user-defined working directory that is 
actually a file."""
@@ -380,7 +387,10 @@ class TestBashDecorator:
         ti = dr.task_instances[0]
         with pytest.raises(AirflowException, match=f"The cwd {cwd_file} must 
be a directory"):
             run_task_instance(ti, bash_task.operator)
-        assert ti.task.bash_command == "echo"
+        if AIRFLOW_V_3_2_PLUS:
+            assert ti.task.bash_command == "DYNAMIC (set during execution)"
+        else:
+            assert ti.task.bash_command == "echo"
 
     def test_command_not_found(self):
         """Fail task if executed command is not found on path."""
@@ -400,7 +410,10 @@ class TestBashDecorator:
             AirflowException, match="Bash command failed\\. The command 
returned a non-zero exit code 127\\."
         ):
             run_task_instance(ti, bash_task.operator)
-        assert ti.task.bash_command == "set -e; something-that-isnt-on-path"
+        if AIRFLOW_V_3_2_PLUS:
+            assert ti.task.bash_command == "DYNAMIC (set during execution)"
+        else:
+            assert ti.task.bash_command == "set -e; 
something-that-isnt-on-path"
 
     def test_multiple_outputs_true(self):
         """Verify setting `multiple_outputs` for a @task.bash-decorated 
function is ignored."""
@@ -496,7 +509,10 @@ class TestBashDecorator:
         ti = dr.task_instances[0]
         with pytest.raises(AirflowException):
             run_task_instance(ti, bash_task.operator)
-        assert ti.task.bash_command == f"{DEFAULT_DATE.date()}; exit 1;"
+        if AIRFLOW_V_3_2_PLUS:
+            assert ti.task.bash_command == "DYNAMIC (set during execution)"
+        else:
+            assert ti.task.bash_command == f"{DEFAULT_DATE.date()}; exit 1;"
 
     @pytest.mark.db_test
     def test_templated_bash_script(self, dag_maker, tmp_path, session):
@@ -519,7 +535,7 @@ class TestBashDecorator:
             task_arg = test_templated_fields_task()
 
         ti: TaskInstance = dag_maker.create_dagrun().task_instances[0]
-        context = ti.get_template_context(session=session)
+        context = get_template_context(ti, task_arg.operator, session=session)
         op = render_template_fields(ti, task_arg.operator, context=context)
         result = op.execute(context=context)
         assert result == "test_templated_fields_task"
diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py 
b/providers/standard/tests/unit/standard/decorators/test_python.py
index 8b82a737739..32b2fc2615d 100644
--- a/providers/standard/tests/unit/standard/decorators/test_python.py
+++ b/providers/standard/tests/unit/standard/decorators/test_python.py
@@ -15,6 +15,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import sys
 import typing
 from collections import namedtuple
@@ -22,11 +23,10 @@ from datetime import date
 
 import pytest
 
-from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.providers.common.compat.sdk import AirflowException, XComNotFound
 
-from tests_common.test_utils.taskinstance import create_task_instance, 
render_template_fields
+from tests_common.test_utils.taskinstance import get_template_context, 
render_template_fields
 from tests_common.test_utils.version_compat import (
     AIRFLOW_V_3_0_1,
     AIRFLOW_V_3_0_PLUS,
@@ -68,6 +68,7 @@ else:
 
 if typing.TYPE_CHECKING:
     from airflow.models.dagrun import DagRun
+    from airflow.models.taskinstance import TaskInstance
 
 pytestmark = pytest.mark.db_test
 
@@ -76,6 +77,15 @@ PY38 = sys.version_info >= (3, 8)
 PY311 = sys.version_info >= (3, 11)
 
 
[email protected](autouse=True)
+def clear_current_task_session():
+    try:
+        import airflow.utils.task_instance_session
+    except ModuleNotFoundError:
+        return
+    airflow.utils.task_instance_session.__current_task_instance_session = None
+
+
 class TestAirflowTaskDecorator(BasePythonTest):
     default_date = DEFAULT_DATE
 
@@ -399,20 +409,13 @@ class TestAirflowTaskDecorator(BasePythonTest):
             ret = arg_task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on 
{{ds}}.", named_tuple)
 
         dr = self.create_dag_run()
-        if AIRFLOW_V_3_0_PLUS:
-            ti = create_task_instance(
-                task=ret.operator,
-                run_id=dr.run_id,
-                dag_version_id=dr.created_dag_version_id,
-            )
-        else:
-            ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
-        rendered_op_args = render_template_fields(ti, ret.operator).op_args
-        assert len(rendered_op_args) == 4
-        assert rendered_op_args[0] == 4
-        assert rendered_op_args[1] == date(2019, 1, 1)
-        assert rendered_op_args[2] == f"dag {self.dag_id} ran on 
{self.ds_templated}."
-        assert rendered_op_args[3] == Named(self.ds_templated, "unchanged")
+        ti = dr.get_task_instance("arg_task", session=self.dag_maker.session)
+        assert render_template_fields(ti, ret.operator).op_args == (
+            4,
+            date(2019, 1, 1),
+            f"dag {self.dag_id} ran on {self.ds_templated}.",
+            Named(self.ds_templated, "unchanged"),
+        )
 
     def test_python_callable_keyword_arguments_are_templatized(self):
         """Test PythonOperator op_kwargs are templatized"""
@@ -427,18 +430,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
             )
 
         dr = self.create_dag_run()
-        if AIRFLOW_V_3_0_PLUS:
-            ti = create_task_instance(
-                task=ret.operator,
-                run_id=dr.run_id,
-                dag_version_id=dr.created_dag_version_id,
-            )
-        else:
-            ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
-        rendered_op_kwargs = render_template_fields(ti, ret.operator).op_kwargs
-        assert rendered_op_kwargs["an_int"] == 4
-        assert rendered_op_kwargs["a_date"] == date(2019, 1, 1)
-        assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} 
ran on {self.ds_templated}."
+        ti = dr.get_task_instance("kwargs_task", 
session=self.dag_maker.session)
+        assert render_template_fields(ti, ret.operator).op_kwargs == {
+            "an_int": 4,
+            "a_date": date(2019, 1, 1),
+            "a_templated_string": f"dag {self.dag_id} ran on 
{self.ds_templated}.",
+        }
 
     def test_manual_task_id(self):
         """Test manually setting task_id"""
@@ -821,14 +818,10 @@ def 
test_mapped_decorator_unmap_merge_op_kwargs(dag_maker, session):
     assert [ti.task_id for ti in dec.schedulable_tis] == ["task2"]
     ti = dec.schedulable_tis[0]
 
-    # Use the real task for unmapping to mimic actual execution path
-    ti.task = dag_maker.dag.task_dict[ti.task_id]
-
-    if AIRFLOW_V_3_0_PLUS:
-        unmapped = ti.task.unmap((ti.get_template_context(session),))
-    else:
-        unmapped = ti.task.unmap((ti.get_template_context(session), session))
-    assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
+    task = dag_maker.dag.task_dict[ti.task_id]
+    context = get_template_context(ti, task, session=session)
+    render_template_fields(ti, task, context=context, session=session)
+    assert set(context["task"].op_kwargs) == {"arg1", "arg2"}
 
 
 @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
@@ -863,11 +856,12 @@ def test_mapped_render_template_fields(dag_maker, 
session):
     mapped_ti.map_index = 0
     mapped_ti.task = mapped.operator
     assert isinstance(mapped_ti.task, MappedOperator)
-    
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
-    assert isinstance(mapped_ti.task, BaseOperator)
+    context = get_template_context(mapped_ti, mapped.operator, session=session)
+    mapped.operator.render_template_fields(context)
+    assert isinstance(context["task"], BaseOperator)
 
-    assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}"
-    assert mapped_ti.task.op_kwargs["arg2"] == "fn"
+    assert context["task"].op_kwargs["arg1"] == "{{ ds }}"
+    assert context["task"].op_kwargs["arg2"] == "fn"
 
 
 @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
diff --git a/providers/standard/tests/unit/standard/operators/test_python.py 
b/providers/standard/tests/unit/standard/operators/test_python.py
index f7bca58e4b9..a59c33b29dc 100644
--- a/providers/standard/tests/unit/standard/operators/test_python.py
+++ b/providers/standard/tests/unit/standard/operators/test_python.py
@@ -62,8 +62,9 @@ from airflow.utils.session import create_session
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
+from tests_common.test_utils.compat import TriggerRule, timezone
 from tests_common.test_utils.db import clear_db_runs
-from tests_common.test_utils.taskinstance import run_task_instance
+from tests_common.test_utils.taskinstance import get_template_context, 
run_task_instance
 from tests_common.test_utils.version_compat import (
     AIRFLOW_V_3_0_1,
     AIRFLOW_V_3_0_PLUS,
@@ -79,16 +80,6 @@ else:
     from airflow.models.baseoperator import BaseOperator  # type: 
ignore[no-redef]
     from airflow.models.taskinstance import set_current_context  # type: 
ignore[attr-defined,no-redef]
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
-try:
-    from airflow.sdk import TriggerRule
-except ImportError:
-    # Compatibility for Airflow < 3.1
-    from airflow.utils.trigger_rule import TriggerRule  # type: 
ignore[no-redef,attr-defined]
-
 if TYPE_CHECKING:
     from airflow.models.dag import DAG
     from airflow.models.dagrun import DagRun
@@ -205,7 +196,7 @@ class BasePythonTest:
         """Create TaskInstance and run it."""
         ti = self.create_ti(fn, **kwargs)
         assert ti.task is not None
-        ti.run()
+        ti = ti.run()
         if return_ti:
             return ti
         return ti.task
@@ -996,11 +987,8 @@ class TestDagBundleImportInSubprocess(BasePythonTest):
         dr = dag_maker.create_dagrun()
         ti = dr.get_task_instance(self.task_id)
 
-        mock_bundle_instance = mock.Mock()
-        mock_bundle_instance.path = str(bundle_root)
-        ti.bundle_instance = mock_bundle_instance
-
-        context = ti.get_template_context()
+        context = get_template_context(ti, op)
+        context["ti"].bundle_instance = mock.Mock(path=str(bundle_root))
 
         # Mock subprocess execution to avoid testing-environment related issues
         # on the ExternalPythonOperator (Socket operation on non-socket)
@@ -1011,13 +999,10 @@ class TestDagBundleImportInSubprocess(BasePythonTest):
         with mock.patch.object(op, "_read_result", return_value=None):
             op.execute(context)
 
-        assert mock_execute_subprocess.called, "_execute_in_subprocess should 
have been called"
-        call_kwargs = mock_execute_subprocess.call_args.kwargs
-        env = call_kwargs.get("env")
-        assert "PYTHONPATH" in env, "PYTHONPATH should be in env"
-
-        pythonpath = env["PYTHONPATH"]
-        assert str(bundle_root) in pythonpath, f"Bundle path {bundle_root} 
should be in PYTHONPATH"
+        pythonpath = 
mock_execute_subprocess.call_args.kwargs["env"]["PYTHONPATH"]
+        assert str(bundle_root) in pythonpath, (
+            f"Bundle path {str(bundle_root)!r} not in PYTHONPATH 
{pythonpath!r}"
+        )
 
 
 @pytest.mark.execution_timeout(120)
@@ -1060,7 +1045,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
             return None
 
         ti = self.run_as_task(f, return_ti=True)
-        assert ti.xcom_pull() is None
+        assert TaskInstance.xcom_pull(ti) is None
 
     def test_return_false(self):
         def f():
@@ -1068,7 +1053,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
 
         ti = self.run_as_task(f, return_ti=True)
 
-        assert ti.xcom_pull() is False
+        assert TaskInstance.xcom_pull(ti) is False
 
     def test_lambda(self):
         with pytest.raises(
@@ -1234,7 +1219,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
             return os.environ["MY_ENV_VAR"]
 
         ti = self.run_as_task(f, env_vars={"MY_ENV_VAR": "ABCDE"}, 
return_ti=True)
-        assert ti.xcom_pull() == "ABCDE"
+        assert TaskInstance.xcom_pull(ti) == "ABCDE"
 
     def test_environment_variables_with_inherit_env_true(self, monkeypatch):
         monkeypatch.setenv("MY_ENV_VAR", "QWERT")
@@ -1245,7 +1230,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
             return os.environ["MY_ENV_VAR"]
 
         ti = self.run_as_task(f, inherit_env=True, return_ti=True)
-        assert ti.xcom_pull() == "QWERT"
+        assert TaskInstance.xcom_pull(ti) == "QWERT"
 
     def test_environment_variables_with_inherit_env_false(self, monkeypatch):
         monkeypatch.setenv("MY_ENV_VAR", "TYUIO")
@@ -1267,7 +1252,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
             return os.environ["MY_ENV_VAR"]
 
         ti = self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, 
inherit_env=True, return_ti=True)
-        assert ti.xcom_pull() == "EFGHI"
+        assert TaskInstance.xcom_pull(ti) == "EFGHI"
 
 
 venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path")
diff --git a/providers/standard/tests/unit/standard/sensors/test_time.py 
b/providers/standard/tests/unit/standard/sensors/test_time.py
index 8e9097d6a64..1989c35d91b 100644
--- a/providers/standard/tests/unit/standard/sensors/test_time.py
+++ b/providers/standard/tests/unit/standard/sensors/test_time.py
@@ -28,10 +28,7 @@ from airflow.providers.common.compat.sdk import TaskDeferred
 from airflow.providers.standard.sensors.time import TimeSensor
 from airflow.providers.standard.triggers.temporal import DateTimeTrigger
 
-try:
-    from airflow.sdk import timezone
-except ImportError:
-    from airflow.utils import timezone  # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
 
 DEFAULT_TIMEZONE = pendulum.timezone("Asia/Singapore")  # UTC+08:00
 DEFAULT_DATE_WO_TZ = datetime(2015, 1, 1)
diff --git a/scripts/ci/prek/check_template_context_variable_in_sync.py 
b/scripts/ci/prek/check_template_context_variable_in_sync.py
index 01f4c461c68..1d78b9d7242 100755
--- a/scripts/ci/prek/check_template_context_variable_in_sync.py
+++ b/scripts/ci/prek/check_template_context_variable_in_sync.py
@@ -32,14 +32,10 @@ import sys
 import typing
 
 sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve()))  # make sure 
common_prek_utils is imported
-from common_prek_utils import (
-    AIRFLOW_CORE_ROOT_PATH,
-    AIRFLOW_CORE_SOURCES_PATH,
-    AIRFLOW_TASK_SDK_SOURCES_PATH,
-)
+
+from common_prek_utils import AIRFLOW_CORE_ROOT_PATH, 
AIRFLOW_TASK_SDK_SOURCES_PATH
 
 TASKRUNNER_PY = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / 
"execution_time" / "task_runner.py"
-CONTEXT_PY = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "utils" / "context.py"
 CONTEXT_HINT = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / 
"definitions" / "context.py"
 TEMPLATES_REF_RST = AIRFLOW_CORE_ROOT_PATH / "docs" / "templates-ref.rst"
 
@@ -117,23 +113,6 @@ def _iter_template_context_keys_from_original_return() -> 
typing.Iterator[str]:
                     yield from extract_keys_from_dict(sub_stmt.value)
 
 
-def _iter_template_context_keys_from_declaration() -> typing.Iterator[str]:
-    context_mod = ast.parse(CONTEXT_PY.read_text("utf-8"), str(CONTEXT_PY))
-    st_known_context_keys = next(
-        stmt.value
-        for stmt in context_mod.body
-        if isinstance(stmt, ast.AnnAssign)
-        and isinstance(stmt.target, ast.Name)
-        and stmt.target.id == "KNOWN_CONTEXT_KEYS"
-    )
-    if not isinstance(st_known_context_keys, ast.Set):
-        raise ValueError("'KNOWN_CONTEXT_KEYS' is not assigned a set literal")
-    for expr in st_known_context_keys.elts:
-        if not isinstance(expr, ast.Constant) or not isinstance(expr.value, 
str):
-            raise ValueError("item in 'KNOWN_CONTEXT_KEYS' set is not a str 
literal")
-        yield expr.value
-
-
 def _iter_template_context_keys_from_type_hints() -> typing.Iterator[str]:
     context_mod = ast.parse(CONTEXT_HINT.read_text("utf-8"), str(CONTEXT_HINT))
     cls_context = next(
@@ -158,7 +137,7 @@ def _iter_template_context_keys_from_documentation() -> 
typing.Iterator[str]:
         yield match.group("name")
 
 
-def _compare_keys(retn_keys: set[str], decl_keys: set[str], hint_keys: 
set[str], docs_keys: set[str]) -> int:
+def _compare_keys(retn_keys: set[str], hint_keys: set[str], docs_keys: 
set[str]) -> int:
     # Added by PythonOperator and commonly used.
     # Not listed in templates-ref (but in operator docs).
     retn_keys.add("templates_dict")
@@ -168,10 +147,7 @@ def _compare_keys(retn_keys: set[str], decl_keys: 
set[str], hint_keys: set[str],
     retn_keys.add("expanded_ti_count")
 
     # TODO: These are the keys that are yet to be ported over to the Task SDK.
-    retn_keys.add("inlet_events")
-    retn_keys.add("params")
     retn_keys.add("test_mode")
-    retn_keys.add("triggering_asset_events")
 
     # Only present in callbacks. Not listed in templates-ref (that doc is for 
task execution).
     retn_keys.update(("exception", "reason", "try_number"))
@@ -182,7 +158,6 @@ def _compare_keys(retn_keys: set[str], decl_keys: set[str], 
hint_keys: set[str],
 
     check_candidates = [
         ("get_template_context()", retn_keys),
-        ("KNOWN_CONTEXT_KEYS", decl_keys),
         ("Context type hint", hint_keys),
         ("templates-ref", docs_keys),
     ]
@@ -198,10 +173,9 @@ def _compare_keys(retn_keys: set[str], decl_keys: 
set[str], hint_keys: set[str],
 
 def main() -> str | int | None:
     retn_keys = set(_iter_template_context_keys_from_original_return())
-    decl_keys = set(_iter_template_context_keys_from_declaration())
     hint_keys = set(_iter_template_context_keys_from_type_hints())
     docs_keys = set(_iter_template_context_keys_from_documentation())
-    return _compare_keys(retn_keys, decl_keys, hint_keys, docs_keys)
+    return _compare_keys(retn_keys, hint_keys, docs_keys)
 
 
 if __name__ == "__main__":
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py 
b/task-sdk/src/airflow/sdk/definitions/dag.py
index b0b8357b5e0..3351f591a95 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -1398,6 +1398,8 @@ def _run_task(
     possible.  This function is only meant for the `dag.test` function as a 
helper function.
     """
     from airflow.sdk._shared.module_loading import import_string
+    from airflow.sdk.serde import deserialize, serialize
+    from airflow.utils.session import create_session
 
     taskrun_result: TaskRunResult | None
     log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
@@ -1429,16 +1431,16 @@ def _run_task(
             ti.task = create_scheduler_operator(taskrun_result.ti.task)
 
             if ti.state == TaskInstanceState.DEFERRED and isinstance(msg, 
DeferTask) and run_triggerer:
-                from airflow.sdk.serde import deserialize, serialize
-                from airflow.utils.session import create_session
-
                 # API Server expects the task instance to be in QUEUED state 
before
                 # resuming from deferral.
                 ti.set_state(TaskInstanceState.QUEUED)
 
                 log.info("[DAG TEST] running trigger in line")
-                # trigger_kwargs need to be deserialized before passing to the 
trigger class since they are in serde encoded format
-                kwargs = deserialize(msg.trigger_kwargs)  # type: 
ignore[type-var]  # needed to convince mypy that trigger_kwargs is a dict or a 
str because its unable to infer JsonValue
+                # trigger_kwargs need to be deserialized before passing to the
+                # trigger class since they are in serde encoded format.
+                # Ignore needed to convince mypy that trigger_kwargs is a dict
+                # or a str because its unable to infer JsonValue.
+                kwargs = deserialize(msg.trigger_kwargs)  # type: 
ignore[type-var]
                 if TYPE_CHECKING:
                     assert isinstance(kwargs, dict)
                 trigger = import_string(msg.classpath)(**kwargs)
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py 
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index 6aff9ca3c68..ef4c4aeb59a 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -752,19 +752,13 @@ class MappedOperator(AbstractOperator):
             "params": params,
         }
 
-    def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator:
+    def unmap(self, resolve: Mapping[str, Any]) -> BaseOperator:
         """
         Get the "normal" Operator after applying the current mapping.
 
         :meta private:
         """
-        if isinstance(resolve, Mapping):
-            kwargs = resolve
-        elif resolve is not None:
-            kwargs, _ = self._expand_mapped_kwargs(*resolve)
-        else:
-            raise RuntimeError("cannot unmap a non-serialized operator without 
context")
-        kwargs = self._get_unmap_kwargs(kwargs, 
strict=self._disallow_kwargs_override)
+        kwargs = self._get_unmap_kwargs(resolve, 
strict=self._disallow_kwargs_override)
         is_setup = kwargs.pop("is_setup", False)
         is_teardown = kwargs.pop("is_teardown", False)
         on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)

Reply via email to