This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e57232ff8f Feature pass dictionary configuration in application_file
in SparkKubernetesOperator (#35848)
e57232ff8f is described below
commit e57232ff8f5c312774a24d80c7ba0ad4e33cc204
Author: rom sharon <[email protected]>
AuthorDate: Sat Nov 25 14:05:55 2023 +0200
Feature pass dictionary configuration in application_file in
SparkKubernetesOperator (#35848)
---
.../cncf/kubernetes/operators/spark_kubernetes.py | 14 +++++++---
.../kubernetes/operators/test_spark_kubernetes.py | 32 ++++++++++++++++++++++
2 files changed, 42 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 82288ef631..879f54e0dc 100644
--- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -43,7 +43,7 @@ class SparkKubernetesOperator(BaseOperator):
https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.1.0-2.4.5/docs/api-docs.md#sparkapplication
:param application_file: Defines Kubernetes 'custom_resource_definition'
of 'sparkApplication' as either a
- path to a '.yaml' file, '.json' file, YAML string or JSON string.
+ path to a '.yaml' file, '.json' file, YAML string or python dictionary.
:param namespace: kubernetes namespace to put sparkApplication
:param kubernetes_conn_id: The :ref:`kubernetes connection id
<howto/connection:kubernetes>`
for the to Kubernetes cluster.
@@ -59,7 +59,7 @@ class SparkKubernetesOperator(BaseOperator):
def __init__(
self,
*,
- application_file: str,
+ application_file: str | dict,
namespace: str | None = None,
kubernetes_conn_id: str = "kubernetes_default",
api_group: str = "sparkoperator.k8s.io",
@@ -111,7 +111,10 @@ class SparkKubernetesOperator(BaseOperator):
raise
def execute(self, context: Context):
- body = _load_body_to_dict(self.application_file)
+ if isinstance(self.application_file, str):
+ body = _load_body_to_dict(self.application_file)
+ else:
+ body = self.application_file
name = body["metadata"]["name"]
namespace = self.namespace or self.hook.get_namespace()
@@ -177,7 +180,10 @@ class SparkKubernetesOperator(BaseOperator):
return response
def on_kill(self) -> None:
- body = _load_body_to_dict(self.application_file)
+ if isinstance(self.application_file, str):
+ body = _load_body_to_dict(self.application_file)
+ else:
+ body = self.application_file
name = body["metadata"]["name"]
namespace = self.namespace or self.hook.get_namespace()
self.hook.delete_custom_object(
diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
index 7e27940642..76369bfd50 100644
--- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -171,3 +171,35 @@ def test_on_kill(mock_kubernetes_hook,
mock_load_body_to_dict):
namespace="default",
name="spark-app",
)
+
+
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
+def test_execute_with_application_file_dict(mock_kubernetes_hook):
+ op = SparkKubernetesOperator(task_id="task_id",
application_file={"metadata": {"name": "spark-app"}})
+ mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
+
+ op.execute({})
+
+
mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
+ group="sparkoperator.k8s.io",
+ version="v1beta2",
+ plural="sparkapplications",
+ body={"metadata": {"name": "spark-app"}},
+ namespace="default",
+ )
+
+
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
+def test_on_kill_with_application_file_dict(mock_kubernetes_hook):
+ op = SparkKubernetesOperator(task_id="task_id",
application_file={"metadata": {"name": "spark-app"}})
+ mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
+
+ op.on_kill()
+
+
mock_kubernetes_hook.return_value.delete_custom_object.assert_called_once_with(
+ group="sparkoperator.k8s.io",
+ version="v1beta2",
+ plural="sparkapplications",
+ name="spark-app",
+ namespace="default",
+ )