This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 16206ef559e Correctly ensure that we give subprocesses time to exit 
after signalling it (#44766)
16206ef559e is described below

commit 16206ef559efd72c8be56fcbb0f1b1f0d40159c1
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Sat Dec 7 16:41:10 2024 +0000

    Correctly ensure that we give subprocesses time to exit after signalling it 
(#44766)
    
    We had a bug hidden in our tests by our use of mocks -- if the subprocess
    returned any output, then `self.selector.select()` would return straight 
away,
    not waiting for the maximum timeout, which would result in the "escalation"
    signal being sent after one output, not after the given interval.
---
 .../src/airflow/sdk/execution_time/supervisor.py   |  28 ++-
 task_sdk/tests/execution_time/test_supervisor.py   | 206 +++++++++------------
 2 files changed, 110 insertions(+), 124 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 3bc714e6415..1b2a7ba7577 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -445,12 +445,21 @@ class WatchedSubprocess:
             try:
                 self._process.send_signal(sig)
 
-                # Service subprocess events during the escalation delay
-                self._service_subprocess(max_wait_time=escalation_delay, 
raise_on_timeout=True)
-                if self._exit_code is not None:
-                    log.info("Process exited", pid=self.pid, 
exit_code=self._exit_code, signal=sig.name)
-                    return
-            except psutil.TimeoutExpired:
+                start = time.monotonic()
+                end = start + escalation_delay
+                now = start
+
+                while now < end:
+                    # Service subprocess events during the escalation delay. 
This will return as soon as it's
+                    # read from any of the sockets, so we need to re-run it if 
the process is still alive
+                    if (
+                        exit_code := 
self._service_subprocess(max_wait_time=end - now, raise_on_timeout=False)
+                    ) is not None:
+                        log.info("Process exited", pid=self.pid, 
exit_code=exit_code, signal=sig.name)
+                        return
+
+                    now = time.monotonic()
+
                 msg = "Process did not terminate in time"
                 if sig != escalation_path[-1]:
                     msg += "; escalating"
@@ -539,6 +548,7 @@ class WatchedSubprocess:
 
         :param max_wait_time: Maximum time to block while waiting for events, 
in seconds.
         :param raise_on_timeout: If True, raise an exception if the subprocess 
does not exit within the timeout.
+        :returns: The process exit code, or None if it's still alive
         """
         events = self.selector.select(timeout=max_wait_time)
         for key, _ in events:
@@ -559,9 +569,9 @@ class WatchedSubprocess:
                 key.fileobj.close()  # type: ignore[union-attr]
 
         # Check if the subprocess has exited
-        self._check_subprocess_exit(raise_on_timeout=raise_on_timeout)
+        return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout)
 
-    def _check_subprocess_exit(self, raise_on_timeout: bool = False):
+    def _check_subprocess_exit(self, raise_on_timeout: bool = False) -> int | 
None:
         """Check if the subprocess has exited."""
         if self._exit_code is None:
             try:
@@ -570,7 +580,7 @@ class WatchedSubprocess:
             except psutil.TimeoutExpired:
                 if raise_on_timeout:
                     raise
-                pass
+        return self._exit_code
 
     def _send_heartbeat_if_needed(self):
         """Send a heartbeat to the client if heartbeat interval has passed."""
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 96506a48bda..cb43904f172 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -573,152 +573,128 @@ class TestWatchedSubprocessKill:
         proc.selector = mock_selector
         return proc
 
+    def test_kill_process_already_exited(self, watched_subprocess, 
mock_process):
+        """Test behavior when the process has already exited."""
+        mock_process.wait.side_effect = psutil.NoSuchProcess(pid=1234)
+
+        watched_subprocess.kill(signal.SIGINT, force=True)
+
+        mock_process.send_signal.assert_called_once_with(signal.SIGINT)
+        mock_process.wait.assert_called_once()
+        assert watched_subprocess._exit_code == -1
+
+    def test_kill_process_custom_signal(self, watched_subprocess, 
mock_process):
+        """Test that the process is killed with the correct signal."""
+        mock_process.wait.return_value = 0
+
+        signal_to_send = signal.SIGUSR1
+        watched_subprocess.kill(signal_to_send, force=False)
+
+        mock_process.send_signal.assert_called_once_with(signal_to_send)
+        mock_process.wait.assert_called_once_with(timeout=0)
+
     @pytest.mark.parametrize(
-        ["signal_to_send", "wait_side_effect", "expected_signals"],
+        ["signal_to_send", "exit_after"],
         [
             pytest.param(
                 signal.SIGINT,
-                [0],
-                [signal.SIGINT],
+                signal.SIGINT,
                 id="SIGINT-success-without-escalation",
             ),
             pytest.param(
                 signal.SIGINT,
-                [psutil.TimeoutExpired(0.1), 0],
-                [signal.SIGINT, signal.SIGTERM],
+                signal.SIGTERM,
                 id="SIGINT-escalates-to-SIGTERM",
             ),
             pytest.param(
                 signal.SIGINT,
-                [
-                    psutil.TimeoutExpired(0.1),  # SIGINT times out
-                    psutil.TimeoutExpired(0.1),  # SIGTERM times out
-                    0,  # SIGKILL succeeds
-                ],
-                [signal.SIGINT, signal.SIGTERM, signal.SIGKILL],
+                None,
                 id="SIGINT-escalates-to-SIGTERM-then-SIGKILL",
             ),
             pytest.param(
                 signal.SIGTERM,
-                [
-                    psutil.TimeoutExpired(0.1),  # SIGTERM times out
-                    0,  # SIGKILL succeeds
-                ],
-                [signal.SIGTERM, signal.SIGKILL],
+                None,
                 id="SIGTERM-escalates-to-SIGKILL",
             ),
             pytest.param(
                 signal.SIGKILL,
-                [0],
-                [signal.SIGKILL],
+                None,
                 id="SIGKILL-success-without-escalation",
             ),
         ],
     )
-    def test_force_kill_escalation(
-        self,
-        watched_subprocess,
-        mock_process,
-        mocker,
-        signal_to_send,
-        wait_side_effect,
-        expected_signals,
-        captured_logs,
-    ):
-        """Test escalation path for SIGINT, SIGTERM, and SIGKILL when 
force=True."""
-        # Mock the process wait method to return the exit code or raise an 
exception
-        mock_process.wait.side_effect = wait_side_effect
-
-        watched_subprocess.kill(signal_to_send=signal_to_send, 
escalation_delay=0.1, force=True)
-
-        # Check that the correct signals were sent
-        mock_process.send_signal.assert_has_calls([mocker.call(sig) for sig in 
expected_signals])
-
-        # Check that the process was waited on for each signal
-        mock_process.wait.assert_has_calls([mocker.call(timeout=0)] * 
len(expected_signals))
-
-        ## Validate log messages
-        # If escalation occurred, we should see a warning log for each signal 
sent
-        if len(expected_signals) > 1:
-            assert {
-                "event": "Process did not terminate in time; escalating",
-                "level": "warning",
-                "logger": "supervisor",
-                "pid": 12345,
-                "signal": expected_signals[-2].name,
-                "timestamp": mocker.ANY,
-            } in captured_logs
+    def test_kill_escalation_path(self, signal_to_send, exit_after, mocker, 
captured_logs, monkeypatch):
+        def subprocess_main():
+            import signal
+
+            def _handler(sig, frame):
+                print(f"Signal {sig} received", file=sys.stderr)
+                if exit_after == sig:
+                    sleep(0.1)
+                    exit(sig)
+                sleep(5)
+                print("Should not get here")
+
+            signal.signal(signal.SIGINT, _handler)
+            signal.signal(signal.SIGTERM, _handler)
+            try:
+                sys.stdin.readline()
+                print("Ready")
+                sleep(10)
+            except Exception as e:
+                print(e)
+            # Shouldn't get here
+            exit(5)
 
-        # Regardless of escalation, we should see an info log for the final 
signal sent
-        assert {
-            "event": "Process exited",
-            "level": "info",
-            "logger": "supervisor",
-            "pid": 12345,
-            "signal": expected_signals[-1].name,
-            "exit_code": 0,
-            "timestamp": mocker.ANY,
-        } in captured_logs
+        ti_id = uuid7()
 
-        # Validate `selector.select` calls
-        assert watched_subprocess.selector.select.call_count == 
len(expected_signals)
-        watched_subprocess.selector.select.assert_has_calls(
-            [mocker.call(timeout=0.1)] * len(expected_signals)
+        proc = WatchedSubprocess.start(
+            path=os.devnull,
+            ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1),
+            client=MagicMock(spec=sdk_client.Client),
+            target=subprocess_main,
         )
+        # Ensure we get one normal run, to give the proc time to register it's 
custom sighandler
+        proc._service_subprocess(max_wait_time=1)
+        proc.kill(signal_to_send=signal_to_send, escalation_delay=0.5, 
force=True)
 
-        assert watched_subprocess._exit_code == 0
-
-    def test_force_kill_with_selector_events(self, watched_subprocess, 
mock_process, mocker):
-        """Test force escalation with selector events handled during wait."""
-        # Mock selector to return events during escalation
-        mock_key = mocker.Mock()
-        mock_key.fileobj = mocker.Mock()
+        # Wait for the subprocess to finish
+        assert proc.wait() == exit_after or -signal.SIGKILL
+        exit_after = exit_after or signal.SIGKILL
 
-        # Simulate EOF
-        mock_key.data = mocker.Mock(return_value=False)
-
-        watched_subprocess.selector.select.side_effect = [
-            [(mock_key, None)],  # Event during SIGINT
-            [],  # No event during SIGTERM
-            [(mock_key, None)],  # Event during SIGKILL
-        ]
-
-        mock_process.wait.side_effect = [
-            psutil.TimeoutExpired(0.1),  # SIGINT times out
-            psutil.TimeoutExpired(0.1),  # SIGTERM times out
-            0,  # SIGKILL succeeds
+        logs = [{"event": m["event"], "chan": m.get("chan"), "logger": 
m["logger"]} for m in captured_logs]
+        expected_logs = [
+            {"chan": "stdout", "event": "Ready", "logger": "task"},
         ]
+        # Work out what logs we expect to see
+        if signal_to_send == signal.SIGINT:
+            expected_logs.append({"chan": "stderr", "event": "Signal 2 
received", "logger": "task"})
+        if signal_to_send == signal.SIGTERM or (
+            signal_to_send == signal.SIGINT and exit_after != signal.SIGINT
+        ):
+            if signal_to_send == signal.SIGINT:
+                expected_logs.append(
+                    {
+                        "chan": None,
+                        "event": "Process did not terminate in time; 
escalating",
+                        "logger": "supervisor",
+                    }
+                )
+            expected_logs.append({"chan": "stderr", "event": "Signal 15 
received", "logger": "task"})
+        if exit_after == signal.SIGKILL:
+            if signal_to_send in {signal.SIGINT, signal.SIGTERM}:
+                expected_logs.append(
+                    {
+                        "chan": None,
+                        "event": "Process did not terminate in time; 
escalating",
+                        "logger": "supervisor",
+                    }
+                )
+            # expected_logs.push({"chan": "stderr", "event": "Signal 9 
received", "logger": "task"})
+            ...
 
-        watched_subprocess.kill(signal.SIGINT, escalation_delay=0.1, 
force=True)
-
-        # Validate selector interactions
-        assert watched_subprocess.selector.select.call_count == 3
-        mock_key.data.assert_has_calls([mocker.call(mock_key.fileobj), 
mocker.call(mock_key.fileobj)])
-
-        # Validate signal escalation
-        mock_process.send_signal.assert_has_calls(
-            [mocker.call(signal.SIGINT), mocker.call(signal.SIGTERM), 
mocker.call(signal.SIGKILL)]
-        )
-
-    def test_kill_process_already_exited(self, watched_subprocess, 
mock_process):
-        """Test behavior when the process has already exited."""
-        mock_process.wait.side_effect = psutil.NoSuchProcess(pid=1234)
-
-        watched_subprocess.kill(signal.SIGINT, force=True)
-
-        mock_process.send_signal.assert_called_once_with(signal.SIGINT)
-        mock_process.wait.assert_called_once()
-        assert watched_subprocess._exit_code == -1
-
-    def test_kill_process_custom_signal(self, watched_subprocess, 
mock_process):
-        """Test that the process is killed with the correct signal."""
-        mock_process.wait.return_value = 0
-
-        signal_to_send = signal.SIGUSR1
-        watched_subprocess.kill(signal_to_send, force=False)
-
-        mock_process.send_signal.assert_called_once_with(signal_to_send)
-        mock_process.wait.assert_called_once_with(timeout=0)
+        expected_logs.extend(({"chan": None, "event": "Process exited", 
"logger": "supervisor"},))
+        assert logs == expected_logs
 
     def test_service_subprocess(self, watched_subprocess, mock_process, 
mocker):
         """Test `_service_subprocess` processes selector events and handles 
subprocess exit."""

Reply via email to