This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f585a8016bc Restructure WatchedSubprocess and CommsDecoder for reuse
in DagParsing (#44874)
f585a8016bc is described below
commit f585a8016bcf4622eb0e7076fde3c49e9ac50d18
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Dec 12 14:04:04 2024 +0000
Restructure WatchedSubprocess and CommsDecoder for reuse in DagParsing
(#44874)
The changes introduced here lets these existing classes serve "double" duty
in
the execution time of TaskSDK and also the Parse Time in the DAG Processor
(but the actual switch to use these will be a separate bigger PR).
There are a few warts here, namely:
- The default on `CommsDecoder`'s decoder argument is incorrect for
subclasses, we might fix that later to be more dynamic about the default.
- The location of this code is not right for reuse in TaskSDK/execution time
and parse time. There is a bigger bit of work being planned to move this
all
around before release of Airflow 3
- Some of the functions on the base WatchedSubprocess class are TI specific
and maybe should be on a separate subclass
---
.../src/airflow/sdk/execution_time/supervisor.py | 134 +++++++++++----------
.../src/airflow/sdk/execution_time/task_runner.py | 21 ++--
task_sdk/tests/execution_time/test_supervisor.py | 26 ++--
3 files changed, 96 insertions(+), 85 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 9b4933de808..cb5554681b3 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -27,7 +27,6 @@ import selectors
import signal
import sys
import time
-import weakref
from collections.abc import Generator
from contextlib import suppress
from datetime import datetime, timezone
@@ -64,6 +63,8 @@ from airflow.sdk.execution_time.comms import (
if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger, WrappedLogger
+ from airflow.typing_compat import Self
+
__all__ = ["WatchedSubprocess", "supervise"]
@@ -263,7 +264,7 @@ def _fork_main(
@attrs.define()
class WatchedSubprocess:
- ti_id: UUID
+ id: UUID
pid: int
stdin: BinaryIO
@@ -292,20 +293,16 @@ class WatchedSubprocess:
selector: selectors.BaseSelector =
attrs.field(factory=selectors.DefaultSelector)
- procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] =
weakref.WeakValueDictionary()
-
- def __attrs_post_init__(self):
- self.procs[self.pid] = self
-
@classmethod
def start(
cls,
path: str | os.PathLike[str],
- ti: TaskInstance,
+ what: TaskInstance,
client: Client,
target: Callable[[], None] = _subprocess_main,
logger: FilteringBoundLogger | None = None,
- ) -> WatchedSubprocess:
+ **constructor_kwargs,
+ ) -> Self:
"""Fork and start a new subprocess to execute the given task."""
# Create socketpairs/"pipes" to connect to the stdin and out from the
subprocess
child_stdin, feed_stdin = mkpipe(remote_read=True)
@@ -324,31 +321,27 @@ class WatchedSubprocess:
# around in the forked processes, especially things that might
involve open files or sockets!
del path
del client
- del ti
+ del what
del logger
# Run the child entrypoint
_fork_main(child_stdin, child_stdout, child_stderr,
child_logs.fileno(), target)
+ requests_fd = child_comms.fileno()
+
+ # Close the remaining parent-end of the sockets we've passed to the
child via fork. We still have the
+ # other end of the pair open
+ cls._close_unused_sockets(child_stdin, child_stdout, child_stderr,
child_comms, child_logs)
+
proc = cls(
- ti_id=ti.id,
+ id=constructor_kwargs.get("id") or getattr(what, "id"),
pid=pid,
stdin=feed_stdin,
process=psutil.Process(pid),
client=client,
+ **constructor_kwargs,
)
- # We've forked, but the task won't start until we send it the
StartupDetails message. But before we do
- # that, we need to tell the server it's started (so it has the chance
to tell us "no, stop!" for any
- # reason)
- try:
- client.task_instances.start(ti.id, pid,
datetime.now(tz=timezone.utc))
- proc._last_successful_heartbeat = time.monotonic()
- except Exception:
- # On any error kill that subprocess!
- proc.kill(signal.SIGKILL)
- raise
-
logger = logger or cast("FilteringBoundLogger",
structlog.get_logger(logger_name="task").bind())
proc._register_pipe_readers(
logger=logger,
@@ -359,11 +352,8 @@ class WatchedSubprocess:
)
# Tell the task process what it needs to do!
- proc._send_startup_message(ti, path, child_comms)
+ proc._on_child_started(what, path, requests_fd)
- # Close the remaining parent-end of the sockets we've passed to the
child via fork. We still have the
- # other end of the pair open
- proc._close_unused_sockets(child_stdin, child_stdout, child_stderr,
child_comms, child_logs)
return proc
def _register_pipe_readers(
@@ -401,12 +391,23 @@ class WatchedSubprocess:
for sock in sockets:
sock.close()
- def _send_startup_message(self, ti: TaskInstance, path: str |
os.PathLike[str], child_comms: socket):
+ def _on_child_started(self, ti: TaskInstance, path: str |
os.PathLike[str], requests_fd: int):
"""Send startup message to the subprocess."""
+ try:
+ # We've forked, but the task won't start doing anything until we
send it the StartupDetails
+ # message. But before we do that, we need to tell the server it's
started (so it has the chance to
+ # tell us "no, stop!" for any reason)
+ self.client.task_instances.start(ti.id, self.pid,
datetime.now(tz=timezone.utc))
+ self._last_successful_heartbeat = time.monotonic()
+ except Exception:
+ # On any error kill that subprocess!
+ self.kill(signal.SIGKILL)
+ raise
+
msg = StartupDetails.model_construct(
ti=ti,
- file=str(path),
- requests_fd=child_comms.fileno(),
+ file=os.fspath(path),
+ requests_fd=requests_fd,
)
# Send the message to tell the process what it needs to execute
@@ -490,7 +491,7 @@ class WatchedSubprocess:
# by the subprocess in the `handle_requests` method.
if self.final_state in TerminalTIState:
self.client.task_instances.finish(
- id=self.ti_id, state=self.final_state,
when=datetime.now(tz=timezone.utc)
+ id=self.id, state=self.final_state,
when=datetime.now(tz=timezone.utc)
)
return self._exit_code
@@ -525,9 +526,9 @@ class WatchedSubprocess:
# logs
self._send_heartbeat_if_needed()
- self._handle_task_overtime_if_needed()
+ self._handle_process_overtime_if_needed()
- def _handle_task_overtime_if_needed(self):
+ def _handle_process_overtime_if_needed(self):
"""Handle termination of auxiliary processes if the task exceeds the
configured overtime."""
# If the task has reached a terminal state, we can start monitoring
the overtime
if not self._terminal_state:
@@ -537,7 +538,7 @@ class WatchedSubprocess:
self._task_end_time_monotonic
and (time.monotonic() - self._task_end_time_monotonic) >
self.TASK_OVERTIME_THRESHOLD
):
- log.warning("Task success overtime reached; terminating process",
ti_id=self.ti_id)
+ log.warning("Workload success overtime reached; terminating
process", ti_id=self.id)
self.kill(signal.SIGTERM, force=True)
def _service_subprocess(self, max_wait_time: float, raise_on_timeout: bool
= False):
@@ -579,7 +580,7 @@ class WatchedSubprocess:
if self._exit_code is None:
try:
self._exit_code = self._process.wait(timeout=0)
- log.debug("Task process exited", exit_code=self._exit_code)
+ log.debug("Workload process exited", exit_code=self._exit_code)
except psutil.TimeoutExpired:
if raise_on_timeout:
raise
@@ -593,7 +594,7 @@ class WatchedSubprocess:
self._last_heartbeat_attempt = time.monotonic()
try:
- self.client.task_instances.heartbeat(self.ti_id,
pid=self._process.pid)
+ self.client.task_instances.heartbeat(self.id,
pid=self._process.pid)
# Update the last heartbeat time on success
self._last_successful_heartbeat = time.monotonic()
@@ -619,7 +620,7 @@ class WatchedSubprocess:
log.warning(
"Failed to send heartbeat. Will be retried",
failed_heartbeats=self.failed_heartbeats,
- ti_id=self.ti_id,
+ ti_id=self.id,
max_retries=MAX_FAILED_HEARTBEATS,
exc_info=True,
)
@@ -646,7 +647,7 @@ class WatchedSubprocess:
return TerminalTIState.FAILED
def __rich_repr__(self):
- yield "ti_id", self.ti_id
+ yield "id", self.id
yield "pid", self.pid
# only include this if it's not the default (third argument)
yield "exit_code", self._exit_code, None
@@ -654,7 +655,7 @@ class WatchedSubprocess:
__rich_repr__.angular = True # type: ignore[attr-defined]
def __repr__(self) -> str:
- rep = f"<WatchedSubprocess ti_id={self.ti_id} pid={self.pid}"
+ rep = f"<WatchedSubprocess id={self.id} pid={self.pid}"
if self._exit_code is not None:
rep += f" exit_code={self._exit_code}"
return rep + " >"
@@ -672,35 +673,38 @@ class WatchedSubprocess:
log.exception("Unable to decode message", line=line)
continue
+ self._handle_request(msg, log)
+
+ def _handle_request(self, msg, log):
+ resp = None
+ if isinstance(msg, TaskState):
+ self._terminal_state = msg.state
+ self._task_end_time_monotonic = time.monotonic()
+ elif isinstance(msg, GetConnection):
+ conn = self.client.connections.get(msg.conn_id)
+ resp = conn.model_dump_json(exclude_unset=True).encode()
+ elif isinstance(msg, GetVariable):
+ var = self.client.variables.get(msg.key)
+ resp = var.model_dump_json(exclude_unset=True).encode()
+ elif isinstance(msg, GetXCom):
+ xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.map_index)
+ resp = xcom.model_dump_json(exclude_unset=True).encode()
+ elif isinstance(msg, DeferTask):
+ self._terminal_state = IntermediateTIState.DEFERRED
+ self.client.task_instances.defer(self.id, msg)
resp = None
- if isinstance(msg, TaskState):
- self._terminal_state = msg.state
- self._task_end_time_monotonic = time.monotonic()
- elif isinstance(msg, GetConnection):
- conn = self.client.connections.get(msg.conn_id)
- resp = conn.model_dump_json(exclude_unset=True).encode()
- elif isinstance(msg, GetVariable):
- var = self.client.variables.get(msg.key)
- resp = var.model_dump_json(exclude_unset=True).encode()
- elif isinstance(msg, GetXCom):
- xcom = self.client.xcoms.get(msg.dag_id, msg.run_id,
msg.task_id, msg.key, msg.map_index)
- resp = xcom.model_dump_json(exclude_unset=True).encode()
- elif isinstance(msg, DeferTask):
- self._terminal_state = IntermediateTIState.DEFERRED
- self.client.task_instances.defer(self.ti_id, msg)
- resp = None
- elif isinstance(msg, SetXCom):
- self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.value, msg.map_index)
- resp = None
- elif isinstance(msg, PutVariable):
- self.client.variables.set(msg.key, msg.value, msg.description)
- resp = None
- else:
- log.error("Unhandled request", msg=msg)
- continue
+ elif isinstance(msg, SetXCom):
+ self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.value, msg.map_index)
+ resp = None
+ elif isinstance(msg, PutVariable):
+ self.client.variables.set(msg.key, msg.value, msg.description)
+ resp = None
+ else:
+ log.error("Unhandled request", msg=msg)
+ return
- if resp:
- self.stdin.write(resp + b"\n")
+ if resp:
+ self.stdin.write(resp + b"\n")
# Sockets, even the `.makefile()` function don't correctly do line buffering
on reading. If a chunk is read
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index db32398c7c5..d210e0011fe 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -23,11 +23,11 @@ import os
import sys
from datetime import datetime, timezone
from io import FileIO
-from typing import TYPE_CHECKING, TextIO
+from typing import TYPE_CHECKING, Generic, TextIO, TypeVar
import attrs
import structlog
-from pydantic import ConfigDict, TypeAdapter
+from pydantic import BaseModel, ConfigDict, TypeAdapter
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.baseoperator import BaseOperator
@@ -77,17 +77,24 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
return
RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True),
task=task)
+SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
+ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
+
+
@attrs.define()
-class CommsDecoder:
+class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
"""Handle communication between the task in this process and the
supervisor parent process."""
input: TextIO
request_socket: FileIO = attrs.field(init=False, default=None)
- decoder: TypeAdapter[ToTask] = attrs.field(init=False, factory=lambda:
TypeAdapter(ToTask))
+ # We could be "clever" here and set the default to this based type
parameters and a custom
+ # `__class_getitem__`, but that's a lot of code the one subclass we've got
currently. So we'll just use a
+ # "sort of wrong default"
+ decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda:
TypeAdapter(ToTask), repr=False)
- def get_message(self) -> ToTask:
+ def get_message(self) -> ReceiveMsgType:
"""
Get a message from the parent.
@@ -106,7 +113,7 @@ class CommsDecoder:
self.request_socket = os.fdopen(msg.requests_fd, "wb",
buffering=0)
return msg
- def send_request(self, log: Logger, msg: ToSupervisor):
+ def send_request(self, log: Logger, msg: SendMsgType):
encoded_msg = msg.model_dump_json().encode() + b"\n"
log.debug("Sending request", json=encoded_msg)
@@ -123,7 +130,7 @@ class CommsDecoder:
# deeply nested execution stack.
# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this
communication mechanism is readily
# accessible wherever needed during task execution without modifying every
layer of the call stack.
-SUPERVISOR_COMMS: CommsDecoder
+SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
# State machine!
# 1. Start up (receive details from supervisor)
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index e44b4942e13..406b2ee2699 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -98,7 +98,7 @@ class TestWatchedSubprocess:
proc = WatchedSubprocess.start(
path=os.devnull,
- ti=TaskInstance(
+ what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
@@ -165,7 +165,7 @@ class TestWatchedSubprocess:
proc = WatchedSubprocess.start(
path=os.devnull,
- ti=TaskInstance(
+ what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
@@ -188,7 +188,7 @@ class TestWatchedSubprocess:
proc = WatchedSubprocess.start(
path=os.devnull,
- ti=TaskInstance(
+ what=TaskInstance(
id=uuid7(),
task_id="b",
dag_id="c",
@@ -224,7 +224,7 @@ class TestWatchedSubprocess:
spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat)
proc = WatchedSubprocess.start(
path=os.devnull,
- ti=TaskInstance(
+ what=TaskInstance(
id=ti_id,
task_id="b",
dag_id="c",
@@ -335,7 +335,7 @@ class TestWatchedSubprocess:
client = make_client(transport=httpx.MockTransport(handle_request))
with pytest.raises(ServerResponseError, match="Server returned error")
as err:
- WatchedSubprocess.start(path=os.devnull, ti=ti, client=client)
+ WatchedSubprocess.start(path=os.devnull, what=ti, client=client)
assert err.value.response.status_code == 409
assert err.value.detail == {
@@ -388,7 +388,7 @@ class TestWatchedSubprocess:
proc = WatchedSubprocess.start(
path=os.devnull,
- ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d",
try_number=1),
+ what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d",
try_number=1),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
)
@@ -440,7 +440,7 @@ class TestWatchedSubprocess:
mock_kill =
mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")
proc = WatchedSubprocess(
- ti_id=TI_ID,
+ id=TI_ID,
pid=mock_process.pid,
stdin=mocker.MagicMock(),
client=client,
@@ -528,7 +528,7 @@ class TestWatchedSubprocess:
monkeypatch.setattr(WatchedSubprocess, "TASK_OVERTIME_THRESHOLD",
overtime_threshold)
mock_watched_subprocess = WatchedSubprocess(
- ti_id=TI_ID,
+ id=TI_ID,
pid=12345,
stdin=mocker.Mock(),
process=mocker.Mock(),
@@ -541,13 +541,13 @@ class TestWatchedSubprocess:
# Call `wait` to trigger the overtime handling
# This will call the `kill` method if the task has been running for
too long
- mock_watched_subprocess._handle_task_overtime_if_needed()
+ mock_watched_subprocess._handle_process_overtime_if_needed()
# Validate process kill behavior and log messages
if expected_kill:
mock_kill.assert_called_once_with(signal.SIGTERM, force=True)
mock_logger.warning.assert_called_once_with(
- "Task success overtime reached; terminating process",
+ "Workload success overtime reached; terminating process",
ti_id=TI_ID,
)
else:
@@ -565,7 +565,7 @@ class TestWatchedSubprocessKill:
@pytest.fixture
def watched_subprocess(self, mocker, mock_process):
proc = WatchedSubprocess(
- ti_id=TI_ID,
+ id=TI_ID,
pid=12345,
stdin=mocker.Mock(),
client=mocker.Mock(),
@@ -656,7 +656,7 @@ class TestWatchedSubprocessKill:
proc = WatchedSubprocess.start(
path=os.devnull,
- ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d",
try_number=1),
+ what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d",
try_number=1),
client=MagicMock(spec=sdk_client.Client),
target=subprocess_main,
)
@@ -746,7 +746,7 @@ class TestHandleRequest:
def watched_subprocess(self, mocker):
"""Fixture to provide a WatchedSubprocess instance."""
return WatchedSubprocess(
- ti_id=TI_ID,
+ id=TI_ID,
pid=12345,
stdin=BytesIO(),
client=mocker.Mock(),