This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e3d08811832 Add cancel_on_kill and safe_to_cancel support to 
KubernetesPodOperator and trigger (#62401)
e3d08811832 is described below

commit e3d08811832161c7a93ba6e2b3727814ef4bc915
Author: manipatnam <[email protected]>
AuthorDate: Thu Mar 19 03:35:07 2026 +0530

    Add cancel_on_kill and safe_to_cancel support to KubernetesPodOperator and 
trigger (#62401)
    
    * Add cancel_on_kill and safe_to_cancel support to KubernetesPodOperator 
and trigger
    
    * spell check
    
    * Session | None
    
    * Run prek
    
    * Reolved the isssues highlighted in the comments
    
    * Renamed cancel_on_kill to on_kill_action based on inputs and made it enum
    
    * Resolved comments
    
    ---------
    
    Co-authored-by: Jens Scheffler <[email protected]>
---
 providers/cncf/kubernetes/docs/operators.rst       |  32 +++++
 .../providers/cncf/kubernetes/hooks/kubernetes.py  |   7 +-
 .../providers/cncf/kubernetes/operators/pod.py     |  14 +++
 .../providers/cncf/kubernetes/triggers/pod.py      | 121 ++++++++++++++++++
 .../providers/cncf/kubernetes/utils/pod_manager.py |   7 ++
 .../unit/cncf/kubernetes/triggers/test_pod.py      | 140 +++++++++++++++++++++
 6 files changed, 319 insertions(+), 2 deletions(-)

diff --git a/providers/cncf/kubernetes/docs/operators.rst 
b/providers/cncf/kubernetes/docs/operators.rst
index 9a2a9b21e3c..72521af704f 100644
--- a/providers/cncf/kubernetes/docs/operators.rst
+++ b/providers/cncf/kubernetes/docs/operators.rst
@@ -155,6 +155,38 @@ Example to fetch and display container log periodically
     :end-before: [END howto_operator_async_log]
 
 
+Pod cleanup on kill
+^^^^^^^^^^^^^^^^^^^
+
+The ``on_kill_action`` parameter controls what happens to the Kubernetes pod 
when a
+running task is killed (e.g. manually marked as success or failed from the 
Airflow UI).
+It accepts the same enum-style string values as ``on_finish_action``:
+
+- ``"delete_pod"`` (default) — the pod is deleted when the task is killed.
+- ``"keep_pod"`` — the pod is left running when the task is killed.
+
+In **sync mode**, ``on_kill_action`` gates the ``on_kill`` callback.
+
+In **deferrable mode**, ``on_kill_action`` is forwarded to the trigger. When 
the trigger
+is cancelled (e.g. the deferred task is manually marked as success or failed), 
the action
+is applied. The ``on_finish_action`` parameter is **not** consulted during a 
kill — it only
+governs cleanup after normal task completion.
+
+If you want to prevent the pod from being deleted when a task is killed (for 
example,
+for debugging), set ``on_kill_action="keep_pod"``:
+
+.. code-block:: python
+
+    k = KubernetesPodOperator(
+        task_id="long_running_task",
+        image="my-image:latest",
+        on_finish_action="delete_pod",
+        on_kill_action="keep_pod",  # pod will NOT be deleted when the task is 
killed
+    )
+
+The ``termination_grace_period`` parameter is also respected during cleanup, 
giving the
+pod time to shut down gracefully before being forcefully terminated.
+
 How does XCom work?
 ^^^^^^^^^^^^^^^^^^^
 The 
:class:`~airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator` 
handles
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 934868efdf8..1c608a8f93c 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -994,18 +994,21 @@ class AsyncKubernetesHook(KubernetesHook):
                 raise KubernetesApiError from e
 
     @generic_api_retry
-    async def delete_pod(self, name: str, namespace: str):
+    async def delete_pod(self, name: str, namespace: str, 
grace_period_seconds: int | None = None):
         """
         Delete pod's object.
 
         :param name: Name of the pod.
         :param namespace: Name of the pod's namespace.
+        :param grace_period_seconds: Optional duration in seconds the pod 
needs to terminate gracefully.
         """
         async with self.get_conn() as connection:
             try:
                 v1_api = async_client.CoreV1Api(connection)
                 await v1_api.delete_namespaced_pod(
-                    name=name, namespace=namespace, 
body=client.V1DeleteOptions()
+                    name=name,
+                    namespace=namespace,
+                    
body=client.V1DeleteOptions(grace_period_seconds=grace_period_seconds),
                 )
             except async_client.ApiException as e:
                 # If the pod is already deleted
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py
index 668c15e9043..edb213cc566 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -73,6 +73,7 @@ from airflow.providers.cncf.kubernetes.utils.container import 
(
 from airflow.providers.cncf.kubernetes.utils.pod_manager import (
     EMPTY_XCOM_RESULT,
     OnFinishAction,
+    OnKillAction,
     PodLaunchFailedException,
     PodManager,
     PodNotFoundException,
@@ -234,6 +235,10 @@ class KubernetesPodOperator(BaseOperator):
         If "delete_pod", the pod will be deleted regardless its state; if 
"delete_succeeded_pod",
         only succeeded pod will be deleted. You can set to "keep_pod" to keep 
the pod. "delete_active_pod" deletes
         pods that are still active (Pending or Running).
+    :param on_kill_action: What to do when the task is killed by the user 
(e.g. manually marked as
+        success/failed from the Airflow UI). If "delete_pod" (default), the 
pod will be deleted.
+        If "keep_pod", the pod will not be deleted. In deferrable mode this is 
forwarded to the
+        trigger which controls cleanup when a deferred task is cancelled.
     :param termination_message_policy: The termination message policy of the 
base container.
         Default value is "File"
     :param active_deadline_seconds: The active_deadline_seconds which 
translates to active_deadline_seconds
@@ -348,6 +353,7 @@ class KubernetesPodOperator(BaseOperator):
         poll_interval: float = 2,
         log_pod_spec_on_failure: bool = True,
         on_finish_action: str = "delete_pod",
+        on_kill_action: str = "delete_pod",
         is_delete_operator_pod: None | bool = None,
         termination_message_policy: str = "File",
         active_deadline_seconds: int | None = None,
@@ -442,6 +448,7 @@ class KubernetesPodOperator(BaseOperator):
         self.remote_pod: k8s.V1Pod | None = None
         self.log_pod_spec_on_failure = log_pod_spec_on_failure
         self.on_finish_action = OnFinishAction(on_finish_action)
+        self.on_kill_action = OnKillAction(on_kill_action)
         # The `is_delete_operator_pod` parameter should have been removed in 
provider version 10.0.0.
         # TODO: remove it from here and from the operator's parameters list 
when the next major version bumped
         self._is_delete_operator_pod = self.on_finish_action == 
OnFinishAction.DELETE_POD
@@ -909,6 +916,8 @@ class KubernetesPodOperator(BaseOperator):
             schedule_timeout=self.schedule_timeout_seconds,
             base_container_name=self.base_container_name,
             on_finish_action=self.on_finish_action.value,
+            on_kill_action=self.on_kill_action.value,
+            termination_grace_period=self.termination_grace_period,
             last_log_time=last_log_time,
             logging_interval=self.logging_interval,
             trigger_kwargs=self.trigger_kwargs,
@@ -1282,6 +1291,11 @@ class KubernetesPodOperator(BaseOperator):
 
     def on_kill(self) -> None:
         self._killed = True
+        if self.on_kill_action == OnKillAction.KEEP_POD:
+            self.log.info(
+                "Skipping pod deletion since on_kill_action is set to %r.", 
self.on_kill_action.value
+            )
+            return
         if self.pod:
             pod = self.pod
             kwargs = {
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
index 6ad7b79b6f4..a1ec923185e 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -25,20 +25,32 @@ from functools import cached_property
 from typing import TYPE_CHECKING, Any, cast
 
 import tenacity
+from asgiref.sync import sync_to_async
 
 from airflow.providers.cncf.kubernetes.exceptions import 
KubernetesApiPermissionError
 from airflow.providers.cncf.kubernetes.hooks.kubernetes import 
AsyncKubernetesHook
 from airflow.providers.cncf.kubernetes.utils.pod_manager import (
     AsyncPodManager,
     OnFinishAction,
+    OnKillAction,
     PodLaunchTimeoutException,
     PodPhase,
 )
+from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.common.compat.sdk import AirflowException
 from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from kubernetes_asyncio.client.models import V1Pod
     from pendulum import DateTime
+    from sqlalchemy.orm.session import Session
+
+if not AIRFLOW_V_3_0_PLUS:
+    from sqlalchemy import select
+
+    from airflow.models.taskinstance import TaskInstance
+    from airflow.utils.session import provide_session
 
 
 class ContainerState(str, Enum):
@@ -75,6 +87,10 @@ class KubernetesPodTrigger(BaseTrigger):
     :param on_finish_action: What to do when the pod reaches its final state, 
or the execution is interrupted.
         If "delete_pod", the pod will be deleted regardless its state; if 
"delete_succeeded_pod",
         only succeeded pod will be deleted. You can set to "keep_pod" to keep 
the pod.
+    :param on_kill_action: What to do when the trigger is cancelled (e.g. when 
a deferred task is
+        manually marked as success/failed). If "delete_pod" (default), the pod 
will be deleted.
+        If "keep_pod", the pod will not be deleted.
+    :param termination_grace_period: Optional grace period in seconds for pod 
termination during cleanup.
     :param logging_interval: number of seconds to wait before kicking it back 
to
         the operator to print latest logs. If ``None`` will wait until 
container done.
     :param last_log_time: where to resume logs from
@@ -98,6 +114,8 @@ class KubernetesPodTrigger(BaseTrigger):
         startup_check_interval: float = 5,
         schedule_timeout: int = 120,
         on_finish_action: str = "delete_pod",
+        on_kill_action: str = "delete_pod",
+        termination_grace_period: int | None = None,
         last_log_time: DateTime | None = None,
         logging_interval: int | None = None,
         trigger_kwargs: dict | None = None,
@@ -120,7 +138,10 @@ class KubernetesPodTrigger(BaseTrigger):
         self.last_log_time = last_log_time
         self.logging_interval = logging_interval
         self.on_finish_action = OnFinishAction(on_finish_action)
+        self.on_kill_action = OnKillAction(on_kill_action)
+        self.termination_grace_period = termination_grace_period
         self.trigger_kwargs = trigger_kwargs or {}
+        self._fired_event = False
         self._since_time = None
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
@@ -143,6 +164,8 @@ class KubernetesPodTrigger(BaseTrigger):
                 "schedule_timeout": self.schedule_timeout,
                 "trigger_start_time": self.trigger_start_time,
                 "on_finish_action": self.on_finish_action.value,
+                "on_kill_action": self.on_kill_action.value,
+                "termination_grace_period": self.termination_grace_period,
                 "last_log_time": self.last_log_time,
                 "logging_interval": self.logging_interval,
                 "trigger_kwargs": self.trigger_kwargs,
@@ -181,10 +204,12 @@ class KubernetesPodTrigger(BaseTrigger):
                 )
             else:
                 event = await self._wait_for_container_completion()
+            self._fired_event = True
             yield event
             return
         except PodLaunchTimeoutException as e:
             message = self._format_exception_description(e)
+            self._fired_event = True
             yield TriggerEvent(
                 {
                     "name": self.pod_name,
@@ -201,6 +226,7 @@ class KubernetesPodTrigger(BaseTrigger):
                 "Please ensure the triggerer's service account is included in 
the 'pod-launcher-role' as defined in the latest Airflow Helm chart. "
                 f"Original error: {e}"
             )
+            self._fired_event = True
             yield TriggerEvent(
                 {
                     "name": self.pod_name,
@@ -217,6 +243,7 @@ class KubernetesPodTrigger(BaseTrigger):
                 self.pod_name,
                 self.pod_namespace,
             )
+            self._fired_event = True
             yield TriggerEvent(
                 {
                     "name": self.pod_name,
@@ -334,6 +361,100 @@ class KubernetesPodTrigger(BaseTrigger):
     def pod_manager(self) -> AsyncPodManager:
         return AsyncPodManager(async_hook=self.hook)
 
+    if not AIRFLOW_V_3_0_PLUS:
+
+        @provide_session
+        def get_task_instance(self, session: Session) -> TaskInstance:
+            """Get the task instance for this trigger from the database 
(Airflow 2.x only)."""
+            task_instance = session.scalar(
+                select(TaskInstance).where(
+                    TaskInstance.dag_id == self.task_instance.dag_id,
+                    TaskInstance.task_id == self.task_instance.task_id,
+                    TaskInstance.run_id == self.task_instance.run_id,
+                    TaskInstance.map_index == self.task_instance.map_index,
+                )
+            )
+            if task_instance is None:
+                raise AirflowException(
+                    "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
+                    self.task_instance.dag_id,
+                    self.task_instance.task_id,
+                    self.task_instance.run_id,
+                    self.task_instance.map_index,
+                )
+            return task_instance
+
+    async def get_task_state(self):
+        """Get the current state of the task instance."""
+        if AIRFLOW_V_3_0_PLUS:
+            from airflow.sdk.execution_time.task_runner import 
RuntimeTaskInstance
+
+            task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+                dag_id=self.task_instance.dag_id,
+                task_ids=[self.task_instance.task_id],
+                run_ids=[self.task_instance.run_id],
+                map_index=self.task_instance.map_index,
+            )
+            try:
+                return 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+            except KeyError:
+                raise AirflowException(
+                    "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
+                    self.task_instance.dag_id,
+                    self.task_instance.task_id,
+                    self.task_instance.run_id,
+                    self.task_instance.map_index,
+                )
+        else:
+            task_instance = await sync_to_async(self.get_task_instance)()  # 
type: ignore[call-arg]
+            return task_instance.state
+
+    async def safe_to_cancel(self) -> bool:
+        """
+        Whether it is safe to cancel the external job which is being executed 
by this trigger.
+
+        Cancel is NOT safe when the task is still in DEFERRED state, because 
it means the
+        triggerer is redistributing triggers and the trigger will be recreated 
on another triggerer.
+        Cancel IS safe when the task state has changed (e.g. user marked it as 
success/failed).
+        """
+        task_state = await self.get_task_state()
+        return task_state != TaskInstanceState.DEFERRED
+
+    async def cleanup(self) -> None:
+        """Clean up the pod when the trigger is cancelled."""
+        if self._fired_event:
+            self.log.debug("Skipping cleanup since an event has already been 
fired.")
+            return
+
+        if self.on_kill_action == OnKillAction.KEEP_POD:
+            self.log.debug("Skipping cleanup since on_kill_action is set to 
%r.", self.on_kill_action.value)
+            return
+
+        try:
+            safe = await self.safe_to_cancel()
+        except Exception:
+            self.log.warning(
+                "Could not determine task state during cleanup; skipping pod 
deletion to be safe.",
+                exc_info=True,
+            )
+            return
+
+        if not safe:
+            self.log.debug(
+                "Skipping cleanup since the task is still in deferred state 
(likely a triggerer restart)."
+            )
+            return
+
+        self.log.info("Deleting pod %s in namespace %s.", self.pod_name, 
self.pod_namespace)
+        try:
+            await self.hook.delete_pod(
+                name=self.pod_name,
+                namespace=self.pod_namespace,
+                grace_period_seconds=self.termination_grace_period,
+            )
+        except Exception:
+            self.log.exception("Unexpected error while deleting pod %s", 
self.pod_name)
+
     def define_container_state(self, pod: V1Pod) -> ContainerState:
         if pod.status is None or pod.status.container_statuses is None:
             return ContainerState.UNDEFINED
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index 78d95d38d0b..2c2405b7fd4 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -1030,6 +1030,13 @@ class OnFinishAction(str, enum.Enum):
     DELETE_SUCCEEDED_POD = "delete_succeeded_pod"
 
 
+class OnKillAction(str, enum.Enum):
+    """Action to take when the task is killed by the user."""
+
+    DELETE_POD = "delete_pod"
+    KEEP_POD = "keep_pod"
+
+
 def is_log_group_marker(line: str) -> bool:
     """Check if the line is a log group marker like `::group::` or 
`::endgroup::`."""
     return line.startswith("::group::") or line.startswith("::endgroup::")
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
index 98fb9a79fb8..37c163bcbba 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
@@ -32,6 +32,7 @@ from pendulum import DateTime
 from airflow.providers.cncf.kubernetes.triggers.pod import ContainerState, 
KubernetesPodTrigger
 from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase
 from airflow.triggers.base import TriggerEvent
+from airflow.utils.state import TaskInstanceState
 
 TRIGGER_PATH = 
"airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger"
 HOOK_PATH = 
"airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook"
@@ -125,6 +126,8 @@ class TestKubernetesPodTrigger:
             "schedule_timeout": STARTUP_TIMEOUT_SECS,
             "trigger_start_time": TRIGGER_START_TIME,
             "on_finish_action": ON_FINISH_ACTION,
+            "on_kill_action": "delete_pod",
+            "termination_grace_period": None,
             "last_log_time": None,
             "logging_interval": None,
             "trigger_kwargs": {},
@@ -555,3 +558,140 @@ class TestKubernetesPodTrigger:
         with context:
             await trigger._get_pod()
         assert mock_hook.get_pod.call_count == call_count
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{TRIGGER_PATH}.hook")
+    async def test_cleanup_does_not_delete_when_fired_event(self, mock_hook):
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+            on_kill_action="delete_pod",
+            on_finish_action="delete_pod",
+        )
+        trigger._fired_event = True
+        await trigger.cleanup()
+        mock_hook.delete_pod.assert_not_called()
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{TRIGGER_PATH}.hook")
+    async def test_cleanup_does_not_delete_when_on_kill_action_keep_pod(self, 
mock_hook):
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+            on_kill_action="keep_pod",
+            on_finish_action="delete_pod",
+        )
+        await trigger.cleanup()
+        mock_hook.delete_pod.assert_not_called()
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{TRIGGER_PATH}.safe_to_cancel", new_callable=mock.AsyncMock, 
return_value=False)
+    @mock.patch(f"{TRIGGER_PATH}.hook")
+    async def test_cleanup_does_not_delete_during_triggerer_restart(self, 
mock_hook, mock_safe):
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+            on_kill_action="delete_pod",
+            on_finish_action="delete_pod",
+        )
+        await trigger.cleanup()
+        mock_hook.delete_pod.assert_not_called()
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{TRIGGER_PATH}.safe_to_cancel", new_callable=mock.AsyncMock, 
return_value=True)
+    @mock.patch(f"{TRIGGER_PATH}.hook")
+    async def test_cleanup_deletes_pod_on_manual_mark(self, mock_hook, 
mock_safe):
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+            on_kill_action="delete_pod",
+            on_finish_action="delete_pod",
+        )
+        mock_hook.delete_pod = mock.AsyncMock()
+        await trigger.cleanup()
+        mock_hook.delete_pod.assert_called_once_with(
+            name=POD_NAME,
+            namespace=NAMESPACE,
+            grace_period_seconds=None,
+        )
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{TRIGGER_PATH}.safe_to_cancel", new_callable=mock.AsyncMock, 
return_value=True)
+    @mock.patch(f"{TRIGGER_PATH}.hook")
+    async def 
test_cleanup_deletes_pod_even_when_on_finish_action_keep_pod(self, mock_hook, 
mock_safe):
+        """on_finish_action is not consulted during kill -- on_kill_action is 
the sole control."""
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+            on_kill_action="delete_pod",
+            on_finish_action="keep_pod",
+        )
+        mock_hook.delete_pod = mock.AsyncMock()
+        await trigger.cleanup()
+        mock_hook.delete_pod.assert_called_once_with(
+            name=POD_NAME,
+            namespace=NAMESPACE,
+            grace_period_seconds=None,
+        )
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{TRIGGER_PATH}.get_task_state", new_callable=mock.AsyncMock)
+    async def test_safe_to_cancel_returns_true_when_task_not_deferred(self, 
mock_get_state):
+        """safe_to_cancel should return True when the task is no longer 
DEFERRED (e.g. user marked success)."""
+        mock_get_state.return_value = TaskInstanceState.SUCCESS
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+        )
+        assert await trigger.safe_to_cancel() is True
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{TRIGGER_PATH}.get_task_state", new_callable=mock.AsyncMock)
+    async def test_safe_to_cancel_returns_false_when_task_still_deferred(self, 
mock_get_state):
+        """safe_to_cancel should return False when the task is still DEFERRED 
(triggerer restart)."""
+        mock_get_state.return_value = TaskInstanceState.DEFERRED
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+        )
+        assert await trigger.safe_to_cancel() is False
+
+    @pytest.mark.asyncio
+    @mock.patch(
+        f"{TRIGGER_PATH}.safe_to_cancel", new_callable=mock.AsyncMock, 
side_effect=Exception("API down")
+    )
+    @mock.patch(f"{TRIGGER_PATH}.hook")
+    async def test_cleanup_skips_deletion_when_safe_to_cancel_raises(self, 
mock_hook, mock_safe):
+        """When safe_to_cancel() raises, cleanup should skip pod deletion 
(fail-safe)."""
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+            on_kill_action="delete_pod",
+            on_finish_action="delete_pod",
+        )
+        await trigger.cleanup()
+        mock_hook.delete_pod.assert_not_called()

Reply via email to