This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5cc2c8c1122 Support for "reconnecting" Supervisor Comms from task 
process when `dag.test()` is used (#58147)
5cc2c8c1122 is described below

commit 5cc2c8c1122ff280338f31a0a0b593868517100c
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Nov 13 10:26:37 2025 +0000

    Support for "reconnecting" Supervisor Comms from task process when 
`dag.test()` is used (#58147)
    
    This is a follow up to #57212, which worked fine "at run time" but did not
    work in many of our own unit tests, which rely on `dag.test` or `ti.run`.
    
    The way this is implemented is that when we use the InProcessTestSupervisor 
we
    pre-emptively create a socket pair. We have to create it even it its not 
being
    used, as we can't know.
    
    And since this is all in one process we create a thread to handle the socket
    comms. Since this is only ever for tests performance or hitting the GIL
    doesn't matter.
---
 .../src/airflow/sdk/execution_time/supervisor.py   | 58 +++++++++++++++++++---
 .../src/airflow/sdk/execution_time/task_runner.py  |  6 ++-
 2 files changed, 55 insertions(+), 9 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 7211423807e..748cd181878 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -27,6 +27,7 @@ import os
 import selectors
 import signal
 import sys
+import threading
 import time
 import weakref
 from collections import deque
@@ -1495,6 +1496,44 @@ class InProcessTestSupervisor(ActivitySubprocess):
             # Bypass the tenacity retries!
             return super().request.__wrapped__(self, *args, **kwargs)  # type: 
ignore[attr-defined]
 
+    def _check_subprocess_exit(
+        self, raise_on_timeout: bool = False, expect_signal: None | int = None
+    ) -> int | None:
+        # InProcessSupervisor has no subprocess, so we don't need to poll 
anything. This is called from
+        # _handle_socket_comms, so we need to override it
+        return None
+
+    def _handle_socket_comms(self):
+        while self._open_sockets:
+            self._service_subprocess(1.0)
+
+    @contextlib.contextmanager
+    def _setup_subprocess_socket(self):
+        thread = threading.Thread(target=self._handle_socket_comms, 
daemon=True)
+
+        requests, child_sock = socketpair()
+
+        self._open_sockets[requests] = "requests"
+        self.stdin = requests
+
+        self.selector.register(
+            requests,
+            selectors.EVENT_READ,
+            length_prefixed_frame_reader(self.handle_requests(log), 
on_close=self._on_socket_closed),
+        )
+        os.set_inheritable(child_sock.fileno(), True)
+        os.environ["__AIRFLOW_SUPERVISOR_FD"] = str(child_sock.fileno())
+
+        try:
+            thread.start()
+            yield child_sock
+        finally:
+            requests.close()
+            child_sock.close()
+            self._on_socket_closed(requests)
+            thread.join(0)
+            os.environ.pop("__AIRFLOW_SUPERVISOR_FD", None)
+
     @classmethod
     def start(  # type: ignore[override]
         cls,
@@ -1547,16 +1586,19 @@ class InProcessTestSupervisor(ActivitySubprocess):
                 start_date=start_date,
                 state=TaskInstanceState.RUNNING,
             )
-            context = ti.get_template_context()
-            log = structlog.get_logger(logger_name="task")
 
-            state, msg, error = run(ti, context, log)
-            finalize(ti, state, context, log, error)
+            # Create a socketpair preemptively, in case the task process runs 
VirtualEnv operator or run_as_user
+            with supervisor._setup_subprocess_socket():
+                context = ti.get_template_context()
+                log = structlog.get_logger(logger_name="task")
+
+                state, msg, error = run(ti, context, log)
+                finalize(ti, state, context, log, error)
 
-            # In the normal subprocess model, the task runner calls this 
before exiting.
-            # Since we're running in-process, we manually notify the API 
server that
-            # the task has finished—unless the terminal state was already sent 
explicitly.
-            supervisor.update_task_state_if_needed()
+                # In the normal subprocess model, the task runner calls this 
before exiting.
+                # Since we're running in-process, we manually notify the API 
server that
+                # the task has finished—unless the terminal state was already 
sent explicitly.
+                supervisor.update_task_state_if_needed()
 
         return TaskRunResult(ti=ti, state=state, msg=msg, error=error)
 
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 8ca403f3c96..26c1e5ee762 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -1491,11 +1491,15 @@ def reinit_supervisor_comms() -> None:
     run_as_user, or from inside the python code in a virtualenv (et al.) 
operator to re-connect so those tasks
     can continue to access variables etc.
     """
+    import socket
+
     if "SUPERVISOR_COMMS" not in globals():
         global SUPERVISOR_COMMS
         log = structlog.get_logger(logger_name="task")
 
-        SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
+        fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0"))
+
+        SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, 
socket=socket.socket(fileno=fd))
 
     logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
     if isinstance(logs, SentFDs):

Reply via email to