This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 97871a0378 Fix SparkKubernetesOperator when using initContainers 
(#38119)
97871a0378 is described below

commit 97871a0378be7b89e8a4aef6ede31c9a884413e8
Author: Roman Sheludko <[email protected]>
AuthorDate: Wed May 1 08:05:38 2024 +0200

    Fix SparkKubernetesOperator when using initContainers (#38119)
    
    * Fix SparkKubernetesOperator when using initContainers
    
    * add tests
    
    * lint
---
 .../kubernetes/operators/custom_object_launcher.py |  2 +-
 .../operators/test_custom_object_launcher.py       | 70 +++++++++++++++++++++-
 2 files changed, 70 insertions(+), 2 deletions(-)

diff --git 
a/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py 
b/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
index 77d99a0fba..8e2edc2606 100644
--- a/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
+++ b/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
@@ -344,7 +344,7 @@ class CustomObjectLauncher(LoggingMixin):
             waiting_message = waiting_status.message
         except Exception:
             return
-        if waiting_reason != "ContainerCreating":
+        if waiting_reason not in ("ContainerCreating", "PodInitializing"):
             raise AirflowException(f"Spark Job Failed. Status: 
{waiting_reason}, Error: {waiting_message}")
 
     def delete_spark_job(self, spark_job_name=None):
diff --git 
a/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py 
b/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
index 3a57fdefdb..244fcf6fd2 100644
--- a/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
+++ b/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
@@ -16,17 +16,48 @@
 # under the License.
 from __future__ import annotations
 
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
 
 import pytest
+from kubernetes.client import (
+    V1ContainerState,
+    V1ContainerStateWaiting,
+    V1ContainerStatus,
+    V1Pod,
+    V1PodStatus,
+)
 
 from airflow.exceptions import AirflowException
 from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import 
(
+    CustomObjectLauncher,
     SparkJobSpec,
     SparkResources,
 )
 
 
[email protected]
+def mock_launcher():
+    launcher = CustomObjectLauncher(
+        name="test-spark-job",
+        namespace="default",
+        kube_client=MagicMock(),
+        custom_obj_api=MagicMock(),
+        template_body={
+            "spark": {
+                "spec": {
+                    "image": "gcr.io/spark-operator/spark-py:v3.0.0",
+                    "driver": {},
+                    "executor": {},
+                },
+                "apiVersion": "sparkoperator.k8s.io/v1beta2",
+                "kind": "SparkApplication",
+            },
+        },
+    )
+    launcher.pod_spec = V1Pod()
+    return launcher
+
+
 class TestSparkJobSpec:
     
@patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.SparkJobSpec.update_resources")
     
@patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.SparkJobSpec.validate")
@@ -150,3 +181,40 @@ class TestSparkResources:
         assert spark_resources.executor["cpu"]["limit"] == "4"
         assert spark_resources.driver["gpu"]["quantity"] == 1
         assert spark_resources.executor["gpu"]["quantity"] == 2
+
+
+class TestCustomObjectLauncher:
+    def get_pod_status(self, reason: str, message: str | None = None):
+        return V1PodStatus(
+            container_statuses=[
+                V1ContainerStatus(
+                    image="test",
+                    image_id="test",
+                    name="test",
+                    ready=False,
+                    restart_count=0,
+                    state=V1ContainerState(
+                        waiting=V1ContainerStateWaiting(
+                            reason=reason,
+                            message=message,
+                        ),
+                    ),
+                ),
+            ]
+        )
+
+    
@patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.PodManager")
+    def test_check_pod_start_failure_no_error(self, mock_pod_manager, 
mock_launcher):
+        mock_pod_manager.return_value.read_pod.return_value.status = 
self.get_pod_status("ContainerCreating")
+        mock_launcher.check_pod_start_failure()
+
+        mock_pod_manager.return_value.read_pod.return_value.status = 
self.get_pod_status("PodInitializing")
+        mock_launcher.check_pod_start_failure()
+
+    
@patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.PodManager")
+    def test_check_pod_start_failure_with_error(self, mock_pod_manager, 
mock_launcher):
+        mock_pod_manager.return_value.read_pod.return_value.status = 
self.get_pod_status(
+            "CrashLoopBackOff", "Error message"
+        )
+        with pytest.raises(AirflowException):
+            mock_launcher.check_pod_start_failure()

Reply via email to