amoghrajesh commented on code in PR #59711:
URL: https://github.com/apache/airflow/pull/59711#discussion_r2645754212


##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -674,13 +674,22 @@ def test_supervise_handles_deferred_task(
                 
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
                 next_method="execute_complete",
                 trigger_kwargs={
-                    "__type": "dict",
-                    "__var": {
-                        "moment": {"__type": "datetime", "__var": 
1730982899.0},
-                        "end_from_trigger": False,
+                    "moment": {
+                        "__classname__": "pendulum.datetime.DateTime",
+                        "__version__": 2,
+                        "__data__": {
+                            "timestamp": 1730982899.0,
+                            "tz": {
+                                "__classname__": "builtins.tuple",
+                                "__version__": 1,
+                                "__data__": ["UTC", 
"pendulum.tz.timezone.Timezone", 1, True],
+                            },
+                        },
                     },
+                    "end_from_trigger": False,
                 },
-                next_kwargs={"__type": "dict", "__var": {}},
+                trigger_timeout=None,
+                next_kwargs={},

Review Comment:
   I was able to test it and it works as I expect it to:
   
   (Also added in PR desc)
   
   To test this, using the same dag as above, I undid all the changes to task 
sdk
   
   ```diff
   Subject: [PATCH] Adding a fixture for creating connections
   ---
   Index: task-sdk/tests/task_sdk/execution_time/test_task_runner.py
   IDEA additional info:
   Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
   <+>UTF-8
   ===================================================================
   diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
   --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py     
(revision c73862f948cae4890d8506782f2d1a6dc072fa44)
   +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py     (date 
1766584407206)
   @@ -363,30 +363,18 @@
        )
        time_machine.move_to(instant, tick=False)
    
   -    # Expected DeferTask, it is constructed by _defer_task from exception 
and is sent to supervisor
   +    # Expected DeferTask
        expected_defer_task = DeferTask(
            state="deferred",
            
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
   +        # Since we are in the task process here, we expect this to have not 
been encoded by serde yet
            trigger_kwargs={
   -            "moment": {
   -                "__classname__": "pendulum.datetime.DateTime",
   -                "__version__": 2,
   -                "__data__": {
   -                    "timestamp": 1732233603.0,
   -                    "tz": {
   -                        "__classname__": "builtins.tuple",
   -                        "__version__": 1,
   -                        "__data__": ["UTC", 
"pendulum.tz.timezone.Timezone", 1, True],
   -                    },
   -                },
   -            },
                "end_from_trigger": False,
   +            "moment": instant + timedelta(seconds=3),
            },
            trigger_timeout=None,
            next_method="execute_complete",
            next_kwargs={},
   -        rendered_map_index=None,
   -        type="DeferTask",
        )
    
        # Run the task
   Index: task-sdk/src/airflow/sdk/definitions/dag.py
   IDEA additional info:
   Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
   <+>UTF-8
   ===================================================================
   diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py 
b/task-sdk/src/airflow/sdk/definitions/dag.py
   --- a/task-sdk/src/airflow/sdk/definitions/dag.py    (revision 
c73862f948cae4890d8506782f2d1a6dc072fa44)
   +++ b/task-sdk/src/airflow/sdk/definitions/dag.py    (date 1766584407199)
   @@ -1428,7 +1428,6 @@
                ti.task = create_scheduler_operator(taskrun_result.ti.task)
    
                if ti.state == TaskInstanceState.DEFERRED and isinstance(msg, 
DeferTask) and run_triggerer:
   -                from airflow.sdk.serde import deserialize, serialize
                    from airflow.utils.session import create_session
    
                    # API Server expects the task instance to be in QUEUED 
state before
   @@ -1436,12 +1435,10 @@
                    ti.set_state(TaskInstanceState.QUEUED)
    
                    log.info("[DAG TEST] running trigger in line")
   -                # trigger_kwargs need to be deserialized before passing to 
the trigger class since they are in serde encoded format
   -                kwargs = deserialize(msg.trigger_kwargs)  # type: 
ignore[type-var]  # mypy doesn't like passing JsonValue | str to deserialize 
but its correct
   -                trigger = import_string(msg.classpath)(**kwargs)
   +                trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
                    event = _run_inline_trigger(trigger, task_sdk_ti)
                    ti.next_method = msg.next_method
   -                ti.next_kwargs = {"event": serialize(event.payload)} if 
event else msg.next_kwargs
   +                ti.next_kwargs = {"event": event.payload} if event else 
msg.next_kwargs
                    log.info("[DAG TEST] Trigger completed")
    
                    # Set the state to SCHEDULED so that the task can be 
resumed.
   Index: task-sdk/src/airflow/sdk/execution_time/task_runner.py
   IDEA additional info:
   Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
   <+>UTF-8
   ===================================================================
   diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
   --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py (revision 
c73862f948cae4890d8506782f2d1a6dc072fa44)
   +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py (date 
1766584407200)
   @@ -962,19 +962,12 @@
        log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, 
task_id=ti.task_id, run_id=ti.run_id)
        classpath, trigger_kwargs = defer.trigger.serialize()
    
   -    from typing import cast
   -
   -    from airflow.sdk.serde import serialize as serde_serialize
   -
   -    trigger_kwargs = cast("JsonValue", serde_serialize(trigger_kwargs))
   -    next_kwargs = cast("JsonValue", serde_serialize(defer.kwargs or {}))
   -
        msg = DeferTask(
            classpath=classpath,
            trigger_kwargs=trigger_kwargs,
            trigger_timeout=defer.timeout,
            next_method=defer.method_name,
   -        next_kwargs=next_kwargs,
   +        next_kwargs=defer.kwargs or {},
        )
        state = TaskInstanceState.DEFERRED
    
   @@ -1382,20 +1375,10 @@
        execute = task.execute
    
        if ti._ti_context_from_server and (next_method := 
ti._ti_context_from_server.next_method):
   -        from airflow.sdk.serde import deserialize
   -
   -        next_kwargs_data = ti._ti_context_from_server.next_kwargs or {}
   -        try:
   -            if TYPE_CHECKING:
   -                assert isinstance(next_kwargs_data, dict)
   -            kwargs = deserialize(next_kwargs_data)
   -        except (ImportError, KeyError, AttributeError, TypeError):
   -            from airflow.serialization.serialized_objects import 
BaseSerialization
   +        from airflow.serialization.serialized_objects import 
BaseSerialization
    
   -            kwargs = BaseSerialization.deserialize(next_kwargs_data)
   +        kwargs = 
BaseSerialization.deserialize(ti._ti_context_from_server.next_kwargs or {})
    
   -        if TYPE_CHECKING:
   -            assert isinstance(kwargs, dict)
            execute = functools.partial(task.resume_execution, 
next_method=next_method, next_kwargs=kwargs)
    
        ctx = contextvars.copy_context()
   Index: task-sdk/tests/task_sdk/execution_time/test_supervisor.py
   IDEA additional info:
   Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
   <+>UTF-8
   ===================================================================
   diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
   --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py      
(revision c73862f948cae4890d8506782f2d1a6dc072fa44)
   +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py      (date 
1766584407200)
   @@ -674,22 +674,13 @@
                    
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
                    next_method="execute_complete",
                    trigger_kwargs={
   -                    "moment": {
   -                        "__classname__": "pendulum.datetime.DateTime",
   -                        "__version__": 2,
   -                        "__data__": {
   -                            "timestamp": 1730982899.0,
   -                            "tz": {
   -                                "__classname__": "builtins.tuple",
   -                                "__version__": 1,
   -                                "__data__": ["UTC", 
"pendulum.tz.timezone.Timezone", 1, True],
   -                            },
   -                        },
   -                    },
   -                    "end_from_trigger": False,
   +                    "__type": "dict",
   +                    "__var": {
   +                        "moment": {"__type": "datetime", "__var": 
1730982899.0},
   +                        "end_from_trigger": False,
   +                    },
                    },
   -                trigger_timeout=None,
   -                next_kwargs={},
   +                next_kwargs={"__type": "dict", "__var": {}},
                ),
            )
    
   Index: task-sdk/src/airflow/sdk/execution_time/comms.py
   IDEA additional info:
   Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
   <+>UTF-8
   ===================================================================
   diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
   --- a/task-sdk/src/airflow/sdk/execution_time/comms.py       (revision 
c73862f948cae4890d8506782f2d1a6dc072fa44)
   +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py       (date 
1766584407199)
   @@ -60,7 +60,7 @@
    import attrs
    import msgspec
    import structlog
   -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, 
JsonValue, TypeAdapter
   +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, 
JsonValue, TypeAdapter, field_serializer
    
    from airflow.sdk.api.datamodels._generated import (
        AssetEventDagRunReference,
   @@ -705,6 +705,19 @@
    
        type: Literal["DeferTask"] = "DeferTask"
    
   +    @field_serializer("trigger_kwargs", "next_kwargs", check_fields=True)
   +    def _serde_kwarg_fields(self, val: str | dict[str, Any] | None, _info):
   +        from airflow.serialization.serialized_objects import 
BaseSerialization
   +
   +        if not isinstance(val, dict):
   +            # None, or an encrypted string
   +            return val
   +
   +        if val.keys() == {"__type", "__var"}:
   +            # Already encoded.
   +            return val
   +        return BaseSerialization.serialize(val or {})
   +
    
    class RetryTask(TIRetryStatePayload):
        """Update a task instance state to up_for_retry."""
   Index: task-sdk/src/airflow/sdk/api/datamodels/_generated.py
   IDEA additional info:
   Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
   <+>UTF-8
   ===================================================================
   diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
   --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py  (revision 
c73862f948cae4890d8506782f2d1a6dc072fa44)
   +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py  (date 
1766584407198)
   @@ -27,7 +27,7 @@
    
    from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, 
JsonValue, RootModel
    
   -API_VERSION: Final[str] = "2026-03-31"
   +API_VERSION: Final[str] = "2025-12-08"
    
    
    class AssetAliasReferenceAssetEventDagRun(BaseModel):
   @@ -181,10 +181,10 @@
        )
        state: Annotated[Literal["deferred"] | None, Field(title="State")] = 
"deferred"
        classpath: Annotated[str, Field(title="Classpath")]
   -    trigger_kwargs: Annotated[dict[str, JsonValue] | str | None, 
Field(title="Trigger Kwargs")] = None
   +    trigger_kwargs: Annotated[dict[str, Any] | str | None, 
Field(title="Trigger Kwargs")] = None
        trigger_timeout: Annotated[timedelta | None, Field(title="Trigger 
Timeout")] = None
        next_method: Annotated[str, Field(title="Next Method")]
   -    next_kwargs: Annotated[dict[str, JsonValue] | None, Field(title="Next 
Kwargs")] = None
   +    next_kwargs: Annotated[dict[str, Any] | None, Field(title="Next 
Kwargs")] = None
        rendered_map_index: Annotated[str | None, Field(title="Rendered Map 
Index")] = None
   ```
   
   Once this was done, I am certain that the data that will be sent to the API 
server will now be BaseSerialisation encoded.
   
   1. Ran the DAG
   
   2. Checked the database:
   
   Trigger table:
   ```python
   {"__var": {"moment": {"__var": 1766584500.193306, "__type": "datetime"}, 
"end_from_trigger": false}, "__type": "dict"}
   ```
   
   next_kwargs in TI table:
   
   ```python
   {"__var": {"message": "hello from deferred task", "numbers": [1, 2, 3, 4, 
5], "timestamp": {"__var": 1766584440.2096, "__type": "datetime"}, 
"numbers_in_tuple": {"__var": [6, 7, 8, 9, 10], "__type": "tuple"}}, "__type": 
"dict"}
   ```
   
   3. Task deferred fine
   4. Task resumed fine and worked fine too
   
   <img width="2560" height="1261" alt="image" 
src="https://github.com/user-attachments/assets/14d1d4ae-5ef9-478e-809d-7b892be503cc";
 />
   
   <img width="2560" height="1261" alt="image" 
src="https://github.com/user-attachments/assets/8e1a330d-7180-4386-a95b-1ef0cf2ce6d4";
 />
   
   <img width="2560" height="1261" alt="image" 
src="https://github.com/user-attachments/assets/f9d2f5a3-1b55-4438-8bcf-5f76117753a6";
 />



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to