This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e57232ff8f Feature pass dictionary configuration in application_file 
in SparkKubernetesOperator (#35848)
e57232ff8f is described below

commit e57232ff8f5c312774a24d80c7ba0ad4e33cc204
Author: rom sharon <[email protected]>
AuthorDate: Sat Nov 25 14:05:55 2023 +0200

    Feature pass dictionary configuration in application_file in 
SparkKubernetesOperator (#35848)
---
 .../cncf/kubernetes/operators/spark_kubernetes.py  | 14 +++++++---
 .../kubernetes/operators/test_spark_kubernetes.py  | 32 ++++++++++++++++++++++
 2 files changed, 42 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py 
b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 82288ef631..879f54e0dc 100644
--- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -43,7 +43,7 @@ class SparkKubernetesOperator(BaseOperator):
         
https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.1.0-2.4.5/docs/api-docs.md#sparkapplication
 
     :param application_file: Defines Kubernetes 'custom_resource_definition' 
of 'sparkApplication' as either a
-        path to a '.yaml' file, '.json' file, YAML string or JSON string.
+        path to a '.yaml' file, '.json' file, YAML string or python dictionary.
     :param namespace: kubernetes namespace to put sparkApplication
     :param kubernetes_conn_id: The :ref:`kubernetes connection id 
<howto/connection:kubernetes>`
         for the to Kubernetes cluster.
@@ -59,7 +59,7 @@ class SparkKubernetesOperator(BaseOperator):
     def __init__(
         self,
         *,
-        application_file: str,
+        application_file: str | dict,
         namespace: str | None = None,
         kubernetes_conn_id: str = "kubernetes_default",
         api_group: str = "sparkoperator.k8s.io",
@@ -111,7 +111,10 @@ class SparkKubernetesOperator(BaseOperator):
                 raise
 
     def execute(self, context: Context):
-        body = _load_body_to_dict(self.application_file)
+        if isinstance(self.application_file, str):
+            body = _load_body_to_dict(self.application_file)
+        else:
+            body = self.application_file
         name = body["metadata"]["name"]
         namespace = self.namespace or self.hook.get_namespace()
 
@@ -177,7 +180,10 @@ class SparkKubernetesOperator(BaseOperator):
         return response
 
     def on_kill(self) -> None:
-        body = _load_body_to_dict(self.application_file)
+        if isinstance(self.application_file, str):
+            body = _load_body_to_dict(self.application_file)
+        else:
+            body = self.application_file
         name = body["metadata"]["name"]
         namespace = self.namespace or self.hook.get_namespace()
         self.hook.delete_custom_object(
diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py 
b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
index 7e27940642..76369bfd50 100644
--- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -171,3 +171,35 @@ def test_on_kill(mock_kubernetes_hook, 
mock_load_body_to_dict):
         namespace="default",
         name="spark-app",
     )
+
+
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
+def test_execute_with_application_file_dict(mock_kubernetes_hook):
+    op = SparkKubernetesOperator(task_id="task_id", 
application_file={"metadata": {"name": "spark-app"}})
+    mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
+
+    op.execute({})
+
+    
mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
+        group="sparkoperator.k8s.io",
+        version="v1beta2",
+        plural="sparkapplications",
+        body={"metadata": {"name": "spark-app"}},
+        namespace="default",
+    )
+
+
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
+def test_on_kill_with_application_file_dict(mock_kubernetes_hook):
+    op = SparkKubernetesOperator(task_id="task_id", 
application_file={"metadata": {"name": "spark-app"}})
+    mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
+
+    op.on_kill()
+
+    
mock_kubernetes_hook.return_value.delete_custom_object.assert_called_once_with(
+        group="sparkoperator.k8s.io",
+        version="v1beta2",
+        plural="sparkapplications",
+        name="spark-app",
+        namespace="default",
+    )

Reply via email to