kaxil commented on code in PR #61627:
URL: https://github.com/apache/airflow/pull/61627#discussion_r3295434802
##########
task-sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -2232,7 +2232,34 @@ def supervise_task(
sentry_integration=sentry_integration,
)
- exit_code = process.wait()
+ # Forward termination signals to the task subprocess so that the
operator's
+ # on_kill() hook is invoked on graceful shutdown (e.g. K8s pod
SIGTERM).
+ # Without this, the supervisor exits on SIGTERM without notifying
the child,
+ # leaving spawned resources (pods, subprocesses, etc.) running.
+ prev_sigterm = signal.getsignal(signal.SIGTERM)
Review Comment:
Re-flagging from prior review (also same suggestion from @ashb in id
2938993077, not taken): `signal.signal()` returns the previous handler, so the
separate `signal.getsignal()` calls are redundant.
```python
prev_sigterm = signal.signal(signal.SIGTERM, _forward_signal)
prev_sigint = signal.signal(signal.SIGINT, _forward_signal)
```
Four lines collapse to two and you avoid the small race between the
`getsignal()` lookup and the `signal()` install.
##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -488,6 +488,98 @@ def on_kill(self) -> None:
captured = capfd.readouterr()
assert "On kill hook called!" in captured.out
+ def test_on_kill_hook_called_when_supervisor_receives_sigterm(
+ self,
+ client_with_ti_start,
+ mocked_parse,
+ make_ti_context,
+ mock_supervisor_comms,
+ create_runtime_ti,
+ make_ti_context_dict,
+ capfd,
+ ):
+ """Test that SIGTERM to the supervisor process is forwarded to the
task subprocess.
+
+ This simulates what happens when Kubernetes sends SIGTERM to the
worker pod:
+ the supervisor should forward the signal to the child process so that
the
+ operator's on_kill() hook is triggered for resource cleanup.
+ """
+ import threading
+
+ ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(200, json=make_ti_context_dict())
+ return httpx.Response(status_code=204)
+
+ def subprocess_main():
+ CommsDecoder()._get_response()
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ for i in range(30):
+ print(f"Iteration {i}")
+ sleep(1)
+
+ def on_kill(self) -> None:
+ print("On kill hook called via signal forwarding!")
+
+ task = CustomOperator(task_id="test-signal-forward")
+ runtime_ti = create_runtime_ti(
+ dag_id="c",
+ task=task,
+ conf={},
+ )
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ proc = ActivitySubprocess.start(
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
+ what=TaskInstance(
+ id=ti_id,
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ ),
+ client=make_client(transport=httpx.MockTransport(handle_request)),
+ target=subprocess_main,
+ )
+
+ # Install signal forwarding handler (same mechanism as supervise()
does)
+ prev_sigterm = signal.getsignal(signal.SIGTERM)
+
+ def _forward_signal(signum, frame):
+ with suppress(ProcessLookupError):
+ os.kill(proc.pid, signum)
+
+ signal.signal(signal.SIGTERM, _forward_signal)
Review Comment:
Re-flagging from prior review (also @ashb in id 2939013171): this test still
installs its own `_forward_signal` handler in the test process and SIGTERMs
itself. The production handler registered inside `supervise_task()` is never
exercised. If the forwarding logic in `supervisor.py` regressed to a no-op,
this test would still pass because the test's own handler does the forwarding.
To actually cover the production path, the SIGTERM needs to arrive while
`supervise_task()` is the one with `signal.signal(...)` installed. One way: run
`supervise_task()` in a subprocess and SIGTERM that subprocess, then assert on
its captured stdout. The current shape just re-implements the feature and
asserts the re-implementation works.
##########
task-sdk/tests/task_sdk/dags/signal_forward_test.py:
##########
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import os
+import signal
+import sys
+
+from airflow.sdk.bases.operator import BaseOperator
+from airflow.sdk.definitions.dag import dag
+
+
+class SignalForwardOperator(BaseOperator):
+ """Operator that sends SIGTERM to its parent (the supervisor) to test
signal forwarding."""
+
+ def execute(self, context):
+ # Print sentinel so the test knows execute() is running
+ print("EXECUTE_STARTED", flush=True)
+ # Send SIGTERM to the supervisor (parent process)
+ os.kill(os.getppid(), signal.SIGTERM)
+ # Keep running long enough for the signal to be forwarded back and
on_kill() to fire
+ import time
+
+ time.sleep(30)
+
+ def on_kill(self) -> None:
+ print("ON_KILL_CALLED_VIA_SIGNAL_FORWARDING", flush=True)
+ sys.exit(0)
+
+
+@dag()
+def signal_forward_test():
+ SignalForwardOperator(task_id="signal_task")
+
+
+signal_forward_test()
Review Comment:
This new DAG file is orphaned: nothing in this PR imports it, loads it via
`DagBag`, or asserts on its `EXECUTE_STARTED` /
`ON_KILL_CALLED_VIA_SIGNAL_FORWARDING` sentinels (`grep -r signal_forward_test
task-sdk/` returns only this file).
It looks like the intended consumer is the new
`test_on_kill_hook_called_when_supervisor_receives_sigterm` test, but that test
inlines a `CustomOperator` instead. Either wire this DAG into a test (which
would also let you replace the `sleep(2)` synchronization with a
`EXECUTE_STARTED` poll) or drop the file before merge -- shipping dead code
makes future cleanup harder.
##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -488,6 +488,98 @@ def on_kill(self) -> None:
captured = capfd.readouterr()
assert "On kill hook called!" in captured.out
+ def test_on_kill_hook_called_when_supervisor_receives_sigterm(
+ self,
+ client_with_ti_start,
+ mocked_parse,
+ make_ti_context,
+ mock_supervisor_comms,
+ create_runtime_ti,
+ make_ti_context_dict,
+ capfd,
+ ):
+ """Test that SIGTERM to the supervisor process is forwarded to the
task subprocess.
+
+ This simulates what happens when Kubernetes sends SIGTERM to the
worker pod:
+ the supervisor should forward the signal to the child process so that
the
+ operator's on_kill() hook is triggered for resource cleanup.
+ """
+ import threading
+
+ ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(200, json=make_ti_context_dict())
+ return httpx.Response(status_code=204)
+
+ def subprocess_main():
+ CommsDecoder()._get_response()
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ for i in range(30):
+ print(f"Iteration {i}")
+ sleep(1)
+
+ def on_kill(self) -> None:
+ print("On kill hook called via signal forwarding!")
+
+ task = CustomOperator(task_id="test-signal-forward")
+ runtime_ti = create_runtime_ti(
+ dag_id="c",
+ task=task,
+ conf={},
+ )
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ proc = ActivitySubprocess.start(
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
+ what=TaskInstance(
+ id=ti_id,
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ ),
+ client=make_client(transport=httpx.MockTransport(handle_request)),
+ target=subprocess_main,
+ )
+
+ # Install signal forwarding handler (same mechanism as supervise()
does)
+ prev_sigterm = signal.getsignal(signal.SIGTERM)
+
+ def _forward_signal(signum, frame):
+ with suppress(ProcessLookupError):
+ os.kill(proc.pid, signum)
+
+ signal.signal(signal.SIGTERM, _forward_signal)
+
+ # Send SIGTERM to ourselves (the supervisor) from a background thread,
+ # giving the subprocess time to start executing first. Then forcefully
+ # terminate the subprocess so the test does not hang.
+ def send_signals():
+ sleep(2)
Review Comment:
Re-flagging from prior review: the two `sleep(2)` calls are fixed-delay
synchronization. On a slow CI runner the first sleep can elapse before the
child's `execute()` has actually started printing, and on a fast runner the
second sleep can elapse before `on_kill()` flushes. Either way this is a flake
source.
It looks like the new `signal_forward_test.py` DAG (with the
`EXECUTE_STARTED` sentinel) was intended to address this, but nothing in this
PR consumes that DAG (see the separate comment on that file). The actual test
still relies on these sleeps. Please either wire the DAG-based approach up, or
replace the sleeps here with a poll-until-condition on the child's stdout.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]