This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5cc2c8c1122 Support for "reconnecting" Supervisor Comms from task
process when `dag.test()` is used (#58147)
5cc2c8c1122 is described below
commit 5cc2c8c1122ff280338f31a0a0b593868517100c
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Nov 13 10:26:37 2025 +0000
Support for "reconnecting" Supervisor Comms from task process when
`dag.test()` is used (#58147)
This is a follow up to #57212, which worked fine "at run time" but did not
work in many of our own unit tests, which rely on `dag.test` or `ti.run`.
The way this is implemented is that when we use the InProcessTestSupervisor
we
pre-emptively create a socket pair. We have to create it even it its not
being
used, as we can't know.
And since this is all in one process we create a thread to handle the socket
comms. Since this is only ever for tests performance or hitting the GIL
doesn't matter.
---
.../src/airflow/sdk/execution_time/supervisor.py | 58 +++++++++++++++++++---
.../src/airflow/sdk/execution_time/task_runner.py | 6 ++-
2 files changed, 55 insertions(+), 9 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 7211423807e..748cd181878 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -27,6 +27,7 @@ import os
import selectors
import signal
import sys
+import threading
import time
import weakref
from collections import deque
@@ -1495,6 +1496,44 @@ class InProcessTestSupervisor(ActivitySubprocess):
# Bypass the tenacity retries!
return super().request.__wrapped__(self, *args, **kwargs) # type:
ignore[attr-defined]
+ def _check_subprocess_exit(
+ self, raise_on_timeout: bool = False, expect_signal: None | int = None
+ ) -> int | None:
+ # InProcessSupervisor has no subprocess, so we don't need to poll
anything. This is called from
+ # _handle_socket_comms, so we need to override it
+ return None
+
+ def _handle_socket_comms(self):
+ while self._open_sockets:
+ self._service_subprocess(1.0)
+
+ @contextlib.contextmanager
+ def _setup_subprocess_socket(self):
+ thread = threading.Thread(target=self._handle_socket_comms,
daemon=True)
+
+ requests, child_sock = socketpair()
+
+ self._open_sockets[requests] = "requests"
+ self.stdin = requests
+
+ self.selector.register(
+ requests,
+ selectors.EVENT_READ,
+ length_prefixed_frame_reader(self.handle_requests(log),
on_close=self._on_socket_closed),
+ )
+ os.set_inheritable(child_sock.fileno(), True)
+ os.environ["__AIRFLOW_SUPERVISOR_FD"] = str(child_sock.fileno())
+
+ try:
+ thread.start()
+ yield child_sock
+ finally:
+ requests.close()
+ child_sock.close()
+ self._on_socket_closed(requests)
+ thread.join(0)
+ os.environ.pop("__AIRFLOW_SUPERVISOR_FD", None)
+
@classmethod
def start( # type: ignore[override]
cls,
@@ -1547,16 +1586,19 @@ class InProcessTestSupervisor(ActivitySubprocess):
start_date=start_date,
state=TaskInstanceState.RUNNING,
)
- context = ti.get_template_context()
- log = structlog.get_logger(logger_name="task")
- state, msg, error = run(ti, context, log)
- finalize(ti, state, context, log, error)
+ # Create a socketpair preemptively, in case the task process runs
VirtualEnv operator or run_as_user
+ with supervisor._setup_subprocess_socket():
+ context = ti.get_template_context()
+ log = structlog.get_logger(logger_name="task")
+
+ state, msg, error = run(ti, context, log)
+ finalize(ti, state, context, log, error)
- # In the normal subprocess model, the task runner calls this
before exiting.
- # Since we're running in-process, we manually notify the API
server that
- # the task has finished—unless the terminal state was already sent
explicitly.
- supervisor.update_task_state_if_needed()
+ # In the normal subprocess model, the task runner calls this
before exiting.
+ # Since we're running in-process, we manually notify the API
server that
+ # the task has finished—unless the terminal state was already
sent explicitly.
+ supervisor.update_task_state_if_needed()
return TaskRunResult(ti=ti, state=state, msg=msg, error=error)
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 8ca403f3c96..26c1e5ee762 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -1491,11 +1491,15 @@ def reinit_supervisor_comms() -> None:
run_as_user, or from inside the python code in a virtualenv (et al.)
operator to re-connect so those tasks
can continue to access variables etc.
"""
+ import socket
+
if "SUPERVISOR_COMMS" not in globals():
global SUPERVISOR_COMMS
log = structlog.get_logger(logger_name="task")
- SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
+ fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0"))
+
+ SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log,
socket=socket.socket(fileno=fd))
logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
if isinstance(logs, SentFDs):