This is an automated email from the ASF dual-hosted git repository. dabla pushed a commit to branch backport-60976-v3-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 3cbc541e5c48fe1b2dd3741ddb2dcd2c709fffcb Author: David Blain <[email protected]> AuthorDate: Thu Jan 22 19:18:31 2026 +0100 Prevent Triggerer from crashing when a trigger event isn't serializable (#60152) * refactor: If asend in TriggerRunner comms decoder crashes due to NotImplementedError as a trigger event is not serializable, then retry without that event and cancel associated trigger * refactor: Applied some reformatting * refactor: Fixed some mypy issues * refactor: Fixed return type of send_changes method * refactor: Changed imports of trigger events * refactor: Reformatted trigger job runner * refactor: Fixed mocking of comms decoder * refactor: Forgot to add the patched supervisor_builder * refactor: Changed asserts in test_sync_state_to_supervisor * refactor: Refactored how state changes are validated * refactor: Validate events while creating the TriggerStateChanges message * refactor: Refactored try/except in validate_state_changes to keep mypy happy * Update airflow-core/src/airflow/jobs/triggerer_job_runner.py Co-authored-by: Ash Berlin-Taylor <[email protected]> * refactor: Renamed validate_state_changes method to process_trigger_events * refactor: Only sanitize invalid trigger events if first attempt fails * refactor: Should check if msg.events is not None * refactor: Adapted test_sync_state_to_supervisor so both initial and retry call are correctly asserted and bad events are being sanitized --------- Co-authored-by: Ash Berlin-Taylor <[email protected]> (cherry picked from commit 0238244413e1b802d2cc6df799c0f2cc16b5f7a7) --- .../src/airflow/jobs/triggerer_job_runner.py | 91 +++++++++++++++------- airflow-core/tests/unit/jobs/test_triggerer_job.py | 28 ++++++- 2 files changed, 88 insertions(+), 31 deletions(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index c229dddc7fd..cc78e21245a 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -70,12 +70,13 @@ from airflow.sdk.execution_time.comms import ( UpdateHITLDetail, VariableResult, XComResult, + _new_encoder, _RequestFrame, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader +from airflow.triggers.base import BaseEventTrigger, BaseTrigger, DiscrimatedTriggerEvent, TriggerEvent from airflow.stats import Stats from airflow.traces.tracer import DebugTrace, Trace, add_debug_span -from airflow.triggers import base as events from airflow.utils.helpers import log_filename_template_renderer from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string @@ -203,7 +204,7 @@ class messages: type: Literal["TriggerStateChanges"] = "TriggerStateChanges" events: Annotated[ - list[tuple[int, events.DiscrimatedTriggerEvent]] | None, + list[tuple[int, DiscrimatedTriggerEvent]] | None, # We have to specify a default here, as otherwise Pydantic struggles to deal with the discriminated # union :shrug: Field(default=None), @@ -355,7 +356,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess): creating_triggers: deque[workloads.RunTrigger] = attrs.field(factory=deque, init=False) # Outbound queue of events - events: deque[tuple[int, events.TriggerEvent]] = attrs.field(factory=deque, init=False) + events: deque[tuple[int, TriggerEvent]] = attrs.field(factory=deque, init=False) # Outbound queue of failed triggers failed_triggers: deque[tuple[int, list[str] | None]] = attrs.field(factory=deque, init=False) @@ -758,8 +759,6 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]): factory=lambda: TypeAdapter(ToTriggerRunner), repr=False ) - _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False) - def _read_frame(self): from asgiref.sync import async_to_sync @@ -794,7 +793,7 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]): frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) bytes = frame.as_bytes() - async with self._lock: + async with self._async_lock: self._async_writer.write(bytes) return await self._aget_response(frame.id) @@ -821,7 +820,7 @@ class TriggerRunner: to_cancel: deque[int] # Outbound queue of events - events: deque[tuple[int, events.TriggerEvent]] + events: deque[tuple[int, TriggerEvent]] # Outbound queue of failed triggers failed_triggers: deque[tuple[int, BaseException | None]] @@ -971,7 +970,7 @@ class TriggerRunner: "task": asyncio.create_task( self.run_trigger(trigger_id, trigger_instance, workload.timeout_after), name=trigger_name ), - "is_watcher": isinstance(trigger_instance, events.BaseEventTrigger), + "is_watcher": isinstance(trigger_instance, BaseEventTrigger), "name": trigger_name, "events": 0, } @@ -1017,7 +1016,7 @@ class TriggerRunner: saved_exc = e else: # See if they foolishly returned a TriggerEvent - if isinstance(result, events.TriggerEvent): + if isinstance(result, TriggerEvent): self.log.error( "Trigger returned a TriggerEvent rather than yielding it", trigger=details["name"], @@ -1037,46 +1036,78 @@ class TriggerRunner: await asyncio.sleep(0) return finished_ids - async def sync_state_to_supervisor(self, finished_ids: list[int]): + def process_trigger_events(self, finished_ids: list[int]) -> messages.TriggerStateChanges: # Copy out of our deques in threadsafe manner to sync state with parent - events_to_send = [] + events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = [] + failures_to_send: list[tuple[int, list[str] | None]] = [] + while self.events: - data = self.events.popleft() - events_to_send.append(data) + trigger_id, trigger_event = self.events.popleft() + events_to_send.append((trigger_id, trigger_event)) - failures_to_send = [] while self.failed_triggers: - id, exc = self.failed_triggers.popleft() + trigger_id, exc = self.failed_triggers.popleft() tb = format_exception(type(exc), exc, exc.__traceback__) if exc else None - failures_to_send.append((id, tb)) + failures_to_send.append((trigger_id, tb)) - msg = messages.TriggerStateChanges( - events=events_to_send, finished=finished_ids, failures=failures_to_send + return messages.TriggerStateChanges( + events=events_to_send if events_to_send else None, + finished=finished_ids if finished_ids else None, + failures=failures_to_send if failures_to_send else None, ) - if not events_to_send: - msg.events = None + def sanitize_trigger_events(self, msg: messages.TriggerStateChanges) -> messages.TriggerStateChanges: + req_encoder = _new_encoder() + events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = [] - if not failures_to_send: - msg.failures = None + if msg.events: + for trigger_id, trigger_event in msg.events: + try: + req_encoder.encode(trigger_event) + except Exception as e: + logger.error( + "Trigger %s returned non-serializable result %r. Cancelling trigger.", + trigger_id, + trigger_event, + ) + self.failed_triggers.append((trigger_id, e)) + else: + events_to_send.append((trigger_id, trigger_event)) + + return messages.TriggerStateChanges( + events=events_to_send if events_to_send else None, + finished=msg.finished, + failures=msg.failures, + ) - if not finished_ids: - msg.finished = None + async def sync_state_to_supervisor(self, finished_ids: list[int]) -> None: + msg = self.process_trigger_events(finished_ids=finished_ids) # Tell the monitor that we've finished triggers so it can update things try: - resp = await self.comms_decoder.asend(msg) + resp = await self.asend(msg) + except NotImplementedError: + # A non-serializable trigger event was detected, remove it and fail associated trigger + resp = await self.asend(self.sanitize_trigger_events(msg)) + + if resp: + self.to_create.extend(resp.to_create) + self.to_cancel.extend(resp.to_cancel) + + async def asend(self, msg: messages.TriggerStateChanges) -> messages.TriggerStateSync | None: + try: + response = await self.comms_decoder.asend(msg) + + if not isinstance(response, messages.TriggerStateSync): + raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") + + return response except asyncio.IncompleteReadError: if task := asyncio.current_task(): task.cancel("EOF - shutting down") return raise - if not isinstance(resp, messages.TriggerStateSync): - raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") - self.to_create.extend(resp.to_create) - self.to_cancel.extend(resp.to_cancel) - async def block_watchdog(self): """ Watchdog loop that detects blocking (badly-written) triggers. diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 3181afa540c..116b3c115df 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -371,7 +371,7 @@ class TestTriggerRunner: trigger_runner = TriggerRunner() trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) trigger_runner.comms_decoder.asend.return_value = messages.TriggerStateSync( - to_create=[], to_cancel=[] + to_create=[], to_cancel=set() ) trigger_runner.to_create.append(workload) @@ -438,6 +438,32 @@ class TestTriggerRunner: trigger_instance.cancel() await runner.cleanup_finished_triggers() + @pytest.mark.asyncio + @patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) + async def test_sync_state_to_supervisor(self, supervisor_builder): + trigger_runner = TriggerRunner() + trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) + trigger_runner.events.append((1, TriggerEvent(payload={"status": "SUCCESS"}))) + trigger_runner.events.append((2, TriggerEvent(payload={"status": "FAILED"}))) + trigger_runner.events.append((3, TriggerEvent(payload={"status": "SUCCESS", "data": object()}))) + + async def asend_side_effect(msg): + if msg.events and len(msg.events) == 3: + raise NotImplementedError("Simulate non-serializable event") + return messages.TriggerStateSync(to_create=[], to_cancel=set()) + + trigger_runner.comms_decoder.asend.side_effect = asend_side_effect + + await trigger_runner.sync_state_to_supervisor(finished_ids=[]) + + assert trigger_runner.comms_decoder.asend.call_count == 2 + + first_call = trigger_runner.comms_decoder.asend.call_args_list[0].args[0] + second_call = trigger_runner.comms_decoder.asend.call_args_list[1].args[0] + + assert len(first_call.events) == 3 + assert len(second_call.events) == 2 + @pytest.mark.asyncio async def test_trigger_create_race_condition_38599(session, supervisor_builder, testing_dag_bundle):
