This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 05ac92fb0d0 Use SDK serde for trigger and next kwargs serialization 
(#59711)
05ac92fb0d0 is described below

commit 05ac92fb0d02de88770846e571b980cdeec7bd72
Author: Amogh Desai <[email protected]>
AuthorDate: Tue Dec 30 12:47:43 2025 +0530

    Use SDK serde for trigger and next kwargs serialization (#59711)
---
 airflow-core/.pre-commit-config.yaml               |  1 +
 .../execution_api/datamodels/taskinstance.py       |  4 +-
 .../execution_api/routes/task_instances.py         | 11 +--
 .../api_fastapi/execution_api/versions/__init__.py |  2 +
 .../execution_api/versions/v2026_03_31.py          | 35 ++++++++
 airflow-core/src/airflow/models/taskinstance.py    |  5 +-
 .../src/airflow/models/taskinstancehistory.py      |  5 +-
 airflow-core/src/airflow/models/trigger.py         | 39 ++++++---
 .../versions/head/test_task_instances.py           | 93 ++++++++++++++++++----
 .../tests/unit/standard/operators/test_hitl.py     |  8 +-
 .../src/airflow/sdk/api/datamodels/_generated.py   |  6 +-
 task-sdk/src/airflow/sdk/definitions/dag.py        |  9 ++-
 task-sdk/src/airflow/sdk/execution_time/comms.py   | 15 +---
 .../src/airflow/sdk/execution_time/task_runner.py  | 25 +++++-
 .../task_sdk/execution_time/test_supervisor.py     | 19 +++--
 .../task_sdk/execution_time/test_task_runner.py    | 18 ++++-
 16 files changed, 225 insertions(+), 70 deletions(-)

diff --git a/airflow-core/.pre-commit-config.yaml 
b/airflow-core/.pre-commit-config.yaml
index 38e740e27ae..106a7605b17 100644
--- a/airflow-core/.pre-commit-config.yaml
+++ b/airflow-core/.pre-commit-config.yaml
@@ -370,6 +370,7 @@ repos:
           ^src/airflow/models/taskmap\.py$|
           ^src/airflow/models/taskmixin\.py$|
           ^src/airflow/models/taskreschedule\.py$|
+          ^src/airflow/models/trigger\.py$|
           ^src/airflow/models/variable\.py$|
           ^src/airflow/models/xcom\.py$|
           ^src/airflow/models/xcom_arg\.py$|
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index 134019b6e31..a4c5355a826 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -129,7 +129,7 @@ class TIDeferredStatePayload(StrictBaseModel):
         ),
     ]
     classpath: str
-    trigger_kwargs: Annotated[dict[str, Any] | str, 
Field(default_factory=dict)]
+    trigger_kwargs: Annotated[dict[str, JsonValue] | str, 
Field(default_factory=dict)]
     """
     Kwargs to pass to the trigger constructor, either a plain dict or an 
encrypted string.
 
@@ -139,7 +139,7 @@ class TIDeferredStatePayload(StrictBaseModel):
     trigger_timeout: timedelta | None = None
     next_method: str
     """The name of the method on the operator to call in the worker after the 
trigger has fired."""
-    next_kwargs: Annotated[dict[str, Any], Field(default_factory=dict)]
+    next_kwargs: Annotated[dict[str, JsonValue], Field(default_factory=dict)]
     """
     Kwargs to pass to the above method, either a plain dict or an encrypted 
string.
 
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 80738cef12f..60711ea255e 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -507,19 +507,12 @@ def _create_ti_state_update_query_and_update_state(
 
         query = update(TI).where(TI.id == ti_id_str)
 
-        # This is slightly inefficient as we deserialize it to then right 
again serialize it in the sqla
-        # TypeAdapter.
-        next_kwargs = None
-        if ti_patch_payload.next_kwargs:
-            from airflow.serialization.serialized_objects import 
BaseSerialization
-
-            next_kwargs = 
BaseSerialization.deserialize(ti_patch_payload.next_kwargs)
-
+        # Store next_kwargs directly (already serialized by worker)
         query = query.values(
             state=TaskInstanceState.DEFERRED,
             trigger_id=trigger_row.id,
             next_method=ti_patch_payload.next_method,
-            next_kwargs=next_kwargs,
+            next_kwargs=ti_patch_payload.next_kwargs,
             trigger_timeout=timeout,
         )
         updated_state = TaskInstanceState.DEFERRED
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index 09c8f2d3e81..36c1b31b959 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -33,9 +33,11 @@ from airflow.api_fastapi.execution_api.versions.v2025_12_08 
import (
     AddDagRunDetailEndpoint,
     MovePreviousRunEndpoint,
 )
+from airflow.api_fastapi.execution_api.versions.v2026_03_31 import 
ModifyDeferredTaskKwargsToJsonValue
 
 bundle = VersionBundle(
     HeadVersion(),
+    Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue),
     Version("2025-12-08", MovePreviousRunEndpoint, AddDagRunDetailEndpoint),
     Version("2025-11-07", AddPartitionKeyField),
     Version("2025-11-05", AddTriggeringUserNameField),
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
new file mode 100644
index 00000000000..48630e0b50c
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Any
+
+from cadwyn import VersionChange, schema
+
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import 
TIDeferredStatePayload
+
+
+class ModifyDeferredTaskKwargsToJsonValue(VersionChange):
+    """Change the types of `trigger_kwargs` and `next_kwargs` in 
TIDeferredStatePayload to JsonValue."""
+
+    description = __doc__
+
+    instructions_to_migrate_to_previous_version = (
+        
schema(TIDeferredStatePayload).field("trigger_kwargs").had(type=dict[str, Any] 
| str),
+        schema(TIDeferredStatePayload).field("next_kwargs").had(type=dict[str, 
Any]),
+    )
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index d917bec159b..0914a635381 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -36,6 +36,7 @@ import dill
 import lazy_object_proxy
 import uuid6
 from sqlalchemy import (
+    JSON,
     Float,
     ForeignKey,
     ForeignKeyConstraint,
@@ -433,7 +434,9 @@ class TaskInstance(Base, LoggingMixin):
     # The method to call next, and any extra arguments to pass to it.
     # Usually used when resuming from DEFERRED.
     next_method: Mapped[str | None] = mapped_column(String(1000), 
nullable=True)
-    next_kwargs: Mapped[dict | None] = 
mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True)
+    next_kwargs: Mapped[dict | None] = mapped_column(
+        MutableDict.as_mutable(JSON().with_variant(postgresql.JSONB, 
"postgresql")), nullable=True
+    )
 
     _task_display_property_value: Mapped[str | None] = mapped_column(
         "task_display_name", String(2000), nullable=True
diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py 
b/airflow-core/src/airflow/models/taskinstancehistory.py
index 3e885aec095..b69d2c96ff9 100644
--- a/airflow-core/src/airflow/models/taskinstancehistory.py
+++ b/airflow-core/src/airflow/models/taskinstancehistory.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING
 
 import dill
 from sqlalchemy import (
+    JSON,
     DateTime,
     Float,
     ForeignKeyConstraint,
@@ -109,7 +110,9 @@ class TaskInstanceHistory(Base):
     trigger_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
     trigger_timeout: Mapped[DateTime | None] = mapped_column(DateTime, 
nullable=True)
     next_method: Mapped[str | None] = mapped_column(String(1000), 
nullable=True)
-    next_kwargs: Mapped[dict | None] = 
mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True)
+    next_kwargs: Mapped[dict | None] = mapped_column(
+        MutableDict.as_mutable(JSON().with_variant(postgresql.JSONB, 
"postgresql")), nullable=True
+    )
 
     task_display_name: Mapped[str | None] = mapped_column(String(2000), 
nullable=True)
     dag_version_id: Mapped[str | None] = mapped_column(UUIDType(binary=False), 
nullable=True)
diff --git a/airflow-core/src/airflow/models/trigger.py 
b/airflow-core/src/airflow/models/trigger.py
index a900dbf44a7..a6016363742 100644
--- a/airflow-core/src/airflow/models/trigger.py
+++ b/airflow-core/src/airflow/models/trigger.py
@@ -142,9 +142,9 @@ class Trigger(Base):
         import json
 
         from airflow.models.crypto import get_fernet
-        from airflow.serialization.serialized_objects import BaseSerialization
+        from airflow.sdk.serde import serialize
 
-        serialized_kwargs = BaseSerialization.serialize(kwargs)
+        serialized_kwargs = serialize(kwargs)
         return 
get_fernet().encrypt(json.dumps(serialized_kwargs).encode("utf-8")).decode("utf-8")
 
     @staticmethod
@@ -153,7 +153,7 @@ class Trigger(Base):
         import json
 
         from airflow.models.crypto import get_fernet
-        from airflow.serialization.serialized_objects import BaseSerialization
+        from airflow.sdk.serde import deserialize
 
         # We weren't able to encrypt the kwargs in all migration paths,
         # so we need to handle the case where they are not encrypted.
@@ -165,7 +165,16 @@ class Trigger(Base):
                 
get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")
             )
 
-        return BaseSerialization.deserialize(decrypted_kwargs)
+        try:
+            result = deserialize(decrypted_kwargs)
+            if TYPE_CHECKING:
+                assert isinstance(result, dict)
+            return result
+        except (ImportError, KeyError, AttributeError, TypeError):
+            # Backward compatibility: fall back to BaseSerialization for old 
format
+            from airflow.serialization.serialized_objects import 
BaseSerialization
+
+            return BaseSerialization.deserialize(decrypted_kwargs)
 
     def rotate_fernet_key(self):
         """Encrypts data with a new key. See: :ref:`security/fernet`."""
@@ -417,16 +426,28 @@ def handle_event_submit(event: TriggerEvent, *, 
task_instance: TaskInstance, ses
     :param task_instance: The task instance to handle the submit event for.
     :param session: The session to be used for the database callback sink.
     """
+    from airflow.sdk.serde import deserialize, serialize
     from airflow.utils.state import TaskInstanceState
 
-    # Get the next kwargs of the task instance, or an empty dictionary if it 
doesn't exist
-    next_kwargs = task_instance.next_kwargs or {}
+    next_kwargs_raw = task_instance.next_kwargs or {}
+
+    # deserialize first to provide a compat layer if there are mixed 
serialized (BaseSerialisation and serde) data
+    # which can happen if a deferred task resumes after upgrade
+    try:
+        next_kwargs = deserialize(next_kwargs_raw)
+    except (ImportError, KeyError, AttributeError, TypeError):
+        from airflow.serialization.serialized_objects import BaseSerialization
+
+        next_kwargs = BaseSerialization.deserialize(next_kwargs_raw)
 
-    # Add the event's payload into the kwargs for the task
+    # Add event to the plain dict, then serialize everything together. This 
ensures that the event is properly
+    # nested inside __var__ in the final serde serialized structure.
+    if TYPE_CHECKING:
+        assert isinstance(next_kwargs, dict)
     next_kwargs["event"] = event.payload
 
-    # Update the next kwargs of the task instance
-    task_instance.next_kwargs = next_kwargs
+    # re-serialize the entire dict using serde to ensure consistent structure
+    task_instance.next_kwargs = serialize(next_kwargs)
 
     # Remove ourselves as its trigger
     task_instance.trigger_id = None
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index d794cb4fe20..e95b09f1ea1 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -579,8 +579,23 @@ class TestTIRunState:
         )
 
         ti.next_method = "execute_complete"
-        # ti.next_kwargs under the hood applies the serde encoding for us
-        ti.next_kwargs = {"moment": instant}
+        # explicitly use serde serialized value before assigning since we use 
JSON/JSONB now
+        # that this value comes serde serialized from the worker
+        expected_next_kwargs = {
+            "moment": {
+                "__classname__": "pendulum.datetime.DateTime",
+                "__version__": 2,
+                "__data__": {
+                    "timestamp": 1727697600.0,
+                    "tz": {
+                        "__classname__": "builtins.tuple",
+                        "__version__": 1,
+                        "__data__": ["UTC", "pendulum.tz.timezone.Timezone", 
1, True],
+                    },
+                },
+            }
+        }
+        ti.next_kwargs = expected_next_kwargs
 
         session.commit()
 
@@ -606,10 +621,7 @@ class TestTIRunState:
             "connections": [],
             "xcom_keys_to_clear": [],
             "next_method": "execute_complete",
-            "next_kwargs": {
-                "__type": "dict",
-                "__var": {"moment": {"__type": "datetime", "__var": 
1727697600.0}},
-            },
+            "next_kwargs": expected_next_kwargs,
         }
 
     @pytest.mark.parametrize("resume", [True, False])
@@ -632,14 +644,26 @@ class TestTIRunState:
         second_start_time = orig_task_start_time.add(seconds=30)
         second_start_time_str = second_start_time.isoformat()
 
-        # ti.next_kwargs under the hood applies the serde encoding for us
+        # explicitly serialize using serde before assigning since we use 
JSON/JSONB now
+        # this value comes serde serialized from the worker
         if resume:
-            ti.next_kwargs = {"moment": second_start_time}
-            expected_start_date = orig_task_start_time
+            # expected format is now in serde serialized format
             expected_next_kwargs = {
-                "__type": "dict",
-                "__var": {"moment": {"__type": "datetime", "__var": 
second_start_time.timestamp()}},
+                "moment": {
+                    "__classname__": "pendulum.datetime.DateTime",
+                    "__version__": 2,
+                    "__data__": {
+                        "timestamp": 1727697635.0,
+                        "tz": {
+                            "__classname__": "builtins.tuple",
+                            "__version__": 1,
+                            "__data__": ["UTC", 
"pendulum.tz.timezone.Timezone", 1, True],
+                        },
+                    },
+                }
             }
+            ti.next_kwargs = expected_next_kwargs
+            expected_start_date = orig_task_start_time
         else:
             expected_start_date = second_start_time
             expected_next_kwargs = None
@@ -1123,17 +1147,40 @@ class TestTIUpdateState:
 
         payload = {
             "state": "deferred",
-            # Raw payload is already "encoded", but not encrypted
+            # expected format is now in serde serialized format
             "trigger_kwargs": {
-                "__type": "dict",
-                "__var": {"key": "value", "moment": {"__type": "datetime", 
"__var": 1734480001.0}},
+                "key": "value",
+                "moment": {
+                    "__classname__": "datetime.datetime",
+                    "__version__": 2,
+                    "__data__": {
+                        "timestamp": 1734480001.0,
+                        "tz": {
+                            "__classname__": "builtins.tuple",
+                            "__version__": 1,
+                            "__data__": ["UTC", 
"pendulum.tz.timezone.Timezone", 1, True],
+                        },
+                    },
+                },
             },
             "trigger_timeout": "P1D",  # 1 day
             "classpath": "my-classpath",
             "next_method": "execute_callback",
+            # expected format is now in serde serialized format
             "next_kwargs": {
-                "__type": "dict",
-                "__var": {"foo": {"__type": "datetime", "__var": 
1734480000.0}, "bar": "abc"},
+                "foo": {
+                    "__classname__": "datetime.datetime",
+                    "__version__": 2,
+                    "__data__": {
+                        "timestamp": 1734480000.0,
+                        "tz": {
+                            "__classname__": "builtins.tuple",
+                            "__version__": 1,
+                            "__data__": ["UTC", 
"pendulum.tz.timezone.Timezone", 1, True],
+                        },
+                    },
+                },
+                "bar": "abc",
             },
         }
 
@@ -1149,9 +1196,21 @@ class TestTIUpdateState:
 
         assert tis[0].state == TaskInstanceState.DEFERRED
         assert tis[0].next_method == "execute_callback"
+
         assert tis[0].next_kwargs == {
+            "foo": {
+                "__classname__": "datetime.datetime",
+                "__version__": 2,
+                "__data__": {
+                    "timestamp": 1734480000.0,
+                    "tz": {
+                        "__classname__": "builtins.tuple",
+                        "__version__": 1,
+                        "__data__": ["UTC", "pendulum.tz.timezone.Timezone", 
1, True],
+                    },
+                },
+            },
             "bar": "abc",
-            "foo": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
         }
         assert tis[0].trigger_timeout == timezone.make_aware(datetime(2024, 
11, 23), timezone=timezone.utc)
 
diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py 
b/providers/standard/tests/unit/standard/operators/test_hitl.py
index 3bce42df198..acfc1078f4a 100644
--- a/providers/standard/tests/unit/standard/operators/test_hitl.py
+++ b/providers/standard/tests/unit/standard/operators/test_hitl.py
@@ -16,6 +16,8 @@
 # under the License.
 from __future__ import annotations
 
+from uuid import UUID
+
 import pytest
 
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, 
AIRFLOW_V_3_2_PLUS
@@ -252,8 +254,12 @@ class TestHITLOperator:
         assert notifier.called is True
 
         expected_params_in_trigger_kwargs: dict[str, dict[str, Any]]
+        # trigger_kwargs are encoded via BaseSerialization in versions < 3.2
+        expected_ti_id = ti.id
         if AIRFLOW_V_3_2_PLUS:
             expected_params_in_trigger_kwargs = expected_params
+            # trigger_kwargs are encoded via serde from task sdk in versions 
>= 3.2
+            expected_ti_id = UUID(ti.id)
         else:
             expected_params_in_trigger_kwargs = {"input_1": {"value": 1, 
"description": None, "schema": {}}}
 
@@ -262,7 +268,7 @@ class TestHITLOperator:
         )
         assert registered_trigger is not None
         assert registered_trigger.kwargs == {
-            "ti_id": ti.id,
+            "ti_id": expected_ti_id,
             "options": ["1", "2", "3", "4", "5"],
             "defaults": ["1"],
             "params": expected_params_in_trigger_kwargs,
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 63e8edbf81c..5ffcd7a42d4 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -27,7 +27,7 @@ from uuid import UUID
 
 from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, 
RootModel
 
-API_VERSION: Final[str] = "2025-12-08"
+API_VERSION: Final[str] = "2026-03-31"
 
 
 class AssetAliasReferenceAssetEventDagRun(BaseModel):
@@ -198,10 +198,10 @@ class TIDeferredStatePayload(BaseModel):
     )
     state: Annotated[Literal["deferred"] | None, Field(title="State")] = 
"deferred"
     classpath: Annotated[str, Field(title="Classpath")]
-    trigger_kwargs: Annotated[dict[str, Any] | str | None, 
Field(title="Trigger Kwargs")] = None
+    trigger_kwargs: Annotated[dict[str, JsonValue] | str | None, 
Field(title="Trigger Kwargs")] = None
     trigger_timeout: Annotated[timedelta | None, Field(title="Trigger 
Timeout")] = None
     next_method: Annotated[str, Field(title="Next Method")]
-    next_kwargs: Annotated[dict[str, Any] | None, Field(title="Next Kwargs")] 
= None
+    next_kwargs: Annotated[dict[str, JsonValue] | None, Field(title="Next 
Kwargs")] = None
     rendered_map_index: Annotated[str | None, Field(title="Rendered Map 
Index")] = None
 
 
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py 
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 479958b99d7..b0b8357b5e0 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -1429,6 +1429,7 @@ def _run_task(
             ti.task = create_scheduler_operator(taskrun_result.ti.task)
 
             if ti.state == TaskInstanceState.DEFERRED and isinstance(msg, 
DeferTask) and run_triggerer:
+                from airflow.sdk.serde import deserialize, serialize
                 from airflow.utils.session import create_session
 
                 # API Server expects the task instance to be in QUEUED state 
before
@@ -1436,10 +1437,14 @@ def _run_task(
                 ti.set_state(TaskInstanceState.QUEUED)
 
                 log.info("[DAG TEST] running trigger in line")
-                trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
+                # trigger_kwargs need to be deserialized before passing to the 
trigger class since they are in serde encoded format
+                kwargs = deserialize(msg.trigger_kwargs)  # type: 
ignore[type-var]  # needed to convince mypy that trigger_kwargs is a dict or a 
str because its unable to infer JsonValue
+                if TYPE_CHECKING:
+                    assert isinstance(kwargs, dict)
+                trigger = import_string(msg.classpath)(**kwargs)
                 event = _run_inline_trigger(trigger, task_sdk_ti)
                 ti.next_method = msg.next_method
-                ti.next_kwargs = {"event": event.payload} if event else 
msg.next_kwargs
+                ti.next_kwargs = {"event": serialize(event.payload)} if event 
else msg.next_kwargs
                 log.info("[DAG TEST] Trigger completed")
 
                 # Set the state to SCHEDULED so that the task can be resumed.
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 7ef112a4921..52a96d0b665 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -60,7 +60,7 @@ from uuid import UUID
 import attrs
 import msgspec
 import structlog
-from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, 
TypeAdapter, field_serializer
+from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, 
TypeAdapter
 
 from airflow.sdk.api.datamodels._generated import (
     AssetEventDagRunReference,
@@ -714,19 +714,6 @@ class DeferTask(TIDeferredStatePayload):
 
     type: Literal["DeferTask"] = "DeferTask"
 
-    @field_serializer("trigger_kwargs", "next_kwargs", check_fields=True)
-    def _serde_kwarg_fields(self, val: str | dict[str, Any] | None, _info):
-        from airflow.serialization.serialized_objects import BaseSerialization
-
-        if not isinstance(val, dict):
-            # None, or an encrypted string
-            return val
-
-        if val.keys() == {"__type", "__var"}:
-            # Already encoded.
-            return val
-        return BaseSerialization.serialize(val or {})
-
 
 class RetryTask(TIRetryStatePayload):
     """Update a task instance state to up_for_retry."""
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 85fd83551d4..1d15b9b8778 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -1005,12 +1005,21 @@ def _defer_task(
     log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, 
task_id=ti.task_id, run_id=ti.run_id)
     classpath, trigger_kwargs = defer.trigger.serialize()
 
+    from airflow.sdk.serde import serialize as serde_serialize
+
+    trigger_kwargs = serde_serialize(trigger_kwargs)
+    next_kwargs = serde_serialize(defer.kwargs or {})
+
+    if TYPE_CHECKING:
+        assert isinstance(next_kwargs, dict)
+        assert isinstance(trigger_kwargs, dict)
+
     msg = DeferTask(
         classpath=classpath,
         trigger_kwargs=trigger_kwargs,
         trigger_timeout=defer.timeout,
         next_method=defer.method_name,
-        next_kwargs=defer.kwargs or {},
+        next_kwargs=next_kwargs,
     )
     state = TaskInstanceState.DEFERRED
 
@@ -1418,10 +1427,20 @@ def _execute_task(context: Context, ti: 
RuntimeTaskInstance, log: Logger):
     execute = task.execute
 
     if ti._ti_context_from_server and (next_method := 
ti._ti_context_from_server.next_method):
-        from airflow.serialization.serialized_objects import BaseSerialization
+        from airflow.sdk.serde import deserialize
 
-        kwargs = 
BaseSerialization.deserialize(ti._ti_context_from_server.next_kwargs or {})
+        next_kwargs_data = ti._ti_context_from_server.next_kwargs or {}
+        try:
+            if TYPE_CHECKING:
+                assert isinstance(next_kwargs_data, dict)
+            kwargs = deserialize(next_kwargs_data)
+        except (ImportError, KeyError, AttributeError, TypeError):
+            from airflow.serialization.serialized_objects import 
BaseSerialization
 
+            kwargs = BaseSerialization.deserialize(next_kwargs_data)
+
+        if TYPE_CHECKING:
+            assert isinstance(kwargs, dict)
         execute = functools.partial(task.resume_execution, 
next_method=next_method, next_kwargs=kwargs)
 
     ctx = contextvars.copy_context()
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 63fbdfe7138..b45f57cd425 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -686,13 +686,22 @@ class TestWatchedSubprocess:
                 
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
                 next_method="execute_complete",
                 trigger_kwargs={
-                    "__type": "dict",
-                    "__var": {
-                        "moment": {"__type": "datetime", "__var": 
1730982899.0},
-                        "end_from_trigger": False,
+                    "moment": {
+                        "__classname__": "pendulum.datetime.DateTime",
+                        "__version__": 2,
+                        "__data__": {
+                            "timestamp": 1730982899.0,
+                            "tz": {
+                                "__classname__": "builtins.tuple",
+                                "__version__": 1,
+                                "__data__": ["UTC", 
"pendulum.tz.timezone.Timezone", 1, True],
+                            },
+                        },
                     },
+                    "end_from_trigger": False,
                 },
-                next_kwargs={"__type": "dict", "__var": {}},
+                trigger_timeout=None,
+                next_kwargs={},
             ),
         )
 
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 7830d8543db..90cbe6a4f2b 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -366,18 +366,30 @@ def test_run_deferred_basic(time_machine, 
create_runtime_ti, mock_supervisor_com
     )
     time_machine.move_to(instant, tick=False)
 
-    # Expected DeferTask
+    # Expected DeferTask, it is constructed by _defer_task from exception and 
is sent to supervisor
     expected_defer_task = DeferTask(
         state="deferred",
         
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
-        # Since we are in the task process here, we expect this to have not 
been encoded by serde yet
         trigger_kwargs={
+            "moment": {
+                "__classname__": "pendulum.datetime.DateTime",
+                "__version__": 2,
+                "__data__": {
+                    "timestamp": 1732233603.0,
+                    "tz": {
+                        "__classname__": "builtins.tuple",
+                        "__version__": 1,
+                        "__data__": ["UTC", "pendulum.tz.timezone.Timezone", 
1, True],
+                    },
+                },
+            },
             "end_from_trigger": False,
-            "moment": instant + timedelta(seconds=3),
         },
         trigger_timeout=None,
         next_method="execute_complete",
         next_kwargs={},
+        rendered_map_index=None,
+        type="DeferTask",
     )
 
     # Run the task

Reply via email to