This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 05ac92fb0d0 Use SDK serde for trigger and next kwargs serialization
(#59711)
05ac92fb0d0 is described below
commit 05ac92fb0d02de88770846e571b980cdeec7bd72
Author: Amogh Desai <[email protected]>
AuthorDate: Tue Dec 30 12:47:43 2025 +0530
Use SDK serde for trigger and next kwargs serialization (#59711)
---
airflow-core/.pre-commit-config.yaml | 1 +
.../execution_api/datamodels/taskinstance.py | 4 +-
.../execution_api/routes/task_instances.py | 11 +--
.../api_fastapi/execution_api/versions/__init__.py | 2 +
.../execution_api/versions/v2026_03_31.py | 35 ++++++++
airflow-core/src/airflow/models/taskinstance.py | 5 +-
.../src/airflow/models/taskinstancehistory.py | 5 +-
airflow-core/src/airflow/models/trigger.py | 39 ++++++---
.../versions/head/test_task_instances.py | 93 ++++++++++++++++++----
.../tests/unit/standard/operators/test_hitl.py | 8 +-
.../src/airflow/sdk/api/datamodels/_generated.py | 6 +-
task-sdk/src/airflow/sdk/definitions/dag.py | 9 ++-
task-sdk/src/airflow/sdk/execution_time/comms.py | 15 +---
.../src/airflow/sdk/execution_time/task_runner.py | 25 +++++-
.../task_sdk/execution_time/test_supervisor.py | 19 +++--
.../task_sdk/execution_time/test_task_runner.py | 18 ++++-
16 files changed, 225 insertions(+), 70 deletions(-)
diff --git a/airflow-core/.pre-commit-config.yaml
b/airflow-core/.pre-commit-config.yaml
index 38e740e27ae..106a7605b17 100644
--- a/airflow-core/.pre-commit-config.yaml
+++ b/airflow-core/.pre-commit-config.yaml
@@ -370,6 +370,7 @@ repos:
^src/airflow/models/taskmap\.py$|
^src/airflow/models/taskmixin\.py$|
^src/airflow/models/taskreschedule\.py$|
+ ^src/airflow/models/trigger\.py$|
^src/airflow/models/variable\.py$|
^src/airflow/models/xcom\.py$|
^src/airflow/models/xcom_arg\.py$|
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index 134019b6e31..a4c5355a826 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -129,7 +129,7 @@ class TIDeferredStatePayload(StrictBaseModel):
),
]
classpath: str
- trigger_kwargs: Annotated[dict[str, Any] | str,
Field(default_factory=dict)]
+ trigger_kwargs: Annotated[dict[str, JsonValue] | str,
Field(default_factory=dict)]
"""
Kwargs to pass to the trigger constructor, either a plain dict or an
encrypted string.
@@ -139,7 +139,7 @@ class TIDeferredStatePayload(StrictBaseModel):
trigger_timeout: timedelta | None = None
next_method: str
"""The name of the method on the operator to call in the worker after the
trigger has fired."""
- next_kwargs: Annotated[dict[str, Any], Field(default_factory=dict)]
+ next_kwargs: Annotated[dict[str, JsonValue], Field(default_factory=dict)]
"""
Kwargs to pass to the above method, either a plain dict or an encrypted
string.
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 80738cef12f..60711ea255e 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -507,19 +507,12 @@ def _create_ti_state_update_query_and_update_state(
query = update(TI).where(TI.id == ti_id_str)
- # This is slightly inefficient as we deserialize it to then right
again serialize it in the sqla
- # TypeAdapter.
- next_kwargs = None
- if ti_patch_payload.next_kwargs:
- from airflow.serialization.serialized_objects import
BaseSerialization
-
- next_kwargs =
BaseSerialization.deserialize(ti_patch_payload.next_kwargs)
-
+ # Store next_kwargs directly (already serialized by worker)
query = query.values(
state=TaskInstanceState.DEFERRED,
trigger_id=trigger_row.id,
next_method=ti_patch_payload.next_method,
- next_kwargs=next_kwargs,
+ next_kwargs=ti_patch_payload.next_kwargs,
trigger_timeout=timeout,
)
updated_state = TaskInstanceState.DEFERRED
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index 09c8f2d3e81..36c1b31b959 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -33,9 +33,11 @@ from airflow.api_fastapi.execution_api.versions.v2025_12_08
import (
AddDagRunDetailEndpoint,
MovePreviousRunEndpoint,
)
+from airflow.api_fastapi.execution_api.versions.v2026_03_31 import
ModifyDeferredTaskKwargsToJsonValue
bundle = VersionBundle(
HeadVersion(),
+ Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue),
Version("2025-12-08", MovePreviousRunEndpoint, AddDagRunDetailEndpoint),
Version("2025-11-07", AddPartitionKeyField),
Version("2025-11-05", AddTriggeringUserNameField),
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
new file mode 100644
index 00000000000..48630e0b50c
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Any
+
+from cadwyn import VersionChange, schema
+
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import
TIDeferredStatePayload
+
+
+class ModifyDeferredTaskKwargsToJsonValue(VersionChange):
+ """Change the types of `trigger_kwargs` and `next_kwargs` in
TIDeferredStatePayload to JsonValue."""
+
+ description = __doc__
+
+ instructions_to_migrate_to_previous_version = (
+
schema(TIDeferredStatePayload).field("trigger_kwargs").had(type=dict[str, Any]
| str),
+ schema(TIDeferredStatePayload).field("next_kwargs").had(type=dict[str,
Any]),
+ )
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index d917bec159b..0914a635381 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -36,6 +36,7 @@ import dill
import lazy_object_proxy
import uuid6
from sqlalchemy import (
+ JSON,
Float,
ForeignKey,
ForeignKeyConstraint,
@@ -433,7 +434,9 @@ class TaskInstance(Base, LoggingMixin):
# The method to call next, and any extra arguments to pass to it.
# Usually used when resuming from DEFERRED.
next_method: Mapped[str | None] = mapped_column(String(1000),
nullable=True)
- next_kwargs: Mapped[dict | None] =
mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True)
+ next_kwargs: Mapped[dict | None] = mapped_column(
+ MutableDict.as_mutable(JSON().with_variant(postgresql.JSONB,
"postgresql")), nullable=True
+ )
_task_display_property_value: Mapped[str | None] = mapped_column(
"task_display_name", String(2000), nullable=True
diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py
b/airflow-core/src/airflow/models/taskinstancehistory.py
index 3e885aec095..b69d2c96ff9 100644
--- a/airflow-core/src/airflow/models/taskinstancehistory.py
+++ b/airflow-core/src/airflow/models/taskinstancehistory.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING
import dill
from sqlalchemy import (
+ JSON,
DateTime,
Float,
ForeignKeyConstraint,
@@ -109,7 +110,9 @@ class TaskInstanceHistory(Base):
trigger_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
trigger_timeout: Mapped[DateTime | None] = mapped_column(DateTime,
nullable=True)
next_method: Mapped[str | None] = mapped_column(String(1000),
nullable=True)
- next_kwargs: Mapped[dict | None] =
mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True)
+ next_kwargs: Mapped[dict | None] = mapped_column(
+ MutableDict.as_mutable(JSON().with_variant(postgresql.JSONB,
"postgresql")), nullable=True
+ )
task_display_name: Mapped[str | None] = mapped_column(String(2000),
nullable=True)
dag_version_id: Mapped[str | None] = mapped_column(UUIDType(binary=False),
nullable=True)
diff --git a/airflow-core/src/airflow/models/trigger.py
b/airflow-core/src/airflow/models/trigger.py
index a900dbf44a7..a6016363742 100644
--- a/airflow-core/src/airflow/models/trigger.py
+++ b/airflow-core/src/airflow/models/trigger.py
@@ -142,9 +142,9 @@ class Trigger(Base):
import json
from airflow.models.crypto import get_fernet
- from airflow.serialization.serialized_objects import BaseSerialization
+ from airflow.sdk.serde import serialize
- serialized_kwargs = BaseSerialization.serialize(kwargs)
+ serialized_kwargs = serialize(kwargs)
return
get_fernet().encrypt(json.dumps(serialized_kwargs).encode("utf-8")).decode("utf-8")
@staticmethod
@@ -153,7 +153,7 @@ class Trigger(Base):
import json
from airflow.models.crypto import get_fernet
- from airflow.serialization.serialized_objects import BaseSerialization
+ from airflow.sdk.serde import deserialize
# We weren't able to encrypt the kwargs in all migration paths,
# so we need to handle the case where they are not encrypted.
@@ -165,7 +165,16 @@ class Trigger(Base):
get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")
)
- return BaseSerialization.deserialize(decrypted_kwargs)
+ try:
+ result = deserialize(decrypted_kwargs)
+ if TYPE_CHECKING:
+ assert isinstance(result, dict)
+ return result
+ except (ImportError, KeyError, AttributeError, TypeError):
+ # Backward compatibility: fall back to BaseSerialization for old
format
+ from airflow.serialization.serialized_objects import
BaseSerialization
+
+ return BaseSerialization.deserialize(decrypted_kwargs)
def rotate_fernet_key(self):
"""Encrypts data with a new key. See: :ref:`security/fernet`."""
@@ -417,16 +426,28 @@ def handle_event_submit(event: TriggerEvent, *,
task_instance: TaskInstance, ses
:param task_instance: The task instance to handle the submit event for.
:param session: The session to be used for the database callback sink.
"""
+ from airflow.sdk.serde import deserialize, serialize
from airflow.utils.state import TaskInstanceState
- # Get the next kwargs of the task instance, or an empty dictionary if it
doesn't exist
- next_kwargs = task_instance.next_kwargs or {}
+ next_kwargs_raw = task_instance.next_kwargs or {}
+
+ # deserialize first to provide a compat layer if there are mixed
serialized (BaseSerialisation and serde) data
+ # which can happen if a deferred task resumes after upgrade
+ try:
+ next_kwargs = deserialize(next_kwargs_raw)
+ except (ImportError, KeyError, AttributeError, TypeError):
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ next_kwargs = BaseSerialization.deserialize(next_kwargs_raw)
- # Add the event's payload into the kwargs for the task
+ # Add event to the plain dict, then serialize everything together. This
ensures that the event is properly
+ # nested inside __var__ in the final serde serialized structure.
+ if TYPE_CHECKING:
+ assert isinstance(next_kwargs, dict)
next_kwargs["event"] = event.payload
- # Update the next kwargs of the task instance
- task_instance.next_kwargs = next_kwargs
+ # re-serialize the entire dict using serde to ensure consistent structure
+ task_instance.next_kwargs = serialize(next_kwargs)
# Remove ourselves as its trigger
task_instance.trigger_id = None
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index d794cb4fe20..e95b09f1ea1 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -579,8 +579,23 @@ class TestTIRunState:
)
ti.next_method = "execute_complete"
- # ti.next_kwargs under the hood applies the serde encoding for us
- ti.next_kwargs = {"moment": instant}
+ # explicitly use serde serialized value before assigning since we use
JSON/JSONB now
+ # that this value comes serde serialized from the worker
+ expected_next_kwargs = {
+ "moment": {
+ "__classname__": "pendulum.datetime.DateTime",
+ "__version__": 2,
+ "__data__": {
+ "timestamp": 1727697600.0,
+ "tz": {
+ "__classname__": "builtins.tuple",
+ "__version__": 1,
+ "__data__": ["UTC", "pendulum.tz.timezone.Timezone",
1, True],
+ },
+ },
+ }
+ }
+ ti.next_kwargs = expected_next_kwargs
session.commit()
@@ -606,10 +621,7 @@ class TestTIRunState:
"connections": [],
"xcom_keys_to_clear": [],
"next_method": "execute_complete",
- "next_kwargs": {
- "__type": "dict",
- "__var": {"moment": {"__type": "datetime", "__var":
1727697600.0}},
- },
+ "next_kwargs": expected_next_kwargs,
}
@pytest.mark.parametrize("resume", [True, False])
@@ -632,14 +644,26 @@ class TestTIRunState:
second_start_time = orig_task_start_time.add(seconds=30)
second_start_time_str = second_start_time.isoformat()
- # ti.next_kwargs under the hood applies the serde encoding for us
+ # explicitly serialize using serde before assigning since we use
JSON/JSONB now
+ # this value comes serde serialized from the worker
if resume:
- ti.next_kwargs = {"moment": second_start_time}
- expected_start_date = orig_task_start_time
+ # expected format is now in serde serialized format
expected_next_kwargs = {
- "__type": "dict",
- "__var": {"moment": {"__type": "datetime", "__var":
second_start_time.timestamp()}},
+ "moment": {
+ "__classname__": "pendulum.datetime.DateTime",
+ "__version__": 2,
+ "__data__": {
+ "timestamp": 1727697635.0,
+ "tz": {
+ "__classname__": "builtins.tuple",
+ "__version__": 1,
+ "__data__": ["UTC",
"pendulum.tz.timezone.Timezone", 1, True],
+ },
+ },
+ }
}
+ ti.next_kwargs = expected_next_kwargs
+ expected_start_date = orig_task_start_time
else:
expected_start_date = second_start_time
expected_next_kwargs = None
@@ -1123,17 +1147,40 @@ class TestTIUpdateState:
payload = {
"state": "deferred",
- # Raw payload is already "encoded", but not encrypted
+ # expected format is now in serde serialized format
"trigger_kwargs": {
- "__type": "dict",
- "__var": {"key": "value", "moment": {"__type": "datetime",
"__var": 1734480001.0}},
+ "key": "value",
+ "moment": {
+ "__classname__": "datetime.datetime",
+ "__version__": 2,
+ "__data__": {
+ "timestamp": 1734480001.0,
+ "tz": {
+ "__classname__": "builtins.tuple",
+ "__version__": 1,
+ "__data__": ["UTC",
"pendulum.tz.timezone.Timezone", 1, True],
+ },
+ },
+ },
},
"trigger_timeout": "P1D", # 1 day
"classpath": "my-classpath",
"next_method": "execute_callback",
+ # expected format is now in serde serialized format
"next_kwargs": {
- "__type": "dict",
- "__var": {"foo": {"__type": "datetime", "__var":
1734480000.0}, "bar": "abc"},
+ "foo": {
+ "__classname__": "datetime.datetime",
+ "__version__": 2,
+ "__data__": {
+ "timestamp": 1734480000.0,
+ "tz": {
+ "__classname__": "builtins.tuple",
+ "__version__": 1,
+ "__data__": ["UTC",
"pendulum.tz.timezone.Timezone", 1, True],
+ },
+ },
+ },
+ "bar": "abc",
},
}
@@ -1149,9 +1196,21 @@ class TestTIUpdateState:
assert tis[0].state == TaskInstanceState.DEFERRED
assert tis[0].next_method == "execute_callback"
+
assert tis[0].next_kwargs == {
+ "foo": {
+ "__classname__": "datetime.datetime",
+ "__version__": 2,
+ "__data__": {
+ "timestamp": 1734480000.0,
+ "tz": {
+ "__classname__": "builtins.tuple",
+ "__version__": 1,
+ "__data__": ["UTC", "pendulum.tz.timezone.Timezone",
1, True],
+ },
+ },
+ },
"bar": "abc",
- "foo": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
}
assert tis[0].trigger_timeout == timezone.make_aware(datetime(2024,
11, 23), timezone=timezone.utc)
diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py
b/providers/standard/tests/unit/standard/operators/test_hitl.py
index 3bce42df198..acfc1078f4a 100644
--- a/providers/standard/tests/unit/standard/operators/test_hitl.py
+++ b/providers/standard/tests/unit/standard/operators/test_hitl.py
@@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations
+from uuid import UUID
+
import pytest
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_2_PLUS
@@ -252,8 +254,12 @@ class TestHITLOperator:
assert notifier.called is True
expected_params_in_trigger_kwargs: dict[str, dict[str, Any]]
+ # trigger_kwargs are encoded via BaseSerialization in versions < 3.2
+ expected_ti_id = ti.id
if AIRFLOW_V_3_2_PLUS:
expected_params_in_trigger_kwargs = expected_params
+ # trigger_kwargs are encoded via serde from task sdk in versions
>= 3.2
+ expected_ti_id = UUID(ti.id)
else:
expected_params_in_trigger_kwargs = {"input_1": {"value": 1,
"description": None, "schema": {}}}
@@ -262,7 +268,7 @@ class TestHITLOperator:
)
assert registered_trigger is not None
assert registered_trigger.kwargs == {
- "ti_id": ti.id,
+ "ti_id": expected_ti_id,
"options": ["1", "2", "3", "4", "5"],
"defaults": ["1"],
"params": expected_params_in_trigger_kwargs,
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 63e8edbf81c..5ffcd7a42d4 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -27,7 +27,7 @@ from uuid import UUID
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue,
RootModel
-API_VERSION: Final[str] = "2025-12-08"
+API_VERSION: Final[str] = "2026-03-31"
class AssetAliasReferenceAssetEventDagRun(BaseModel):
@@ -198,10 +198,10 @@ class TIDeferredStatePayload(BaseModel):
)
state: Annotated[Literal["deferred"] | None, Field(title="State")] =
"deferred"
classpath: Annotated[str, Field(title="Classpath")]
- trigger_kwargs: Annotated[dict[str, Any] | str | None,
Field(title="Trigger Kwargs")] = None
+ trigger_kwargs: Annotated[dict[str, JsonValue] | str | None,
Field(title="Trigger Kwargs")] = None
trigger_timeout: Annotated[timedelta | None, Field(title="Trigger
Timeout")] = None
next_method: Annotated[str, Field(title="Next Method")]
- next_kwargs: Annotated[dict[str, Any] | None, Field(title="Next Kwargs")]
= None
+ next_kwargs: Annotated[dict[str, JsonValue] | None, Field(title="Next
Kwargs")] = None
rendered_map_index: Annotated[str | None, Field(title="Rendered Map
Index")] = None
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 479958b99d7..b0b8357b5e0 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -1429,6 +1429,7 @@ def _run_task(
ti.task = create_scheduler_operator(taskrun_result.ti.task)
if ti.state == TaskInstanceState.DEFERRED and isinstance(msg,
DeferTask) and run_triggerer:
+ from airflow.sdk.serde import deserialize, serialize
from airflow.utils.session import create_session
# API Server expects the task instance to be in QUEUED state
before
@@ -1436,10 +1437,14 @@ def _run_task(
ti.set_state(TaskInstanceState.QUEUED)
log.info("[DAG TEST] running trigger in line")
- trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
+ # trigger_kwargs need to be deserialized before passing to the
trigger class since they are in serde encoded format
+ kwargs = deserialize(msg.trigger_kwargs) # type:
ignore[type-var] # needed to convince mypy that trigger_kwargs is a dict or a
str because its unable to infer JsonValue
+ if TYPE_CHECKING:
+ assert isinstance(kwargs, dict)
+ trigger = import_string(msg.classpath)(**kwargs)
event = _run_inline_trigger(trigger, task_sdk_ti)
ti.next_method = msg.next_method
- ti.next_kwargs = {"event": event.payload} if event else
msg.next_kwargs
+ ti.next_kwargs = {"event": serialize(event.payload)} if event
else msg.next_kwargs
log.info("[DAG TEST] Trigger completed")
# Set the state to SCHEDULED so that the task can be resumed.
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 7ef112a4921..52a96d0b665 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -60,7 +60,7 @@ from uuid import UUID
import attrs
import msgspec
import structlog
-from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue,
TypeAdapter, field_serializer
+from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue,
TypeAdapter
from airflow.sdk.api.datamodels._generated import (
AssetEventDagRunReference,
@@ -714,19 +714,6 @@ class DeferTask(TIDeferredStatePayload):
type: Literal["DeferTask"] = "DeferTask"
- @field_serializer("trigger_kwargs", "next_kwargs", check_fields=True)
- def _serde_kwarg_fields(self, val: str | dict[str, Any] | None, _info):
- from airflow.serialization.serialized_objects import BaseSerialization
-
- if not isinstance(val, dict):
- # None, or an encrypted string
- return val
-
- if val.keys() == {"__type", "__var"}:
- # Already encoded.
- return val
- return BaseSerialization.serialize(val or {})
-
class RetryTask(TIRetryStatePayload):
"""Update a task instance state to up_for_retry."""
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 85fd83551d4..1d15b9b8778 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -1005,12 +1005,21 @@ def _defer_task(
log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id,
task_id=ti.task_id, run_id=ti.run_id)
classpath, trigger_kwargs = defer.trigger.serialize()
+ from airflow.sdk.serde import serialize as serde_serialize
+
+ trigger_kwargs = serde_serialize(trigger_kwargs)
+ next_kwargs = serde_serialize(defer.kwargs or {})
+
+ if TYPE_CHECKING:
+ assert isinstance(next_kwargs, dict)
+ assert isinstance(trigger_kwargs, dict)
+
msg = DeferTask(
classpath=classpath,
trigger_kwargs=trigger_kwargs,
trigger_timeout=defer.timeout,
next_method=defer.method_name,
- next_kwargs=defer.kwargs or {},
+ next_kwargs=next_kwargs,
)
state = TaskInstanceState.DEFERRED
@@ -1418,10 +1427,20 @@ def _execute_task(context: Context, ti:
RuntimeTaskInstance, log: Logger):
execute = task.execute
if ti._ti_context_from_server and (next_method :=
ti._ti_context_from_server.next_method):
- from airflow.serialization.serialized_objects import BaseSerialization
+ from airflow.sdk.serde import deserialize
- kwargs =
BaseSerialization.deserialize(ti._ti_context_from_server.next_kwargs or {})
+ next_kwargs_data = ti._ti_context_from_server.next_kwargs or {}
+ try:
+ if TYPE_CHECKING:
+ assert isinstance(next_kwargs_data, dict)
+ kwargs = deserialize(next_kwargs_data)
+ except (ImportError, KeyError, AttributeError, TypeError):
+ from airflow.serialization.serialized_objects import
BaseSerialization
+ kwargs = BaseSerialization.deserialize(next_kwargs_data)
+
+ if TYPE_CHECKING:
+ assert isinstance(kwargs, dict)
execute = functools.partial(task.resume_execution,
next_method=next_method, next_kwargs=kwargs)
ctx = contextvars.copy_context()
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 63fbdfe7138..b45f57cd425 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -686,13 +686,22 @@ class TestWatchedSubprocess:
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
next_method="execute_complete",
trigger_kwargs={
- "__type": "dict",
- "__var": {
- "moment": {"__type": "datetime", "__var":
1730982899.0},
- "end_from_trigger": False,
+ "moment": {
+ "__classname__": "pendulum.datetime.DateTime",
+ "__version__": 2,
+ "__data__": {
+ "timestamp": 1730982899.0,
+ "tz": {
+ "__classname__": "builtins.tuple",
+ "__version__": 1,
+ "__data__": ["UTC",
"pendulum.tz.timezone.Timezone", 1, True],
+ },
+ },
},
+ "end_from_trigger": False,
},
- next_kwargs={"__type": "dict", "__var": {}},
+ trigger_timeout=None,
+ next_kwargs={},
),
)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 7830d8543db..90cbe6a4f2b 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -366,18 +366,30 @@ def test_run_deferred_basic(time_machine,
create_runtime_ti, mock_supervisor_com
)
time_machine.move_to(instant, tick=False)
- # Expected DeferTask
+ # Expected DeferTask, it is constructed by _defer_task from exception and
is sent to supervisor
expected_defer_task = DeferTask(
state="deferred",
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
- # Since we are in the task process here, we expect this to have not
been encoded by serde yet
trigger_kwargs={
+ "moment": {
+ "__classname__": "pendulum.datetime.DateTime",
+ "__version__": 2,
+ "__data__": {
+ "timestamp": 1732233603.0,
+ "tz": {
+ "__classname__": "builtins.tuple",
+ "__version__": 1,
+ "__data__": ["UTC", "pendulum.tz.timezone.Timezone",
1, True],
+ },
+ },
+ },
"end_from_trigger": False,
- "moment": instant + timedelta(seconds=3),
},
trigger_timeout=None,
next_method="execute_complete",
next_kwargs={},
+ rendered_map_index=None,
+ type="DeferTask",
)
# Run the task