kaxil commented on code in PR #43893: URL: https://github.com/apache/airflow/pull/43893#discussion_r1840537318
########## task_sdk/src/airflow/sdk/execution_time/supervisor.py: ########## @@ -0,0 +1,589 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Supervise and run Tasks in a subprocess.""" + +from __future__ import annotations + +import atexit +import io +import logging +import os +import selectors +import signal +import sys +import time +import weakref +from collections.abc import Generator +from contextlib import suppress +from datetime import datetime, timezone +from socket import socket, socketpair +from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, cast, overload +from uuid import UUID + +import attrs +import httpx +import msgspec +import psutil +import structlog + +from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import TaskInstanceState +from airflow.sdk.execution_time.comms import ConnectionResponse, GetConnection, StartupDetails, ToSupervisor + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger + + from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity + from airflow.sdk.api.datamodels.ti import TaskInstance + + +__all__ = ["WatchedSubprocess", "supervise"] + +log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor") + +# TODO: Pull this from config +SLOWEST_HEARTBEAT_INTERVAL: int = 30 +# Don't heartbeat more often than this +FASTEST_HEARTBEAT_INTERVAL: int = 5 + + +@overload +def mkpipe() -> tuple[socket, socket]: ... + + +@overload +def mkpipe(remote_read: Literal[True]) -> tuple[socket, BinaryIO]: ... + + +def mkpipe( + remote_read: bool = False, +) -> tuple[socket, socket | BinaryIO]: + """ + Create a pair of connected sockets. + + The inheritable flag will be set correctly so that the end destined for the subprocess is kept open but + the end for this process is closed automatically by the OS. + """ + rsock, wsock = socketpair() + local, remote = (wsock, rsock) if remote_read else (rsock, wsock) + + remote.set_inheritable(True) + local.setblocking(False) + + io: BinaryIO | socket + if remote_read: + # If _we_ are writing, we don't want to buffer + io = cast(BinaryIO, local.makefile("wb", buffering=0)) + else: + io = local + + return remote, io + + +def _subprocess_main(): + from airflow.sdk.execution_time.task_runner import main + + main() + + +def _reset_signals(): + # Uninstall the rich etc. exception handler + sys.excepthook = sys.__excepthook__ + signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGUSR2, signal.SIG_DFL) + + +def _configure_logs_over_json_channel(log_fd: int): + # A channel that the task can send JSON-formated logs over. + # + # JSON logs sent this way will be handled nicely + from airflow.sdk.log import configure_logging + + log_io = os.fdopen(log_fd, "wb", buffering=0) + configure_logging(enable_pretty_log=False, output=log_io) + + +def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): + if "PYTEST_CURRENT_TEST" in os.environ: + # When we are running in pytest, it's output capturing messes us up. This works around it + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + + # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the + # pipes form the supervisor + + for handle_name, sock, mode, close in ( + ("stdin", child_stdin, "r", True), + ("stdout", child_stdout, "w", True), + ("stderr", child_stderr, "w", False), + ): + handle = getattr(sys, handle_name) + try: + fd = handle.fileno() + os.dup2(sock.fileno(), fd) + if close: + handle.close() + except io.UnsupportedOperation: + if "PYTEST_CURRENT_TEST" in os.environ: + # When we're running under pytest, the stdin is not a real filehandle with an fd, so we need + # to handle that differently + fd = sock.fileno() + else: + raise + + setattr(sys, handle_name, os.fdopen(fd, mode)) + + +def _fork_main( + child_stdin: socket, + child_stdout: socket, + child_stderr: socket, + log_fd: int, + target: Callable[[], None], +) -> NoReturn: + """ + "Entrypoint" of the child process. + + Ultimately this process will be running the user's code in the operators ``execute()`` function. + + The responsibility of this function is to: + + - Reset any signals handlers we inherited from the parent process (so they don't fire twice - once in + parent, and once in child) + - Set up the out/err handles to the streams created in the parent (to capture stdout and stderr for + logging) + - Configure the loggers in the child (both stdlib logging and Structlog) to send JSON logs back to the + supervisor for processing/output. + - Catch un-handled exceptions and attempt to show _something_ in case of error + - Finally, run the actual task runner code (``target`` argument, defaults to ``.task_runner:main`) + """ + # TODO: Make this process a session leader + + # Store original stderr for last-chance exception handling + last_chance_stderr = sys.__stderr__ or sys.stderr + + _reset_signals() + if log_fd: + _configure_logs_over_json_channel(log_fd) + _reopen_std_io_handles(child_stdin, child_stdout, child_stderr) + + def exit(n: int) -> NoReturn: + with suppress(ValueError, OSError): + sys.stdout.flush() + with suppress(ValueError, OSError): + sys.stderr.flush() + with suppress(ValueError, OSError): + last_chance_stderr.flush() + os._exit(n) + + if hasattr(atexit, "_clear"): + # Since we're in a fork we want to try and clear them. If we can't do it cleanly, then we won't try + # and run new atexit handlers. + with suppress(Exception): + atexit._clear() + base_exit = exit + + def exit(n: int) -> NoReturn: + # This will only run any atexit funcs registered after we've forked. + atexit._run_exitfuncs() + base_exit(n) + + try: + target() + exit(0) + except SystemExit as e: + code = 1 + if isinstance(e.code, int): + code = e.code + elif e.code: + print(e.code, file=sys.stderr) + exit(code) + except Exception: + # Last ditch log attempt + exc, v, tb = sys.exc_info() + + import traceback + + try: + last_chance_stderr.write("--- Last chance exception handler ---\n") + traceback.print_exception(exc, value=v, tb=tb, file=last_chance_stderr) + # Exit code 126 and 125 don't have any "special" meaning, they are only meant to serve as an + # identifier that the task process died in a really odd way. + exit(126) + except Exception as e: + with suppress(Exception): + print( + f"--- Last chance exception handler failed --- {repr(str(e))}\n", file=last_chance_stderr + ) + exit(125) + + [email protected]() +class WatchedSubprocess: + ti_id: UUID + pid: int + + stdin: BinaryIO + stdout: socket + stderr: socket + + client: Client + + _process: psutil.Process + _exit_code: int | None = None + _terminal_state: str | None = None + + _last_heartbeat: float = 0 + + selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) + + procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary() + + def __attrs_post_init__(self): + self.procs[self.pid] = self + + @classmethod + def start( + cls, + path: str | os.PathLike[str], + ti: TaskInstance, + client: Client, + target: Callable[[], None] = _subprocess_main, + ) -> WatchedSubprocess: + """Fork and start a new subprocess to execute the given task.""" + # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess + child_stdin, feed_stdin = mkpipe(remote_read=True) + child_stdout, read_stdout = mkpipe() + child_stderr, read_stderr = mkpipe() + + # Open these socketpair before forking off the child, so that it is open when we fork. + child_comms, read_msgs = mkpipe() + child_logs, read_logs = mkpipe() + + pid = os.fork() + if pid == 0: + # Parent ends of the sockets are closed by the OS as they are set as non-inheritable + + # Run the child entryoint + _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + + proc = cls( + ti_id=ti.id, + pid=pid, + stdin=feed_stdin, + stdout=read_stdout, + stderr=read_stderr, + process=psutil.Process(pid), + client=client, + ) + + # We've forked, but the task won't start until we send it the StartupDetails message. But before we do + # that, we need to tell the server it's started (so it has the chance to tell us "no, stop!" for any + # reason) + try: + client.task_instances.start(ti.id, pid, datetime.now(tz=timezone.utc)) + proc._last_heartbeat = time.monotonic() + except Exception: + # On any error kill that subprocess! + proc.kill(signal.SIGKILL) + raise + + # TODO: Use logging providers to handle the chunked upload for us + task_logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind() + + # proc.selector is a way of registering a handler/callback to be called when the given IO channel has + # activity to read on (https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better + # alternatives are used automatically) -- this is a way of having "event-based" code, but without + # needing full async, to read and process output from each socket as it is received. + + cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stdout"), level=logging.INFO)) + proc.selector.register(read_stdout, selectors.EVENT_READ, cb) + + cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stderr"), level=logging.ERROR)) + proc.selector.register(read_stderr, selectors.EVENT_READ, cb) + + proc.selector.register( + read_logs, + selectors.EVENT_READ, + make_buffered_socket_reader(process_log_messages_from_subprocess(task_logger)), + ) + proc.selector.register( + read_msgs, + selectors.EVENT_READ, + make_buffered_socket_reader(proc.handle_requests(log=log)), + ) + + # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the + # other end of the pair open + child_stdout.close() + child_stdin.close() + child_comms.close() + child_logs.close() + + # Tell the task process what it needs to do! + msg = StartupDetails( + ti=ti, + file=str(path), + requests_fd=child_comms.fileno(), + ) + + # Send the message to tell the process what it needs to execute + log.debug("Sending", msg=msg) + feed_stdin.write(msgspec.json.encode(msg)) + feed_stdin.write(b"\n") + + return proc + + def kill(self, signal: signal.Signals = signal.SIGINT): + if self._exit_code is not None: + return + + with suppress(ProcessLookupError): + os.kill(self.pid, signal) + + def wait(self) -> int: + if self._exit_code is not None: + return self._exit_code + + # Until we have a selector for the process, don't poll for more than 10s, just in case it exists but + # doesn't produce any output + max_poll_interval = 10 + + try: + while self._exit_code is None or len(self.selector.get_map()): + last_heartbeat_ago = time.monotonic() - self._last_heartbeat + # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible + # so we notice the subprocess finishing as quick as we can. + max_wait_time = max( + 0, # Make sure this value is never negative, + min( + # Ensure we heartbeat _at most_ 75% through time the zombie threshold time + SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75, + max_poll_interval, + ), + ) + events = self.selector.select(timeout=max_wait_time) + for key, _ in events: + socket_handler = key.data + need_more = socket_handler(key.fileobj) + + if not need_more: + self.selector.unregister(key.fileobj) + key.fileobj.close() # type: ignore[union-attr] + + if self._exit_code is None: + try: + self._exit_code = self._process.wait(timeout=0) + log.debug("Task process exited", exit_code=self._exit_code) + except psutil.TimeoutExpired: + pass + + if last_heartbeat_ago < FASTEST_HEARTBEAT_INTERVAL: + # Avoid heartbeating too frequently + continue + + try: + self.client.task_instances.heartbeat(self.ti_id) + self._last_heartbeat = time.monotonic() + except Exception: + log.warning("Couldn't heartbeat", exc_info=True) + # TODO: If we couldn't heartbeat for X times the interval, kill ourselves + pass + finally: + self.selector.close() + + self.client.task_instances.finish( + id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc) + ) + return self._exit_code + + @property + def final_state(self): + """ + The final state of the TaskInstance. + + By default this will be derived from the exit code of the task + (0=success, failed otherwise) but can be changed by the subprocess + sending a TaskState message, as long as the process exits with 0 + + Not valid before the process has finished. + """ + if self._exit_code == 0: + return self._terminal_state or TaskInstanceState.SUCCESS + return TaskInstanceState.FAILED + + def __rich_repr__(self): + yield "pid", self.pid + yield "exit_code", self._exit_code, None + + __rich_repr__.angular = True # type: ignore[attr-defined] + + def __repr__(self) -> str: + rep = f"<WatchedSubprocess pid={self.pid}" + if self._exit_code is not None: + rep += f" exit_code={self._exit_code}" + return rep + " >" + + def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]: + decoder: msgspec.json.Decoder[ToSupervisor] = msgspec.json.Decoder(type=ToSupervisor) + encoder = msgspec.json.Encoder() + # Use a buffer to avoid small allocations + buffer = bytearray(64) + while True: + line = yield + + try: + msg = decoder.decode(line) + except Exception: + log.exception("Unable to decode message", line=line) + continue + + # if isinstnace(msg, TaskState): + # self._terminal_state = msg.state + # elif isinstance(msg, ReadXCom): + # resp = XComResponse(key="secret", value=True) + # encoder.encode_into(resp, buffer) + # self.stdin.write(buffer + b"\n") + if isinstance(msg, GetConnection): + conn = self.client.connections.get(msg.id) + resp = ConnectionResponse(conn=conn) + encoder.encode_into(resp, buffer) + else: + log.error("Unhandled request", msg=msg) + continue + + buffer.extend(b"\n") + self.stdin.write(buffer) + + # Ensure the buffer doesn't grow and stay large if a large payload is used. This won't grow it + # larger than it is, but it will shrink it + if len(buffer) > 1024: + buffer = buffer[:1024] + + +# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read +# and it doesn't contain a new line character, `.readline()` will just return the chunk as is. +# +# This returns a cb suitable for attaching to a `selector` that reads in to a buffer, and yields lines to a Review Comment: Yes, based on the code -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
