This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3d15dcfeffd Bugfix/fix spark k8s pod duplicate issue (#61110)
3d15dcfeffd is described below
commit 3d15dcfeffd7b38b0c52028967f3593d7f6222e2
Author: Nataneljpwd <[email protected]>
AuthorDate: Thu Feb 5 21:00:20 2026 +0000
Bugfix/fix spark k8s pod duplicate issue (#61110)
* fix an edge case where if a pod was pending in SparkKubernetes operator,
the task won't fail and will recover
* formatting
* fixed tests and removed redundent tests
* fixed pod status phase
* address comment
* added another test for the success case
* resolve CR comments
* fixed mypy issue
* address cr comments
* fix last test
* remove deletion timestamp
---------
Co-authored-by: Natanel Rudyuklakir <[email protected]>
---
.../cncf/kubernetes/operators/spark_kubernetes.py | 23 ++-
.../providers/cncf/kubernetes/utils/pod_manager.py | 1 +
.../kubernetes/operators/test_spark_kubernetes.py | 174 ++++++++++++++++++---
3 files changed, 170 insertions(+), 28 deletions(-)
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index a01f05ae82d..bb75f797525 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -248,23 +248,32 @@ class SparkKubernetesOperator(KubernetesPodOperator):
self._build_find_pod_label_selector(context,
exclude_checked=exclude_checked)
+ ",spark-role=driver"
)
- pod_list = self.client.list_namespaced_pod(self.namespace,
label_selector=label_selector).items
+ # since we did not specify a resource version, we make sure to get the
latest data
+ # we make sure we get only running or pending pods.
+ field_selector = self._get_field_selector()
+ pod_list = self.client.list_namespaced_pod(
+ self.namespace, label_selector=label_selector,
field_selector=field_selector
+ ).items
pod = None
if len(pod_list) > 1:
- # When multiple pods match the same labels, select one
deterministically,
- # preferring a Running pod, then creation time, with name as a
tie-breaker.
+ # When multiple pods match the same labels, select one
deterministically.
+ # Prefer Succeeded, then Running (excluding terminating), then
Pending.
+ # Terminating pods can be identified via deletion_timestamp.
+ # Pending pods are included to handle recent driver restarts
without failing the task.
pod = max(
pod_list,
key=lambda p: (
- p.status.phase == PodPhase.RUNNING,
+ p.metadata.deletion_timestamp is None, # not a
terminating pod in pending
+ p.status.phase == PodPhase.SUCCEEDED, # if the job
succeeded while the worker was down
+ p.status.phase == PodPhase.PENDING,
p.metadata.creation_timestamp or
datetime.min.replace(tzinfo=timezone.utc),
p.metadata.name or "",
),
)
self.log.warning(
"Found %d Spark driver pods matching labels %s; "
- "selecting pod %s for reattachment based on status and
creation time.",
+ "selecting pod %s for reattachment based on status.",
len(pod_list),
label_selector,
pod.metadata.name,
@@ -279,6 +288,10 @@ class SparkKubernetesOperator(KubernetesPodOperator):
self.log.info("`try_number` of pod: %s",
pod.metadata.labels.get("try_number", "unknown"))
return pod
+ def _get_field_selector(self) -> str:
+ # exclude terminal failure states, to get only running, pending and
succeeded states.
+ return
f"status.phase!={PodPhase.FAILED},status.phase!={PodPhase.UNKNOWN}"
+
def process_pod_deletion(self, pod, *, reraise=True):
if pod is not None:
if self.delete_on_termination:
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index cd25f06a76e..01f8f4a3fa2 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -90,6 +90,7 @@ class PodPhase:
RUNNING = "Running"
FAILED = "Failed"
SUCCEEDED = "Succeeded"
+ UNKNOWN = "Unknown"
terminal_states = {FAILED, SUCCEEDED}
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
index 1e7f5c1da23..31c38811385 100644
---
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -863,7 +863,11 @@ class TestSparkKubernetesOperator:
op.execute(context)
label_selector = op._build_find_pod_label_selector(context) +
",spark-role=driver"
op.find_spark_job(context)
- mock_get_kube_client.list_namespaced_pod.assert_called_with("default",
label_selector=label_selector)
+ mock_get_kube_client.list_namespaced_pod.assert_called_with(
+ "default",
+ label_selector=label_selector,
+ field_selector=op._get_field_selector(),
+ )
@patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook")
def test_adds_task_context_labels_to_driver_and_executor(
@@ -941,7 +945,11 @@ class TestSparkKubernetesOperator:
op.execute(context)
label_selector = op._build_find_pod_label_selector(context) +
",spark-role=driver"
- mock_get_kube_client.list_namespaced_pod.assert_called_with("default",
label_selector=label_selector)
+ mock_get_kube_client.list_namespaced_pod.assert_called_with(
+ "default",
+ label_selector=label_selector,
+ field_selector=op._get_field_selector(),
+ )
mock_create_namespaced_crd.assert_not_called()
@@ -983,21 +991,140 @@ class TestSparkKubernetesOperator:
running_pod.metadata.labels = {"try_number": "1"}
running_pod.status.phase = "Running"
- # Pending pod should not be selected.
+ # Terminating pod should not be selected.
+ terminating_pod = mock.MagicMock()
+ terminating_pod.metadata.creation_timestamp = timezone.datetime(2025,
1, 1, tzinfo=timezone.utc)
+ terminating_pod.metadata.deletion_timestamp = timezone.datetime(2025,
1, 2, tzinfo=timezone.utc)
+ terminating_pod.metadata.name = "spark-driver-pending"
+ terminating_pod.metadata.labels = {"try_number": "1"}
+ terminating_pod.status.phase = "Running"
+
+ mock_get_kube_client.list_namespaced_pod.return_value.items = [
+ running_pod,
+ terminating_pod,
+ ]
+
+ returned_pod = op.find_spark_job(context)
+
+ assert returned_pod is running_pod
+
+ def test_find_spark_job_picks_pending_pod(
+ self,
+ mock_is_in_cluster,
+ mock_parent_execute,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_completion,
+ mock_fetch_requested_container_logs,
+ data_file,
+ ):
+ """
+ Verifies that find_spark_job picks a Pending Spark driver pod over a
Terminating.
+ """
+
+ task_name = "test_find_spark_job_prefers_pending_pod"
+ job_spec =
yaml.safe_load(data_file("spark/application_template.yaml").read_text())
+
+ mock_create_job_name.return_value = task_name
+ op = SparkKubernetesOperator(
+ template_spec=job_spec,
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ task_id=task_name,
+ get_logs=True,
+ reattach_on_restart=True,
+ )
+ context = create_context(op)
+
+ # Pending pod should be selected.
pending_pod = mock.MagicMock()
pending_pod.metadata.creation_timestamp = timezone.datetime(2025, 1,
1, tzinfo=timezone.utc)
- pending_pod.metadata.name = "spark-driver-pending"
+ pending_pod.metadata.name = "spark-driver"
pending_pod.metadata.labels = {"try_number": "1"}
pending_pod.status.phase = "Pending"
+ # Terminating pod should not be selected.
+ terminating_pod = mock.MagicMock()
+ terminating_pod.metadata.creation_timestamp = timezone.datetime(2025,
1, 1, tzinfo=timezone.utc)
+ terminating_pod.metadata.deletion_timestamp = timezone.datetime(2025,
1, 2, tzinfo=timezone.utc)
+ terminating_pod.metadata.name = "spark-driver"
+ terminating_pod.metadata.labels = {"try_number": "1"}
+ terminating_pod.status.phase = "Running"
+
mock_get_kube_client.list_namespaced_pod.return_value.items = [
- running_pod,
+ terminating_pod, # comes first but should be ignored, as it is
terminating
pending_pod,
]
returned_pod = op.find_spark_job(context)
- assert returned_pod is running_pod
+ assert returned_pod is pending_pod
+
+ def test_find_spark_job_picks_succeeded(
+ self,
+ mock_is_in_cluster,
+ mock_parent_execute,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_completion,
+ mock_fetch_requested_container_logs,
+ data_file,
+ ):
+ """
+ Verifies that find_spark_job picks a Succeeded Spark driver pod over a
non-Running pod.
+ """
+
+ task_name = "test_find_spark_job_prefers_succeeded_pod"
+ job_spec =
yaml.safe_load(data_file("spark/application_template.yaml").read_text())
+
+ mock_create_job_name.return_value = task_name
+ op = SparkKubernetesOperator(
+ template_spec=job_spec,
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ task_id=task_name,
+ get_logs=True,
+ reattach_on_restart=True,
+ )
+ context = create_context(op)
+
+ # Succeeded pod should be selected.
+ succeeded_pod = mock.MagicMock()
+ succeeded_pod.metadata.creation_timestamp = timezone.datetime(2025, 1,
1, tzinfo=timezone.utc)
+ succeeded_pod.metadata.name = "spark-driver"
+ succeeded_pod.metadata.labels = {"try_number": "1"}
+ succeeded_pod.status.phase = "Succeeded"
+
+ # Running pod should be selected.
+ running_pod = mock.MagicMock()
+ running_pod.metadata.creation_timestamp = timezone.datetime(2025, 1,
1, tzinfo=timezone.utc)
+ running_pod.metadata.name = "spark-driver"
+ running_pod.metadata.labels = {"try_number": "1"}
+ running_pod.status.phase = "Running"
+
+ # Terminating pod should not be selected.
+ terminating_pod = mock.MagicMock()
+ terminating_pod.metadata.creation_timestamp = timezone.datetime(2025,
1, 1, tzinfo=timezone.utc)
+ terminating_pod.metadata.deletion_timestamp = timezone.datetime(2025,
1, 2, tzinfo=timezone.utc)
+ terminating_pod.metadata.name = "spark-driver"
+ terminating_pod.metadata.labels = {"try_number": "1"}
+ terminating_pod.status.phase = "Running"
+
+ mock_get_kube_client.list_namespaced_pod.return_value.items = [
+ terminating_pod,
+ running_pod,
+ succeeded_pod,
+ ]
+
+ returned_pod = op.find_spark_job(context)
+
+ assert returned_pod is succeeded_pod
def test_find_spark_job_picks_latest_pod(
self,
@@ -1029,30 +1156,31 @@ class TestSparkKubernetesOperator:
get_logs=True,
reattach_on_restart=True,
)
- context = create_context(op)
- # Older pod that should be ignored.
- old_mock_pod = mock.MagicMock()
- old_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1,
1, tzinfo=timezone.utc)
- old_mock_pod.metadata.name = "spark-driver-old"
- old_mock_pod.status.phase = PodPhase.RUNNING
+ context = create_context(op)
- # Newer pod that should be picked up.
- new_mock_pod = mock.MagicMock()
- new_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1,
2, tzinfo=timezone.utc)
- new_mock_pod.metadata.name = "spark-driver-new"
- new_mock_pod.status.phase = PodPhase.RUNNING
+ # Latest pod should be selected.
+ new_pod = mock.MagicMock()
+ new_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 3,
tzinfo=timezone.utc)
+ new_pod.metadata.name = "spark-driver"
+ new_pod.metadata.labels = {"try_number": "1"}
+ new_pod.status.phase = "Pending"
- # Same try_number to simulate abrupt failure scenarios (e.g. scheduler
crash)
- # where cleanup did not occur and multiple pods share identical labels.
- old_mock_pod.metadata.labels = {"try_number": "1"}
- new_mock_pod.metadata.labels = {"try_number": "1"}
+ # Older pod should not be selected.
+ old_pod = mock.MagicMock()
+ old_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1,
tzinfo=timezone.utc)
+ old_pod.metadata.name = "spark-driver"
+ old_pod.metadata.labels = {"try_number": "1"}
+ old_pod.status.phase = "Running"
- mock_get_kube_client.list_namespaced_pod.return_value.items =
[old_mock_pod, new_mock_pod]
+ mock_get_kube_client.list_namespaced_pod.return_value.items = [
+ old_pod,
+ new_pod,
+ ]
returned_pod = op.find_spark_job(context)
- assert returned_pod is new_mock_pod
+ assert returned_pod is new_pod
def test_find_spark_job_tiebreaks_by_name(
self,