This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 97871a0378 Fix SparkKubernetesOperator when using initContainers
(#38119)
97871a0378 is described below
commit 97871a0378be7b89e8a4aef6ede31c9a884413e8
Author: Roman Sheludko <[email protected]>
AuthorDate: Wed May 1 08:05:38 2024 +0200
Fix SparkKubernetesOperator when using initContainers (#38119)
* Fix SparkKubernetesOperator when using initContainers
* add tests
* lint
---
.../kubernetes/operators/custom_object_launcher.py | 2 +-
.../operators/test_custom_object_launcher.py | 70 +++++++++++++++++++++-
2 files changed, 70 insertions(+), 2 deletions(-)
diff --git
a/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
b/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
index 77d99a0fba..8e2edc2606 100644
--- a/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
+++ b/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
@@ -344,7 +344,7 @@ class CustomObjectLauncher(LoggingMixin):
waiting_message = waiting_status.message
except Exception:
return
- if waiting_reason != "ContainerCreating":
+ if waiting_reason not in ("ContainerCreating", "PodInitializing"):
raise AirflowException(f"Spark Job Failed. Status:
{waiting_reason}, Error: {waiting_message}")
def delete_spark_job(self, spark_job_name=None):
diff --git
a/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
b/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
index 3a57fdefdb..244fcf6fd2 100644
--- a/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
+++ b/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
@@ -16,17 +16,48 @@
# under the License.
from __future__ import annotations
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
import pytest
+from kubernetes.client import (
+ V1ContainerState,
+ V1ContainerStateWaiting,
+ V1ContainerStatus,
+ V1Pod,
+ V1PodStatus,
+)
from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import
(
+ CustomObjectLauncher,
SparkJobSpec,
SparkResources,
)
[email protected]
+def mock_launcher():
+ launcher = CustomObjectLauncher(
+ name="test-spark-job",
+ namespace="default",
+ kube_client=MagicMock(),
+ custom_obj_api=MagicMock(),
+ template_body={
+ "spark": {
+ "spec": {
+ "image": "gcr.io/spark-operator/spark-py:v3.0.0",
+ "driver": {},
+ "executor": {},
+ },
+ "apiVersion": "sparkoperator.k8s.io/v1beta2",
+ "kind": "SparkApplication",
+ },
+ },
+ )
+ launcher.pod_spec = V1Pod()
+ return launcher
+
+
class TestSparkJobSpec:
@patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.SparkJobSpec.update_resources")
@patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.SparkJobSpec.validate")
@@ -150,3 +181,40 @@ class TestSparkResources:
assert spark_resources.executor["cpu"]["limit"] == "4"
assert spark_resources.driver["gpu"]["quantity"] == 1
assert spark_resources.executor["gpu"]["quantity"] == 2
+
+
+class TestCustomObjectLauncher:
+ def get_pod_status(self, reason: str, message: str | None = None):
+ return V1PodStatus(
+ container_statuses=[
+ V1ContainerStatus(
+ image="test",
+ image_id="test",
+ name="test",
+ ready=False,
+ restart_count=0,
+ state=V1ContainerState(
+ waiting=V1ContainerStateWaiting(
+ reason=reason,
+ message=message,
+ ),
+ ),
+ ),
+ ]
+ )
+
+
@patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.PodManager")
+ def test_check_pod_start_failure_no_error(self, mock_pod_manager,
mock_launcher):
+ mock_pod_manager.return_value.read_pod.return_value.status =
self.get_pod_status("ContainerCreating")
+ mock_launcher.check_pod_start_failure()
+
+ mock_pod_manager.return_value.read_pod.return_value.status =
self.get_pod_status("PodInitializing")
+ mock_launcher.check_pod_start_failure()
+
+
@patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.PodManager")
+ def test_check_pod_start_failure_with_error(self, mock_pod_manager,
mock_launcher):
+ mock_pod_manager.return_value.read_pod.return_value.status =
self.get_pod_status(
+ "CrashLoopBackOff", "Error message"
+ )
+ with pytest.raises(AirflowException):
+ mock_launcher.check_pod_start_failure()