Dev-iL commented on code in PR #57952:
URL: https://github.com/apache/airflow/pull/57952#discussion_r2533060136


##########
task-sdk/src/airflow/sdk/definitions/dag.py:
##########
@@ -1414,7 +1415,15 @@ def _run_task(
                 trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
                 event = _run_inline_trigger(trigger, task_sdk_ti)
                 ti.next_method = msg.next_method
-                ti.next_kwargs = {"event": event.payload} if event else 
msg.next_kwargs
+
+                # Deserialize next_kwargs if it's a string (encrypted dict), 
similar to what the API server does
+                next_kwargs_value = {"event": event.payload} if event else 
msg.next_kwargs
+                if isinstance(next_kwargs_value, str):
+                    from airflow.serialization.serialized_objects import 
BaseSerialization
+
+                    ti.next_kwargs = 
BaseSerialization.deserialize(next_kwargs_value)
+                else:
+                    ti.next_kwargs = next_kwargs_value

Review Comment:
   mypy has an issue with the `| str |` case of 
   ```python
   # 
airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py#331
       next_kwargs: dict[str, Any] | str | None = None
       """
       Args to pass to ``next_method``.
   
       Can either be a "decorated" dict, or a string encrypted with the shared 
Fernet key.
       """
   ```
   not the dict. [TP 
remarked](https://github.com/apache/airflow/pull/57952#discussion_r2512591103) 
that perhaps this should never be a dict at all.
   
   ----
   I let the LLM analyze it some more and it came up with this:
   ```md
   ## The Issue
   The original mypy error occurred because `ti.next_kwargs` expects `dict | 
None`, but `msg.next_kwargs` from `DeferTask` can be `dict | str | None` (where 
the `str` represents an encrypted value).
   
   ## The Solution
   Update the deserialization logic in `dag.py` to properly handle all cases:
   
   - None → assign `None` directly
   - Serialized dict (with `__type` and `__var` keys) → call 
`BaseSerialization.deserialize()`
   - Regular dict → use as-is
   - Encrypted string → log warning and assign `None` (encryption not supported 
in `dag.test()` context)
   
   ## Key Discovery
   `BaseSerialization.serialize()` returns a dict with special keys (not a JSON 
string), and `deserialize()` expects this dict format. The "encrypted string" 
case mentioned in comments refers to when Fernet encryption is enabled, which 
doesn't apply to `dag.test()`.
   ```
   
   <details><summary>So now it thinks the code should instead look like 
this:</summary>
   <p>
   
   ```python
   # Deserialize next_kwargs if needed, matching what the API server does
   next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs
   if next_kwargs_value is None:
       ti.next_kwargs = None
   elif isinstance(next_kwargs_value, dict):
       if set(next_kwargs_value.keys()) == {"__type", "__var"}:
           # Serialized format - deserialize it
           from airflow.serialization.serialized_objects import 
BaseSerialization
   
           ti.next_kwargs = BaseSerialization.deserialize(next_kwargs_value)
       else:
           # Regular dict - use as-is
           ti.next_kwargs = next_kwargs_value
   else:
       # String (encrypted) - in dag.test() context, encryption is not used,
       # but we need to handle this for type checking
       # The API server would decrypt this before calling deserialize
       if not isinstance(next_kwargs_value, str):
           raise TypeError(f"Unexpected type for next_kwargs: 
{type(next_kwargs_value)}")
       # For now, we can't decrypt without the Fernet key, so log a warning
       log.warning(
           "[DAG TEST] Received encrypted next_kwargs string, cannot decrypt in 
dag.test() context"
       )
       # Type-ignore needed because we can't decrypt in dag.test()
       ti.next_kwargs = None  # type: ignore[assignment]
   ```
   
   </p>
   </details> 
   
   <details><summary>...with a couple of accompanying unit tests that 
illustrate the different behavior of the existing and proposed 
implementations:</summary>
   <p>
   
   ```python
   class TestDeferredTaskNextKwargs:
       """Test that next_kwargs is properly deserialized when it's an encrypted 
string."""
   
       @pytest.fixture
       def mock_task_instance(self, mocker):
           """Create a mock scheduler TaskInstance."""
           from airflow.sdk import TaskInstanceState
   
           ti = mocker.MagicMock()
           ti.task_id = "test_task"
           ti.dag_id = "test_dag"
           ti.run_id = "test_run"
           ti.map_index = -1
           ti.id = "123e4567-e89b-12d3-a456-426614174000"
           ti.dag_version_id = "223e4567-e89b-12d3-a456-426614174000"
           ti.try_number = 1
           ti.state = TaskInstanceState.DEFERRED
           return ti
   
       @pytest.fixture
       def mock_task(self, mocker):
           """Create a mock task."""
   
           task = mocker.MagicMock()
           task.task_id = "test_task"
           return task
   
       def test_next_kwargs_deserialized_when_encrypted_string(
           self, mock_task_instance, mock_task, monkeypatch, mocker,
       ):
           """
           Test that when msg.next_kwargs is a serialized string, it gets 
properly deserialized to a dict.
   
           This simulates the case where DeferTask contains an 
encrypted/serialized next_kwargs.
           The bug we're testing: without the fix, string next_kwargs would be 
assigned directly to ti.next_kwargs,
           causing a type error. With the fix, it should be deserialized to a 
dict.
           """
           from airflow.sdk.definitions.dag import _run_task
           from airflow.sdk.execution_time.comms import DeferTask
           from airflow.serialization.serialized_objects import 
BaseSerialization
   
           # Setup: Create a dict that will be serialized (simulating real 
behavior)
           original_kwargs = {"custom_param": "value", "count": 42}
           # Serialize it to get the format with __type and __var keys
           serialized_kwargs = BaseSerialization.serialize(original_kwargs)
   
           # serialized_kwargs should now be a dict like:
           # {'__type': 'dict', '__var': {'custom_param': 'value', 'count': 42}}
   
           mock_defer_task = DeferTask(
               classpath="airflow.triggers.base.BaseTrigger",
               trigger_kwargs={},
               next_method="execute_complete",
               next_kwargs=serialized_kwargs,
           )
   
           # Create a mock TaskRunResult with the DeferTask message
           mock_task_run_result = mocker.MagicMock()
           mock_task_run_result.msg = mock_defer_task
           mock_task_run_result.ti.state = "deferred"
           mock_task_run_result.ti.task = mock_task
   
           # Mock the run_task_in_process to return our prepared result
           mock_run_task = mocker.MagicMock(return_value=mock_task_run_result)
           monkeypatch.setattr(
               "airflow.sdk.execution_time.supervisor.run_task_in_process",
               mock_run_task,
               raising=False,
           )
   
           # Mock _run_inline_trigger to return no event (so msg.next_kwargs is 
used)
           mock_inline_trigger = mocker.MagicMock(return_value=None)
           monkeypatch.setattr(
               "airflow.sdk.definitions.dag._run_inline_trigger",
               mock_inline_trigger,
               raising=False,
           )
   
           # Mock create_scheduler_operator
           monkeypatch.setattr(
               
"airflow.serialization.serialized_objects.create_scheduler_operator",
               mocker.MagicMock(return_value=mock_task),
               raising=False,
           )
   
           # Mock import_string to return a mock trigger
           mock_trigger = mocker.MagicMock()
           monkeypatch.setattr(
               "airflow.sdk.module_loading.import_string",
               mocker.MagicMock(return_value=lambda **kwargs: mock_trigger),
               raising=False,
           )
   
           # Mock create_session
           mock_session = mocker.MagicMock()
           mock_create_session = mocker.MagicMock()
           mock_create_session.__enter__ = 
mocker.MagicMock(return_value=mock_session)
           mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
           monkeypatch.setattr(
               "airflow.utils.session.create_session",
               mocker.MagicMock(return_value=mock_create_session),
               raising=False,
           )
   
           # Track assignments to next_kwargs using a property descriptor
           assigned_values = []
           original_next_kwargs = mock_task_instance.next_kwargs
   
           def track_next_kwargs_setter(value):
               assigned_values.append(value)
               # Update the mock's return value for the property
               type(mock_task_instance).next_kwargs = 
mocker.PropertyMock(return_value=value)
   
           # Replace next_kwargs with a property that tracks assignments
           type(mock_task_instance).next_kwargs = mocker.PropertyMock(
               return_value=original_next_kwargs,
               side_effect=lambda: assigned_values[-1] if assigned_values else 
original_next_kwargs
           )
           mocker.patch.object(
               type(mock_task_instance),
               'next_kwargs',
               new_callable=mocker.PropertyMock,
               return_value=original_next_kwargs
           )
   
           # We need to use __setattr__ to capture the assignment
           original_setattr = type(mock_task_instance).__setattr__
   
           def capturing_setattr(obj, name, value):
               if name == 'next_kwargs':
                   assigned_values.append(value)
                   # Store it in the mock's internal state
                   object.__setattr__(obj, '_next_kwargs_value', value)
               else:
                   original_setattr(obj, name, value)
   
           type(mock_task_instance).__setattr__ = capturing_setattr
   
           # Execute _run_task with run_triggerer=True to trigger the deferred 
path
           _run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)
   
           # Restore original behavior
           type(mock_task_instance).__setattr__ = original_setattr
   
           # Verify: The key behavior is that next_kwargs is now a dict, not a 
string
           # This is what the fix ensures - regardless of HOW it's done
           assert len(assigned_values) > 0, "next_kwargs should have been 
assigned at least once"
           final_value = assigned_values[-1]  # Get the last assigned value
   
           assert isinstance(final_value, dict), (
               f"next_kwargs should be a dict after deserialization, not 
{type(final_value).__name__}"
           )
           # Verify the dict contains the expected keys (proves it was properly 
deserialized)
           assert "custom_param" in final_value
           assert "count" in final_value
   
       def test_next_kwargs_with_trigger_event(self, mock_task_instance, 
mock_task, monkeypatch, mocker):
           """
           Test that when a trigger returns an event, event.payload is used for 
next_kwargs.
   
           This verifies that the event path takes precedence over 
msg.next_kwargs.
           """
           from airflow.sdk.definitions.dag import _run_task
           from airflow.sdk.execution_time.comms import DeferTask
   
           # Create a mock event with payload
           mock_event = mocker.MagicMock()
           mock_event.payload = {"event_data": "from_trigger", "timestamp": 
"2024-01-01"}
   
           # Create mock DeferTask with a string next_kwargs that should be 
ignored
           mock_defer_task = DeferTask(
               classpath="airflow.triggers.base.BaseTrigger",
               trigger_kwargs={},
               next_method="execute_complete",
               next_kwargs='{"should": "be_ignored"}',  # Should be ignored 
since we have an event
           )
   
           # Create mock TaskRunResult
           mock_task_run_result = mocker.MagicMock()
           mock_task_run_result.msg = mock_defer_task
           mock_task_run_result.ti.state = "deferred"
           mock_task_run_result.ti.task = mock_task
   
           # Mock dependencies
           monkeypatch.setattr(
               "airflow.sdk.execution_time.supervisor.run_task_in_process",
               mocker.MagicMock(return_value=mock_task_run_result),
               raising=False,
           )
           monkeypatch.setattr(
               "airflow.sdk.definitions.dag._run_inline_trigger",
               mocker.MagicMock(return_value=mock_event),
               raising=False,
           )
           monkeypatch.setattr(
               
"airflow.serialization.serialized_objects.create_scheduler_operator",
               mocker.MagicMock(return_value=mock_task),
               raising=False,
           )
           monkeypatch.setattr(
               "airflow.sdk.module_loading.import_string",
               mocker.MagicMock(return_value=lambda **kwargs: 
mocker.MagicMock()),
               raising=False,
           )
   
           mock_session = mocker.MagicMock()
           mock_create_session = mocker.MagicMock()
           mock_create_session.__enter__ = 
mocker.MagicMock(return_value=mock_session)
           mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
           monkeypatch.setattr(
               "airflow.utils.session.create_session",
               mocker.MagicMock(return_value=mock_create_session),
               raising=False,
           )
   
           # Execute
           _run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)
   
           # Verify: next_kwargs should be set to {"event": event.payload}, not 
msg.next_kwargs
           assert mock_task_instance.next_kwargs == {"event": 
mock_event.payload}
           assert isinstance(mock_task_instance.next_kwargs, dict)
           # Ensure msg.next_kwargs was not used
           assert mock_task_instance.next_kwargs.get("should") != "be_ignored"
   ```
   
   </p>
   </details> 
   
   -------
   
   I'll revert my change to this bit since I don't see how to solve this in an 
acceptable way.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to