SameerMesiah97 commented on code in PR #61627:
URL: https://github.com/apache/airflow/pull/61627#discussion_r2779759901
##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -482,6 +482,102 @@ def on_kill(self) -> None:
captured = capfd.readouterr()
assert "On kill hook called!" in captured.out
+ def test_on_kill_hook_called_when_supervisor_receives_sigterm(
+ self,
+ client_with_ti_start,
+ mocked_parse,
+ make_ti_context,
+ mock_supervisor_comms,
+ create_runtime_ti,
+ make_ti_context_dict,
+ capfd,
+ ):
+ """Test that SIGTERM to the supervisor process is forwarded to the
task subprocess.
+
+ This simulates what happens when Kubernetes sends SIGTERM to the
worker pod:
+ the supervisor should forward the signal to the child process so that
the
+ operator's on_kill() hook is triggered for resource cleanup.
+ """
+ import threading
+
+ ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(200, json=make_ti_context_dict())
+ return httpx.Response(status_code=204)
+
+ def subprocess_main():
+ CommsDecoder()._get_response()
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ for i in range(1000):
+ print(f"Iteration {i}")
+ sleep(1)
Review Comment:
I get why the loop needs to run “long enough” so the subprocess is alive
when the signal is delivered, but 1000 iterations at 1s each feels a bit
overkill. If for some reason the subprocess doesn’t get terminated as expected,
this could run for ~15 minutes and materially stall CI. Is there any reason
this can’t be reduced to a smaller number (e.g. 30–60 iterations) while still
leaving plenty of headroom for signal delivery?
##########
task-sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -2078,7 +2078,33 @@ def supervise(
sentry_integration=sentry_integration,
)
- exit_code = process.wait()
+ # Forward termination signals to the task subprocess so that the
operator's
+ # on_kill() hook is invoked on graceful shutdown (e.g. K8s pod
SIGTERM).
+ # Without this, the supervisor exits on SIGTERM without notifying the
child,
+ # leaving spawned resources (pods, subprocesses, etc.) running.
+ prev_sigterm = signal.getsignal(signal.SIGTERM)
+ prev_sigint = signal.getsignal(signal.SIGINT)
+
+ def _forward_signal(signum, frame):
+ log.info(
+ "Received signal, forwarding to task subprocess",
+ signal=signal.Signals(signum).name,
+ pid=process.pid,
+ )
+ try:
+ os.kill(process.pid, signum)
+ except ProcessLookupError:
+ pass
Review Comment:
Are you swallowing the exception here because you anticipate that the child
worker has been killed before invoking `os.kill`? If that is the case, I would
add a comment here explaining that. Like this:
`# Child process may have already exited during shutdown races.`
This is more a nit but silent exception swallowing tends to raise eyebrows
for readers who might have partial context.
##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -482,6 +482,102 @@ def on_kill(self) -> None:
captured = capfd.readouterr()
assert "On kill hook called!" in captured.out
+ def test_on_kill_hook_called_when_supervisor_receives_sigterm(
+ self,
+ client_with_ti_start,
+ mocked_parse,
+ make_ti_context,
+ mock_supervisor_comms,
+ create_runtime_ti,
+ make_ti_context_dict,
+ capfd,
+ ):
+ """Test that SIGTERM to the supervisor process is forwarded to the
task subprocess.
+
+ This simulates what happens when Kubernetes sends SIGTERM to the
worker pod:
+ the supervisor should forward the signal to the child process so that
the
+ operator's on_kill() hook is triggered for resource cleanup.
+ """
+ import threading
+
+ ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(200, json=make_ti_context_dict())
+ return httpx.Response(status_code=204)
+
+ def subprocess_main():
+ CommsDecoder()._get_response()
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ for i in range(1000):
+ print(f"Iteration {i}")
+ sleep(1)
+
+ def on_kill(self) -> None:
+ print("On kill hook called via signal forwarding!")
+
+ task = CustomOperator(task_id="test-signal-forward")
+ runtime_ti = create_runtime_ti(
+ dag_id="c",
+ task=task,
+ conf={},
+ )
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ proc = ActivitySubprocess.start(
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
+ what=TaskInstance(
+ id=ti_id,
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ ),
+ client=make_client(transport=httpx.MockTransport(handle_request)),
+ target=subprocess_main,
+ )
+
+ # Install signal forwarding handler (same mechanism as supervise()
does)
+ prev_sigterm = signal.getsignal(signal.SIGTERM)
+
+ def _forward_signal(signum, frame):
+ try:
+ os.kill(proc.pid, signum)
+ except ProcessLookupError:
+ pass
+
+ signal.signal(signal.SIGTERM, _forward_signal)
Review Comment:
Your new implementation handles both SIGTERM and SIGINT, but you appear to
be testing only SIGTERM here. Is this because you only anticipate K8s to send
SIGTERM? I would suggest explaining that here in a comment so that a casual
reader does not assume this is a gap. Not a blocking suggestion.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]