This is an automated email from the ASF dual-hosted git repository.
dabla pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-1-test by this push:
new 0167f876049 Prevent Triggerer from crashing when a trigger event isn't
serializable (#60152) (#60981)
0167f876049 is described below
commit 0167f8760494a1bef286e8291382acf7141d14db
Author: David Blain <[email protected]>
AuthorDate: Sat Jan 24 20:08:50 2026 +0100
Prevent Triggerer from crashing when a trigger event isn't serializable
(#60152) (#60981)
* refactor: If asend in TriggerRunner comms decoder crashes due to
NotImplementedError as a trigger event is not serializable, then retry without
that event and cancel associated trigger
* refactor: Applied some reformatting
* refactor: Fixed some mypy issues
* refactor: Fixed return type of send_changes method
* refactor: Changed imports of trigger events
* refactor: Reformatted trigger job runner
* refactor: Fixed mocking of comms decoder
* refactor: Forgot to add the patched supervisor_builder
* refactor: Changed asserts in test_sync_state_to_supervisor
* refactor: Refactored how state changes are validated
* refactor: Validate events while creating the TriggerStateChanges message
* refactor: Refactored try/except in validate_state_changes to keep mypy
happy
* Update airflow-core/src/airflow/jobs/triggerer_job_runner.py
Co-authored-by: Ash Berlin-Taylor <[email protected]>
* refactor: Renamed validate_state_changes method to process_trigger_events
* refactor: Only sanitize invalid trigger events if first attempt fails
* refactor: Should check if msg.events is not None
* refactor: Adapted test_sync_state_to_supervisor so both initial and retry
call are correctly asserted and bad events are being sanitized
---------
Co-authored-by: Ash Berlin-Taylor <[email protected]>
(cherry picked from commit 0238244413e1b802d2cc6df799c0f2cc16b5f7a7)
---
.../src/airflow/jobs/triggerer_job_runner.py | 89 +++++++++++++++-------
airflow-core/tests/unit/jobs/test_triggerer_job.py | 28 ++++++-
2 files changed, 88 insertions(+), 29 deletions(-)
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index c229dddc7fd..1d009657b42 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -70,12 +70,13 @@ from airflow.sdk.execution_time.comms import (
UpdateHITLDetail,
VariableResult,
XComResult,
+ _new_encoder,
_RequestFrame,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess,
make_buffered_socket_reader
from airflow.stats import Stats
from airflow.traces.tracer import DebugTrace, Trace, add_debug_span
-from airflow.triggers import base as events
+from airflow.triggers.base import BaseEventTrigger, BaseTrigger,
DiscrimatedTriggerEvent, TriggerEvent
from airflow.utils.helpers import log_filename_template_renderer
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
@@ -203,7 +204,7 @@ class messages:
type: Literal["TriggerStateChanges"] = "TriggerStateChanges"
events: Annotated[
- list[tuple[int, events.DiscrimatedTriggerEvent]] | None,
+ list[tuple[int, DiscrimatedTriggerEvent]] | None,
# We have to specify a default here, as otherwise Pydantic
struggles to deal with the discriminated
# union :shrug:
Field(default=None),
@@ -355,7 +356,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
creating_triggers: deque[workloads.RunTrigger] =
attrs.field(factory=deque, init=False)
# Outbound queue of events
- events: deque[tuple[int, events.TriggerEvent]] =
attrs.field(factory=deque, init=False)
+ events: deque[tuple[int, TriggerEvent]] = attrs.field(factory=deque,
init=False)
# Outbound queue of failed triggers
failed_triggers: deque[tuple[int, list[str] | None]] =
attrs.field(factory=deque, init=False)
@@ -821,7 +822,7 @@ class TriggerRunner:
to_cancel: deque[int]
# Outbound queue of events
- events: deque[tuple[int, events.TriggerEvent]]
+ events: deque[tuple[int, TriggerEvent]]
# Outbound queue of failed triggers
failed_triggers: deque[tuple[int, BaseException | None]]
@@ -971,7 +972,7 @@ class TriggerRunner:
"task": asyncio.create_task(
self.run_trigger(trigger_id, trigger_instance,
workload.timeout_after), name=trigger_name
),
- "is_watcher": isinstance(trigger_instance,
events.BaseEventTrigger),
+ "is_watcher": isinstance(trigger_instance, BaseEventTrigger),
"name": trigger_name,
"events": 0,
}
@@ -1017,7 +1018,7 @@ class TriggerRunner:
saved_exc = e
else:
# See if they foolishly returned a TriggerEvent
- if isinstance(result, events.TriggerEvent):
+ if isinstance(result, TriggerEvent):
self.log.error(
"Trigger returned a TriggerEvent rather than
yielding it",
trigger=details["name"],
@@ -1037,46 +1038,78 @@ class TriggerRunner:
await asyncio.sleep(0)
return finished_ids
- async def sync_state_to_supervisor(self, finished_ids: list[int]):
+ def process_trigger_events(self, finished_ids: list[int]) ->
messages.TriggerStateChanges:
# Copy out of our deques in threadsafe manner to sync state with parent
- events_to_send = []
+ events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = []
+ failures_to_send: list[tuple[int, list[str] | None]] = []
+
while self.events:
- data = self.events.popleft()
- events_to_send.append(data)
+ trigger_id, trigger_event = self.events.popleft()
+ events_to_send.append((trigger_id, trigger_event))
- failures_to_send = []
while self.failed_triggers:
- id, exc = self.failed_triggers.popleft()
+ trigger_id, exc = self.failed_triggers.popleft()
tb = format_exception(type(exc), exc, exc.__traceback__) if exc
else None
- failures_to_send.append((id, tb))
+ failures_to_send.append((trigger_id, tb))
- msg = messages.TriggerStateChanges(
- events=events_to_send, finished=finished_ids,
failures=failures_to_send
+ return messages.TriggerStateChanges(
+ events=events_to_send if events_to_send else None,
+ finished=finished_ids if finished_ids else None,
+ failures=failures_to_send if failures_to_send else None,
)
- if not events_to_send:
- msg.events = None
+ def sanitize_trigger_events(self, msg: messages.TriggerStateChanges) ->
messages.TriggerStateChanges:
+ req_encoder = _new_encoder()
+ events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = []
+
+ if msg.events:
+ for trigger_id, trigger_event in msg.events:
+ try:
+ req_encoder.encode(trigger_event)
+ except Exception as e:
+ logger.error(
+ "Trigger %s returned non-serializable result %r.
Cancelling trigger.",
+ trigger_id,
+ trigger_event,
+ )
+ self.failed_triggers.append((trigger_id, e))
+ else:
+ events_to_send.append((trigger_id, trigger_event))
- if not failures_to_send:
- msg.failures = None
+ return messages.TriggerStateChanges(
+ events=events_to_send if events_to_send else None,
+ finished=msg.finished,
+ failures=msg.failures,
+ )
- if not finished_ids:
- msg.finished = None
+ async def sync_state_to_supervisor(self, finished_ids: list[int]) -> None:
+ msg = self.process_trigger_events(finished_ids=finished_ids)
# Tell the monitor that we've finished triggers so it can update things
try:
- resp = await self.comms_decoder.asend(msg)
+ resp = await self.asend(msg)
+ except NotImplementedError:
+ # A non-serializable trigger event was detected, remove it and
fail associated trigger
+ resp = await self.asend(self.sanitize_trigger_events(msg))
+
+ if resp:
+ self.to_create.extend(resp.to_create)
+ self.to_cancel.extend(resp.to_cancel)
+
+ async def asend(self, msg: messages.TriggerStateChanges) ->
messages.TriggerStateSync | None:
+ try:
+ response = await self.comms_decoder.asend(msg)
+
+ if not isinstance(response, messages.TriggerStateSync):
+ raise RuntimeError(f"Expected to get a TriggerStateSync
message, instead we got {type(msg)}")
+
+ return response
except asyncio.IncompleteReadError:
if task := asyncio.current_task():
task.cancel("EOF - shutting down")
- return
+ return None
raise
- if not isinstance(resp, messages.TriggerStateSync):
- raise RuntimeError(f"Expected to get a TriggerStateSync message,
instead we got {type(msg)}")
- self.to_create.extend(resp.to_create)
- self.to_cancel.extend(resp.to_cancel)
-
async def block_watchdog(self):
"""
Watchdog loop that detects blocking (badly-written) triggers.
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 3181afa540c..116b3c115df 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -371,7 +371,7 @@ class TestTriggerRunner:
trigger_runner = TriggerRunner()
trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder)
trigger_runner.comms_decoder.asend.return_value =
messages.TriggerStateSync(
- to_create=[], to_cancel=[]
+ to_create=[], to_cancel=set()
)
trigger_runner.to_create.append(workload)
@@ -438,6 +438,32 @@ class TestTriggerRunner:
trigger_instance.cancel()
await runner.cleanup_finished_triggers()
+ @pytest.mark.asyncio
+ @patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS",
create=True)
+ async def test_sync_state_to_supervisor(self, supervisor_builder):
+ trigger_runner = TriggerRunner()
+ trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder)
+ trigger_runner.events.append((1, TriggerEvent(payload={"status":
"SUCCESS"})))
+ trigger_runner.events.append((2, TriggerEvent(payload={"status":
"FAILED"})))
+ trigger_runner.events.append((3, TriggerEvent(payload={"status":
"SUCCESS", "data": object()})))
+
+ async def asend_side_effect(msg):
+ if msg.events and len(msg.events) == 3:
+ raise NotImplementedError("Simulate non-serializable event")
+ return messages.TriggerStateSync(to_create=[], to_cancel=set())
+
+ trigger_runner.comms_decoder.asend.side_effect = asend_side_effect
+
+ await trigger_runner.sync_state_to_supervisor(finished_ids=[])
+
+ assert trigger_runner.comms_decoder.asend.call_count == 2
+
+ first_call =
trigger_runner.comms_decoder.asend.call_args_list[0].args[0]
+ second_call =
trigger_runner.comms_decoder.asend.call_args_list[1].args[0]
+
+ assert len(first_call.events) == 3
+ assert len(second_call.events) == 2
+
@pytest.mark.asyncio
async def test_trigger_create_race_condition_38599(session,
supervisor_builder, testing_dag_bundle):