This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 69b6c54ef17 Do not deserialize trigger_kwargs when loading serialized
DAGs (#66002)
69b6c54ef17 is described below
commit 69b6c54ef17b048ebb0e9a656d0d9e1c99480d6b
Author: Amogh Desai <[email protected]>
AuthorDate: Thu May 7 16:12:52 2026 +0530
Do not deserialize trigger_kwargs when loading serialized DAGs (#66002)
---
airflow-core/src/airflow/models/taskinstance.py | 3 +-
airflow-core/src/airflow/models/trigger.py | 3 +-
airflow-core/src/airflow/serialization/enums.py | 26 +++++++
.../airflow/serialization/serialized_objects.py | 10 +--
airflow-core/src/airflow/triggers/base.py | 17 +++--
airflow-core/tests/unit/jobs/test_triggerer_job.py | 10 +--
airflow-core/tests/unit/models/test_dagrun.py | 82 ++++++++++++++++++++++
.../unit/serialization/test_dag_serialization.py | 52 ++++++++++++++
8 files changed, 180 insertions(+), 23 deletions(-)
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index e5b19f2768d..ba0b75747c4 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -88,6 +88,7 @@ from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComSelectSequence,
XComModel
+from airflow.serialization.enums import stringify_encoding_keys
from airflow.settings import task_instance_mutation_hook
from airflow.task.priority_strategy import
validate_and_load_priority_weight_strategy
from airflow.ti_deps.dep_context import DepContext
@@ -1691,7 +1692,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
self.state = TaskInstanceState.DEFERRED
self.trigger_id = trigger_row.id
self.next_method = start_trigger_args.next_method
- self.next_kwargs = start_trigger_args.next_kwargs or {}
+ self.next_kwargs =
stringify_encoding_keys(start_trigger_args.next_kwargs or {})
self.start_date = timezone.utcnow()
# If an execution_timeout is set, set the timeout to the minimum of
diff --git a/airflow-core/src/airflow/models/trigger.py
b/airflow-core/src/airflow/models/trigger.py
index 2949262532e..a6ee583032a 100644
--- a/airflow-core/src/airflow/models/trigger.py
+++ b/airflow-core/src/airflow/models/trigger.py
@@ -35,6 +35,7 @@ from airflow.configuration import conf
from airflow.models.asset import AssetWatcherModel
from airflow.models.base import Base
from airflow.models.taskinstance import TaskInstance
+from airflow.serialization.enums import stringify_encoding_keys as
_stringify_encoding_keys
from airflow.triggers.base import BaseTaskEndEvent
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, provide_session
@@ -146,7 +147,7 @@ class Trigger(Base):
from airflow.models.crypto import get_fernet
from airflow.sdk.serde import serialize
- serialized_kwargs = serialize(kwargs)
+ serialized_kwargs = serialize(_stringify_encoding_keys(kwargs))
return
get_fernet().encrypt(json.dumps(serialized_kwargs).encode("utf-8")).decode("utf-8")
@staticmethod
diff --git a/airflow-core/src/airflow/serialization/enums.py
b/airflow-core/src/airflow/serialization/enums.py
index 5fdd4d66987..ae4c1249cab 100644
--- a/airflow-core/src/airflow/serialization/enums.py
+++ b/airflow-core/src/airflow/serialization/enums.py
@@ -20,6 +20,7 @@
from __future__ import annotations
from enum import Enum, unique
+from typing import Any
# Fields of an encoded object in serialization.
@@ -31,6 +32,31 @@ class Encoding(str, Enum):
VAR = "__var"
+def stringify_encoding_keys(d: Any) -> Any:
+ """
+ Convert BaseSerialization Encoding enum keys to their string values
recursively.
+
+ Python 3.10 compatibility: str(Encoding.TYPE) returns "Encoding.TYPE" on
3.10
+ instead of "__type__" (3.10 is still the default CI target).
serde.serialize
+ uses str(k) for dict keys, so without this conversion the encrypted blob
ends up
+ with "Encoding.TYPE" keys that neither serde._convert nor the
BaseSerialization
+ fallback can read back.
+ """
+ if isinstance(d, dict):
+ return {
+ (k.value if isinstance(k, Encoding) else str(k)):
stringify_encoding_keys(v) for k, v in d.items()
+ }
+ if isinstance(d, list):
+ return [stringify_encoding_keys(i) for i in d]
+ if isinstance(d, tuple):
+ converted = [stringify_encoding_keys(i) for i in d]
+ # namedtuples require positional args, not a single list argument
+ if hasattr(d, "_fields"):
+ return type(d)(*converted)
+ return tuple(converted)
+ return d
+
+
# Supported types for encoding. primitives and list are not encoded.
@unique
class DagAttributeTypes(str, Enum):
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index a0295dc2370..7764323fca7 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -251,17 +251,11 @@ def _encode_start_trigger_args(var: StartTriggerArgs) ->
dict[str, Any]:
def _decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
"""Decode a StartTriggerArgs."""
-
- def deserialize_kwargs(key: str) -> Any:
- if (val := var[key]) is None:
- return None
- return BaseSerialization.deserialize(val)
-
return StartTriggerArgs(
trigger_cls=var["trigger_cls"],
- trigger_kwargs=deserialize_kwargs("trigger_kwargs"),
+ trigger_kwargs=var["trigger_kwargs"],
next_method=var["next_method"],
- next_kwargs=deserialize_kwargs("next_kwargs"),
+ next_kwargs=var["next_kwargs"],
timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"]
else None,
)
diff --git a/airflow-core/src/airflow/triggers/base.py
b/airflow-core/src/airflow/triggers/base.py
index b37448edfea..f39b62facf7 100644
--- a/airflow-core/src/airflow/triggers/base.py
+++ b/airflow-core/src/airflow/triggers/base.py
@@ -120,9 +120,16 @@ class BaseTrigger(abc.ABC, Templater, LoggingMixin):
# does not build a template context, so render_template_fields is
# never called and empty template_fields is safe.
start_trigger_args = getattr(self.task, "start_trigger_args", None)
- trigger_kwarg_keys = (
- set((start_trigger_args.trigger_kwargs or {}).keys()) if
start_trigger_args else set()
- )
+ if start_trigger_args:
+ from airflow.serialization.enums import Encoding
+
+ raw = start_trigger_args.trigger_kwargs or {}
+ # trigger_kwargs may be BaseSerialization-encoded; extract
inner dict keys
+ if isinstance(raw, dict) and Encoding.TYPE in raw:
+ raw = raw.get(Encoding.VAR) or {}
+ trigger_kwarg_keys = set(raw.keys())
+ else:
+ trigger_kwarg_keys = set()
if trigger_kwarg_keys:
self.template_fields = tuple(
f for f in self.task.template_fields if f in
trigger_kwarg_keys and hasattr(self, f)
@@ -256,9 +263,11 @@ class BaseEventTrigger(BaseTrigger):
We do not want to have this logic in ``BaseTrigger`` because, when
used to defer tasks, 2 triggers
can have the same classpath and kwargs. This is not true for event
driven scheduling.
"""
+ from airflow.serialization.encoders import encode_trigger
from airflow.serialization.serialized_objects import BaseSerialization
- return hash((classpath,
json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
+ normalized = encode_trigger({"classpath": classpath, "kwargs":
kwargs})["kwargs"]
+ return hash((classpath,
json.dumps(BaseSerialization.serialize(normalized)).encode("utf-8")))
class TriggerEvent(BaseModel):
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 968858be336..27247a42952 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -851,15 +851,7 @@ class TestTriggerRunner:
session.commit()
stored_kwargs = trigger_orm.kwargs
- assert stored_kwargs == {
- "Encoding.TYPE": "dict",
- "Encoding.VAR": {
- "dict": {"Encoding.TYPE": "dict", "Encoding.VAR": {}},
- "list": [],
- "simple": "test",
- "tuple": {"Encoding.TYPE": "tuple", "Encoding.VAR": []},
- },
- }
+ assert stored_kwargs == kw
runner = TriggerRunner()
runner.to_create.append(
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 34e8acbaacf..478e4f5ea4d 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -54,6 +54,7 @@ from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance, TaskInstanceNote,
clear_task_instances
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
+from airflow.models.trigger import Trigger
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator,
ShortCircuitOperator
@@ -2298,6 +2299,87 @@ def test_schedule_tis_start_trigger(dag_maker, session):
assert ti.state == TaskInstanceState.DEFERRED
[email protected]_serialized_dag
+def test_schedule_tis_start_trigger_next_kwargs_round_trip(dag_maker, session):
+ """next_kwargs with encoded values (timedelta) must survive the defer_task
round-trip."""
+ import datetime
+
+ from airflow.sdk.serde import deserialize
+
+ class TestOperator(BaseOperator):
+ start_trigger_args = StartTriggerArgs(
+ trigger_cls="airflow.triggers.testing.SuccessTrigger",
+ trigger_kwargs={},
+ next_method="execute_complete",
+ next_kwargs={"delay": datetime.timedelta(seconds=30)},
+ timeout=None,
+ )
+ start_from_trigger = True
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def execute_complete(self):
+ pass
+
+ with dag_maker(session=session):
+ TestOperator(task_id="test_task")
+
+ dr: DagRun = dag_maker.create_dagrun()
+ ti = dr.get_task_instance("test_task")
+ ti.task = dr.dag.get_task("test_task")
+ dr.schedule_tis((ti,), session=session)
+
+ assert ti.state == TaskInstanceState.DEFERRED
+ assert deserialize(ti.next_kwargs) == {"delay":
datetime.timedelta(seconds=30)}
+
+
[email protected]_serialized_dag
+def test_schedule_tis_start_trigger_kwargs_e2e(dag_maker, session):
+ """
+ End to end test of scheduler defer_task with non-trivial trigger_kwargs
(timedelta) ->
+ Trigger row -> Trigger.kwargs returns correct Python objects.
+
+ Covers the path: BaseSerialization encodes trigger_kwargs with Encoding
enum keys,
+ defer_task passes them to Trigger as kwargs which calls encrypt_kwargs ->
_stringify_encoding_keys
+ -> serde.serialize -> stores them.
+
+ On reading, serde.deserialize + _convert must reconstruct the original
values.
+ """
+ import datetime
+
+ class TestOperator(BaseOperator):
+ start_trigger_args = StartTriggerArgs(
+ trigger_cls="airflow.triggers.testing.SuccessTrigger",
+ trigger_kwargs={"delta": datetime.timedelta(seconds=2)},
+ next_method="execute_complete",
+ next_kwargs=None,
+ timeout=None,
+ )
+ start_from_trigger = True
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def execute_complete(self):
+ pass
+
+ with dag_maker(session=session):
+ TestOperator(task_id="test_task")
+
+ dr: DagRun = dag_maker.create_dagrun()
+ ti = dr.get_task_instance("test_task")
+ ti.task = dr.dag.get_task("test_task")
+ dr.schedule_tis((ti,), session=session)
+
+ assert ti.state == TaskInstanceState.DEFERRED
+
+ trigger_row = session.get(Trigger, ti.trigger_id)
+ assert trigger_row is not None
+ # trigger_kwargs must round-trip correctly through encrypt_kwargs →
_decrypt_kwargs
+ assert trigger_row.kwargs == {"delta": datetime.timedelta(seconds=2)}
+
+
def test_schedule_tis_empty_operator_try_number(dag_maker, session: Session):
"""
When empty operator is not actually run, then we need to increment the
try_number,
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 33db3d8d906..fd333d5d799 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -817,6 +817,8 @@ class TestStringifiedDAGs:
"on_failure_fail_dagrun",
"_needs_expansion",
"_is_sensor",
+ # trigger_kwargs is kept as raw JSON after deserialization;
checked separately
+ "start_trigger_args",
}
else: # Promised to be mapped by the assert above.
assert isinstance(serialized_task, SerializedMappedOperator)
@@ -857,6 +859,20 @@ class TestStringifiedDAGs:
else:
assert serialized_task.resources == task.resources
+ # start_trigger_args: trigger_kwargs is kept as raw
BaseSerialization-encoded form
+ # after deserialization. Compare the encoded forms directly —
s.trigger_kwargs is
+ # exactly BaseSerialization.serialize(o.trigger_kwargs) since
_encode_start_trigger_args
+ # serializes it and _decode_start_trigger_args keeps it raw.
+ if task.start_trigger_args is not None:
+ from airflow.serialization.serialized_objects import
BaseSerialization
+
+ s = serialized_task.start_trigger_args
+ o = task.start_trigger_args
+ assert s.trigger_cls == o.trigger_cls
+ assert s.next_method == o.next_method
+ assert s.timeout == o.timeout
+ assert s.trigger_kwargs ==
BaseSerialization.serialize(o.trigger_kwargs or {})
+
assert [ensure_serialized_asset(i) for i in task.inlets] ==
serialized_task.inlets
assert [ensure_serialized_asset(o) for o in task.outlets] ==
serialized_task.outlets
@@ -2630,6 +2646,42 @@ class TestStringifiedDAGs:
}
assert tasks[1]["__var"]["start_from_trigger"] is True
+ def test_trigger_kwargs_not_deserialised_through_serdag(self):
+ """trigger_kwargs and next_kwargs are kept as raw BaseSerialization
JSON when loading a serialized DAG."""
+
+ class TestOperator(BaseOperator):
+ start_trigger_args = StartTriggerArgs(
+
trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
+ trigger_kwargs={"delta": timedelta(seconds=2)},
+ next_method="execute_complete",
+ next_kwargs={"resume_after": timedelta(seconds=5)},
+ timeout=None,
+ )
+ start_from_trigger = True
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def execute_complete(self):
+ pass
+
+ dag = DAG(dag_id="test_dag_kwargs_raw", schedule=None,
start_date=datetime(2023, 11, 9))
+ with dag:
+ TestOperator(task_id="test_task")
+
+ serialized = DagSerialization.to_dict(dag)
+ deserialized_dag = DagSerialization.from_dict(serialized)
+
+ task = deserialized_dag.get_task("test_task")
+ assert task.start_trigger_args.trigger_kwargs == {
+ "__type": "dict",
+ "__var": {"delta": {"__type": "timedelta", "__var": 2.0}},
+ }
+ assert task.start_trigger_args.next_kwargs == {
+ "__type": "dict",
+ "__var": {"resume_after": {"__type": "timedelta", "__var": 5.0}},
+ }
+
def test_kubernetes_optional():
"""Test that serialization module loads without kubernetes, but
deserialization of PODs requires it"""