This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f585a8016bc Restructure WatchedSubprocess and CommsDecoder for reuse 
in DagParsing (#44874)
f585a8016bc is described below

commit f585a8016bcf4622eb0e7076fde3c49e9ac50d18
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Dec 12 14:04:04 2024 +0000

    Restructure WatchedSubprocess and CommsDecoder for reuse in DagParsing 
(#44874)
    
    The changes introduced here lets these existing classes serve "double" duty 
in
    the execution time of TaskSDK and also the Parse Time in the DAG Processor
    (but the actual switch to use these will be a separate bigger PR).
    
    There are a few warts here, namely:
    
    - The default on `CommsDecoder`'s decoder argument is incorrect for
      subclasses, we might fix that later to be more dynamic about the default.
    - The location of this code is not right for reuse in TaskSDK/execution time
      and parse time. There is a bigger bit of work being planned to move this 
all
      around before release of Airflow 3
    - Some of the functions on the base WatchedSubprocess class are TI specific
      and maybe should be on a separate subclass
---
 .../src/airflow/sdk/execution_time/supervisor.py   | 134 +++++++++++----------
 .../src/airflow/sdk/execution_time/task_runner.py  |  21 ++--
 task_sdk/tests/execution_time/test_supervisor.py   |  26 ++--
 3 files changed, 96 insertions(+), 85 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 9b4933de808..cb5554681b3 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -27,7 +27,6 @@ import selectors
 import signal
 import sys
 import time
-import weakref
 from collections.abc import Generator
 from contextlib import suppress
 from datetime import datetime, timezone
@@ -64,6 +63,8 @@ from airflow.sdk.execution_time.comms import (
 if TYPE_CHECKING:
     from structlog.typing import FilteringBoundLogger, WrappedLogger
 
+    from airflow.typing_compat import Self
+
 
 __all__ = ["WatchedSubprocess", "supervise"]
 
@@ -263,7 +264,7 @@ def _fork_main(
 
 @attrs.define()
 class WatchedSubprocess:
-    ti_id: UUID
+    id: UUID
     pid: int
 
     stdin: BinaryIO
@@ -292,20 +293,16 @@ class WatchedSubprocess:
 
     selector: selectors.BaseSelector = 
attrs.field(factory=selectors.DefaultSelector)
 
-    procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = 
weakref.WeakValueDictionary()
-
-    def __attrs_post_init__(self):
-        self.procs[self.pid] = self
-
     @classmethod
     def start(
         cls,
         path: str | os.PathLike[str],
-        ti: TaskInstance,
+        what: TaskInstance,
         client: Client,
         target: Callable[[], None] = _subprocess_main,
         logger: FilteringBoundLogger | None = None,
-    ) -> WatchedSubprocess:
+        **constructor_kwargs,
+    ) -> Self:
         """Fork and start a new subprocess to execute the given task."""
         # Create socketpairs/"pipes" to connect to the stdin and out from the 
subprocess
         child_stdin, feed_stdin = mkpipe(remote_read=True)
@@ -324,31 +321,27 @@ class WatchedSubprocess:
             # around in the forked processes, especially things that might 
involve open files or sockets!
             del path
             del client
-            del ti
+            del what
             del logger
 
             # Run the child entrypoint
             _fork_main(child_stdin, child_stdout, child_stderr, 
child_logs.fileno(), target)
 
+        requests_fd = child_comms.fileno()
+
+        # Close the remaining parent-end of the sockets we've passed to the 
child via fork. We still have the
+        # other end of the pair open
+        cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, 
child_comms, child_logs)
+
         proc = cls(
-            ti_id=ti.id,
+            id=constructor_kwargs.get("id") or getattr(what, "id"),
             pid=pid,
             stdin=feed_stdin,
             process=psutil.Process(pid),
             client=client,
+            **constructor_kwargs,
         )
 
-        # We've forked, but the task won't start until we send it the 
StartupDetails message. But before we do
-        # that, we need to tell the server it's started (so it has the chance 
to tell us "no, stop!" for any
-        # reason)
-        try:
-            client.task_instances.start(ti.id, pid, 
datetime.now(tz=timezone.utc))
-            proc._last_successful_heartbeat = time.monotonic()
-        except Exception:
-            # On any error kill that subprocess!
-            proc.kill(signal.SIGKILL)
-            raise
-
         logger = logger or cast("FilteringBoundLogger", 
structlog.get_logger(logger_name="task").bind())
         proc._register_pipe_readers(
             logger=logger,
@@ -359,11 +352,8 @@ class WatchedSubprocess:
         )
 
         # Tell the task process what it needs to do!
-        proc._send_startup_message(ti, path, child_comms)
+        proc._on_child_started(what, path, requests_fd)
 
-        # Close the remaining parent-end of the sockets we've passed to the 
child via fork. We still have the
-        # other end of the pair open
-        proc._close_unused_sockets(child_stdin, child_stdout, child_stderr, 
child_comms, child_logs)
         return proc
 
     def _register_pipe_readers(
@@ -401,12 +391,23 @@ class WatchedSubprocess:
         for sock in sockets:
             sock.close()
 
-    def _send_startup_message(self, ti: TaskInstance, path: str | 
os.PathLike[str], child_comms: socket):
+    def _on_child_started(self, ti: TaskInstance, path: str | 
os.PathLike[str], requests_fd: int):
         """Send startup message to the subprocess."""
+        try:
+            # We've forked, but the task won't start doing anything until we 
send it the StartupDetails
+            # message. But before we do that, we need to tell the server it's 
started (so it has the chance to
+            # tell us "no, stop!" for any reason)
+            self.client.task_instances.start(ti.id, self.pid, 
datetime.now(tz=timezone.utc))
+            self._last_successful_heartbeat = time.monotonic()
+        except Exception:
+            # On any error kill that subprocess!
+            self.kill(signal.SIGKILL)
+            raise
+
         msg = StartupDetails.model_construct(
             ti=ti,
-            file=str(path),
-            requests_fd=child_comms.fileno(),
+            file=os.fspath(path),
+            requests_fd=requests_fd,
         )
 
         # Send the message to tell the process what it needs to execute
@@ -490,7 +491,7 @@ class WatchedSubprocess:
         # by the subprocess in the `handle_requests` method.
         if self.final_state in TerminalTIState:
             self.client.task_instances.finish(
-                id=self.ti_id, state=self.final_state, 
when=datetime.now(tz=timezone.utc)
+                id=self.id, state=self.final_state, 
when=datetime.now(tz=timezone.utc)
             )
         return self._exit_code
 
@@ -525,9 +526,9 @@ class WatchedSubprocess:
                 # logs
                 self._send_heartbeat_if_needed()
 
-                self._handle_task_overtime_if_needed()
+                self._handle_process_overtime_if_needed()
 
-    def _handle_task_overtime_if_needed(self):
+    def _handle_process_overtime_if_needed(self):
         """Handle termination of auxiliary processes if the task exceeds the 
configured overtime."""
         # If the task has reached a terminal state, we can start monitoring 
the overtime
         if not self._terminal_state:
@@ -537,7 +538,7 @@ class WatchedSubprocess:
             self._task_end_time_monotonic
             and (time.monotonic() - self._task_end_time_monotonic) > 
self.TASK_OVERTIME_THRESHOLD
         ):
-            log.warning("Task success overtime reached; terminating process", 
ti_id=self.ti_id)
+            log.warning("Workload success overtime reached; terminating 
process", ti_id=self.id)
             self.kill(signal.SIGTERM, force=True)
 
     def _service_subprocess(self, max_wait_time: float, raise_on_timeout: bool 
= False):
@@ -579,7 +580,7 @@ class WatchedSubprocess:
         if self._exit_code is None:
             try:
                 self._exit_code = self._process.wait(timeout=0)
-                log.debug("Task process exited", exit_code=self._exit_code)
+                log.debug("Workload process exited", exit_code=self._exit_code)
             except psutil.TimeoutExpired:
                 if raise_on_timeout:
                     raise
@@ -593,7 +594,7 @@ class WatchedSubprocess:
 
         self._last_heartbeat_attempt = time.monotonic()
         try:
-            self.client.task_instances.heartbeat(self.ti_id, 
pid=self._process.pid)
+            self.client.task_instances.heartbeat(self.id, 
pid=self._process.pid)
             # Update the last heartbeat time on success
             self._last_successful_heartbeat = time.monotonic()
 
@@ -619,7 +620,7 @@ class WatchedSubprocess:
         log.warning(
             "Failed to send heartbeat. Will be retried",
             failed_heartbeats=self.failed_heartbeats,
-            ti_id=self.ti_id,
+            ti_id=self.id,
             max_retries=MAX_FAILED_HEARTBEATS,
             exc_info=True,
         )
@@ -646,7 +647,7 @@ class WatchedSubprocess:
         return TerminalTIState.FAILED
 
     def __rich_repr__(self):
-        yield "ti_id", self.ti_id
+        yield "id", self.id
         yield "pid", self.pid
         # only include this if it's not the default (third argument)
         yield "exit_code", self._exit_code, None
@@ -654,7 +655,7 @@ class WatchedSubprocess:
     __rich_repr__.angular = True  # type: ignore[attr-defined]
 
     def __repr__(self) -> str:
-        rep = f"<WatchedSubprocess ti_id={self.ti_id} pid={self.pid}"
+        rep = f"<WatchedSubprocess id={self.id} pid={self.pid}"
         if self._exit_code is not None:
             rep += f" exit_code={self._exit_code}"
         return rep + " >"
@@ -672,35 +673,38 @@ class WatchedSubprocess:
                 log.exception("Unable to decode message", line=line)
                 continue
 
+            self._handle_request(msg, log)
+
+    def _handle_request(self, msg, log):
+        resp = None
+        if isinstance(msg, TaskState):
+            self._terminal_state = msg.state
+            self._task_end_time_monotonic = time.monotonic()
+        elif isinstance(msg, GetConnection):
+            conn = self.client.connections.get(msg.conn_id)
+            resp = conn.model_dump_json(exclude_unset=True).encode()
+        elif isinstance(msg, GetVariable):
+            var = self.client.variables.get(msg.key)
+            resp = var.model_dump_json(exclude_unset=True).encode()
+        elif isinstance(msg, GetXCom):
+            xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.map_index)
+            resp = xcom.model_dump_json(exclude_unset=True).encode()
+        elif isinstance(msg, DeferTask):
+            self._terminal_state = IntermediateTIState.DEFERRED
+            self.client.task_instances.defer(self.id, msg)
             resp = None
-            if isinstance(msg, TaskState):
-                self._terminal_state = msg.state
-                self._task_end_time_monotonic = time.monotonic()
-            elif isinstance(msg, GetConnection):
-                conn = self.client.connections.get(msg.conn_id)
-                resp = conn.model_dump_json(exclude_unset=True).encode()
-            elif isinstance(msg, GetVariable):
-                var = self.client.variables.get(msg.key)
-                resp = var.model_dump_json(exclude_unset=True).encode()
-            elif isinstance(msg, GetXCom):
-                xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, 
msg.task_id, msg.key, msg.map_index)
-                resp = xcom.model_dump_json(exclude_unset=True).encode()
-            elif isinstance(msg, DeferTask):
-                self._terminal_state = IntermediateTIState.DEFERRED
-                self.client.task_instances.defer(self.ti_id, msg)
-                resp = None
-            elif isinstance(msg, SetXCom):
-                self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.value, msg.map_index)
-                resp = None
-            elif isinstance(msg, PutVariable):
-                self.client.variables.set(msg.key, msg.value, msg.description)
-                resp = None
-            else:
-                log.error("Unhandled request", msg=msg)
-                continue
+        elif isinstance(msg, SetXCom):
+            self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.value, msg.map_index)
+            resp = None
+        elif isinstance(msg, PutVariable):
+            self.client.variables.set(msg.key, msg.value, msg.description)
+            resp = None
+        else:
+            log.error("Unhandled request", msg=msg)
+            return
 
-            if resp:
-                self.stdin.write(resp + b"\n")
+        if resp:
+            self.stdin.write(resp + b"\n")
 
 
 # Sockets, even the `.makefile()` function don't correctly do line buffering 
on reading. If a chunk is read
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index db32398c7c5..d210e0011fe 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -23,11 +23,11 @@ import os
 import sys
 from datetime import datetime, timezone
 from io import FileIO
-from typing import TYPE_CHECKING, TextIO
+from typing import TYPE_CHECKING, Generic, TextIO, TypeVar
 
 import attrs
 import structlog
-from pydantic import ConfigDict, TypeAdapter
+from pydantic import BaseModel, ConfigDict, TypeAdapter
 
 from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
 from airflow.sdk.definitions.baseoperator import BaseOperator
@@ -77,17 +77,24 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
     return 
RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), 
task=task)
 
 
+SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
+ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
+
+
 @attrs.define()
-class CommsDecoder:
+class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
     """Handle communication between the task in this process and the 
supervisor parent process."""
 
     input: TextIO
 
     request_socket: FileIO = attrs.field(init=False, default=None)
 
-    decoder: TypeAdapter[ToTask] = attrs.field(init=False, factory=lambda: 
TypeAdapter(ToTask))
+    # We could be "clever" here and set the default to this based type 
parameters and a custom
+    # `__class_getitem__`, but that's a lot of code the one subclass we've got 
currently. So we'll just use a
+    # "sort of wrong default"
+    decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: 
TypeAdapter(ToTask), repr=False)
 
-    def get_message(self) -> ToTask:
+    def get_message(self) -> ReceiveMsgType:
         """
         Get a message from the parent.
 
@@ -106,7 +113,7 @@ class CommsDecoder:
                 self.request_socket = os.fdopen(msg.requests_fd, "wb", 
buffering=0)
         return msg
 
-    def send_request(self, log: Logger, msg: ToSupervisor):
+    def send_request(self, log: Logger, msg: SendMsgType):
         encoded_msg = msg.model_dump_json().encode() + b"\n"
 
         log.debug("Sending request", json=encoded_msg)
@@ -123,7 +130,7 @@ class CommsDecoder:
 #   deeply nested execution stack.
 # - By defining `SUPERVISOR_COMMS` as a global, it ensures that this 
communication mechanism is readily
 #   accessible wherever needed during task execution without modifying every 
layer of the call stack.
-SUPERVISOR_COMMS: CommsDecoder
+SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
 
 # State machine!
 # 1. Start up (receive details from supervisor)
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index e44b4942e13..406b2ee2699 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -98,7 +98,7 @@ class TestWatchedSubprocess:
 
         proc = WatchedSubprocess.start(
             path=os.devnull,
-            ti=TaskInstance(
+            what=TaskInstance(
                 id="4d828a62-a417-4936-a7a6-2b3fabacecab",
                 task_id="b",
                 dag_id="c",
@@ -165,7 +165,7 @@ class TestWatchedSubprocess:
 
         proc = WatchedSubprocess.start(
             path=os.devnull,
-            ti=TaskInstance(
+            what=TaskInstance(
                 id="4d828a62-a417-4936-a7a6-2b3fabacecab",
                 task_id="b",
                 dag_id="c",
@@ -188,7 +188,7 @@ class TestWatchedSubprocess:
 
         proc = WatchedSubprocess.start(
             path=os.devnull,
-            ti=TaskInstance(
+            what=TaskInstance(
                 id=uuid7(),
                 task_id="b",
                 dag_id="c",
@@ -224,7 +224,7 @@ class TestWatchedSubprocess:
         spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat)
         proc = WatchedSubprocess.start(
             path=os.devnull,
-            ti=TaskInstance(
+            what=TaskInstance(
                 id=ti_id,
                 task_id="b",
                 dag_id="c",
@@ -335,7 +335,7 @@ class TestWatchedSubprocess:
         client = make_client(transport=httpx.MockTransport(handle_request))
 
         with pytest.raises(ServerResponseError, match="Server returned error") 
as err:
-            WatchedSubprocess.start(path=os.devnull, ti=ti, client=client)
+            WatchedSubprocess.start(path=os.devnull, what=ti, client=client)
 
         assert err.value.response.status_code == 409
         assert err.value.detail == {
@@ -388,7 +388,7 @@ class TestWatchedSubprocess:
 
         proc = WatchedSubprocess.start(
             path=os.devnull,
-            ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1),
+            what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1),
             client=make_client(transport=httpx.MockTransport(handle_request)),
             target=subprocess_main,
         )
@@ -440,7 +440,7 @@ class TestWatchedSubprocess:
         mock_kill = 
mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")
 
         proc = WatchedSubprocess(
-            ti_id=TI_ID,
+            id=TI_ID,
             pid=mock_process.pid,
             stdin=mocker.MagicMock(),
             client=client,
@@ -528,7 +528,7 @@ class TestWatchedSubprocess:
         monkeypatch.setattr(WatchedSubprocess, "TASK_OVERTIME_THRESHOLD", 
overtime_threshold)
 
         mock_watched_subprocess = WatchedSubprocess(
-            ti_id=TI_ID,
+            id=TI_ID,
             pid=12345,
             stdin=mocker.Mock(),
             process=mocker.Mock(),
@@ -541,13 +541,13 @@ class TestWatchedSubprocess:
 
         # Call `wait` to trigger the overtime handling
         # This will call the `kill` method if the task has been running for 
too long
-        mock_watched_subprocess._handle_task_overtime_if_needed()
+        mock_watched_subprocess._handle_process_overtime_if_needed()
 
         # Validate process kill behavior and log messages
         if expected_kill:
             mock_kill.assert_called_once_with(signal.SIGTERM, force=True)
             mock_logger.warning.assert_called_once_with(
-                "Task success overtime reached; terminating process",
+                "Workload success overtime reached; terminating process",
                 ti_id=TI_ID,
             )
         else:
@@ -565,7 +565,7 @@ class TestWatchedSubprocessKill:
     @pytest.fixture
     def watched_subprocess(self, mocker, mock_process):
         proc = WatchedSubprocess(
-            ti_id=TI_ID,
+            id=TI_ID,
             pid=12345,
             stdin=mocker.Mock(),
             client=mocker.Mock(),
@@ -656,7 +656,7 @@ class TestWatchedSubprocessKill:
 
         proc = WatchedSubprocess.start(
             path=os.devnull,
-            ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1),
+            what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1),
             client=MagicMock(spec=sdk_client.Client),
             target=subprocess_main,
         )
@@ -746,7 +746,7 @@ class TestHandleRequest:
     def watched_subprocess(self, mocker):
         """Fixture to provide a WatchedSubprocess instance."""
         return WatchedSubprocess(
-            ti_id=TI_ID,
+            id=TI_ID,
             pid=12345,
             stdin=BytesIO(),
             client=mocker.Mock(),

Reply via email to