This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 16206ef559e Correctly ensure that we give subprocesses time to exit
after signalling it (#44766)
16206ef559e is described below
commit 16206ef559efd72c8be56fcbb0f1b1f0d40159c1
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Sat Dec 7 16:41:10 2024 +0000
Correctly ensure that we give subprocesses time to exit after signalling it
(#44766)
We had a bug hidden in our tests by our use of mocks -- if the subprocess
returned any output, then `self.selector.select()` would return straight
away,
not waiting for the maximum timeout, which would result in the "escalation"
signal being sent after one output, not after the given interval.
---
.../src/airflow/sdk/execution_time/supervisor.py | 28 ++-
task_sdk/tests/execution_time/test_supervisor.py | 206 +++++++++------------
2 files changed, 110 insertions(+), 124 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 3bc714e6415..1b2a7ba7577 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -445,12 +445,21 @@ class WatchedSubprocess:
try:
self._process.send_signal(sig)
- # Service subprocess events during the escalation delay
- self._service_subprocess(max_wait_time=escalation_delay,
raise_on_timeout=True)
- if self._exit_code is not None:
- log.info("Process exited", pid=self.pid,
exit_code=self._exit_code, signal=sig.name)
- return
- except psutil.TimeoutExpired:
+ start = time.monotonic()
+ end = start + escalation_delay
+ now = start
+
+ while now < end:
+ # Service subprocess events during the escalation delay.
This will return as soon as it's
+ # read from any of the sockets, so we need to re-run it if
the process is still alive
+ if (
+ exit_code :=
self._service_subprocess(max_wait_time=end - now, raise_on_timeout=False)
+ ) is not None:
+ log.info("Process exited", pid=self.pid,
exit_code=exit_code, signal=sig.name)
+ return
+
+ now = time.monotonic()
+
msg = "Process did not terminate in time"
if sig != escalation_path[-1]:
msg += "; escalating"
@@ -539,6 +548,7 @@ class WatchedSubprocess:
:param max_wait_time: Maximum time to block while waiting for events,
in seconds.
:param raise_on_timeout: If True, raise an exception if the subprocess
does not exit within the timeout.
+ :returns: The process exit code, or None if it's still alive
"""
events = self.selector.select(timeout=max_wait_time)
for key, _ in events:
@@ -559,9 +569,9 @@ class WatchedSubprocess:
key.fileobj.close() # type: ignore[union-attr]
# Check if the subprocess has exited
- self._check_subprocess_exit(raise_on_timeout=raise_on_timeout)
+ return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout)
- def _check_subprocess_exit(self, raise_on_timeout: bool = False):
+ def _check_subprocess_exit(self, raise_on_timeout: bool = False) -> int |
None:
"""Check if the subprocess has exited."""
if self._exit_code is None:
try:
@@ -570,7 +580,7 @@ class WatchedSubprocess:
except psutil.TimeoutExpired:
if raise_on_timeout:
raise
- pass
+ return self._exit_code
def _send_heartbeat_if_needed(self):
"""Send a heartbeat to the client if heartbeat interval has passed."""
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 96506a48bda..cb43904f172 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -573,152 +573,128 @@ class TestWatchedSubprocessKill:
proc.selector = mock_selector
return proc
+ def test_kill_process_already_exited(self, watched_subprocess,
mock_process):
+ """Test behavior when the process has already exited."""
+ mock_process.wait.side_effect = psutil.NoSuchProcess(pid=1234)
+
+ watched_subprocess.kill(signal.SIGINT, force=True)
+
+ mock_process.send_signal.assert_called_once_with(signal.SIGINT)
+ mock_process.wait.assert_called_once()
+ assert watched_subprocess._exit_code == -1
+
+ def test_kill_process_custom_signal(self, watched_subprocess,
mock_process):
+ """Test that the process is killed with the correct signal."""
+ mock_process.wait.return_value = 0
+
+ signal_to_send = signal.SIGUSR1
+ watched_subprocess.kill(signal_to_send, force=False)
+
+ mock_process.send_signal.assert_called_once_with(signal_to_send)
+ mock_process.wait.assert_called_once_with(timeout=0)
+
@pytest.mark.parametrize(
- ["signal_to_send", "wait_side_effect", "expected_signals"],
+ ["signal_to_send", "exit_after"],
[
pytest.param(
signal.SIGINT,
- [0],
- [signal.SIGINT],
+ signal.SIGINT,
id="SIGINT-success-without-escalation",
),
pytest.param(
signal.SIGINT,
- [psutil.TimeoutExpired(0.1), 0],
- [signal.SIGINT, signal.SIGTERM],
+ signal.SIGTERM,
id="SIGINT-escalates-to-SIGTERM",
),
pytest.param(
signal.SIGINT,
- [
- psutil.TimeoutExpired(0.1), # SIGINT times out
- psutil.TimeoutExpired(0.1), # SIGTERM times out
- 0, # SIGKILL succeeds
- ],
- [signal.SIGINT, signal.SIGTERM, signal.SIGKILL],
+ None,
id="SIGINT-escalates-to-SIGTERM-then-SIGKILL",
),
pytest.param(
signal.SIGTERM,
- [
- psutil.TimeoutExpired(0.1), # SIGTERM times out
- 0, # SIGKILL succeeds
- ],
- [signal.SIGTERM, signal.SIGKILL],
+ None,
id="SIGTERM-escalates-to-SIGKILL",
),
pytest.param(
signal.SIGKILL,
- [0],
- [signal.SIGKILL],
+ None,
id="SIGKILL-success-without-escalation",
),
],
)
- def test_force_kill_escalation(
- self,
- watched_subprocess,
- mock_process,
- mocker,
- signal_to_send,
- wait_side_effect,
- expected_signals,
- captured_logs,
- ):
- """Test escalation path for SIGINT, SIGTERM, and SIGKILL when
force=True."""
- # Mock the process wait method to return the exit code or raise an
exception
- mock_process.wait.side_effect = wait_side_effect
-
- watched_subprocess.kill(signal_to_send=signal_to_send,
escalation_delay=0.1, force=True)
-
- # Check that the correct signals were sent
- mock_process.send_signal.assert_has_calls([mocker.call(sig) for sig in
expected_signals])
-
- # Check that the process was waited on for each signal
- mock_process.wait.assert_has_calls([mocker.call(timeout=0)] *
len(expected_signals))
-
- ## Validate log messages
- # If escalation occurred, we should see a warning log for each signal
sent
- if len(expected_signals) > 1:
- assert {
- "event": "Process did not terminate in time; escalating",
- "level": "warning",
- "logger": "supervisor",
- "pid": 12345,
- "signal": expected_signals[-2].name,
- "timestamp": mocker.ANY,
- } in captured_logs
+ def test_kill_escalation_path(self, signal_to_send, exit_after, mocker,
captured_logs, monkeypatch):
+ def subprocess_main():
+ import signal
+
+ def _handler(sig, frame):
+ print(f"Signal {sig} received", file=sys.stderr)
+ if exit_after == sig:
+ sleep(0.1)
+ exit(sig)
+ sleep(5)
+ print("Should not get here")
+
+ signal.signal(signal.SIGINT, _handler)
+ signal.signal(signal.SIGTERM, _handler)
+ try:
+ sys.stdin.readline()
+ print("Ready")
+ sleep(10)
+ except Exception as e:
+ print(e)
+ # Shouldn't get here
+ exit(5)
- # Regardless of escalation, we should see an info log for the final
signal sent
- assert {
- "event": "Process exited",
- "level": "info",
- "logger": "supervisor",
- "pid": 12345,
- "signal": expected_signals[-1].name,
- "exit_code": 0,
- "timestamp": mocker.ANY,
- } in captured_logs
+ ti_id = uuid7()
- # Validate `selector.select` calls
- assert watched_subprocess.selector.select.call_count ==
len(expected_signals)
- watched_subprocess.selector.select.assert_has_calls(
- [mocker.call(timeout=0.1)] * len(expected_signals)
+ proc = WatchedSubprocess.start(
+ path=os.devnull,
+ ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d",
try_number=1),
+ client=MagicMock(spec=sdk_client.Client),
+ target=subprocess_main,
)
+ # Ensure we get one normal run, to give the proc time to register it's
custom sighandler
+ proc._service_subprocess(max_wait_time=1)
+ proc.kill(signal_to_send=signal_to_send, escalation_delay=0.5,
force=True)
- assert watched_subprocess._exit_code == 0
-
- def test_force_kill_with_selector_events(self, watched_subprocess,
mock_process, mocker):
- """Test force escalation with selector events handled during wait."""
- # Mock selector to return events during escalation
- mock_key = mocker.Mock()
- mock_key.fileobj = mocker.Mock()
+ # Wait for the subprocess to finish
+ assert proc.wait() == exit_after or -signal.SIGKILL
+ exit_after = exit_after or signal.SIGKILL
- # Simulate EOF
- mock_key.data = mocker.Mock(return_value=False)
-
- watched_subprocess.selector.select.side_effect = [
- [(mock_key, None)], # Event during SIGINT
- [], # No event during SIGTERM
- [(mock_key, None)], # Event during SIGKILL
- ]
-
- mock_process.wait.side_effect = [
- psutil.TimeoutExpired(0.1), # SIGINT times out
- psutil.TimeoutExpired(0.1), # SIGTERM times out
- 0, # SIGKILL succeeds
+ logs = [{"event": m["event"], "chan": m.get("chan"), "logger":
m["logger"]} for m in captured_logs]
+ expected_logs = [
+ {"chan": "stdout", "event": "Ready", "logger": "task"},
]
+ # Work out what logs we expect to see
+ if signal_to_send == signal.SIGINT:
+ expected_logs.append({"chan": "stderr", "event": "Signal 2
received", "logger": "task"})
+ if signal_to_send == signal.SIGTERM or (
+ signal_to_send == signal.SIGINT and exit_after != signal.SIGINT
+ ):
+ if signal_to_send == signal.SIGINT:
+ expected_logs.append(
+ {
+ "chan": None,
+ "event": "Process did not terminate in time;
escalating",
+ "logger": "supervisor",
+ }
+ )
+ expected_logs.append({"chan": "stderr", "event": "Signal 15
received", "logger": "task"})
+ if exit_after == signal.SIGKILL:
+ if signal_to_send in {signal.SIGINT, signal.SIGTERM}:
+ expected_logs.append(
+ {
+ "chan": None,
+ "event": "Process did not terminate in time;
escalating",
+ "logger": "supervisor",
+ }
+ )
+ # expected_logs.push({"chan": "stderr", "event": "Signal 9
received", "logger": "task"})
+ ...
- watched_subprocess.kill(signal.SIGINT, escalation_delay=0.1,
force=True)
-
- # Validate selector interactions
- assert watched_subprocess.selector.select.call_count == 3
- mock_key.data.assert_has_calls([mocker.call(mock_key.fileobj),
mocker.call(mock_key.fileobj)])
-
- # Validate signal escalation
- mock_process.send_signal.assert_has_calls(
- [mocker.call(signal.SIGINT), mocker.call(signal.SIGTERM),
mocker.call(signal.SIGKILL)]
- )
-
- def test_kill_process_already_exited(self, watched_subprocess,
mock_process):
- """Test behavior when the process has already exited."""
- mock_process.wait.side_effect = psutil.NoSuchProcess(pid=1234)
-
- watched_subprocess.kill(signal.SIGINT, force=True)
-
- mock_process.send_signal.assert_called_once_with(signal.SIGINT)
- mock_process.wait.assert_called_once()
- assert watched_subprocess._exit_code == -1
-
- def test_kill_process_custom_signal(self, watched_subprocess,
mock_process):
- """Test that the process is killed with the correct signal."""
- mock_process.wait.return_value = 0
-
- signal_to_send = signal.SIGUSR1
- watched_subprocess.kill(signal_to_send, force=False)
-
- mock_process.send_signal.assert_called_once_with(signal_to_send)
- mock_process.wait.assert_called_once_with(timeout=0)
+ expected_logs.extend(({"chan": None, "event": "Process exited",
"logger": "supervisor"},))
+ assert logs == expected_logs
def test_service_subprocess(self, watched_subprocess, mock_process,
mocker):
"""Test `_service_subprocess` processes selector events and handles
subprocess exit."""