This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d506349f868 Make BaseOperator on_kill functionality work with TaskSDK 
(#53718)
d506349f868 is described below

commit d506349f8683c357a3f5f51e5310c6f9048a85b5
Author: GPK <[email protected]>
AuthorDate: Mon Jul 28 20:48:30 2025 +0100

    Make BaseOperator on_kill functionality work with TaskSDK (#53718)
    
    * Make BaseOperator on_kill functionality work with TaskSDK
    
    * Fix static checks
    
    * Resolve review comments
    
    * Update task-sdk/tests/task_sdk/execution_time/test_supervisor.py
    
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
    
    ---------
    
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
 .../src/airflow/sdk/execution_time/task_runner.py  | 13 ++++
 .../task_sdk/execution_time/test_supervisor.py     | 82 +++++++++++++++++++++-
 2 files changed, 94 insertions(+), 1 deletion(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index be680a08d04..9af6da18bd1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -860,6 +860,8 @@ def run(
     log: Logger,
 ) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
     """Run the task in this process."""
+    import signal
+
     from airflow.exceptions import (
         AirflowException,
         AirflowFailException,
@@ -877,6 +879,17 @@ def run(
         assert ti.task is not None
         assert isinstance(ti.task, BaseOperator)
 
+    parent_pid = os.getpid()
+
+    def _on_term(signum, frame):
+        pid = os.getpid()
+        if pid != parent_pid:
+            return
+
+        ti.task.on_kill()
+
+    signal.signal(signal.SIGTERM, _on_term)
+
     msg: ToSupervisor | None = None
     state: TaskInstanceState
     error: BaseException | None = None
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 0331a7a2b9c..c091cec52ab 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -44,7 +44,7 @@ from task_sdk import FAKE_BUNDLE, make_client
 from uuid6 import uuid7
 
 from airflow.executors.workloads import BundleInfo
-from airflow.sdk import timezone
+from airflow.sdk import BaseOperator, timezone
 from airflow.sdk.api import client as sdk_client
 from airflow.sdk.api.client import ServerResponseError
 from airflow.sdk.api.datamodels._generated import (
@@ -121,6 +121,7 @@ from airflow.sdk.execution_time.supervisor import (
     set_supervisor_comms,
     supervise,
 )
+from airflow.sdk.execution_time.task_runner import run
 
 from tests_common.test_utils.config import conf_vars
 
@@ -341,6 +342,85 @@ class TestWatchedSubprocess:
             ]
         )
 
+    def test_on_kill_hook_called_when_sigkilled(
+        self,
+        client_with_ti_start,
+        mocked_parse,
+        make_ti_context,
+        mock_supervisor_comms,
+        create_runtime_ti,
+        make_ti_context_dict,
+        capfd,
+    ):
+        main_pid = os.getpid()
+        ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/heartbeat":
+                return httpx.Response(
+                    status_code=409,
+                    json={
+                        "detail": {
+                            "reason": "not_running",
+                            "message": "TI is no longer in the 'running' 
state. Task state might be externally set and task should terminate",
+                            "current_state": "failed",
+                        }
+                    },
+                )
+            if request.url.path == f"/task-instances/{ti_id}/run":
+                return httpx.Response(200, json=make_ti_context_dict())
+            return httpx.Response(status_code=204)
+
+        def subprocess_main():
+            # Ensure we follow the "protocol" and get the startup message 
before we do anything
+            CommsDecoder()._get_response()
+
+            class CustomOperator(BaseOperator):
+                def execute(self, context):
+                    for i in range(1000):
+                        print(f"Iteration {i}")
+                        sleep(1)
+
+                def on_kill(self) -> None:
+                    print("On kill hook called!")
+
+            task = CustomOperator(task_id="print-params")
+            runtime_ti = create_runtime_ti(
+                dag_id="c",
+                task=task,
+                conf={
+                    "x": 3,
+                    "text": "Hello World!",
+                    "flag": False,
+                    "a_simple_list": ["one", "two", "three", "actually one 
value is made per line"],
+                },
+            )
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+            assert os.getpid() != main_pid
+            os.kill(os.getpid(), signal.SIGTERM)
+            # Ensure that the signal is serviced before we finish and exit the 
subprocess.
+            sleep(0.5)
+
+        proc = ActivitySubprocess.start(
+            dag_rel_path=os.devnull,
+            bundle_info=FAKE_BUNDLE,
+            what=TaskInstance(
+                id=ti_id,
+                task_id="b",
+                dag_id="c",
+                run_id="d",
+                try_number=1,
+                dag_version_id=uuid7(),
+            ),
+            client=make_client(transport=httpx.MockTransport(handle_request)),
+            target=subprocess_main,
+        )
+
+        proc.wait()
+        captured = capfd.readouterr()
+        assert "On kill hook called!" in captured.out
+
     def test_subprocess_sigkilled(self, client_with_ti_start):
         main_pid = os.getpid()
 

Reply via email to