This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a73c6268a3b Fix KubernetesPodTrigger.get_task_state KeyError on mapped
TIs (#67296) (#67297)
a73c6268a3b is described below
commit a73c6268a3b2b718cb3363051a80e5e4e9511072
Author: Paul Mathew <[email protected]>
AuthorDate: Thu May 21 17:21:31 2026 -0400
Fix KubernetesPodTrigger.get_task_state KeyError on mapped TIs (#67296)
(#67297)
The execution API's /states endpoint encodes the response key as
``f"{task_id}_{map_index}"`` for mapped TIs but the trigger was looking
the value up by plain ``task_id``. For any mapped deferrable
KubernetesPodOperator task that lookup raised KeyError, which
cleanup()'s broad ``except Exception`` swallowed and skipped
``hook.delete_pod()`` -- so Mark Failed in the UI left the pod running
until ``active_deadline_seconds`` expired.
Compose the lookup key with the ``_{map_index}`` suffix when the TI is
mapped, matching how the API serialises the response. cleanup() now
sees the real state, ``safe_to_cancel()`` returns the right value, and
mark-failed actually deletes the pod within the grace period.
Co-authored-by: Cursor <[email protected]>
---
.../providers/cncf/kubernetes/triggers/pod.py | 10 ++-
.../unit/cncf/kubernetes/triggers/test_pod.py | 86 +++++++++++++++++++++-
2 files changed, 94 insertions(+), 2 deletions(-)
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
index e801a2f0566..65b1b45bb04 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -397,8 +397,16 @@ class KubernetesPodTrigger(BaseTrigger):
run_ids=[self.task_instance.run_id],
map_index=self.task_instance.map_index,
)
+ # The /states endpoint suffixes the response key with
``_{map_index}`` for mapped TIs
+ # (see ``get_task_instance_states`` in airflow-core's
execution_api routes); non-mapped
+ # TIs keep the plain ``task_id``.
+ ti_key = (
+ f"{self.task_instance.task_id}_{self.task_instance.map_index}"
+ if self.task_instance.map_index >= 0
+ else self.task_instance.task_id
+ )
try:
- return
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ return task_states_response[self.task_instance.run_id][ti_key]
except KeyError:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
index 26b2db90e16..765a3f35e3d 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
@@ -34,7 +34,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager
import PodPhase
from airflow.triggers.base import TriggerEvent
from airflow.utils.state import TaskInstanceState
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_3_PLUS
TRIGGER_PATH =
"airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger"
HOOK_PATH =
"airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook"
@@ -827,6 +827,90 @@ class TestKubernetesPodTrigger:
)
assert await trigger.safe_to_cancel() is False
+ @pytest.mark.skipif(
+ not AIRFLOW_V_3_0_PLUS,
+ reason="get_task_state uses RuntimeTaskInstance.get_task_states on
Airflow 3.0+",
+ )
+ @pytest.mark.asyncio
+
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
+ async def test_get_task_state_uses_task_id_for_non_mapped_ti(self,
mock_get_task_states):
+ # Non-mapped TIs (``map_index < 0``) are keyed by plain ``task_id`` in
the
+ # response, matching the dict-key construction in the execution API's
+ # ``get_task_instance_states`` handler.
+ run_id = "manual__2026-05-21T00:00:00+00:00"
+ mock_get_task_states.return_value = {run_id: {"my_task":
TaskInstanceState.SUCCESS}}
+
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ )
+ trigger.task_instance = MagicMock(dag_id="my_dag", task_id="my_task",
run_id=run_id, map_index=-1)
+
+ assert await trigger.get_task_state() == TaskInstanceState.SUCCESS
+
+ @pytest.mark.skipif(
+ not AIRFLOW_V_3_0_PLUS,
+ reason="get_task_state uses RuntimeTaskInstance.get_task_states on
Airflow 3.0+",
+ )
+ @pytest.mark.asyncio
+
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
+ async def test_get_task_state_uses_composite_key_for_mapped_ti(self,
mock_get_task_states):
+ # Regression guard for #67296: mapped TIs (``map_index >= 0``) are
+ # keyed by ``f"{task_id}_{map_index}"`` in the response. Without the
+ # suffix this lookup would KeyError, which ``cleanup()`` would
+ # defensively swallow and skip ``hook.delete_pod()`` -- leaking the
+ # pod until ``active_deadline_seconds`` expires on user mark-failed.
+ run_id = "manual__2026-05-21T00:00:00+00:00"
+ mock_get_task_states.return_value = {run_id: {"map_group.task_a_2":
TaskInstanceState.FAILED}}
+
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ )
+ trigger.task_instance = MagicMock(
+ dag_id="my_dag", task_id="map_group.task_a", run_id=run_id,
map_index=2
+ )
+
+ assert await trigger.get_task_state() == TaskInstanceState.FAILED
+
+ @pytest.mark.skipif(
+ not AIRFLOW_V_3_0_PLUS,
+ reason="get_task_state uses RuntimeTaskInstance.get_task_states on
Airflow 3.0+",
+ )
+ @pytest.mark.asyncio
+
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
+ async def test_get_task_state_raises_when_mapped_key_missing(self,
mock_get_task_states):
+ # The wrapped ``AirflowException`` shape is preserved when the
+ # response is missing the expected (composite) key, so callers
+ # like ``safe_to_cancel`` keep the same behaviour they had before
+ # the lookup was fixed.
+ from airflow.exceptions import AirflowException
+
+ run_id = "manual__2026-05-21T00:00:00+00:00"
+ # Response has the run_id but not the (``map_group.task_a``, ``2``)
+ # entry -- e.g. supervisor has not observed the TI yet.
+ mock_get_task_states.return_value = {run_id: {"map_group.task_a_5":
"running"}}
+
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ )
+ trigger.task_instance = MagicMock(
+ dag_id="my_dag", task_id="map_group.task_a", run_id=run_id,
map_index=2
+ )
+
+ with pytest.raises(AirflowException, match="TaskInstance with dag_id"):
+ await trigger.get_task_state()
+
@pytest.mark.skipif(
AIRFLOW_V_3_3_PLUS,
reason="Legacy cleanup path runs only on Airflow < 3.3",