This is an automated email from the ASF dual-hosted git repository.

dabla pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-1-test by this push:
     new b7d1c41e61a Prevent Triggerer from crashing when a trigger event isn't 
serializable (#60152)
b7d1c41e61a is described below

commit b7d1c41e61a41fa6919223f31a73bb817c64317f
Author: David Blain <[email protected]>
AuthorDate: Thu Jan 22 19:18:31 2026 +0100

    Prevent Triggerer from crashing when a trigger event isn't serializable 
(#60152)
    
    * refactor: If asend in TriggerRunner comms decoder crashes due to 
NotImplementedError as a trigger event is not serializable, then retry without 
that event and cancel associated trigger
    
    * refactor: Applied some reformatting
    
    * refactor: Fixed some mypy issues
    
    * refactor: Fixed return type of send_changes method
    
    * refactor: Changed imports of trigger events
    
    * refactor: Reformatted trigger job runner
    
    * refactor: Fixed mocking of comms decoder
    
    * refactor: Forgot to add the patched supervisor_builder
    
    * refactor: Changed asserts in test_sync_state_to_supervisor
    
    * refactor: Refactored how state changes are validated
    
    * refactor: Validate events while creating the TriggerStateChanges message
    
    * refactor: Refactored try/except in validate_state_changes to keep mypy 
happy
    
    * Update airflow-core/src/airflow/jobs/triggerer_job_runner.py
    
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
    
    * refactor: Renamed validate_state_changes method to process_trigger_events
    
    * refactor: Only sanitize invalid trigger events if first attempt fails
    
    * refactor: Should check if msg.events is not None
    
    * refactor: Adapted test_sync_state_to_supervisor so both initial and retry 
call are correctly asserted and bad events are being sanitized
    
    ---------
    
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
    
    (cherry picked from commit 0238244413e1b802d2cc6df799c0f2cc16b5f7a7)
---
 .../src/airflow/jobs/triggerer_job_runner.py       | 91 +++++++++++++++-------
 airflow-core/tests/unit/jobs/test_triggerer_job.py | 28 ++++++-
 2 files changed, 88 insertions(+), 31 deletions(-)

diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py 
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index c229dddc7fd..cc78e21245a 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -70,12 +70,13 @@ from airflow.sdk.execution_time.comms import (
     UpdateHITLDetail,
     VariableResult,
     XComResult,
+    _new_encoder,
     _RequestFrame,
 )
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess, 
make_buffered_socket_reader
+from airflow.triggers.base import BaseEventTrigger, BaseTrigger, 
DiscrimatedTriggerEvent, TriggerEvent
 from airflow.stats import Stats
 from airflow.traces.tracer import DebugTrace, Trace, add_debug_span
-from airflow.triggers import base as events
 from airflow.utils.helpers import log_filename_template_renderer
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.module_loading import import_string
@@ -203,7 +204,7 @@ class messages:
 
         type: Literal["TriggerStateChanges"] = "TriggerStateChanges"
         events: Annotated[
-            list[tuple[int, events.DiscrimatedTriggerEvent]] | None,
+            list[tuple[int, DiscrimatedTriggerEvent]] | None,
             # We have to specify a default here, as otherwise Pydantic 
struggles to deal with the discriminated
             # union :shrug:
             Field(default=None),
@@ -355,7 +356,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
     creating_triggers: deque[workloads.RunTrigger] = 
attrs.field(factory=deque, init=False)
 
     # Outbound queue of events
-    events: deque[tuple[int, events.TriggerEvent]] = 
attrs.field(factory=deque, init=False)
+    events: deque[tuple[int, TriggerEvent]] = attrs.field(factory=deque, 
init=False)
 
     # Outbound queue of failed triggers
     failed_triggers: deque[tuple[int, list[str] | None]] = 
attrs.field(factory=deque, init=False)
@@ -758,8 +759,6 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, 
ToTriggerSupervisor]):
         factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
     )
 
-    _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
-
     def _read_frame(self):
         from asgiref.sync import async_to_sync
 
@@ -794,7 +793,7 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, 
ToTriggerSupervisor]):
         frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
         bytes = frame.as_bytes()
 
-        async with self._lock:
+        async with self._async_lock:
             self._async_writer.write(bytes)
 
             return await self._aget_response(frame.id)
@@ -821,7 +820,7 @@ class TriggerRunner:
     to_cancel: deque[int]
 
     # Outbound queue of events
-    events: deque[tuple[int, events.TriggerEvent]]
+    events: deque[tuple[int, TriggerEvent]]
 
     # Outbound queue of failed triggers
     failed_triggers: deque[tuple[int, BaseException | None]]
@@ -971,7 +970,7 @@ class TriggerRunner:
                 "task": asyncio.create_task(
                     self.run_trigger(trigger_id, trigger_instance, 
workload.timeout_after), name=trigger_name
                 ),
-                "is_watcher": isinstance(trigger_instance, 
events.BaseEventTrigger),
+                "is_watcher": isinstance(trigger_instance, BaseEventTrigger),
                 "name": trigger_name,
                 "events": 0,
             }
@@ -1017,7 +1016,7 @@ class TriggerRunner:
                     saved_exc = e
                 else:
                     # See if they foolishly returned a TriggerEvent
-                    if isinstance(result, events.TriggerEvent):
+                    if isinstance(result, TriggerEvent):
                         self.log.error(
                             "Trigger returned a TriggerEvent rather than 
yielding it",
                             trigger=details["name"],
@@ -1037,46 +1036,78 @@ class TriggerRunner:
             await asyncio.sleep(0)
         return finished_ids
 
-    async def sync_state_to_supervisor(self, finished_ids: list[int]):
+    def process_trigger_events(self, finished_ids: list[int]) -> 
messages.TriggerStateChanges:
         # Copy out of our deques in threadsafe manner to sync state with parent
-        events_to_send = []
+        events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = []
+        failures_to_send: list[tuple[int, list[str] | None]] = []
+
         while self.events:
-            data = self.events.popleft()
-            events_to_send.append(data)
+            trigger_id, trigger_event = self.events.popleft()
+            events_to_send.append((trigger_id, trigger_event))
 
-        failures_to_send = []
         while self.failed_triggers:
-            id, exc = self.failed_triggers.popleft()
+            trigger_id, exc = self.failed_triggers.popleft()
             tb = format_exception(type(exc), exc, exc.__traceback__) if exc 
else None
-            failures_to_send.append((id, tb))
+            failures_to_send.append((trigger_id, tb))
 
-        msg = messages.TriggerStateChanges(
-            events=events_to_send, finished=finished_ids, 
failures=failures_to_send
+        return messages.TriggerStateChanges(
+            events=events_to_send if events_to_send else None,
+            finished=finished_ids if finished_ids else None,
+            failures=failures_to_send if failures_to_send else None,
         )
 
-        if not events_to_send:
-            msg.events = None
+    def sanitize_trigger_events(self, msg: messages.TriggerStateChanges) -> 
messages.TriggerStateChanges:
+        req_encoder = _new_encoder()
+        events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = []
 
-        if not failures_to_send:
-            msg.failures = None
+        if msg.events:
+            for trigger_id, trigger_event in msg.events:
+                try:
+                    req_encoder.encode(trigger_event)
+                except Exception as e:
+                    logger.error(
+                        "Trigger %s returned non-serializable result %r. 
Cancelling trigger.",
+                        trigger_id,
+                        trigger_event,
+                    )
+                    self.failed_triggers.append((trigger_id, e))
+                else:
+                    events_to_send.append((trigger_id, trigger_event))
+
+        return messages.TriggerStateChanges(
+            events=events_to_send if events_to_send else None,
+            finished=msg.finished,
+            failures=msg.failures,
+        )
 
-        if not finished_ids:
-            msg.finished = None
+    async def sync_state_to_supervisor(self, finished_ids: list[int]) -> None:
+        msg = self.process_trigger_events(finished_ids=finished_ids)
 
         # Tell the monitor that we've finished triggers so it can update things
         try:
-            resp = await self.comms_decoder.asend(msg)
+            resp = await self.asend(msg)
+        except NotImplementedError:
+            # A non-serializable trigger event was detected, remove it and 
fail associated trigger
+            resp = await self.asend(self.sanitize_trigger_events(msg))
+
+        if resp:
+            self.to_create.extend(resp.to_create)
+            self.to_cancel.extend(resp.to_cancel)
+
+    async def asend(self, msg: messages.TriggerStateChanges) -> 
messages.TriggerStateSync | None:
+        try:
+            response = await self.comms_decoder.asend(msg)
+
+            if not isinstance(response, messages.TriggerStateSync):
+                raise RuntimeError(f"Expected to get a TriggerStateSync 
message, instead we got {type(msg)}")
+
+            return response
         except asyncio.IncompleteReadError:
             if task := asyncio.current_task():
                 task.cancel("EOF - shutting down")
                 return
             raise
 
-        if not isinstance(resp, messages.TriggerStateSync):
-            raise RuntimeError(f"Expected to get a TriggerStateSync message, 
instead we got {type(msg)}")
-        self.to_create.extend(resp.to_create)
-        self.to_cancel.extend(resp.to_cancel)
-
     async def block_watchdog(self):
         """
         Watchdog loop that detects blocking (badly-written) triggers.
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py 
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 3181afa540c..116b3c115df 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -371,7 +371,7 @@ class TestTriggerRunner:
         trigger_runner = TriggerRunner()
         trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder)
         trigger_runner.comms_decoder.asend.return_value = 
messages.TriggerStateSync(
-            to_create=[], to_cancel=[]
+            to_create=[], to_cancel=set()
         )
 
         trigger_runner.to_create.append(workload)
@@ -438,6 +438,32 @@ class TestTriggerRunner:
         trigger_instance.cancel()
         await runner.cleanup_finished_triggers()
 
+    @pytest.mark.asyncio
+    @patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True)
+    async def test_sync_state_to_supervisor(self, supervisor_builder):
+        trigger_runner = TriggerRunner()
+        trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder)
+        trigger_runner.events.append((1, TriggerEvent(payload={"status": 
"SUCCESS"})))
+        trigger_runner.events.append((2, TriggerEvent(payload={"status": 
"FAILED"})))
+        trigger_runner.events.append((3, TriggerEvent(payload={"status": 
"SUCCESS", "data": object()})))
+
+        async def asend_side_effect(msg):
+            if msg.events and len(msg.events) == 3:
+                raise NotImplementedError("Simulate non-serializable event")
+            return messages.TriggerStateSync(to_create=[], to_cancel=set())
+
+        trigger_runner.comms_decoder.asend.side_effect = asend_side_effect
+
+        await trigger_runner.sync_state_to_supervisor(finished_ids=[])
+
+        assert trigger_runner.comms_decoder.asend.call_count == 2
+
+        first_call = 
trigger_runner.comms_decoder.asend.call_args_list[0].args[0]
+        second_call = 
trigger_runner.comms_decoder.asend.call_args_list[1].args[0]
+
+        assert len(first_call.events) == 3
+        assert len(second_call.events) == 2
+
 
 @pytest.mark.asyncio
 async def test_trigger_create_race_condition_38599(session, 
supervisor_builder, testing_dag_bundle):

Reply via email to