This is an automated email from the ASF dual-hosted git repository.
gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d506349f868 Make BaseOperator on_kill functionality work with TaskSDK
(#53718)
d506349f868 is described below
commit d506349f8683c357a3f5f51e5310c6f9048a85b5
Author: GPK <[email protected]>
AuthorDate: Mon Jul 28 20:48:30 2025 +0100
Make BaseOperator on_kill functionality work with TaskSDK (#53718)
* Make BaseOperator on_kill functionality work with TaskSDK
* Fix static checks
* Resolve review comments
* Update task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Co-authored-by: Ash Berlin-Taylor <[email protected]>
---------
Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
.../src/airflow/sdk/execution_time/task_runner.py | 13 ++++
.../task_sdk/execution_time/test_supervisor.py | 82 +++++++++++++++++++++-
2 files changed, 94 insertions(+), 1 deletion(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index be680a08d04..9af6da18bd1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -860,6 +860,8 @@ def run(
log: Logger,
) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
"""Run the task in this process."""
+ import signal
+
from airflow.exceptions import (
AirflowException,
AirflowFailException,
@@ -877,6 +879,17 @@ def run(
assert ti.task is not None
assert isinstance(ti.task, BaseOperator)
+ parent_pid = os.getpid()
+
+ def _on_term(signum, frame):
+ pid = os.getpid()
+ if pid != parent_pid:
+ return
+
+ ti.task.on_kill()
+
+ signal.signal(signal.SIGTERM, _on_term)
+
msg: ToSupervisor | None = None
state: TaskInstanceState
error: BaseException | None = None
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 0331a7a2b9c..c091cec52ab 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -44,7 +44,7 @@ from task_sdk import FAKE_BUNDLE, make_client
from uuid6 import uuid7
from airflow.executors.workloads import BundleInfo
-from airflow.sdk import timezone
+from airflow.sdk import BaseOperator, timezone
from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import (
@@ -121,6 +121,7 @@ from airflow.sdk.execution_time.supervisor import (
set_supervisor_comms,
supervise,
)
+from airflow.sdk.execution_time.task_runner import run
from tests_common.test_utils.config import conf_vars
@@ -341,6 +342,85 @@ class TestWatchedSubprocess:
]
)
+ def test_on_kill_hook_called_when_sigkilled(
+ self,
+ client_with_ti_start,
+ mocked_parse,
+ make_ti_context,
+ mock_supervisor_comms,
+ create_runtime_ti,
+ make_ti_context_dict,
+ capfd,
+ ):
+ main_pid = os.getpid()
+ ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/heartbeat":
+ return httpx.Response(
+ status_code=409,
+ json={
+ "detail": {
+ "reason": "not_running",
+ "message": "TI is no longer in the 'running'
state. Task state might be externally set and task should terminate",
+ "current_state": "failed",
+ }
+ },
+ )
+ if request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(200, json=make_ti_context_dict())
+ return httpx.Response(status_code=204)
+
+ def subprocess_main():
+ # Ensure we follow the "protocol" and get the startup message
before we do anything
+ CommsDecoder()._get_response()
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ for i in range(1000):
+ print(f"Iteration {i}")
+ sleep(1)
+
+ def on_kill(self) -> None:
+ print("On kill hook called!")
+
+ task = CustomOperator(task_id="print-params")
+ runtime_ti = create_runtime_ti(
+ dag_id="c",
+ task=task,
+ conf={
+ "x": 3,
+ "text": "Hello World!",
+ "flag": False,
+ "a_simple_list": ["one", "two", "three", "actually one
value is made per line"],
+ },
+ )
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ assert os.getpid() != main_pid
+ os.kill(os.getpid(), signal.SIGTERM)
+ # Ensure that the signal is serviced before we finish and exit the
subprocess.
+ sleep(0.5)
+
+ proc = ActivitySubprocess.start(
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
+ what=TaskInstance(
+ id=ti_id,
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ ),
+ client=make_client(transport=httpx.MockTransport(handle_request)),
+ target=subprocess_main,
+ )
+
+ proc.wait()
+ captured = capfd.readouterr()
+ assert "On kill hook called!" in captured.out
+
def test_subprocess_sigkilled(self, client_with_ti_start):
main_pid = os.getpid()