Dev-iL commented on code in PR #57952:
URL: https://github.com/apache/airflow/pull/57952#discussion_r2533060136
##########
task-sdk/src/airflow/sdk/definitions/dag.py:
##########
@@ -1414,7 +1415,15 @@ def _run_task(
trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
event = _run_inline_trigger(trigger, task_sdk_ti)
ti.next_method = msg.next_method
- ti.next_kwargs = {"event": event.payload} if event else
msg.next_kwargs
+
+ # Deserialize next_kwargs if it's a string (encrypted dict),
similar to what the API server does
+ next_kwargs_value = {"event": event.payload} if event else
msg.next_kwargs
+ if isinstance(next_kwargs_value, str):
+ from airflow.serialization.serialized_objects import
BaseSerialization
+
+ ti.next_kwargs =
BaseSerialization.deserialize(next_kwargs_value)
+ else:
+ ti.next_kwargs = next_kwargs_value
Review Comment:
mypy has an issue with the `| str |` case of
```python
#
airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py#331
next_kwargs: dict[str, Any] | str | None = None
"""
Args to pass to ``next_method``.
Can either be a "decorated" dict, or a string encrypted with the shared
Fernet key.
"""
```
not the dict. [TP
remarked](https://github.com/apache/airflow/pull/57952#discussion_r2512591103)
that perhaps this should never be a dict at all.
----
I let the LLM analyze it some more and it came up with this:
```md
## The Issue
The original mypy error occurred because `ti.next_kwargs` expects `dict |
None`, but `msg.next_kwargs` from `DeferTask` can be `dict | str | None` (where
the `str` represents an encrypted value).
## The Solution
Update the deserialization logic in `dag.py` to properly handle all cases:
- None → assign `None` directly
- Serialized dict (with `__type` and `__var` keys) → call
`BaseSerialization.deserialize()`
- Regular dict → use as-is
- Encrypted string → log warning and assign `None` (encryption not supported
in `dag.test()` context)
## Key Discovery
`BaseSerialization.serialize()` returns a dict with special keys (not a JSON
string), and `deserialize()` expects this dict format. The "encrypted string"
case mentioned in comments refers to when Fernet encryption is enabled, which
doesn't apply to `dag.test()`.
```
<details><summary>So now it thinks the code should instead look like
this:</summary>
<p>
```python
# Deserialize next_kwargs if needed, matching what the API server does
next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs
if next_kwargs_value is None:
ti.next_kwargs = None
elif isinstance(next_kwargs_value, dict):
if set(next_kwargs_value.keys()) == {"__type", "__var"}:
# Serialized format - deserialize it
from airflow.serialization.serialized_objects import
BaseSerialization
ti.next_kwargs = BaseSerialization.deserialize(next_kwargs_value)
else:
# Regular dict - use as-is
ti.next_kwargs = next_kwargs_value
else:
# String (encrypted) - in dag.test() context, encryption is not used,
# but we need to handle this for type checking
# The API server would decrypt this before calling deserialize
if not isinstance(next_kwargs_value, str):
raise TypeError(f"Unexpected type for next_kwargs:
{type(next_kwargs_value)}")
# For now, we can't decrypt without the Fernet key, so log a warning
log.warning(
"[DAG TEST] Received encrypted next_kwargs string, cannot decrypt in
dag.test() context"
)
# Type-ignore needed because we can't decrypt in dag.test()
ti.next_kwargs = None # type: ignore[assignment]
```
</p>
</details>
<details><summary>...with a couple of accompanying unit tests that
illustrate the different behavior of the existing and proposed
implementations:</summary>
<p>
```python
class TestDeferredTaskNextKwargs:
"""Test that next_kwargs is properly deserialized when it's an encrypted
string."""
@pytest.fixture
def mock_task_instance(self, mocker):
"""Create a mock scheduler TaskInstance."""
from airflow.sdk import TaskInstanceState
ti = mocker.MagicMock()
ti.task_id = "test_task"
ti.dag_id = "test_dag"
ti.run_id = "test_run"
ti.map_index = -1
ti.id = "123e4567-e89b-12d3-a456-426614174000"
ti.dag_version_id = "223e4567-e89b-12d3-a456-426614174000"
ti.try_number = 1
ti.state = TaskInstanceState.DEFERRED
return ti
@pytest.fixture
def mock_task(self, mocker):
"""Create a mock task."""
task = mocker.MagicMock()
task.task_id = "test_task"
return task
def test_next_kwargs_deserialized_when_encrypted_string(
self, mock_task_instance, mock_task, monkeypatch, mocker,
):
"""
Test that when msg.next_kwargs is a serialized string, it gets
properly deserialized to a dict.
This simulates the case where DeferTask contains an
encrypted/serialized next_kwargs.
The bug we're testing: without the fix, string next_kwargs would be
assigned directly to ti.next_kwargs,
causing a type error. With the fix, it should be deserialized to a
dict.
"""
from airflow.sdk.definitions.dag import _run_task
from airflow.sdk.execution_time.comms import DeferTask
from airflow.serialization.serialized_objects import
BaseSerialization
# Setup: Create a dict that will be serialized (simulating real
behavior)
original_kwargs = {"custom_param": "value", "count": 42}
# Serialize it to get the format with __type and __var keys
serialized_kwargs = BaseSerialization.serialize(original_kwargs)
# serialized_kwargs should now be a dict like:
# {'__type': 'dict', '__var': {'custom_param': 'value', 'count': 42}}
mock_defer_task = DeferTask(
classpath="airflow.triggers.base.BaseTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs=serialized_kwargs,
)
# Create a mock TaskRunResult with the DeferTask message
mock_task_run_result = mocker.MagicMock()
mock_task_run_result.msg = mock_defer_task
mock_task_run_result.ti.state = "deferred"
mock_task_run_result.ti.task = mock_task
# Mock the run_task_in_process to return our prepared result
mock_run_task = mocker.MagicMock(return_value=mock_task_run_result)
monkeypatch.setattr(
"airflow.sdk.execution_time.supervisor.run_task_in_process",
mock_run_task,
raising=False,
)
# Mock _run_inline_trigger to return no event (so msg.next_kwargs is
used)
mock_inline_trigger = mocker.MagicMock(return_value=None)
monkeypatch.setattr(
"airflow.sdk.definitions.dag._run_inline_trigger",
mock_inline_trigger,
raising=False,
)
# Mock create_scheduler_operator
monkeypatch.setattr(
"airflow.serialization.serialized_objects.create_scheduler_operator",
mocker.MagicMock(return_value=mock_task),
raising=False,
)
# Mock import_string to return a mock trigger
mock_trigger = mocker.MagicMock()
monkeypatch.setattr(
"airflow.sdk.module_loading.import_string",
mocker.MagicMock(return_value=lambda **kwargs: mock_trigger),
raising=False,
)
# Mock create_session
mock_session = mocker.MagicMock()
mock_create_session = mocker.MagicMock()
mock_create_session.__enter__ =
mocker.MagicMock(return_value=mock_session)
mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
monkeypatch.setattr(
"airflow.utils.session.create_session",
mocker.MagicMock(return_value=mock_create_session),
raising=False,
)
# Track assignments to next_kwargs using a property descriptor
assigned_values = []
original_next_kwargs = mock_task_instance.next_kwargs
def track_next_kwargs_setter(value):
assigned_values.append(value)
# Update the mock's return value for the property
type(mock_task_instance).next_kwargs =
mocker.PropertyMock(return_value=value)
# Replace next_kwargs with a property that tracks assignments
type(mock_task_instance).next_kwargs = mocker.PropertyMock(
return_value=original_next_kwargs,
side_effect=lambda: assigned_values[-1] if assigned_values else
original_next_kwargs
)
mocker.patch.object(
type(mock_task_instance),
'next_kwargs',
new_callable=mocker.PropertyMock,
return_value=original_next_kwargs
)
# We need to use __setattr__ to capture the assignment
original_setattr = type(mock_task_instance).__setattr__
def capturing_setattr(obj, name, value):
if name == 'next_kwargs':
assigned_values.append(value)
# Store it in the mock's internal state
object.__setattr__(obj, '_next_kwargs_value', value)
else:
original_setattr(obj, name, value)
type(mock_task_instance).__setattr__ = capturing_setattr
# Execute _run_task with run_triggerer=True to trigger the deferred
path
_run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)
# Restore original behavior
type(mock_task_instance).__setattr__ = original_setattr
# Verify: The key behavior is that next_kwargs is now a dict, not a
string
# This is what the fix ensures - regardless of HOW it's done
assert len(assigned_values) > 0, "next_kwargs should have been
assigned at least once"
final_value = assigned_values[-1] # Get the last assigned value
assert isinstance(final_value, dict), (
f"next_kwargs should be a dict after deserialization, not
{type(final_value).__name__}"
)
# Verify the dict contains the expected keys (proves it was properly
deserialized)
assert "custom_param" in final_value
assert "count" in final_value
def test_next_kwargs_with_trigger_event(self, mock_task_instance,
mock_task, monkeypatch, mocker):
"""
Test that when a trigger returns an event, event.payload is used for
next_kwargs.
This verifies that the event path takes precedence over
msg.next_kwargs.
"""
from airflow.sdk.definitions.dag import _run_task
from airflow.sdk.execution_time.comms import DeferTask
# Create a mock event with payload
mock_event = mocker.MagicMock()
mock_event.payload = {"event_data": "from_trigger", "timestamp":
"2024-01-01"}
# Create mock DeferTask with a string next_kwargs that should be
ignored
mock_defer_task = DeferTask(
classpath="airflow.triggers.base.BaseTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs='{"should": "be_ignored"}', # Should be ignored
since we have an event
)
# Create mock TaskRunResult
mock_task_run_result = mocker.MagicMock()
mock_task_run_result.msg = mock_defer_task
mock_task_run_result.ti.state = "deferred"
mock_task_run_result.ti.task = mock_task
# Mock dependencies
monkeypatch.setattr(
"airflow.sdk.execution_time.supervisor.run_task_in_process",
mocker.MagicMock(return_value=mock_task_run_result),
raising=False,
)
monkeypatch.setattr(
"airflow.sdk.definitions.dag._run_inline_trigger",
mocker.MagicMock(return_value=mock_event),
raising=False,
)
monkeypatch.setattr(
"airflow.serialization.serialized_objects.create_scheduler_operator",
mocker.MagicMock(return_value=mock_task),
raising=False,
)
monkeypatch.setattr(
"airflow.sdk.module_loading.import_string",
mocker.MagicMock(return_value=lambda **kwargs:
mocker.MagicMock()),
raising=False,
)
mock_session = mocker.MagicMock()
mock_create_session = mocker.MagicMock()
mock_create_session.__enter__ =
mocker.MagicMock(return_value=mock_session)
mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
monkeypatch.setattr(
"airflow.utils.session.create_session",
mocker.MagicMock(return_value=mock_create_session),
raising=False,
)
# Execute
_run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)
# Verify: next_kwargs should be set to {"event": event.payload}, not
msg.next_kwargs
assert mock_task_instance.next_kwargs == {"event":
mock_event.payload}
assert isinstance(mock_task_instance.next_kwargs, dict)
# Ensure msg.next_kwargs was not used
assert mock_task_instance.next_kwargs.get("should") != "be_ignored"
```
</p>
</details>
-------
I'll revert my change to this bit since I don't see how to solve this in an
acceptable way.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]