This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch rework-tasksdk-supervisor-comms-protocol
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 8ab1de2f5a444cafaa9239b81d4febd49a1a3467
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Jun 11 13:58:39 2025 +0100

    Convert the Dag Processor Manager+Procs to the frame-based protocol
    
    In order to support this (and the fact that the proc manager uses a single
    selector for multiple processes) means I have moved the `on_close` callback 
to
    be part of the object we store in the selector, previoulsy it was the
    "on_read" callback, now we store a tuple of `(on_read, on_close)` and 
on_close
    is called once universally.
---
 airflow-core/src/airflow/dag_processing/manager.py |  9 ++-
 .../src/airflow/dag_processing/processor.py        | 17 +++--
 .../src/airflow/jobs/triggerer_job_runner.py       |  7 +-
 .../tests/unit/dag_processing/test_processor.py    | 39 +++++++---
 task-sdk/src/airflow/sdk/execution_time/comms.py   | 89 ++++++++++------------
 .../src/airflow/sdk/execution_time/supervisor.py   | 57 +++++---------
 6 files changed, 109 insertions(+), 109 deletions(-)

diff --git a/airflow-core/src/airflow/dag_processing/manager.py 
b/airflow-core/src/airflow/dag_processing/manager.py
index 8b0cab74b6e..2ffb471f703 100644
--- a/airflow-core/src/airflow/dag_processing/manager.py
+++ b/airflow-core/src/airflow/dag_processing/manager.py
@@ -77,6 +77,8 @@ from airflow.utils.session import NEW_SESSION, 
create_session, provide_session
 from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks
 
 if TYPE_CHECKING:
+    from socket import socket
+
     from sqlalchemy.orm import Session
 
     from airflow.callbacks.callback_requests import CallbackRequest
@@ -388,7 +390,7 @@ class DagFileProcessorManager(LoggingMixin):
         """
         events = self.selector.select(timeout=timeout)
         for key, _ in events:
-            socket_handler = key.data
+            socket_handler, on_close = key.data
 
             # BrokenPipeError should be caught and treated as if the handler 
returned false, similar
             # to EOF case
@@ -397,8 +399,9 @@ class DagFileProcessorManager(LoggingMixin):
             except BrokenPipeError:
                 need_more = False
             if not need_more:
-                self.selector.unregister(key.fileobj)
-                key.fileobj.close()  # type: ignore[union-attr]
+                sock: socket = key.fileobj  # type: ignore[assignment]
+                on_close(sock)
+                sock.close()
 
     def _queue_requested_files_for_parsing(self) -> None:
         """Queue any files requested for parsing as requested by users via 
UI/API."""
diff --git a/airflow-core/src/airflow/dag_processing/processor.py 
b/airflow-core/src/airflow/dag_processing/processor.py
index 1d1ad25f00b..4341d6d9139 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -267,15 +267,14 @@ class DagFileProcessorProcess(WatchedSubprocess):
         )
         self.send_msg(msg, in_response_to=0)
 
-    def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> 
None:  # type: ignore[override]
+    def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, 
req_id: int) -> None:  # type: ignore[override]
         from airflow.sdk.api.datamodels._generated import ConnectionResponse, 
VariableResponse
 
         resp: BaseModel | None = None
         dump_opts = {}
         if isinstance(msg, DagFileParsingResult):
             self.parsing_result = msg
-            return
-        if isinstance(msg, GetConnection):
+        elif isinstance(msg, GetConnection):
             conn = self.client.connections.get(msg.conn_id)
             if isinstance(conn, ConnectionResponse):
                 conn_result = ConnectionResult.from_conn_response(conn)
@@ -297,10 +296,16 @@ class DagFileProcessorProcess(WatchedSubprocess):
             resp = self.client.variables.delete(msg.key)
         else:
             log.error("Unhandled request", msg=msg)
+            self.send_msg(
+                None,
+                in_response_to=req_id,
+                error=ErrorResponse(
+                    detail={"status_code": 400, "message": "Unhandled 
request"},
+                ),
+            )
             return
 
-        if resp:
-            self.send_msg(resp, **dump_opts)
+        self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts)
 
     @property
     def is_ready(self) -> bool:
@@ -308,7 +313,7 @@ class DagFileProcessorProcess(WatchedSubprocess):
             # Process still alive, def can't be finished yet
             return False
 
-        return self._num_open_sockets == 0
+        return not self._open_sockets
 
     def wait(self) -> int:
         raise NotImplementedError(f"Don't call wait on {type(self).__name__} 
objects")
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py 
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 0a516abb1a6..258d6172c4e 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -793,13 +793,12 @@ class TriggerRunner:
         This also sets up the SUPERVISOR_COMMS so that TaskSDK code can work 
as expected too (but that will
         need to be wrapped in an ``sync_to_async()`` call)
         """
-        from airflow.sdk.execution_time import task_runner
+        from airflow.sdk.execution_time import comms, task_runner
 
         loop = asyncio.get_event_loop()
 
-        comms_decoder = task_runner.CommsDecoder[ToTriggerRunner, 
ToTriggerSupervisor](
-            input=sys.stdin,
-            decoder=self.decoder,
+        comms_decoder = comms.CommsDecoder[ToTriggerRunner, 
ToTriggerSupervisor](
+            body_decoder=self.decoder,
         )
 
         task_runner.SUPERVISOR_COMMS = comms_decoder
diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py 
b/airflow-core/tests/unit/dag_processing/test_processor.py
index e673848a915..eb6a58dfcfb 100644
--- a/airflow-core/tests/unit/dag_processing/test_processor.py
+++ b/airflow-core/tests/unit/dag_processing/test_processor.py
@@ -42,7 +42,7 @@ from airflow.models import DagBag, TaskInstance
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.sdk.api.client import Client
-from airflow.sdk.execution_time.task_runner import CommsDecoder
+from airflow.sdk.execution_time import comms
 from airflow.utils import timezone
 from airflow.utils.session import create_session
 from airflow.utils.state import DagRunState, TaskInstanceState
@@ -392,23 +392,38 @@ def disable_capturing():
 
 @pytest.mark.usefixtures("testing_dag_bundle")
 @pytest.mark.usefixtures("disable_capturing")
-def test_parse_file_entrypoint_parses_dag_callbacks(spy_agency):
+def test_parse_file_entrypoint_parses_dag_callbacks(mocker):
     r, w = socketpair()
 
-    w.makefile("wb").write(
-        b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags",'
-        b'"callback_requests": [{"filepath": "wait.py", "bundle_name": 
"testing", "bundle_version": null, '
-        b'"msg": "task_failure", "dag_id": "wait_to_fail", "run_id": '
-        b'"manual__2024-12-30T21:02:55.203691+00:00", '
-        b'"is_failure_callback": true, "type": "DagCallbackRequest"}], "type": 
"DagFileParseRequest"}\n'
+    frame = comms._ResponseFrame(
+        id=1,
+        body={
+            "file": "/files/dags/wait.py",
+            "bundle_path": "/files/dags",
+            "callback_requests": [
+                {
+                    "filepath": "wait.py",
+                    "bundle_name": "testing",
+                    "bundle_version": None,
+                    "msg": "task_failure",
+                    "dag_id": "wait_to_fail",
+                    "run_id": "manual__2024-12-30T21:02:55.203691+00:00",
+                    "is_failure_callback": True,
+                    "type": "DagCallbackRequest",
+                }
+            ],
+            "type": "DagFileParseRequest",
+        },
     )
+    bytes = frame.as_bytes()
+    w.sendall(bytes)
 
-    decoder = CommsDecoder[DagFileParseRequest, DagFileParsingResult](
-        input=r.makefile("r"),
-        decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest),
+    decoder = comms.CommsDecoder[DagFileParseRequest, DagFileParsingResult](
+        request_socket=r,
+        body_decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest),
     )
 
-    msg = decoder.get_message()
+    msg = decoder._get_response()
     assert isinstance(msg, DagFileParseRequest)
     assert msg.file == "/files/dags/wait.py"
     assert msg.callback_requests == [
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 4041323540a..dd6f7aa7ea3 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -47,8 +47,9 @@ import itertools
 from collections.abc import Iterator
 from datetime import datetime
 from functools import cached_property
+from pathlib import Path
 from socket import socket
-from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar, 
Union
+from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, 
TypeVar, Union
 from uuid import UUID
 
 import aiologic
@@ -92,26 +93,6 @@ SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
 ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
 
 
-class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, 
omit_defaults=True):
-    id: int
-    """
-    The request id, set by the sender.
-
-    This is used to allow "pipeling" of requests and to be able to tie 
response to requests, which is
-    particularly useful in the Triggerer where multiple async tasks can send a 
requests concurrently.
-    """
-    body: dict[str, Any]
-
-
-class _ResponseFrame(msgspec.Struct, array_like=True, frozen=True, 
omit_defaults=True):
-    id: int
-    """
-    The id of the request this is a response to
-    """
-    body: dict[str, Any] | None = None
-    error: dict[str, Any] | None = None
-
-
 def _msgpack_enc_hook(obj: Any) -> Any:
     import pendulum
 
@@ -120,6 +101,8 @@ def _msgpack_enc_hook(obj: Any) -> Any:
         return datetime(
             obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, 
obj.microsecond, tzinfo=obj.tzinfo
         )
+    if isinstance(obj, Path):
+        return str(obj)
     if isinstance(obj, BaseModel):
         return obj.model_dump(exclude_unset=True)
 
@@ -131,6 +114,41 @@ def _new_encoder() -> msgspec.msgpack.Encoder:
     return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook)
 
 
+class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, 
omit_defaults=True):
+    id: int
+    """
+    The request id, set by the sender.
+
+    This is used to allow "pipeling" of requests and to be able to tie 
response to requests, which is
+    particularly useful in the Triggerer where multiple async tasks can send a 
requests concurrently.
+    """
+    body: dict[str, Any] | None
+
+    req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder()
+
+    def as_bytes(self) -> bytearray:
+        # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing 
for inspiration
+        buffer = bytearray(256)
+
+        self.req_encoder.encode_into(self, buffer, 4)
+
+        n = len(buffer) - 4
+        if n > 2**32:
+            raise OverflowError("Cannot send messages larger than 4GiB")
+        buffer[:4] = n.to_bytes(4, byteorder="big")
+
+        return buffer
+
+
+class _ResponseFrame(_RequestFrame, msgspec.Struct, array_like=True, 
frozen=True, omit_defaults=True):
+    id: int
+    """
+    The id of the request this is a response to
+    """
+    body: dict[str, Any] | None = None
+    error: dict[str, Any] | None = None
+
+
 @attrs.define()
 class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
     """Handle communication between the task in this process and the 
supervisor parent process."""
@@ -144,7 +162,6 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
     resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field(
         factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False
     )
-    req_encoder: msgspec.msgpack.Encoder = attrs.field(factory=_new_encoder, 
repr=False)
 
     id_counter: Iterator[int] = attrs.field(factory=itertools.count)
 
@@ -157,30 +174,12 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
 
     def send(self, msg: SendMsgType) -> ReceiveMsgType:
         """Send a request to the parent and block until the response is 
received."""
-        bytes = self._encode(msg)
-
-        # print(
-        #     f"Subp: sending {type(msg)} request on 
{self.request_socket.fileno()}, total len={len(bytes)}",
-        #     file=__import__("sys")._ash_out,
-        # )
-        nsent = self.request_socket.send(bytes)
-        # print(f"Subp: {nsent=}", file=__import__("sys")._ash_out)
-
-        return self._get_response()
-
-    def _encode(self, msg: SendMsgType) -> bytearray:
-        # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing 
for inspiration
-        buffer = bytearray(256)
-
         frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
-        self.req_encoder.encode_into(frame, buffer, 4)
+        bytes = frame.as_bytes()
 
-        n = len(buffer) - 4
-        if n > 2**32:
-            raise OverflowError("Cannot send messages larger than 4GiB")
-        buffer[:4] = n.to_bytes(4, byteorder="big")
+        self.request_socket.sendall(bytes)
 
-        return buffer
+        return self._get_response()
 
     def _read_frame(self):
         """
@@ -190,14 +189,12 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
         """
         if self.request_socket:
             self.request_socket.setblocking(True)
-        # print("Subp: reading length prefix", file=__import__("sys")._ash_out)
         len_bytes = self.request_socket.recv(4)
 
         if len_bytes == b"":
             raise EOFError("Request socket closed before length")
 
         len = int.from_bytes(len_bytes, byteorder="big")
-        # print(f"Subp: frame {len=} ({len_bytes=})", 
file=__import__("sys")._ash_out)
 
         buffer = bytearray(len)
         nread = self.request_socket.recv_into(buffer)
@@ -206,7 +203,6 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
                 f"unable to read full response in child. (We read {nread}, but 
expected {len})"
             )
         if nread == 0:
-            # print("Subp: EOF when trying to read frame", 
file=__import__("sys")._ash_out)
             raise EOFError("Request socket closed before response was 
complete")
 
         try:
@@ -217,7 +213,6 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
     def _from_frame(self, frame):
         from airflow.sdk.exceptions import AirflowRuntimeError
 
-        # print(f"Subp: {frame.body=}", file=__import__("sys")._ash_out)
         if frame.error is not None:
             err = self.err_decoder.validate_python(frame.error)
             raise AirflowRuntimeError(error=err)
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index a3c9bcbb3df..b31f2da0360 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -574,27 +574,22 @@ class WatchedSubprocess:
 
     def _on_socket_closed(self, sock: socket):
         # We want to keep servicing this process until we've read up to EOF 
from all the sockets.
-        self._open_sockets.pop(sock, None)
+
+        with suppress(KeyError):
+            self.selector.unregister(sock)
+            del self._open_sockets[sock]
 
     def send_msg(
         self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse 
| None = None, **dump_opts
     ):
         """Send the msg as a length-prefixed response frame."""
-        # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing 
for inspiration
         if msg:
             frame = _ResponseFrame(id=in_response_to, 
body=msg.model_dump(**dump_opts))
         else:
             err_resp = error.model_dump() if error else None
             frame = _ResponseFrame(id=in_response_to, error=err_resp)
-        buffer = bytearray(256)
-
-        self._frame_encoder.encode_into(frame, buffer, 4)
-        n = len(buffer) - 4
-        if n > 2**32:
-            raise OverflowError("Cannot send messages larger than 4GiB")
-        buffer[:4] = n.to_bytes(4, byteorder="big")
 
-        self.stdin.sendall(buffer)
+        self.stdin.sendall(frame.as_bytes())
 
     def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, 
_RequestFrame, None]:
         """Handle incoming requests from the task process, respond with the 
appropriate data."""
@@ -631,8 +626,9 @@ class WatchedSubprocess:
                     ),
                     in_response_to=request.id,
                 )
+                return
 
-    def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> 
None:
+    def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> 
BaseModel | None:
         raise NotImplementedError()
 
     @staticmethod
@@ -758,7 +754,7 @@ class WatchedSubprocess:
         events = self.selector.select(timeout=timeout)
         for key, _ in events:
             # Retrieve the handler responsible for processing this file object 
(e.g., stdout, stderr)
-            socket_handler = key.data
+            socket_handler, on_close = key.data
 
             # Example of handler behavior:
             # If the subprocess writes "Hello, World!" to stdout:
@@ -775,10 +771,9 @@ class WatchedSubprocess:
             # unregister it from the selector to stop monitoring; `wait()` 
blocks until all selectors
             # are removed.
             if not need_more:
-                self.selector.unregister(key.fileobj)
                 sock: socket = key.fileobj  # type: ignore[assignment]
+                on_close(sock)
                 sock.close()
-                self._on_socket_closed(sock)
 
         # Check if the subprocess has exited
         return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, 
expect_signal=expect_signal)
@@ -797,16 +792,16 @@ class WatchedSubprocess:
                 raise
         else:
             self._process_exit_monotonic = time.monotonic()
-            self._close_unused_sockets(self.stdin)
-            # Put a message in the viewable task logs
 
             if expect_signal is not None and self._exit_code == -expect_signal:
                 # Bypass logging, the caller expected us to exit with this
                 return self._exit_code
 
-            # psutil turns signal exit codes into an enum for us. Handy. 
(Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, 
"name")):
+            # Put a message in the viewable task logs
+
             if self._exit_code == -signal.SIGSEGV:
                 self.process_log.critical(SIGSEGV_MESSAGE)
+            # psutil turns signal exit codes into an enum for us. Handy. 
(Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, 
"name")):
             elif name := getattr(self._exit_code, "name", None):
                 message = "Process terminated by signal"
                 level = logging.ERROR
@@ -1252,11 +1247,7 @@ class ActivitySubprocess(WatchedSubprocess):
             )
             return
 
-        if resp:
-            self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts)
-        else:
-            # Send an empty response frame (which signifies no error) if we 
dont have anything else to say
-            self.send_msg(None, in_response_to=req_id)
+        self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts)
 
 
 def in_process_api_server():
@@ -1454,7 +1445,7 @@ def make_buffered_socket_reader(
     gen: Generator[None, bytes | bytearray, None],
     on_close: Callable[[socket], None],
     buffer_size: int = 4096,
-) -> Callable[[socket], bool]:
+):
     buffer = bytearray()  # This will hold our accumulated binary data
     read_buffer = bytearray(buffer_size)  # Temporary buffer for each read
 
@@ -1472,7 +1463,6 @@ def make_buffered_socket_reader(
                 with suppress(StopIteration):
                     gen.send(buffer)
             # Tell loop to close this selector
-            on_close(sock)
             return False
 
         buffer.extend(read_buffer[:n_received])
@@ -1483,13 +1473,12 @@ def make_buffered_socket_reader(
             try:
                 gen.send(line)
             except StopIteration:
-                on_close(sock)
                 return False
             buffer = buffer[newline_pos + 1 :]  # Update the buffer with 
remaining data
 
         return True
 
-    return cb
+    return cb, on_close
 
 
 def length_prefixed_frame_reader(
@@ -1506,15 +1495,12 @@ def length_prefixed_frame_reader(
     next(gen)
 
     def cb(sock: socket):
-        print("Main: length_prefixed_frame_reader.cb fired")
         nonlocal buffer, length_needed, pos
-        # Read up to `buffer_size` bytes of data from the socket
 
         if length_needed is None:
             # Read the 32bit length of the frame
             bytes = sock.recv(4)
             if bytes == b"":
-                on_close(sock)
                 return False
 
             length_needed = int.from_bytes(bytes, byteorder="big")
@@ -1523,11 +1509,10 @@ def length_prefixed_frame_reader(
             n = sock.recv_into(buffer[pos:])
             if n == 0:
                 # EOF
-                on_close(sock)
                 return False
             pos += n
 
-            if len(buffer) >= length_needed:
+            if pos >= length_needed:
                 request = decoder.decode(buffer)
                 buffer = None
                 pos = 0
@@ -1535,16 +1520,15 @@ def length_prefixed_frame_reader(
                 try:
                     gen.send(request)
                 except StopIteration:
-                    on_close(sock)
                     return False
         return True
 
-    return cb
+    return cb, on_close
 
 
 def process_log_messages_from_subprocess(
     loggers: tuple[FilteringBoundLogger, ...],
-) -> Generator[None, bytes, None]:
+) -> Generator[None, bytes | bytearray, None]:
     from structlog.stdlib import NAME_TO_LEVEL
 
     while True:
@@ -1580,10 +1564,9 @@ def process_log_messages_from_subprocess(
 
 def forward_to_log(
     target_loggers: tuple[FilteringBoundLogger, ...], chan: str, level: int
-) -> Generator[None, bytes, None]:
+) -> Generator[None, bytes | bytearray, None]:
     while True:
-        buf = yield
-        line = bytes(buf)
+        line = yield
         # Strip off new line
         line = line.rstrip()
         try:

Reply via email to