jedcunningham commented on code in PR #66412:
URL: https://github.com/apache/airflow/pull/66412#discussion_r3191440581


##########
airflow-core/tests/unit/jobs/test_triggerer_job.py:
##########
@@ -1853,3 +1906,175 @@ def 
test_make_trigger_span_sets_only_trigger_name_without_ti(self):
         assert attrs["airflow.trigger.name"] == "OnlyTrigger"
         assert "airflow.dag_id" not in attrs
         assert "airflow.task_id" not in attrs
+
+
+def _read_frame_sync(sock) -> _RequestFrame | None:
+    """Read a length-prefixed msgpack frame from a blocking socket."""
+    lb = b""
+    while len(lb) < 4:
+        chunk = sock.recv(4 - len(lb))
+        if not chunk:
+            return None
+        lb += chunk
+    n = int.from_bytes(lb, "big")
+    data = b""
+    while len(data) < n:
+        chunk = sock.recv(n - len(data))
+        if not chunk:
+            return None
+        data += chunk
+    return msgspec.msgpack.decode(data, type=_RequestFrame)
+
+
+@pytest_asyncio.fixture
+async def decoder_pair():
+    """Yield (decoder, server_sock). Caller owns closing."""
+    server_sock, client_sock = socketpair()
+    reader, writer = await asyncio.open_connection(sock=client_sock)
+    decoder = TriggerCommsDecoder(async_writer=writer, async_reader=reader, 
socket=client_sock)
+    await decoder.start_reader()
+    yield decoder, server_sock
+    if decoder._reader_task:
+        if not decoder._reader_task.done():
+            decoder._reader_task.cancel()
+        with contextlib.suppress(asyncio.CancelledError, Exception):
+            await decoder._reader_task
+    writer.close()
+    server_sock.close()
+
+
[email protected]
[email protected]_timeout(15)
+async def test_all_send_paths_concurrent(decoder_pair):
+    """
+    All four send() paths running concurrently with responses returned out of 
order:
+
+      1. asend() directly from async code           — pure-async path
+      2. send() via asyncio.to_thread()              — mirrors 
apache/airflow#63913:
+                                                       
sync_to_async(hook_class)() → get_connection()
+                                                       → 
SUPERVISOR_COMMS.send() from a thread pool thread
+      3. send() from the event-loop thread           — mirrors 
apache/airflow#63760:
+         via greenback                                 async_to_sync raised 
RuntimeError in same thread
+      4. async_to_sync(asend)() from a thread        — trigger code that wraps 
an async fn which
+                                                       internally calls asend; 
bridges via wrap_future
+
+    The concurrent mix with shuffled responses also covers 
apache/airflow#65286: the
+    _thread_lock + async_to_sync approach stalled the triggerer under this 
exact load pattern.
+    """
+    decoder, server_sock = decoder_pair
+    N = 5
+    N_TOTAL = N * 4
+
+    def supervisor():
+        frames = []
+        for _ in range(N_TOTAL):
+            f = _read_frame_sync(server_sock)
+            if f is None:
+                break
+            frames.append(f)
+        random.shuffle(frames)
+        for f in frames:
+            server_sock.sendall(
+                _ResponseFrame(
+                    id=f.id,
+                    body={"type": "TriggerStateSync", "to_create": [], 
"to_cancel": []},
+                ).as_bytes()
+            )
+
+    sup = threading.Thread(target=supervisor, daemon=True)
+    sup.start()
+
+    async def async_send(idx):
+        return await decoder.asend(messages.TriggerStateChanges(events=None, 
finished=[idx], failures=None))
+
+    async def from_thread_send(idx):
+        # In production this path is taken by asgiref's own thread pool 
(sync_to_async),
+        # which is invisible to asyncio's default executor.  We avoid 
asyncio.to_thread()
+        # here because on Python < 3.12 loop.shutdown_default_executor() has 
no timeout
+        # and hangs if any executor threads are still alive at loop teardown.
+        # TODO: simplify with asyncio.to_thread() when Python 3.12 is the 
minimum.
+        loop = asyncio.get_running_loop()
+        fut: asyncio.Future[messages.TriggerStateSync] = loop.create_future()
+
+        def sync_send():
+            try:
+                result = decoder.send(
+                    messages.TriggerStateChanges(events=None, finished=[N + 
idx], failures=None)
+                )
+                loop.call_soon_threadsafe(fut.set_result, result)
+            except Exception as exc:
+                loop.call_soon_threadsafe(fut.set_exception, exc)
+
+        threading.Thread(target=sync_send, daemon=True).start()
+        return await fut
+
+    async def greenback_send(idx):
+        await greenback.ensure_portal()
+        return decoder.send(messages.TriggerStateChanges(events=None, 
finished=[2 * N + idx], failures=None))
+
+    async def async_to_sync_send(idx):
+        # Same executor-avoidance reason as from_thread_send above.
+        # TODO: simplify with asyncio.to_thread() when Python 3.12 is the 
minimum.
+        loop = asyncio.get_running_loop()
+        fut: asyncio.Future[messages.TriggerStateSync] = loop.create_future()
+
+        def thread_fn():
+            try:
+                result = async_to_sync(decoder.asend)(
+                    messages.TriggerStateChanges(events=None, finished=[3 * N 
+ idx], failures=None)
+                )
+                loop.call_soon_threadsafe(fut.set_result, result)
+            except Exception as exc:
+                loop.call_soon_threadsafe(fut.set_exception, exc)
+
+        threading.Thread(target=thread_fn, daemon=True).start()
+        return await fut
+
+    results = await asyncio.gather(
+        *[asyncio.create_task(async_send(i)) for i in range(N)],
+        *[asyncio.create_task(from_thread_send(i)) for i in range(N)],
+        *[asyncio.create_task(greenback_send(i)) for i in range(N)],
+        *[asyncio.create_task(async_to_sync_send(i)) for i in range(N)],
+        return_exceptions=True,
+    )
+
+    sup.join(timeout=5)
+
+    errors = [r for r in results if isinstance(r, Exception)]
+    assert not errors, f"errors: {errors}"
+    assert len(results) == N_TOTAL
+    assert all(isinstance(r, messages.TriggerStateSync) for r in results)
+
+
[email protected]
+async def test_connection_close_cancels_pending(decoder_pair):
+    """When the connection closes while asend() is awaiting, the future is 
cancelled."""
+    decoder, server_sock = decoder_pair
+
+    task = asyncio.create_task(
+        decoder.asend(messages.TriggerStateChanges(events=None, finished=[1], 
failures=None))
+    )
+    await asyncio.sleep(0)
+
+    server_sock.close()
+
+    with pytest.raises((asyncio.CancelledError, Exception)):
+        await asyncio.wait_for(task, timeout=5)

Review Comment:
   @parkhojeong want to open a follow up to change this? Thanks :)



##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -951,46 +977,80 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, 
ToTriggerSupervisor]):
         factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
     )
 
-    def _read_frame(self):
-        from asgiref.sync import async_to_sync
-
-        with self._thread_lock:
-            return async_to_sync(self._aread_frame)()
-
-    def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
-        from asgiref.sync import async_to_sync
-
-        with self._thread_lock:
-            return async_to_sync(self.asend)(msg)
+    _pending: dict[int, asyncio.Future] = attrs.field(factory=dict, repr=False)
+    _loop: asyncio.AbstractEventLoop | None = attrs.field(default=None, 
repr=False)
+    _loop_thread_id: int | None = attrs.field(default=None, repr=False)
+    _reader_task: asyncio.Task | None = attrs.field(default=None, repr=False)
 
     async def _aread_frame(self):
         try:
             len_bytes = await self._async_reader.readexactly(4)
         except ConnectionResetError:
             asyncio.current_task().cancel("Supervisor closed")
+            raise
         length = int.from_bytes(len_bytes, byteorder="big")
         if length >= 2**32:
             raise OverflowError(f"Refusing to receive messages larger than 
4GiB {length=}")
-
         buffer = await self._async_reader.readexactly(length)
         return self.resp_decoder.decode(buffer)
 
-    async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None:
-        frame = await self._aread_frame()
-        if frame.id != expect_id:
-            # Given the lock we take out in `asend`, this _shouldn't_ be 
possible, but I'd rather fail with
-            # this explicit error return the wrong type of message back to a 
Trigger
-            raise RuntimeError(f"Response read out of order! Got {frame.id=}, 
{expect_id=}")
-        return self._from_frame(frame)
+    async def _reader_loop(self) -> None:
+        try:
+            while True:
+                frame = await self._aread_frame()
+                future = self._pending.pop(frame.id, None)
+                if future is not None and not future.done():
+                    future.set_result(frame)
+                else:
+                    self.log.warning("Got response for unknown request frame", 
frame_id=frame.id)
+        finally:
+            for fut in self._pending.values():
+                if not fut.done():
+                    fut.cancel("Reader loop exited")
+            self._pending.clear()
 
-    async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
-        frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
-        bytes = frame.as_bytes()
+    async def start_reader(self) -> None:
+        self._loop = asyncio.get_running_loop()
+        self._loop_thread_id = threading.get_ident()
+        self._reader_task = asyncio.create_task(self._reader_loop(), 
name="trigger-comms-reader")
 
-        async with self._async_lock:
-            self._async_writer.write(bytes)
+    def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
+        if self._loop is None:
+            raise RuntimeError("start_reader() must be called before send()")
+        if threading.get_ident() == self._loop_thread_id:
+            # Called from the event loop thread itself (e.g. a trigger calling 
a sync SDK method
+            # directly from async def run()). 
run_coroutine_threadsafe(...).result() would deadlock
+            # here because .result() blocks the thread the event loop runs on.
+            # greenback.await_() teleports the coroutine back into the running 
loop instead.
+            if not greenback.has_portal():
+                raise RuntimeError(
+                    "Sync SDK methods (e.g. get_connection(), get_variable()) 
cannot be called "
+                    "from a trigger's async def run() when 
AIRFLOW_DISABLE_GREENBACK_PORTAL is "
+                    "set. Either remove that environment variable, or use the 
async equivalent "
+                    "(e.g. aget_connection(), aget_variable())."
+                )
+            return greenback.await_(self.asend(msg))
+        return asyncio.run_coroutine_threadsafe(self.asend(msg), 
self._loop).result()
 
-            return await self._aget_response(frame.id)
+    async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
+        if self._loop is None:
+            raise RuntimeError("start_reader() must be called before asend()")
+        current_loop = asyncio.get_running_loop()
+        if self._loop is not None and current_loop is not self._loop:

Review Comment:
   @parkhojeong want to open a follow up to change this? Thanks :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to