This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d292e0e9ec6 Fix double-serialization issue by unwrapping serialized 
kwargs in `encode_trigger` (#64626)
d292e0e9ec6 is described below

commit d292e0e9ec67b5d0e0e03acf99c17809fa58c24a
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Fri Apr 3 03:23:48 2026 +0800

    Fix double-serialization issue by unwrapping serialized kwargs in 
`encode_trigger` (#64626)
---
 airflow-core/src/airflow/models/trigger.py         |   5 +-
 airflow-core/src/airflow/serialization/encoders.py |   8 +
 .../tests/unit/dag_processing/test_collection.py   |  88 ++++++++++
 .../tests/unit/serialization/test_encoders.py      | 181 +++++++++++++++++++++
 4 files changed, 281 insertions(+), 1 deletion(-)

diff --git a/airflow-core/src/airflow/models/trigger.py 
b/airflow-core/src/airflow/models/trigger.py
index d2c0fde3c89..d17af8532e0 100644
--- a/airflow-core/src/airflow/models/trigger.py
+++ b/airflow-core/src/airflow/models/trigger.py
@@ -32,7 +32,6 @@ from sqlalchemy.sql.functions import coalesce
 from airflow._shared.timezones import timezone
 from airflow.assets.manager import AssetManager
 from airflow.configuration import conf
-from airflow.models import Callback
 from airflow.models.asset import AssetWatcherModel
 from airflow.models.base import Base
 from airflow.models.taskinstance import TaskInstance
@@ -210,6 +209,8 @@ class Trigger(Base):
     @provide_session
     def fetch_trigger_ids_with_non_task_associations(cls, session: Session = 
NEW_SESSION) -> set[str]:
         """Fetch all trigger IDs actively associated with non-task entities 
like assets and callbacks."""
+        from airflow.models.callback import Callback
+
         query = select(AssetWatcherModel.trigger_id).union_all(
             select(Callback.trigger_id).where(Callback.trigger_id.is_not(None))
         )
@@ -408,6 +409,8 @@ class Trigger(Base):
         :param queues: The optional set of trigger queues to filter triggers 
by.
         :param session: The database session.
         """
+        from airflow.models.callback import Callback
+
         result: list[Row[Any]] = []
 
         # Add triggers associated to callbacks first, then tasks, then assets
diff --git a/airflow-core/src/airflow/serialization/encoders.py 
b/airflow-core/src/airflow/serialization/encoders.py
index dcb064dcde0..2f30511a1e5 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -162,6 +162,14 @@ def encode_trigger(trigger: BaseEventTrigger | dict):
     if isinstance(trigger, dict):
         classpath = trigger["classpath"]
         kwargs = trigger["kwargs"]
+        # unwrap any kwargs that are themselves serialized objects, to avoid 
double-serialization in the trigger's own serialize() method.
+        unwrapped = {}
+        for k, v in kwargs.items():
+            if isinstance(v, dict) and Encoding.TYPE in v:
+                unwrapped[k] = BaseSerialization.deserialize(v)
+            else:
+                unwrapped[k] = v
+        kwargs = unwrapped
     else:
         classpath, kwargs = trigger.serialize()
     return {
diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py 
b/airflow-core/tests/unit/dag_processing/test_collection.py
index 77dc1318b05..6a0aef00eaa 100644
--- a/airflow-core/tests/unit/dag_processing/test_collection.py
+++ b/airflow-core/tests/unit/dag_processing/test_collection.py
@@ -182,6 +182,94 @@ class TestAssetModelOperation:
         asset_model = session.scalars(select(AssetModel)).one()
         assert len(asset_model.triggers) == expected_num_triggers
 
+    @pytest.mark.usefixtures("testing_dag_bundle")
+    def test_add_asset_trigger_references_hash_consistency(self, dag_maker, 
session):
+        """Trigger hash from the DAG-parsed path must equal the hash computed
+        from the DB-stored Trigger row.  A mismatch causes the scheduler to
+        recreate trigger rows on every heartbeat.
+        """
+        from airflow.models.trigger import Trigger
+        from airflow.serialization.encoders import encode_trigger
+        from airflow.triggers.base import BaseEventTrigger
+
+        trigger = FileDeleteTrigger(filepath="/tmp/test.txt", 
poke_interval=5.0)
+        asset = Asset(
+            "test_hash_consistency_asset",
+            watchers=[AssetWatcher(name="file_watcher", trigger=trigger)],
+        )
+
+        with dag_maker(dag_id="test_hash_consistency_dag", schedule=[asset]) 
as dag:
+            EmptyOperator(task_id="mytask")
+
+        dags = {dag.dag_id: LazyDeserializedDAG.from_dag(dag)}
+        orm_dags = DagModelOperation(dags, "testing", 
None).add_dags(session=session)
+        orm_dags[dag.dag_id].is_paused = False
+
+        asset_op = AssetModelOperation.collect(dags)
+        orm_assets = asset_op.sync_assets(session=session)
+        session.flush()
+
+        asset_op.add_dag_asset_references(orm_dags, orm_assets, 
session=session)
+        asset_op.activate_assets_if_possible(orm_assets.values(), 
session=session)
+        asset_op.add_asset_trigger_references(orm_assets, session=session)
+        session.flush()
+
+        # DAG-side hash (same computation as add_asset_trigger_references line 
1025)
+        encoded = encode_trigger(trigger)
+        dag_hash = BaseEventTrigger.hash(encoded["classpath"], 
encoded["kwargs"])
+
+        # DB-side: expire and re-load the Trigger row to force a real DB read
+        asset_model = session.scalars(select(AssetModel)).one()
+        assert len(asset_model.triggers) == 1
+        orm_trigger = asset_model.triggers[0]
+        trigger_id = orm_trigger.id
+        session.expire(orm_trigger)
+        reloaded = session.get(Trigger, trigger_id)
+
+        # DB-side hash (same computation as add_asset_trigger_references line 
1033)
+        db_hash = BaseEventTrigger.hash(reloaded.classpath, reloaded.kwargs)
+
+        assert dag_hash == db_hash
+
+    @pytest.mark.usefixtures("testing_dag_bundle")
+    def test_add_asset_trigger_references_idempotent(self, dag_maker, session):
+        """Calling add_asset_trigger_references twice with the same trigger
+        must not create duplicate rows.
+        """
+        from airflow.models.trigger import Trigger
+
+        trigger = FileDeleteTrigger(filepath="/tmp/test.txt", 
poke_interval=5.0)
+        asset = Asset(
+            "test_idempotent_asset",
+            watchers=[AssetWatcher(name="file_watcher", trigger=trigger)],
+        )
+
+        with dag_maker(dag_id="test_idempotent_dag", schedule=[asset]) as dag:
+            EmptyOperator(task_id="mytask")
+
+        dags = {dag.dag_id: LazyDeserializedDAG.from_dag(dag)}
+        orm_dags = DagModelOperation(dags, "testing", 
None).add_dags(session=session)
+        orm_dags[dag.dag_id].is_paused = False
+
+        asset_op = AssetModelOperation.collect(dags)
+        orm_assets = asset_op.sync_assets(session=session)
+        session.flush()
+
+        asset_op.add_dag_asset_references(orm_dags, orm_assets, 
session=session)
+        asset_op.activate_assets_if_possible(orm_assets.values(), 
session=session)
+
+        # First call — creates the trigger
+        asset_op.add_asset_trigger_references(orm_assets, session=session)
+        session.flush()
+        count_after_first = session.scalar(select(func.count(Trigger.id)))
+
+        # Second call — should be a no-op (hashes match, no diff)
+        asset_op.add_asset_trigger_references(orm_assets, session=session)
+        session.flush()
+        count_after_second = session.scalar(select(func.count(Trigger.id)))
+
+        assert count_after_first == count_after_second
+
     @pytest.mark.parametrize(
         ("schedule", "model", "columns", "expected"),
         [
diff --git a/airflow-core/tests/unit/serialization/test_encoders.py 
b/airflow-core/tests/unit/serialization/test_encoders.py
new file mode 100644
index 00000000000..479af95ae88
--- /dev/null
+++ b/airflow-core/tests/unit/serialization/test_encoders.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+from sqlalchemy import delete
+
+from airflow.models.trigger import Trigger
+from airflow.providers.standard.triggers.file import FileDeleteTrigger
+from airflow.serialization.encoders import encode_trigger
+from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
+from airflow.triggers.base import BaseEventTrigger
+
+pytest.importorskip("airflow.providers.apache.kafka")
+from airflow.providers.apache.kafka.triggers.await_message import 
AwaitMessageTrigger
+
+# Trigger fixtures covering primitive-only kwargs (FileDeleteTrigger) and
+# non-primitive kwargs like tuple/dict (AwaitMessageTrigger).
+_TRIGGER_PARAMS = [
+    pytest.param(
+        FileDeleteTrigger(filepath="/tmp/test.txt", poke_interval=5.0),
+        id="primitive_kwargs_only",
+    ),
+    pytest.param(AwaitMessageTrigger(topics=()), id="empty_tuple"),
+    pytest.param(
+        AwaitMessageTrigger(topics=("fizz_buzz",), poll_timeout=1.0, 
commit_offset=True),
+        id="single_topic_tuple",
+    ),
+    pytest.param(
+        AwaitMessageTrigger(
+            topics=["t1", "t2"],
+            apply_function="my.module.func",
+            apply_function_args=["a", "b"],
+            apply_function_kwargs={"key": "value"},
+            kafka_config_id="my_kafka",
+            poll_interval=2,
+            poll_timeout=3,
+        ),
+        id="all_non_primitive_kwargs",
+    ),
+]
+
+
+class TestEncodeTrigger:
+    """Tests for encode_trigger round-trip correctness.
+
+    When a serialized DAG with asset-watcher triggers is re-serialized
+    (e.g. in ``add_asset_trigger_references``), ``encode_trigger`` receives
+    a dict whose kwargs already contain wrapped values like
+    ``{__type: tuple, __var: [...]}``.  The fix ensures these are unwrapped
+    before re-serialization to prevent double-wrapping.
+    """
+
+    def test_encode_from_trigger_object(self):
+        """Non-primitive kwargs are properly serialized from a trigger 
object."""
+        trigger = AwaitMessageTrigger(topics=())
+        result = encode_trigger(trigger)
+
+        assert (
+            result["classpath"] == 
"airflow.providers.apache.kafka.triggers.await_message.AwaitMessageTrigger"
+        )
+        # tuple kwarg is wrapped by BaseSerialization
+        assert result["kwargs"]["topics"] == {Encoding.TYPE: DAT.TUPLE, 
Encoding.VAR: []}
+        # Primitives pass through as-is
+        assert result["kwargs"]["poll_timeout"] == 1
+        assert result["kwargs"]["commit_offset"] is True
+
+    def test_encode_file_delete_trigger(self):
+        """Primitive-only kwargs pass through without wrapping."""
+        trigger = FileDeleteTrigger(filepath="/tmp/test.txt", 
poke_interval=10.0)
+        result = encode_trigger(trigger)
+
+        assert result["classpath"] == 
"airflow.providers.standard.triggers.file.FileDeleteTrigger"
+        assert result["kwargs"]["filepath"] == "/tmp/test.txt"
+        assert result["kwargs"]["poke_interval"] == 10.0
+
+    @pytest.mark.parametrize("trigger", _TRIGGER_PARAMS)
+    def test_re_encode_is_idempotent(self, trigger):
+        """Encoding the output of encode_trigger again must not double-wrap 
kwargs."""
+        first = encode_trigger(trigger)
+        second = encode_trigger(first)
+
+        assert first == second
+
+    @pytest.mark.parametrize("trigger", _TRIGGER_PARAMS)
+    def test_multiple_round_trips_are_stable(self, trigger):
+        """Encoding the same trigger dict many times remains idempotent."""
+        result = encode_trigger(trigger)
+        for _ in range(5):
+            result = encode_trigger(result)
+
+        assert result == encode_trigger(trigger)
+
+
[email protected]_test
+class TestTriggerHashConsistency:
+    """Verify ``BaseEventTrigger.hash`` produces the same value for kwargs
+    from the DAG-parsed path and kwargs read back from the database.
+
+    This mirrors the comparison in
+    ``AssetModelOperation.add_asset_trigger_references``
+    (``airflow-core/src/airflow/dag_processing/collection.py``), where:
+
+    * **DAG side** — ``BaseEventTrigger.hash(classpath, 
encode_trigger(watcher.trigger)["kwargs"])``
+    * **DB side** — ``BaseEventTrigger.hash(trigger.classpath, 
trigger.kwargs)``
+      where the ``Trigger`` row was persisted with ``encrypt_kwargs`` and
+      read back via ``_decrypt_kwargs``.
+
+    If the hashes diverge, the scheduler sees phantom diffs and keeps
+    recreating trigger rows on every heartbeat.
+    """
+
+    @pytest.fixture(autouse=True)
+    def _clean_triggers(self, session):
+        session.execute(delete(Trigger))
+        session.commit()
+        yield
+        session.execute(delete(Trigger))
+        session.commit()
+
+    @pytest.mark.parametrize("trigger", _TRIGGER_PARAMS)
+    def test_hash_matches_after_db_round_trip(self, trigger, session):
+        """Hash from DAG-parsed kwargs equals hash from a DB-persisted 
Trigger."""
+        encoded = encode_trigger(trigger)
+        classpath = encoded["classpath"]
+        dag_kwargs = encoded["kwargs"]
+
+        # DAG side hash — what add_asset_trigger_references computes
+        dag_hash = BaseEventTrigger.hash(classpath, dag_kwargs)
+
+        # Persist to DB (same as add_asset_trigger_references lines 1073-1074)
+        trigger_row = Trigger(classpath=classpath, kwargs=dag_kwargs)
+        session.add(trigger_row)
+        session.flush()
+
+        # Force a real DB read — expire the instance and re-select
+        trigger_id = trigger_row.id
+        session.expire(trigger_row)
+        reloaded = session.get(Trigger, trigger_id)
+
+        # DB side hash — what add_asset_trigger_references computes from ORM
+        db_hash = BaseEventTrigger.hash(reloaded.classpath, reloaded.kwargs)
+
+        assert dag_hash == db_hash
+
+    @pytest.mark.parametrize("trigger", _TRIGGER_PARAMS)
+    def test_hash_matches_after_re_encode_and_db_round_trip(self, trigger, 
session):
+        """Hash stays consistent when encode_trigger output is re-encoded
+        (deserialized-DAG re-serialization path) before DB storage.
+        """
+        re_encoded = encode_trigger(encode_trigger(trigger))
+        classpath = re_encoded["classpath"]
+        dag_kwargs = re_encoded["kwargs"]
+
+        dag_hash = BaseEventTrigger.hash(classpath, dag_kwargs)
+
+        trigger_row = Trigger(classpath=classpath, kwargs=dag_kwargs)
+        session.add(trigger_row)
+        session.flush()
+
+        trigger_id = trigger_row.id
+        session.expire(trigger_row)
+        reloaded = session.get(Trigger, trigger_id)
+
+        db_hash = BaseEventTrigger.hash(reloaded.classpath, reloaded.kwargs)
+
+        assert dag_hash == db_hash

Reply via email to