This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch rework-tasksdk-supervisor-comms-protocol in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 8ab1de2f5a444cafaa9239b81d4febd49a1a3467 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Wed Jun 11 13:58:39 2025 +0100 Convert the Dag Processor Manager+Procs to the frame-based protocol In order to support this (and the fact that the proc manager uses a single selector for multiple processes) means I have moved the `on_close` callback to be part of the object we store in the selector, previoulsy it was the "on_read" callback, now we store a tuple of `(on_read, on_close)` and on_close is called once universally. --- airflow-core/src/airflow/dag_processing/manager.py | 9 ++- .../src/airflow/dag_processing/processor.py | 17 +++-- .../src/airflow/jobs/triggerer_job_runner.py | 7 +- .../tests/unit/dag_processing/test_processor.py | 39 +++++++--- task-sdk/src/airflow/sdk/execution_time/comms.py | 89 ++++++++++------------ .../src/airflow/sdk/execution_time/supervisor.py | 57 +++++--------- 6 files changed, 109 insertions(+), 109 deletions(-) diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index 8b0cab74b6e..2ffb471f703 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -77,6 +77,8 @@ from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: + from socket import socket + from sqlalchemy.orm import Session from airflow.callbacks.callback_requests import CallbackRequest @@ -388,7 +390,7 @@ class DagFileProcessorManager(LoggingMixin): """ events = self.selector.select(timeout=timeout) for key, _ in events: - socket_handler = key.data + socket_handler, on_close = key.data # BrokenPipeError should be caught and treated as if the handler returned false, similar # to EOF case @@ -397,8 +399,9 @@ class DagFileProcessorManager(LoggingMixin): except BrokenPipeError: need_more = False if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() def _queue_requested_files_for_parsing(self) -> None: """Queue any files requested for parsing as requested by users via UI/API.""" diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 1d1ad25f00b..4341d6d9139 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -267,15 +267,14 @@ class DagFileProcessorProcess(WatchedSubprocess): ) self.send_msg(msg, in_response_to=0) - def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override] + def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse resp: BaseModel | None = None dump_opts = {} if isinstance(msg, DagFileParsingResult): self.parsing_result = msg - return - if isinstance(msg, GetConnection): + elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) if isinstance(conn, ConnectionResponse): conn_result = ConnectionResult.from_conn_response(conn) @@ -297,10 +296,16 @@ class DagFileProcessorProcess(WatchedSubprocess): resp = self.client.variables.delete(msg.key) else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + in_response_to=req_id, + error=ErrorResponse( + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) @property def is_ready(self) -> bool: @@ -308,7 +313,7 @@ class DagFileProcessorProcess(WatchedSubprocess): # Process still alive, def can't be finished yet return False - return self._num_open_sockets == 0 + return not self._open_sockets def wait(self) -> int: raise NotImplementedError(f"Don't call wait on {type(self).__name__} objects") diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 0a516abb1a6..258d6172c4e 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -793,13 +793,12 @@ class TriggerRunner: This also sets up the SUPERVISOR_COMMS so that TaskSDK code can work as expected too (but that will need to be wrapped in an ``sync_to_async()`` call) """ - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time import comms, task_runner loop = asyncio.get_event_loop() - comms_decoder = task_runner.CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]( - input=sys.stdin, - decoder=self.decoder, + comms_decoder = comms.CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]( + body_decoder=self.decoder, ) task_runner.SUPERVISOR_COMMS = comms_decoder diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index e673848a915..eb6a58dfcfb 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -42,7 +42,7 @@ from airflow.models import DagBag, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.models.serialized_dag import SerializedDagModel from airflow.sdk.api.client import Client -from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.sdk.execution_time import comms from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState @@ -392,23 +392,38 @@ def disable_capturing(): @pytest.mark.usefixtures("testing_dag_bundle") @pytest.mark.usefixtures("disable_capturing") -def test_parse_file_entrypoint_parses_dag_callbacks(spy_agency): +def test_parse_file_entrypoint_parses_dag_callbacks(mocker): r, w = socketpair() - w.makefile("wb").write( - b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags",' - b'"callback_requests": [{"filepath": "wait.py", "bundle_name": "testing", "bundle_version": null, ' - b'"msg": "task_failure", "dag_id": "wait_to_fail", "run_id": ' - b'"manual__2024-12-30T21:02:55.203691+00:00", ' - b'"is_failure_callback": true, "type": "DagCallbackRequest"}], "type": "DagFileParseRequest"}\n' + frame = comms._ResponseFrame( + id=1, + body={ + "file": "/files/dags/wait.py", + "bundle_path": "/files/dags", + "callback_requests": [ + { + "filepath": "wait.py", + "bundle_name": "testing", + "bundle_version": None, + "msg": "task_failure", + "dag_id": "wait_to_fail", + "run_id": "manual__2024-12-30T21:02:55.203691+00:00", + "is_failure_callback": True, + "type": "DagCallbackRequest", + } + ], + "type": "DagFileParseRequest", + }, ) + bytes = frame.as_bytes() + w.sendall(bytes) - decoder = CommsDecoder[DagFileParseRequest, DagFileParsingResult]( - input=r.makefile("r"), - decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), + decoder = comms.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( + request_socket=r, + body_decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), ) - msg = decoder.get_message() + msg = decoder._get_response() assert isinstance(msg, DagFileParseRequest) assert msg.file == "/files/dags/wait.py" assert msg.callback_requests == [ diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 4041323540a..dd6f7aa7ea3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -47,8 +47,9 @@ import itertools from collections.abc import Iterator from datetime import datetime from functools import cached_property +from pathlib import Path from socket import socket -from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union from uuid import UUID import aiologic @@ -92,26 +93,6 @@ SendMsgType = TypeVar("SendMsgType", bound=BaseModel) ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) -class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): - id: int - """ - The request id, set by the sender. - - This is used to allow "pipeling" of requests and to be able to tie response to requests, which is - particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. - """ - body: dict[str, Any] - - -class _ResponseFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): - id: int - """ - The id of the request this is a response to - """ - body: dict[str, Any] | None = None - error: dict[str, Any] | None = None - - def _msgpack_enc_hook(obj: Any) -> Any: import pendulum @@ -120,6 +101,8 @@ def _msgpack_enc_hook(obj: Any) -> Any: return datetime( obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo ) + if isinstance(obj, Path): + return str(obj) if isinstance(obj, BaseModel): return obj.model_dump(exclude_unset=True) @@ -131,6 +114,41 @@ def _new_encoder() -> msgspec.msgpack.Encoder: return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook) +class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The request id, set by the sender. + + This is used to allow "pipeling" of requests and to be able to tie response to requests, which is + particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. + """ + body: dict[str, Any] | None + + req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder() + + def as_bytes(self) -> bytearray: + # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration + buffer = bytearray(256) + + self.req_encoder.encode_into(self, buffer, 4) + + n = len(buffer) - 4 + if n > 2**32: + raise OverflowError("Cannot send messages larger than 4GiB") + buffer[:4] = n.to_bytes(4, byteorder="big") + + return buffer + + +class _ResponseFrame(_RequestFrame, msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The id of the request this is a response to + """ + body: dict[str, Any] | None = None + error: dict[str, Any] | None = None + + @attrs.define() class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): """Handle communication between the task in this process and the supervisor parent process.""" @@ -144,7 +162,6 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False ) - req_encoder: msgspec.msgpack.Encoder = attrs.field(factory=_new_encoder, repr=False) id_counter: Iterator[int] = attrs.field(factory=itertools.count) @@ -157,30 +174,12 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): def send(self, msg: SendMsgType) -> ReceiveMsgType: """Send a request to the parent and block until the response is received.""" - bytes = self._encode(msg) - - # print( - # f"Subp: sending {type(msg)} request on {self.request_socket.fileno()}, total len={len(bytes)}", - # file=__import__("sys")._ash_out, - # ) - nsent = self.request_socket.send(bytes) - # print(f"Subp: {nsent=}", file=__import__("sys")._ash_out) - - return self._get_response() - - def _encode(self, msg: SendMsgType) -> bytearray: - # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration - buffer = bytearray(256) - frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) - self.req_encoder.encode_into(frame, buffer, 4) + bytes = frame.as_bytes() - n = len(buffer) - 4 - if n > 2**32: - raise OverflowError("Cannot send messages larger than 4GiB") - buffer[:4] = n.to_bytes(4, byteorder="big") + self.request_socket.sendall(bytes) - return buffer + return self._get_response() def _read_frame(self): """ @@ -190,14 +189,12 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): """ if self.request_socket: self.request_socket.setblocking(True) - # print("Subp: reading length prefix", file=__import__("sys")._ash_out) len_bytes = self.request_socket.recv(4) if len_bytes == b"": raise EOFError("Request socket closed before length") len = int.from_bytes(len_bytes, byteorder="big") - # print(f"Subp: frame {len=} ({len_bytes=})", file=__import__("sys")._ash_out) buffer = bytearray(len) nread = self.request_socket.recv_into(buffer) @@ -206,7 +203,6 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): f"unable to read full response in child. (We read {nread}, but expected {len})" ) if nread == 0: - # print("Subp: EOF when trying to read frame", file=__import__("sys")._ash_out) raise EOFError("Request socket closed before response was complete") try: @@ -217,7 +213,6 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): def _from_frame(self, frame): from airflow.sdk.exceptions import AirflowRuntimeError - # print(f"Subp: {frame.body=}", file=__import__("sys")._ash_out) if frame.error is not None: err = self.err_decoder.validate_python(frame.error) raise AirflowRuntimeError(error=err) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index a3c9bcbb3df..b31f2da0360 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -574,27 +574,22 @@ class WatchedSubprocess: def _on_socket_closed(self, sock: socket): # We want to keep servicing this process until we've read up to EOF from all the sockets. - self._open_sockets.pop(sock, None) + + with suppress(KeyError): + self.selector.unregister(sock) + del self._open_sockets[sock] def send_msg( self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts ): """Send the msg as a length-prefixed response frame.""" - # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration if msg: frame = _ResponseFrame(id=in_response_to, body=msg.model_dump(**dump_opts)) else: err_resp = error.model_dump() if error else None frame = _ResponseFrame(id=in_response_to, error=err_resp) - buffer = bytearray(256) - - self._frame_encoder.encode_into(frame, buffer, 4) - n = len(buffer) - 4 - if n > 2**32: - raise OverflowError("Cannot send messages larger than 4GiB") - buffer[:4] = n.to_bytes(4, byteorder="big") - self.stdin.sendall(buffer) + self.stdin.sendall(frame.as_bytes()) def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, _RequestFrame, None]: """Handle incoming requests from the task process, respond with the appropriate data.""" @@ -631,8 +626,9 @@ class WatchedSubprocess: ), in_response_to=request.id, ) + return - def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: + def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> BaseModel | None: raise NotImplementedError() @staticmethod @@ -758,7 +754,7 @@ class WatchedSubprocess: events = self.selector.select(timeout=timeout) for key, _ in events: # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr) - socket_handler = key.data + socket_handler, on_close = key.data # Example of handler behavior: # If the subprocess writes "Hello, World!" to stdout: @@ -775,10 +771,9 @@ class WatchedSubprocess: # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors # are removed. if not need_more: - self.selector.unregister(key.fileobj) sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) sock.close() - self._on_socket_closed(sock) # Check if the subprocess has exited return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, expect_signal=expect_signal) @@ -797,16 +792,16 @@ class WatchedSubprocess: raise else: self._process_exit_monotonic = time.monotonic() - self._close_unused_sockets(self.stdin) - # Put a message in the viewable task logs if expect_signal is not None and self._exit_code == -expect_signal: # Bypass logging, the caller expected us to exit with this return self._exit_code - # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): + # Put a message in the viewable task logs + if self._exit_code == -signal.SIGSEGV: self.process_log.critical(SIGSEGV_MESSAGE) + # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): elif name := getattr(self._exit_code, "name", None): message = "Process terminated by signal" level = logging.ERROR @@ -1252,11 +1247,7 @@ class ActivitySubprocess(WatchedSubprocess): ) return - if resp: - self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) - else: - # Send an empty response frame (which signifies no error) if we dont have anything else to say - self.send_msg(None, in_response_to=req_id) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) def in_process_api_server(): @@ -1454,7 +1445,7 @@ def make_buffered_socket_reader( gen: Generator[None, bytes | bytearray, None], on_close: Callable[[socket], None], buffer_size: int = 4096, -) -> Callable[[socket], bool]: +): buffer = bytearray() # This will hold our accumulated binary data read_buffer = bytearray(buffer_size) # Temporary buffer for each read @@ -1472,7 +1463,6 @@ def make_buffered_socket_reader( with suppress(StopIteration): gen.send(buffer) # Tell loop to close this selector - on_close(sock) return False buffer.extend(read_buffer[:n_received]) @@ -1483,13 +1473,12 @@ def make_buffered_socket_reader( try: gen.send(line) except StopIteration: - on_close(sock) return False buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data return True - return cb + return cb, on_close def length_prefixed_frame_reader( @@ -1506,15 +1495,12 @@ def length_prefixed_frame_reader( next(gen) def cb(sock: socket): - print("Main: length_prefixed_frame_reader.cb fired") nonlocal buffer, length_needed, pos - # Read up to `buffer_size` bytes of data from the socket if length_needed is None: # Read the 32bit length of the frame bytes = sock.recv(4) if bytes == b"": - on_close(sock) return False length_needed = int.from_bytes(bytes, byteorder="big") @@ -1523,11 +1509,10 @@ def length_prefixed_frame_reader( n = sock.recv_into(buffer[pos:]) if n == 0: # EOF - on_close(sock) return False pos += n - if len(buffer) >= length_needed: + if pos >= length_needed: request = decoder.decode(buffer) buffer = None pos = 0 @@ -1535,16 +1520,15 @@ def length_prefixed_frame_reader( try: gen.send(request) except StopIteration: - on_close(sock) return False return True - return cb + return cb, on_close def process_log_messages_from_subprocess( loggers: tuple[FilteringBoundLogger, ...], -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: from structlog.stdlib import NAME_TO_LEVEL while True: @@ -1580,10 +1564,9 @@ def process_log_messages_from_subprocess( def forward_to_log( target_loggers: tuple[FilteringBoundLogger, ...], chan: str, level: int -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: while True: - buf = yield - line = bytes(buf) + line = yield # Strip off new line line = line.rstrip() try:
