SameerMesiah97 commented on code in PR #61627:
URL: https://github.com/apache/airflow/pull/61627#discussion_r2779759901


##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -482,6 +482,102 @@ def on_kill(self) -> None:
         captured = capfd.readouterr()
         assert "On kill hook called!" in captured.out
 
+    def test_on_kill_hook_called_when_supervisor_receives_sigterm(
+        self,
+        client_with_ti_start,
+        mocked_parse,
+        make_ti_context,
+        mock_supervisor_comms,
+        create_runtime_ti,
+        make_ti_context_dict,
+        capfd,
+    ):
+        """Test that SIGTERM to the supervisor process is forwarded to the 
task subprocess.
+
+        This simulates what happens when Kubernetes sends SIGTERM to the 
worker pod:
+        the supervisor should forward the signal to the child process so that 
the
+        operator's on_kill() hook is triggered for resource cleanup.
+        """
+        import threading
+
+        ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/run":
+                return httpx.Response(200, json=make_ti_context_dict())
+            return httpx.Response(status_code=204)
+
+        def subprocess_main():
+            CommsDecoder()._get_response()
+
+            class CustomOperator(BaseOperator):
+                def execute(self, context):
+                    for i in range(1000):
+                        print(f"Iteration {i}")
+                        sleep(1)

Review Comment:
   I get why the loop needs to run “long enough” so the subprocess is alive 
when the signal is delivered, but 1000 iterations at 1s each feels a bit 
overkill. If for some reason the subprocess doesn’t get terminated as expected, 
this could run for ~15 minutes and materially stall CI. Is there any reason 
this can’t be reduced to a smaller number (e.g. 30–60 iterations) while still 
leaving plenty of headroom for signal delivery?



##########
task-sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -2078,7 +2078,33 @@ def supervise(
             sentry_integration=sentry_integration,
         )
 
-        exit_code = process.wait()
+        # Forward termination signals to the task subprocess so that the 
operator's
+        # on_kill() hook is invoked on graceful shutdown (e.g. K8s pod 
SIGTERM).
+        # Without this, the supervisor exits on SIGTERM without notifying the 
child,
+        # leaving spawned resources (pods, subprocesses, etc.) running.
+        prev_sigterm = signal.getsignal(signal.SIGTERM)
+        prev_sigint = signal.getsignal(signal.SIGINT)
+
+        def _forward_signal(signum, frame):
+            log.info(
+                "Received signal, forwarding to task subprocess",
+                signal=signal.Signals(signum).name,
+                pid=process.pid,
+            )
+            try:
+                os.kill(process.pid, signum)
+            except ProcessLookupError:
+                pass

Review Comment:
   Are you swallowing the exception here because you anticipate that the child 
worker has been killed before invoking `os.kill`? If that is the case, I would 
add a comment here explaining that. Like this:
   
    `# Child process may have already exited during shutdown races.`
   
   This is more a nit but silent exception swallowing tends to raise eyebrows 
for readers who might have partial context. 



##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -482,6 +482,102 @@ def on_kill(self) -> None:
         captured = capfd.readouterr()
         assert "On kill hook called!" in captured.out
 
+    def test_on_kill_hook_called_when_supervisor_receives_sigterm(
+        self,
+        client_with_ti_start,
+        mocked_parse,
+        make_ti_context,
+        mock_supervisor_comms,
+        create_runtime_ti,
+        make_ti_context_dict,
+        capfd,
+    ):
+        """Test that SIGTERM to the supervisor process is forwarded to the 
task subprocess.
+
+        This simulates what happens when Kubernetes sends SIGTERM to the 
worker pod:
+        the supervisor should forward the signal to the child process so that 
the
+        operator's on_kill() hook is triggered for resource cleanup.
+        """
+        import threading
+
+        ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/run":
+                return httpx.Response(200, json=make_ti_context_dict())
+            return httpx.Response(status_code=204)
+
+        def subprocess_main():
+            CommsDecoder()._get_response()
+
+            class CustomOperator(BaseOperator):
+                def execute(self, context):
+                    for i in range(1000):
+                        print(f"Iteration {i}")
+                        sleep(1)
+
+                def on_kill(self) -> None:
+                    print("On kill hook called via signal forwarding!")
+
+            task = CustomOperator(task_id="test-signal-forward")
+            runtime_ti = create_runtime_ti(
+                dag_id="c",
+                task=task,
+                conf={},
+            )
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        proc = ActivitySubprocess.start(
+            dag_rel_path=os.devnull,
+            bundle_info=FAKE_BUNDLE,
+            what=TaskInstance(
+                id=ti_id,
+                task_id="b",
+                dag_id="c",
+                run_id="d",
+                try_number=1,
+                dag_version_id=uuid7(),
+            ),
+            client=make_client(transport=httpx.MockTransport(handle_request)),
+            target=subprocess_main,
+        )
+
+        # Install signal forwarding handler (same mechanism as supervise() 
does)
+        prev_sigterm = signal.getsignal(signal.SIGTERM)
+
+        def _forward_signal(signum, frame):
+            try:
+                os.kill(proc.pid, signum)
+            except ProcessLookupError:
+                pass
+
+        signal.signal(signal.SIGTERM, _forward_signal)

Review Comment:
   Your new implementation handles both SIGTERM and SIGINT, but you appear to 
be testing only SIGTERM here.  Is this because you only anticipate K8s to send 
SIGTERM? I would suggest explaining that here in a comment so that a casual 
reader does not assume this is a gap. Not a blocking suggestion. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to