amoghrajesh commented on code in PR #59711:
URL: https://github.com/apache/airflow/pull/59711#discussion_r2645754212
##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -674,13 +674,22 @@ def test_supervise_handles_deferred_task(
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
next_method="execute_complete",
trigger_kwargs={
- "__type": "dict",
- "__var": {
- "moment": {"__type": "datetime", "__var":
1730982899.0},
- "end_from_trigger": False,
+ "moment": {
+ "__classname__": "pendulum.datetime.DateTime",
+ "__version__": 2,
+ "__data__": {
+ "timestamp": 1730982899.0,
+ "tz": {
+ "__classname__": "builtins.tuple",
+ "__version__": 1,
+ "__data__": ["UTC",
"pendulum.tz.timezone.Timezone", 1, True],
+ },
+ },
},
+ "end_from_trigger": False,
},
- next_kwargs={"__type": "dict", "__var": {}},
+ trigger_timeout=None,
+ next_kwargs={},
Review Comment:
I was able to test it and it works as I expect it to:
(Also added in PR desc)
To test this, using the same dag as above, I undid all the changes to task
sdk
```diff
Subject: [PATCH] Adding a fixture for creating connections
---
Index: task-sdk/tests/task_sdk/execution_time/test_task_runner.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
(revision c73862f948cae4890d8506782f2d1a6dc072fa44)
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py (date
1766584407206)
@@ -363,30 +363,18 @@
)
time_machine.move_to(instant, tick=False)
- # Expected DeferTask, it is constructed by _defer_task from exception
and is sent to supervisor
+ # Expected DeferTask
expected_defer_task = DeferTask(
state="deferred",
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
+ # Since we are in the task process here, we expect this to have not
been encoded by serde yet
trigger_kwargs={
- "moment": {
- "__classname__": "pendulum.datetime.DateTime",
- "__version__": 2,
- "__data__": {
- "timestamp": 1732233603.0,
- "tz": {
- "__classname__": "builtins.tuple",
- "__version__": 1,
- "__data__": ["UTC",
"pendulum.tz.timezone.Timezone", 1, True],
- },
- },
- },
"end_from_trigger": False,
+ "moment": instant + timedelta(seconds=3),
},
trigger_timeout=None,
next_method="execute_complete",
next_kwargs={},
- rendered_map_index=None,
- type="DeferTask",
)
# Run the task
Index: task-sdk/src/airflow/sdk/definitions/dag.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
--- a/task-sdk/src/airflow/sdk/definitions/dag.py (revision
c73862f948cae4890d8506782f2d1a6dc072fa44)
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py (date 1766584407199)
@@ -1428,7 +1428,6 @@
ti.task = create_scheduler_operator(taskrun_result.ti.task)
if ti.state == TaskInstanceState.DEFERRED and isinstance(msg,
DeferTask) and run_triggerer:
- from airflow.sdk.serde import deserialize, serialize
from airflow.utils.session import create_session
# API Server expects the task instance to be in QUEUED
state before
@@ -1436,12 +1435,10 @@
ti.set_state(TaskInstanceState.QUEUED)
log.info("[DAG TEST] running trigger in line")
- # trigger_kwargs need to be deserialized before passing to
the trigger class since they are in serde encoded format
- kwargs = deserialize(msg.trigger_kwargs) # type:
ignore[type-var] # mypy doesn't like passing JsonValue | str to deserialize
but its correct
- trigger = import_string(msg.classpath)(**kwargs)
+ trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
event = _run_inline_trigger(trigger, task_sdk_ti)
ti.next_method = msg.next_method
- ti.next_kwargs = {"event": serialize(event.payload)} if
event else msg.next_kwargs
+ ti.next_kwargs = {"event": event.payload} if event else
msg.next_kwargs
log.info("[DAG TEST] Trigger completed")
# Set the state to SCHEDULED so that the task can be
resumed.
Index: task-sdk/src/airflow/sdk/execution_time/task_runner.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py (revision
c73862f948cae4890d8506782f2d1a6dc072fa44)
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py (date
1766584407200)
@@ -962,19 +962,12 @@
log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id,
task_id=ti.task_id, run_id=ti.run_id)
classpath, trigger_kwargs = defer.trigger.serialize()
- from typing import cast
-
- from airflow.sdk.serde import serialize as serde_serialize
-
- trigger_kwargs = cast("JsonValue", serde_serialize(trigger_kwargs))
- next_kwargs = cast("JsonValue", serde_serialize(defer.kwargs or {}))
-
msg = DeferTask(
classpath=classpath,
trigger_kwargs=trigger_kwargs,
trigger_timeout=defer.timeout,
next_method=defer.method_name,
- next_kwargs=next_kwargs,
+ next_kwargs=defer.kwargs or {},
)
state = TaskInstanceState.DEFERRED
@@ -1382,20 +1375,10 @@
execute = task.execute
if ti._ti_context_from_server and (next_method :=
ti._ti_context_from_server.next_method):
- from airflow.sdk.serde import deserialize
-
- next_kwargs_data = ti._ti_context_from_server.next_kwargs or {}
- try:
- if TYPE_CHECKING:
- assert isinstance(next_kwargs_data, dict)
- kwargs = deserialize(next_kwargs_data)
- except (ImportError, KeyError, AttributeError, TypeError):
- from airflow.serialization.serialized_objects import
BaseSerialization
+ from airflow.serialization.serialized_objects import
BaseSerialization
- kwargs = BaseSerialization.deserialize(next_kwargs_data)
+ kwargs =
BaseSerialization.deserialize(ti._ti_context_from_server.next_kwargs or {})
- if TYPE_CHECKING:
- assert isinstance(kwargs, dict)
execute = functools.partial(task.resume_execution,
next_method=next_method, next_kwargs=kwargs)
ctx = contextvars.copy_context()
Index: task-sdk/tests/task_sdk/execution_time/test_supervisor.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
(revision c73862f948cae4890d8506782f2d1a6dc072fa44)
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py (date
1766584407200)
@@ -674,22 +674,13 @@
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
next_method="execute_complete",
trigger_kwargs={
- "moment": {
- "__classname__": "pendulum.datetime.DateTime",
- "__version__": 2,
- "__data__": {
- "timestamp": 1730982899.0,
- "tz": {
- "__classname__": "builtins.tuple",
- "__version__": 1,
- "__data__": ["UTC",
"pendulum.tz.timezone.Timezone", 1, True],
- },
- },
- },
- "end_from_trigger": False,
+ "__type": "dict",
+ "__var": {
+ "moment": {"__type": "datetime", "__var":
1730982899.0},
+ "end_from_trigger": False,
+ },
},
- trigger_timeout=None,
- next_kwargs={},
+ next_kwargs={"__type": "dict", "__var": {}},
),
)
Index: task-sdk/src/airflow/sdk/execution_time/comms.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py (revision
c73862f948cae4890d8506782f2d1a6dc072fa44)
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py (date
1766584407199)
@@ -60,7 +60,7 @@
import attrs
import msgspec
import structlog
-from pydantic import AwareDatetime, BaseModel, ConfigDict, Field,
JsonValue, TypeAdapter
+from pydantic import AwareDatetime, BaseModel, ConfigDict, Field,
JsonValue, TypeAdapter, field_serializer
from airflow.sdk.api.datamodels._generated import (
AssetEventDagRunReference,
@@ -705,6 +705,19 @@
type: Literal["DeferTask"] = "DeferTask"
+ @field_serializer("trigger_kwargs", "next_kwargs", check_fields=True)
+ def _serde_kwarg_fields(self, val: str | dict[str, Any] | None, _info):
+ from airflow.serialization.serialized_objects import
BaseSerialization
+
+ if not isinstance(val, dict):
+ # None, or an encrypted string
+ return val
+
+ if val.keys() == {"__type", "__var"}:
+ # Already encoded.
+ return val
+ return BaseSerialization.serialize(val or {})
+
class RetryTask(TIRetryStatePayload):
"""Update a task instance state to up_for_retry."""
Index: task-sdk/src/airflow/sdk/api/datamodels/_generated.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py (revision
c73862f948cae4890d8506782f2d1a6dc072fa44)
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py (date
1766584407198)
@@ -27,7 +27,7 @@
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field,
JsonValue, RootModel
-API_VERSION: Final[str] = "2026-03-31"
+API_VERSION: Final[str] = "2025-12-08"
class AssetAliasReferenceAssetEventDagRun(BaseModel):
@@ -181,10 +181,10 @@
)
state: Annotated[Literal["deferred"] | None, Field(title="State")] =
"deferred"
classpath: Annotated[str, Field(title="Classpath")]
- trigger_kwargs: Annotated[dict[str, JsonValue] | str | None,
Field(title="Trigger Kwargs")] = None
+ trigger_kwargs: Annotated[dict[str, Any] | str | None,
Field(title="Trigger Kwargs")] = None
trigger_timeout: Annotated[timedelta | None, Field(title="Trigger
Timeout")] = None
next_method: Annotated[str, Field(title="Next Method")]
- next_kwargs: Annotated[dict[str, JsonValue] | None, Field(title="Next
Kwargs")] = None
+ next_kwargs: Annotated[dict[str, Any] | None, Field(title="Next
Kwargs")] = None
rendered_map_index: Annotated[str | None, Field(title="Rendered Map
Index")] = None
```
Once this was done, I am certain that the data that will be sent to the API
server will now be BaseSerialisation encoded.
1. Ran the DAG
2. Checked the database:
Trigger table:
```python
{"__var": {"moment": {"__var": 1766584500.193306, "__type": "datetime"},
"end_from_trigger": false}, "__type": "dict"}
```
next_kwargs in TI table:
```python
{"__var": {"message": "hello from deferred task", "numbers": [1, 2, 3, 4,
5], "timestamp": {"__var": 1766584440.2096, "__type": "datetime"},
"numbers_in_tuple": {"__var": [6, 7, 8, 9, 10], "__type": "tuple"}}, "__type":
"dict"}
```
3. Task deferred fine
4. Task resumed fine and worked fine too
<img width="2560" height="1261" alt="image"
src="https://github.com/user-attachments/assets/14d1d4ae-5ef9-478e-809d-7b892be503cc"
/>
<img width="2560" height="1261" alt="image"
src="https://github.com/user-attachments/assets/8e1a330d-7180-4386-a95b-1ef0cf2ce6d4"
/>
<img width="2560" height="1261" alt="image"
src="https://github.com/user-attachments/assets/f9d2f5a3-1b55-4438-8bcf-5f76117753a6"
/>
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]