This is an automated email from the ASF dual-hosted git repository.
bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new aa25affec6 Add SparkKubernetesOperator crd implementation (#22253)
aa25affec6 is described below
commit aa25affec68fe1ddcaa162ecfbd4199156bb88d1
Author: Hamed <[email protected]>
AuthorDate: Fri Jan 12 13:26:30 2024 +0000
Add SparkKubernetesOperator crd implementation (#22253)
---------
Co-authored-by: Ephraim Anierobi <[email protected]>
Co-authored-by: Hussein Awala <[email protected]>
Co-authored-by: bolkedebruin <[email protected]>
Co-authored-by: Hamed Saljooghinejad <[email protected]>
---
.pre-commit-config.yaml | 1 +
.../kubernetes/operators/custom_object_launcher.py | 367 +++++++++++++
airflow/providers/cncf/kubernetes/operators/pod.py | 12 +-
.../cncf/kubernetes/operators/spark_kubernetes.py | 357 ++++++++-----
airflow/providers/cncf/kubernetes/provider.yaml | 1 +
.../cncf/kubernetes/resource_convert/__init__.py | 16 +
.../cncf/kubernetes/resource_convert/configmap.py | 52 ++
.../kubernetes/resource_convert/env_variable.py | 39 ++
.../cncf/kubernetes/resource_convert/secret.py | 40 ++
.../operators.rst | 304 +++++++++++
docs/spelling_wordlist.txt | 7 +
.../operators/spark_application_template.yaml | 147 ++++++
.../operators/spark_application_test.json | 58 +++
.../operators/spark_application_test.yaml | 56 ++
.../operators/test_custom_object_launcher.py | 16 +
.../kubernetes/operators/test_spark_kubernetes.py | 577 +++++++++++++++------
.../cncf/kubernetes/resource_convert/__init__.py | 16 +
.../kubernetes/resource_convert/test_configmap.py | 16 +
.../resource_convert/test_env_variable.py | 16 +
.../kubernetes/resource_convert/test_secret.py | 16 +
.../cncf/kubernetes/example_spark_kubernetes.py | 5 +-
.../cncf/kubernetes/spark_job_template.yaml | 149 ++++++
22 files changed, 1980 insertions(+), 288 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index da64ff457f..291c087b89 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -570,6 +570,7 @@ repos:
^docs/apache-airflow-providers-fab/auth-manager/webserver-authentication.rst$|
^docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst$|
^docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst$|
+ ^docs/apache-airflow-providers-cncf-kubernetes/operators.rst$|
^docs/conf.py$|
^docs/exts/removemarktransform.py$|
^scripts/ci/pre_commit/pre_commit_vendor_k8s_json_schema.py$|
diff --git
a/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
b/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
new file mode 100644
index 0000000000..cd0317c08c
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py
@@ -0,0 +1,367 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Launches Custom object."""
+from __future__ import annotations
+
+import time
+from copy import deepcopy
+from datetime import datetime as dt
+from functools import cached_property
+
+import tenacity
+from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
+from kubernetes.client.rest import ApiException
+
+from airflow.exceptions import AirflowException
+from airflow.providers.cncf.kubernetes.resource_convert.configmap import (
+ convert_configmap,
+ convert_configmap_to_volume,
+)
+from airflow.providers.cncf.kubernetes.resource_convert.env_variable import
convert_env_vars
+from airflow.providers.cncf.kubernetes.resource_convert.secret import (
+ convert_image_pull_secrets,
+ convert_secret,
+)
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+def should_retry_start_spark_job(exception: BaseException) -> bool:
+ """Check if an Exception indicates a transient error and warrants
retrying."""
+ if isinstance(exception, ApiException):
+ return exception.status == 409
+ return False
+
+
+class SparkJobSpec:
+ """Spark job spec."""
+
+ def __init__(self, **entries):
+ self.__dict__.update(entries)
+ self.validate()
+ self.update_resources()
+
+ def validate(self):
+ if self.spec.get("dynamicAllocation", {}).get("enabled"):
+ if not all(
+ [
+ self.spec["dynamicAllocation"]["initialExecutors"],
+ self.spec["dynamicAllocation"]["minExecutors"],
+ self.spec["dynamicAllocation"]["maxExecutors"],
+ ]
+ ):
+ raise AirflowException("Make sure initial/min/max value for
dynamic allocation is passed")
+
+ def update_resources(self):
+ if self.spec["driver"].get("container_resources"):
+ spark_resources = SparkResources(
+ self.spec["driver"].pop("container_resources"),
+ self.spec["executor"].pop("container_resources"),
+ )
+ self.spec["driver"].update(spark_resources.resources["driver"])
+ self.spec["executor"].update(spark_resources.resources["executor"])
+
+
+class KubernetesSpec:
+ """Spark kubernetes spec."""
+
+ def __init__(self, **entries):
+ self.__dict__.update(entries)
+ self.set_attribute()
+
+ def set_attribute(self):
+ self.env_vars = convert_env_vars(self.env_vars) if self.env_vars else
[]
+ self.image_pull_secrets = (
+ convert_image_pull_secrets(self.image_pull_secrets) if
self.image_pull_secrets else []
+ )
+ if self.config_map_mounts:
+ vols, vols_mounts =
convert_configmap_to_volume(self.config_map_mounts)
+ self.volumes.extend(vols)
+ self.volume_mounts.extend(vols_mounts)
+ if self.from_env_config_map:
+ self.env_from.extend([convert_configmap(c_name) for c_name in
self.from_env_config_map])
+ if self.from_env_secret:
+ self.env_from.extend([convert_secret(c) for c in
self.from_env_secret])
+
+
+class SparkResources:
+ """spark resources."""
+
+ def __init__(
+ self,
+ driver: dict | None = None,
+ executor: dict | None = None,
+ ):
+ self.default = {
+ "gpu": {"name": None, "quantity": 0},
+ "cpu": {"request": None, "limit": None},
+ "memory": {"request": None, "limit": None},
+ }
+ self.driver = deepcopy(self.default)
+ self.executor = deepcopy(self.default)
+ if driver:
+ self.driver.update(driver)
+ if executor:
+ self.executor.update(executor)
+ self.convert_resources()
+
+ @property
+ def resources(self):
+ """Return job resources."""
+ return {"driver": self.driver_resources, "executor":
self.executor_resources}
+
+ @property
+ def driver_resources(self):
+ """Return resources to use."""
+ driver = {}
+ if self.driver["cpu"].get("request"):
+ driver["cores"] = self.driver["cpu"]["request"]
+ if self.driver["cpu"].get("limit"):
+ driver["coreLimit"] = self.driver["cpu"]["limit"]
+ if self.driver["memory"].get("limit"):
+ driver["memory"] = self.driver["memory"]["limit"]
+ if self.driver["gpu"].get("name") and
self.driver["gpu"].get("quantity"):
+ driver["gpu"] = {"name": self.driver["gpu"]["name"], "quantity":
self.driver["gpu"]["quantity"]}
+ return driver
+
+ @property
+ def executor_resources(self):
+ """Return resources to use."""
+ executor = {}
+ if self.executor["cpu"].get("request"):
+ executor["cores"] = self.executor["cpu"]["request"]
+ if self.executor["cpu"].get("limit"):
+ executor["coreLimit"] = self.executor["cpu"]["limit"]
+ if self.executor["memory"].get("limit"):
+ executor["memory"] = self.executor["memory"]["limit"]
+ if self.executor["gpu"].get("name") and
self.executor["gpu"].get("quantity"):
+ executor["gpu"] = {
+ "name": self.executor["gpu"]["name"],
+ "quantity": self.executor["gpu"]["quantity"],
+ }
+ return executor
+
+ def convert_resources(self):
+ if isinstance(self.driver["memory"].get("limit"), str):
+ if "G" in self.driver["memory"]["limit"] or "Gi" in
self.driver["memory"]["limit"]:
+ self.driver["memory"]["limit"] =
float(self.driver["memory"]["limit"].rstrip("Gi G")) * 1024
+ elif "m" in self.driver["memory"]["limit"]:
+ self.driver["memory"]["limit"] =
float(self.driver["memory"]["limit"].rstrip("m"))
+ # Adjusting the memory value as operator adds 40% to the given
value
+ self.driver["memory"]["limit"] =
str(int(self.driver["memory"]["limit"] / 1.4)) + "m"
+
+ if isinstance(self.executor["memory"].get("limit"), str):
+ if "G" in self.executor["memory"]["limit"] or "Gi" in
self.executor["memory"]["limit"]:
+ self.executor["memory"]["limit"] = (
+ float(self.executor["memory"]["limit"].rstrip("Gi G")) *
1024
+ )
+ elif "m" in self.executor["memory"]["limit"]:
+ self.executor["memory"]["limit"] =
float(self.executor["memory"]["limit"].rstrip("m"))
+ # Adjusting the memory value as operator adds 40% to the given
value
+ self.executor["memory"]["limit"] =
str(int(self.executor["memory"]["limit"] / 1.4)) + "m"
+
+ if self.driver["cpu"].get("request"):
+ self.driver["cpu"]["request"] =
int(float(self.driver["cpu"]["request"]))
+ if self.driver["cpu"].get("limit"):
+ self.driver["cpu"]["limit"] = str(self.driver["cpu"]["limit"])
+ if self.executor["cpu"].get("request"):
+ self.executor["cpu"]["request"] =
int(float(self.executor["cpu"]["request"]))
+ if self.executor["cpu"].get("limit"):
+ self.executor["cpu"]["limit"] = str(self.executor["cpu"]["limit"])
+
+ if self.driver["gpu"].get("quantity"):
+ self.driver["gpu"]["quantity"] =
int(float(self.driver["gpu"]["quantity"]))
+ if self.executor["gpu"].get("quantity"):
+ self.executor["gpu"]["quantity"] =
int(float(self.executor["gpu"]["quantity"]))
+
+
+class CustomObjectStatus:
+ """Status of the PODs."""
+
+ SUBMITTED = "SUBMITTED"
+ RUNNING = "RUNNING"
+ FAILED = "FAILED"
+ SUCCEEDED = "SUCCEEDED"
+
+
+class CustomObjectLauncher(LoggingMixin):
+ """Launches PODS."""
+
+ def __init__(
+ self,
+ name: str | None,
+ namespace: str | None,
+ kube_client: CoreV1Api,
+ custom_obj_api: CustomObjectsApi,
+ template_body: str | None = None,
+ ):
+ """
+ Creates custom object launcher(sparkapplications crd).
+
+ :param kube_client: kubernetes client.
+ """
+ super().__init__()
+ self.name = name
+ self.namespace = namespace
+ self.template_body = template_body
+ self.body: dict = self.get_body()
+ self.kind = self.body["kind"]
+ self.plural = f"{self.kind.lower()}s"
+ if self.body.get("apiVersion"):
+ self.api_group, self.api_version =
self.body["apiVersion"].split("/")
+ else:
+ self.api_group = self.body["apiGroup"]
+ self.api_version = self.body["version"]
+ self._client = kube_client
+ self.custom_obj_api = custom_obj_api
+ self.spark_obj_spec: dict = {}
+ self.pod_spec: k8s.V1Pod | None = None
+
+ @cached_property
+ def pod_manager(self) -> PodManager:
+ return PodManager(kube_client=self._client)
+
+ def get_body(self):
+ self.body: dict = SparkJobSpec(**self.template_body["spark"])
+ self.body.metadata = {"name": self.name, "namespace": self.namespace}
+ if self.template_body.get("kubernetes"):
+ k8s_spec: dict = KubernetesSpec(**self.template_body["kubernetes"])
+ self.body.spec["volumes"] = k8s_spec.volumes
+ if k8s_spec.image_pull_secrets:
+ self.body.spec["imagePullSecrets"] =
k8s_spec.image_pull_secrets
+ for item in ["driver", "executor"]:
+ # Env List
+ self.body.spec[item]["env"] = k8s_spec.env_vars
+ self.body.spec[item]["envFrom"] = k8s_spec.env_from
+ # Volumes
+ self.body.spec[item]["volumeMounts"] = k8s_spec.volume_mounts
+ # Add affinity
+ self.body.spec[item]["affinity"] = k8s_spec.affinity
+ self.body.spec[item]["tolerations"] = k8s_spec.tolerations
+ self.body.spec[item]["nodeSelector"] = k8s_spec.node_selector
+ # Labels
+ self.body.spec[item]["labels"] = self.body.spec["labels"]
+
+ return self.body.__dict__
+
+ @tenacity.retry(
+ stop=tenacity.stop_after_attempt(3),
+ wait=tenacity.wait_random_exponential(),
+ reraise=True,
+ retry=tenacity.retry_if_exception(should_retry_start_spark_job),
+ )
+ def start_spark_job(self, image=None, code_path=None, startup_timeout: int
= 600):
+ """
+ Launches the pod synchronously and waits for completion.
+
+ :param image: image name
+ :param code_path: path to the .py file for python and jar file for
scala
+ :param startup_timeout: Timeout for startup of the pod (if pod is
pending for too long, fails task)
+ :return:
+ """
+ try:
+ if image:
+ self.body["spec"]["image"] = image
+ if code_path:
+ self.body["spec"]["mainApplicationFile"] = code_path
+ self.log.debug("Spark Job Creation Request Submitted")
+ self.spark_obj_spec =
self.custom_obj_api.create_namespaced_custom_object(
+ group=self.api_group,
+ version=self.api_version,
+ namespace=self.namespace,
+ plural=self.plural,
+ body=self.body,
+ )
+ self.log.debug("Spark Job Creation Response: %s",
self.spark_obj_spec)
+
+ # Wait for the driver pod to come alive
+ self.pod_spec = k8s.V1Pod(
+ metadata=k8s.V1ObjectMeta(
+ labels=self.spark_obj_spec["spec"]["driver"]["labels"],
+ name=self.spark_obj_spec["metadata"]["name"] + "-driver",
+ namespace=self.namespace,
+ )
+ )
+ curr_time = dt.now()
+ while self.spark_job_not_running(self.spark_obj_spec):
+ self.log.warning(
+ "Spark job submitted but not yet started. job_id: %s",
+ self.spark_obj_spec["metadata"]["name"],
+ )
+ self.check_pod_start_failure()
+ delta = dt.now() - curr_time
+ if delta.total_seconds() >= startup_timeout:
+ pod_status =
self.pod_manager.read_pod(self.pod_spec).status.container_statuses
+ raise AirflowException(f"Job took too long to start. pod
status: {pod_status}")
+ time.sleep(10)
+ except Exception as e:
+ self.log.exception("Exception when attempting to create spark job")
+ raise e
+
+ return self.pod_spec, self.spark_obj_spec
+
+ def spark_job_not_running(self, spark_obj_spec):
+ """Tests if spark_obj_spec has not started."""
+ spark_job_info =
self.custom_obj_api.get_namespaced_custom_object_status(
+ group=self.api_group,
+ version=self.api_version,
+ namespace=self.namespace,
+ name=spark_obj_spec["metadata"]["name"],
+ plural=self.plural,
+ )
+ driver_state = spark_job_info.get("status",
{}).get("applicationState", {}).get("state", "SUBMITTED")
+ if driver_state == CustomObjectStatus.FAILED:
+ err = spark_job_info.get("status", {}).get("applicationState",
{}).get("errorMessage", "N/A")
+ try:
+ self.pod_manager.fetch_container_logs(
+ pod=self.pod_spec, container_name="spark-kubernetes-driver"
+ )
+ except Exception:
+ pass
+ raise AirflowException(f"Spark Job Failed. Error stack: {err}")
+ return driver_state == CustomObjectStatus.SUBMITTED
+
+ def check_pod_start_failure(self):
+ try:
+ waiting_status = (
+
self.pod_manager.read_pod(self.pod_spec).status.container_statuses[0].state.waiting
+ )
+ waiting_reason = waiting_status.reason
+ waiting_message = waiting_status.message
+ except Exception:
+ return
+ if waiting_reason != "ContainerCreating":
+ raise AirflowException(f"Spark Job Failed. Status:
{waiting_reason}, Error: {waiting_message}")
+
+ def delete_spark_job(self, spark_job_name=None):
+ """Deletes spark job."""
+ spark_job_name = spark_job_name or self.spark_obj_spec.get("metadata",
{}).get("name")
+ if not spark_job_name:
+ self.log.warning("Spark job not found: %s", spark_job_name)
+ return
+ try:
+ self.custom_obj_api.delete_namespaced_custom_object(
+ group=self.api_group,
+ version=self.api_version,
+ namespace=self.namespace,
+ plural=self.plural,
+ name=spark_job_name,
+ )
+ except ApiException as e:
+ # If the pod is already deleted
+ if e.status != 404:
+ raise
diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 6b483cd4be..8a701b675a 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -536,11 +536,13 @@ class KubernetesPodOperator(BaseOperator):
def execute_sync(self, context: Context):
result = None
try:
- self.pod_request_obj = self.build_pod_request_obj(context)
- self.pod = self.get_or_create_pod( # must set `self.pod` for
`on_kill`
- pod_request_obj=self.pod_request_obj,
- context=context,
- )
+ if self.pod_request_obj is None:
+ self.pod_request_obj = self.build_pod_request_obj(context)
+ if self.pod is None:
+ self.pod = self.get_or_create_pod( # must set `self.pod` for
`on_kill`
+ pod_request_obj=self.pod_request_obj,
+ context=context,
+ )
# push to xcom now so that if there is an error we still have the
values
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 879f54e0dc..83d6f484b3 100644
--- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -17,179 +17,264 @@
# under the License.
from __future__ import annotations
-import datetime
+import re
from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any
-from kubernetes.client import ApiException
-from kubernetes.watch import Watch
+from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
+from airflow.providers.cncf.kubernetes import pod_generator
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook,
_load_body_to_dict
+from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import
CustomObjectLauncher
+from airflow.providers.cncf.kubernetes.operators.pod import
KubernetesPodOperator
+from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN,
PodGenerator
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
+from airflow.utils.helpers import prune_dict
if TYPE_CHECKING:
- from kubernetes.client.models import CoreV1EventList
+ import jinja2
from airflow.utils.context import Context
-class SparkKubernetesOperator(BaseOperator):
+class SparkKubernetesOperator(KubernetesPodOperator):
"""
Creates sparkApplication object in kubernetes cluster.
.. seealso::
For more detail about Spark Application Object have a look at the
reference:
-
https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.1.0-2.4.5/docs/api-docs.md#sparkapplication
+
https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.3.3-3.1.1/docs/api-docs.md#sparkapplication
- :param application_file: Defines Kubernetes 'custom_resource_definition'
of 'sparkApplication' as either a
- path to a '.yaml' file, '.json' file, YAML string or python dictionary.
+ :param application_file: filepath to kubernetes custom_resource_definition
of sparkApplication
+ :param kubernetes_conn_id: the connection to Kubernetes cluster
+ :param image: Docker image you wish to launch. Defaults to hub.docker.com,
+ :param code_path: path to the spark code in image,
:param namespace: kubernetes namespace to put sparkApplication
- :param kubernetes_conn_id: The :ref:`kubernetes connection id
<howto/connection:kubernetes>`
- for the to Kubernetes cluster.
- :param api_group: kubernetes api group of sparkApplication
- :param api_version: kubernetes api version of sparkApplication
- :param watch: whether to watch the job status and logs or not
+ :param cluster_context: context of the cluster
+ :param application_file: yaml file if passed
+ :param get_logs: get the stdout of the container as logs of the tasks.
+ :param do_xcom_push: If True, the content of the file
+ /airflow/xcom/return.json in the container will also be pushed to an
+ XCom when the container completes.
+ :param success_run_history_limit: Number of past successful runs of the
application to keep.
+ :param delete_on_termination: What to do when the pod reaches its final
+ state, or the execution is interrupted. If True (default), delete the
+ pod; if False, leave the pod.
+ :param startup_timeout_seconds: timeout in seconds to startup the pod.
+ :param log_events_on_failure: Log the pod's events if a failure occurs
+ :param reattach_on_restart: if the scheduler dies while the pod is
running, reattach and monitor
"""
- template_fields: Sequence[str] = ("application_file", "namespace")
- template_ext: Sequence[str] = (".yaml", ".yml", ".json")
+ template_fields = ["application_file", "namespace", "template_spec"]
+ template_fields_renderers = {"template_spec": "py"}
+ template_ext = ("yaml", "yml", "json")
ui_color = "#f4a460"
def __init__(
self,
*,
- application_file: str | dict,
- namespace: str | None = None,
+ image: str | None = None,
+ code_path: str | None = None,
+ namespace: str = "default",
+ name: str = "default",
+ application_file: str | None = None,
+ template_spec=None,
+ get_logs: bool = True,
+ do_xcom_push: bool = False,
+ success_run_history_limit: int = 1,
+ startup_timeout_seconds=600,
+ log_events_on_failure: bool = False,
+ reattach_on_restart: bool = True,
+ delete_on_termination: bool = True,
kubernetes_conn_id: str = "kubernetes_default",
- api_group: str = "sparkoperator.k8s.io",
- api_version: str = "v1beta2",
- in_cluster: bool | None = None,
- cluster_context: str | None = None,
- config_file: str | None = None,
- watch: bool = False,
**kwargs,
) -> None:
- super().__init__(**kwargs)
- self.namespace = namespace
- self.kubernetes_conn_id = kubernetes_conn_id
- self.api_group = api_group
- self.api_version = api_version
- self.plural = "sparkapplications"
+ if kwargs.get("xcom_push") is not None:
+ raise AirflowException("'xcom_push' was deprecated, use
'do_xcom_push' instead")
+ super().__init__(name=name, **kwargs)
+ self.image = image
+ self.code_path = code_path
self.application_file = application_file
- self.in_cluster = in_cluster
- self.cluster_context = cluster_context
- self.config_file = config_file
- self.watch = watch
+ self.template_spec = template_spec
+ self.name = self.create_job_name()
+ self.kubernetes_conn_id = kubernetes_conn_id
+ self.startup_timeout_seconds = startup_timeout_seconds
+ self.reattach_on_restart = reattach_on_restart
+ self.delete_on_termination = delete_on_termination
+ self.do_xcom_push = do_xcom_push
+ self.namespace = namespace
+ self.get_logs = get_logs
+ self.log_events_on_failure = log_events_on_failure
+ self.success_run_history_limit = success_run_history_limit
+ self.template_body = self.manage_template_specs()
+
+ def _render_nested_template_fields(
+ self,
+ content: Any,
+ context: Context,
+ jinja_env: jinja2.Environment,
+ seen_oids: set,
+ ) -> None:
+ if id(content) not in seen_oids and isinstance(content, k8s.V1EnvVar):
+ seen_oids.add(id(content))
+ self._do_render_template_fields(content, ("value", "name"),
context, jinja_env, seen_oids)
+ return
+
+ super()._render_nested_template_fields(content, context, jinja_env,
seen_oids)
+
+ def manage_template_specs(self):
+ if self.application_file:
+ template_body = _load_body_to_dict(open(self.application_file))
+ elif self.template_spec:
+ template_body = self.template_spec
+ else:
+ raise AirflowException("either application_file or template_spec
should be passed")
+ if "spark" not in template_body:
+ template_body = {"spark": template_body}
+ return template_body
+
+ def create_job_name(self):
+ initial_name =
PodGenerator.make_unique_pod_id(self.task_id)[:MAX_LABEL_LEN]
+ return re.sub(r"[^a-z0-9-]+", "-", initial_name.lower())
+
+ @staticmethod
+ def _get_pod_identifying_label_string(labels) -> str:
+ filtered_labels = {label_id: label for label_id, label in
labels.items() if label_id != "try_number"}
+ return ",".join([label_id + "=" + label for label_id, label in
sorted(filtered_labels.items())])
+
+ @staticmethod
+ def create_labels_for_pod(context: dict | None = None, include_try_number:
bool = True) -> dict:
+ """
+ Generate labels for the pod to track the pod in case of Operator crash.
+
+ :param include_try_number: add try number to labels
+ :param context: task context provided by airflow DAG
+ :return: dict.
+ """
+ if not context:
+ return {}
+
+ ti = context["ti"]
+ run_id = context["run_id"]
+
+ labels = {
+ "dag_id": ti.dag_id,
+ "task_id": ti.task_id,
+ "run_id": run_id,
+ "spark_kubernetes_operator": "True",
+ # 'execution_date': context['ts'],
+ # 'try_number': context['ti'].try_number,
+ }
+
+ # If running on Airflow 2.3+:
+ map_index = getattr(ti, "map_index", -1)
+ if map_index >= 0:
+ labels["map_index"] = map_index
+
+ if include_try_number:
+ labels.update(try_number=ti.try_number)
+
+ # In the case of sub dags this is just useful
+ if context["dag"].is_subdag:
+ labels["parent_dag_id"] = context["dag"].parent_dag.dag_id
+ # Ensure that label is valid for Kube,
+ # and if not truncate/remove invalid chars and replace with short hash.
+ for label_id, label in labels.items():
+ safe_label = pod_generator.make_safe_label_value(str(label))
+ labels[label_id] = safe_label
+ return labels
+
+ @cached_property
+ def pod_manager(self) -> PodManager:
+ return PodManager(kube_client=self.client)
+
+ @staticmethod
+ def _try_numbers_match(context, pod) -> bool:
+ return pod.metadata.labels["try_number"] == context["ti"].try_number
+
+ def find_spark_job(self, context):
+ labels = self.create_labels_for_pod(context, include_try_number=False)
+ label_selector = self._get_pod_identifying_label_string(labels) +
",spark-role=driver"
+ pod_list = self.client.list_namespaced_pod(self.namespace,
label_selector=label_selector).items
+
+ pod = None
+ if len(pod_list) > 1: # and self.reattach_on_restart:
+ raise AirflowException(f"More than one pod running with labels:
{label_selector}")
+ elif len(pod_list) == 1:
+ pod = pod_list[0]
+ self.log.info(
+ "Found matching driver pod %s with labels %s",
pod.metadata.name, pod.metadata.labels
+ )
+ self.log.info("`try_number` of task_instance: %s",
context["ti"].try_number)
+ self.log.info("`try_number` of pod: %s",
pod.metadata.labels["try_number"])
+ return pod
+
+ def get_or_create_spark_crd(self, launcher: CustomObjectLauncher, context)
-> k8s.V1Pod:
+ if self.reattach_on_restart:
+ driver_pod = self.find_spark_job(context)
+ if driver_pod:
+ return driver_pod
+
+ driver_pod, spark_obj_spec = launcher.start_spark_job(
+ image=self.image, code_path=self.code_path,
startup_timeout=self.startup_timeout_seconds
+ )
+ return driver_pod
+
+ def process_pod_deletion(self, pod, *, reraise=True):
+ if pod is not None:
+ if self.delete_on_termination:
+ self.log.info("Deleting spark job: %s",
pod.metadata.name.replace("-driver", ""))
+
self.launcher.delete_spark_job(pod.metadata.name.replace("-driver", ""))
+ else:
+ self.log.info("skipping deleting spark job: %s",
pod.metadata.name)
@cached_property
def hook(self) -> KubernetesHook:
- return KubernetesHook(
+ hook = KubernetesHook(
conn_id=self.kubernetes_conn_id,
- in_cluster=self.in_cluster,
- config_file=self.config_file,
- cluster_context=self.cluster_context,
+ in_cluster=self.in_cluster or self.template_body.get("kubernetes",
{}).get("in_cluster", False),
+ config_file=self.config_file
+ or self.template_body.get("kubernetes",
{}).get("kube_config_file", None),
+ cluster_context=self.cluster_context
+ or self.template_body.get("kubernetes", {}).get("cluster_context",
None),
)
+ return hook
- def _get_namespace_event_stream(self, namespace, query_kwargs=None):
- try:
- return Watch().stream(
- self.hook.core_v1_client.list_namespaced_event,
- namespace=namespace,
- watch=True,
- **(query_kwargs or {}),
- )
- except ApiException as e:
- if e.status == 410: # Resource version is too old
- events: CoreV1EventList =
self.hook.core_v1_client.list_namespaced_event(
- namespace=namespace, watch=False
- )
- resource_version = events.metadata.resource_version
- query_kwargs["resource_version"] = resource_version
- return self._get_namespace_event_stream(namespace,
query_kwargs)
- else:
- raise
+ @cached_property
+ def client(self) -> CoreV1Api:
+ return self.hook.core_v1_client
- def execute(self, context: Context):
- if isinstance(self.application_file, str):
- body = _load_body_to_dict(self.application_file)
- else:
- body = self.application_file
- name = body["metadata"]["name"]
- namespace = self.namespace or self.hook.get_namespace()
-
- response = None
- is_job_created = False
- if self.watch:
- try:
- namespace_event_stream = self._get_namespace_event_stream(
- namespace=namespace,
- query_kwargs={
- "field_selector":
f"involvedObject.kind=SparkApplication,involvedObject.name={name}"
- },
- )
-
- response = self.hook.create_custom_object(
- group=self.api_group,
- version=self.api_version,
- plural=self.plural,
- body=body,
- namespace=namespace,
- )
- is_job_created = True
- for event in namespace_event_stream:
- obj = event["object"]
- if event["object"].last_timestamp >=
datetime.datetime.strptime(
- response["metadata"]["creationTimestamp"],
"%Y-%m-%dT%H:%M:%S%z"
- ):
- self.log.info(obj.message)
- if obj.reason == "SparkDriverRunning":
- pod_log_stream = Watch().stream(
-
self.hook.core_v1_client.read_namespaced_pod_log,
- name=f"{name}-driver",
- namespace=namespace,
- timestamps=True,
- )
- for line in pod_log_stream:
- self.log.info(line)
- elif obj.reason in [
- "SparkApplicationSubmissionFailed",
- "SparkApplicationFailed",
- "SparkApplicationDeleted",
- ]:
- is_job_created = False
- raise AirflowException(obj.message)
- elif obj.reason == "SparkApplicationCompleted":
- break
- else:
- continue
- except Exception:
- if is_job_created:
- self.on_kill()
- raise
+ @cached_property
+ def custom_obj_api(self) -> CustomObjectsApi:
+ return CustomObjectsApi()
- else:
- response = self.hook.create_custom_object(
- group=self.api_group,
- version=self.api_version,
- plural=self.plural,
- body=body,
- namespace=namespace,
- )
+ def execute(self, context: Context):
+ self.log.info("Creating sparkApplication.")
+ self.launcher = CustomObjectLauncher(
+ name=self.name,
+ namespace=self.namespace,
+ kube_client=self.client,
+ custom_obj_api=self.custom_obj_api,
+ template_body=self.template_body,
+ )
+ self.pod = self.get_or_create_spark_crd(self.launcher, context)
+ self.BASE_CONTAINER_NAME = "spark-kubernetes-driver"
+ self.pod_request_obj = self.launcher.pod_spec
- return response
+ return super().execute(context=context)
def on_kill(self) -> None:
- if isinstance(self.application_file, str):
- body = _load_body_to_dict(self.application_file)
- else:
- body = self.application_file
- name = body["metadata"]["name"]
- namespace = self.namespace or self.hook.get_namespace()
- self.hook.delete_custom_object(
- group=self.api_group,
- version=self.api_version,
- plural=self.plural,
- namespace=namespace,
- name=name,
- )
+ if self.launcher:
+ self.log.debug("Deleting spark job for task %s", self.task_id)
+ self.launcher.delete_spark_job()
+
+ def patch_already_checked(self, pod: k8s.V1Pod, *, reraise=True):
+ """Add an "already checked" annotation to ensure we don't reattach on
retries."""
+ pod.metadata.labels["already_checked"] = "True"
+ body = PodGenerator.serialize_pod(pod)
+ self.client.patch_namespaced_pod(pod.metadata.name,
pod.metadata.namespace, body)
+
+ def dry_run(self) -> None:
+ """Prints out the spark job that would be created by this operator."""
+ print(prune_dict(self.launcher.body, mode="strict"))
diff --git a/airflow/providers/cncf/kubernetes/provider.yaml
b/airflow/providers/cncf/kubernetes/provider.yaml
index a54edd6f3b..f9d570091e 100644
--- a/airflow/providers/cncf/kubernetes/provider.yaml
+++ b/airflow/providers/cncf/kubernetes/provider.yaml
@@ -112,6 +112,7 @@ integrations:
operators:
- integration-name: Kubernetes
python-modules:
+ - airflow.providers.cncf.kubernetes.operators.custom_object_launcher
- airflow.providers.cncf.kubernetes.operators.pod
- airflow.providers.cncf.kubernetes.operators.spark_kubernetes
- airflow.providers.cncf.kubernetes.operators.resource
diff --git a/airflow/providers/cncf/kubernetes/resource_convert/__init__.py
b/airflow/providers/cncf/kubernetes/resource_convert/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/resource_convert/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/cncf/kubernetes/resource_convert/configmap.py
b/airflow/providers/cncf/kubernetes/resource_convert/configmap.py
new file mode 100644
index 0000000000..ffd92922d6
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/resource_convert/configmap.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from kubernetes.client import models as k8s
+
+
+def convert_configmap(configmap_name) -> k8s.V1EnvFromSource:
+ """
+ Converts a str into an k8s object.
+
+ :param configmap_name: config map name
+ :return:
+ """
+ return
k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap_name))
+
+
+def convert_configmap_to_volume(
+ configmap_info: dict[str, str],
+) -> tuple[list[k8s.V1Volume], list[k8s.V1VolumeMount]]:
+ """
+ Converts a dictionary of config_map_name and mount_path into k8s volume
mount object and k8s volume.
+
+ :param configmap_info: a dictionary of {config_map_name: mount_path}
+ :return:
+ """
+ volume_mounts = []
+ volumes = []
+ for config_name, mount_path in configmap_info.items():
+ volume_mounts.append(k8s.V1VolumeMount(mount_path=mount_path,
name=config_name))
+ volumes.append(
+ k8s.V1Volume(
+ name=config_name,
+ config_map=k8s.V1ConfigMapVolumeSource(name=config_name),
+ )
+ )
+
+ return volumes, volume_mounts
diff --git a/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py
b/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py
new file mode 100644
index 0000000000..f7d48bf32a
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from kubernetes.client import models as k8s
+
+from airflow.exceptions import AirflowException
+
+
+def convert_env_vars(env_vars) -> list[k8s.V1EnvVar]:
+ """
+ Converts a dictionary of key:value into a list of env_vars.
+
+ :param env_vars:
+ :return:
+ """
+ if isinstance(env_vars, dict):
+ res = []
+ for k, v in env_vars.items():
+ res.append(k8s.V1EnvVar(name=k, value=v))
+ return res
+ elif isinstance(env_vars, list):
+ if all([isinstance(e, k8s.V1EnvVar) for e in env_vars]):
+ return env_vars
+ raise AirflowException(f"Expected dict or list of V1EnvVar, got
{type(env_vars)}")
diff --git a/airflow/providers/cncf/kubernetes/resource_convert/secret.py
b/airflow/providers/cncf/kubernetes/resource_convert/secret.py
new file mode 100644
index 0000000000..7d0e0f91ad
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/resource_convert/secret.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from kubernetes.client import models as k8s
+
+
+def convert_secret(secret_name: str) -> k8s.V1EnvFromSource:
+ """
+ Converts a str into an k8s object.
+
+ :param secret_name:
+ :return:
+ """
+ return
k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_name))
+
+
+def convert_image_pull_secrets(image_pull_secrets: str) ->
list[k8s.V1LocalObjectReference]:
+ """
+ Converts a image pull secret name into k8s local object reference.
+
+ :param image_pull_secrets: comma separated string that contains secrets
+ :return:
+ """
+ secrets = image_pull_secrets.split(",")
+ return [k8s.V1LocalObjectReference(name=secret) for secret in secrets]
diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
index 44685c31ab..690a857ea0 100644
--- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
+++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
@@ -16,6 +16,8 @@
under the License.
+.. contents:: Table of Contents
+ :depth: 2
.. _howto/operator:kubernetespodoperator:
@@ -199,3 +201,305 @@ For further information, look at:
* `Kubernetes Documentation <https://kubernetes.io/docs/home/>`__
* `Pull an Image from a Private Registry
<https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/>`__
+
+SparkKubernetesOperator
+==========================
+The
:class:`~airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator`
allows
+you to create and run spark job on a Kubernetes cluster. It is based on [
spark-on-k8s-operator
](https://github.com/GoogleCloudPlatform/spark-on-k8s-operator)project.
+
+This operator simplify the interface and accept different parameters to
configure and run spark application on Kubernetes.
+Similar to the KubernetesOperator, we have added the logic to wait for a job
after submission,
+manage error handling, retrieve logs from the driver pod and the ability to
delete a spark job.
+It also supports out-of-the-box Kubernetes functionalities such as handling of
volumes, config maps, secrets, etc.
+
+
+How does this operator work?
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+The operator initiates a Spark task by generating a SparkApplication Custom
Resource Definition (CRD) within Kubernetes.
+This SparkApplication task subsequently generates driver and required executor
pods, using the parameters specified by the user.
+The operator continuously monitors the task's progress until it either
succeeds or fails.
+It retrieves logs from the driver pod and displays them in the Airflow UI.
+
+
+Usage examples
+^^^^^^^^^^^^^^
+In order to create a SparkKubernetesOperator task, you must provide a basic
template that includes Spark configuration and
+Kubernetes-related resource configuration. This template, which can be in
either YAML or JSON format, serves as a
+starting point for the operator. Below is a sample template that you can
utilize:
+
+spark_job_template.yaml
+
+.. code-block:: yaml
+
+ spark:
+ apiVersion: sparkoperator.k8s.io/v1beta2
+ version: v1beta2
+ kind: SparkApplication
+ apiGroup: sparkoperator.k8s.io
+ metadata:
+ namespace: ds
+ spec:
+ type: Python
+ pythonVersion: "3"
+ mode: cluster
+ sparkVersion: 3.0.0
+ successfulRunHistoryLimit: 1
+ restartPolicy:
+ type: Never
+ imagePullPolicy: Always
+ hadoopConf: {}
+ imagePullSecrets: []
+ dynamicAllocation:
+ enabled: false
+ initialExecutors: 1
+ minExecutors: 1
+ maxExecutors: 1
+ labels: {}
+ driver:
+ serviceAccount: default
+ container_resources:
+ gpu:
+ name: null
+ quantity: 0
+ cpu:
+ request: null
+ limit: null
+ memory:
+ request: null
+ limit: null
+ executor:
+ instances: 1
+ container_resources:
+ gpu:
+ name: null
+ quantity: 0
+ cpu:
+ request: null
+ limit: null
+ memory:
+ request: null
+ limit: null
+ kubernetes:
+ # example:
+ # env_vars:
+ # - name: TEST_NAME
+ # value: TEST_VALUE
+ env_vars: []
+
+ # example:
+ # env_from:
+ # - name: test
+ # valueFrom:
+ # secretKeyRef:
+ # name: mongo-secret
+ # key: mongo-password
+ env_from: []
+
+ # example:
+ # node_selector:
+ # karpenter.sh/provisioner-name: spark
+ node_selector: {}
+
+ # example:
https://kubernetes.io/docs/concepts/scheduling-eviction/assign-pod-node/
+ # affinity:
+ # nodeAffinity:
+ # requiredDuringSchedulingIgnoredDuringExecution:
+ # nodeSelectorTerms:
+ # - matchExpressions:
+ # - key: beta.kubernetes.io/instance-type
+ # operator: In
+ # values:
+ # - r5.xlarge
+ affinity:
+ nodeAffinity: {}
+ podAffinity: {}
+ podAntiAffinity: {}
+
+ # example:
https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/
+ # type: list
+ # tolerations:
+ # - key: "key1"
+ # operator: "Equal"
+ # value: "value1"
+ # effect: "NoSchedule"
+ tolerations: []
+
+ # example:
+ # config_map_mounts:
+ # snowflake-default: /mnt/tmp
+ config_map_mounts: {}
+
+ # example:
+ # volume_mounts:
+ # - name: config
+ # mountPath: /airflow
+ volume_mounts: []
+
+ # https://kubernetes.io/docs/concepts/storage/volumes/
+ # example:
+ # volumes:
+ # - name: config
+ # persistentVolumeClaim:
+ # claimName: airflow
+ volumes: []
+
+ # read config map into an env variable
+ # example:
+ # from_env_config_map:
+ # - configmap_1
+ # - configmap_2
+ from_env_config_map: []
+
+ # load secret into an env variable
+ # example:
+ # from_env_secret:
+ # - secret_1
+ # - secret_2
+ from_env_secret: []
+
+ in_cluster: true
+ conn_id: kubernetes_default
+ kube_config_file: null
+ cluster_context: null
+
+.. important::
+
+ * The template file consists of two primary categories: ``spark`` and
``kubernetes``.
+
+ * spark: This segment encompasses the task's Spark configuration,
mirroring the structure of the Spark API template.
+
+ * kubernetes: This segment encompasses the task's Kubernetes resource
configuration, directly corresponding to the Kubernetes API Documentation. Each
resource type includes an example within the template.
+
+ * The designated base image to be utilized is
``gcr.io/spark-operator/spark-py:v3.1.1``.
+
+ * Ensure that the Spark code is either embedded within the image, mounted
using a persistentVolume, or accessible from an external location such as an S3
bucket.
+
+Next, create the task using the following:
+
+.. code-block:: python
+
+ SparkKubernetesOperator(
+ task_id="spark_task",
+ image="gcr.io/spark-operator/spark-py:v3.1.1", # OR custom image
using that
+ code_path="local://path/to/spark/code.py",
+ application_file="spark_job_template.json", # OR
spark_job_template.json
+ dag=dag,
+ )
+
+Note: Alternatively application_file can also be a json file. see below example
+
+spark_job_template.json
+
+.. code-block:: json
+
+ {
+ "spark": {
+ "apiVersion": "sparkoperator.k8s.io/v1beta2",
+ "version": "v1beta2",
+ "kind": "SparkApplication",
+ "apiGroup": "sparkoperator.k8s.io",
+ "metadata": {
+ "namespace": "ds"
+ },
+ "spec": {
+ "type": "Python",
+ "pythonVersion": "3",
+ "mode": "cluster",
+ "sparkVersion": "3.0.0",
+ "successfulRunHistoryLimit": 1,
+ "restartPolicy": {
+ "type": "Never"
+ },
+ "imagePullPolicy": "Always",
+ "hadoopConf": {},
+ "imagePullSecrets": [],
+ "dynamicAllocation": {
+ "enabled": false,
+ "initialExecutors": 1,
+ "minExecutors": 1,
+ "maxExecutors": 1
+ },
+ "labels": {},
+ "driver": {
+ "serviceAccount": "default",
+ "container_resources": {
+ "gpu": {
+ "name": null,
+ "quantity": 0
+ },
+ "cpu": {
+ "request": null,
+ "limit": null
+ },
+ "memory": {
+ "request": null,
+ "limit": null
+ }
+ }
+ },
+ "executor": {
+ "instances": 1,
+ "container_resources": {
+ "gpu": {
+ "name": null,
+ "quantity": 0
+ },
+ "cpu": {
+ "request": null,
+ "limit": null
+ },
+ "memory": {
+ "request": null,
+ "limit": null
+ }
+ }
+ }
+ }
+ },
+ "kubernetes": {
+ "env_vars": [],
+ "env_from": [],
+ "node_selector": {},
+ "affinity": {
+ "nodeAffinity": {},
+ "podAffinity": {},
+ "podAntiAffinity": {}
+ },
+ "tolerations": [],
+ "config_map_mounts": {},
+ "volume_mounts": [
+ {
+ "name": "config",
+ "mountPath": "/airflow"
+ }
+ ],
+ "volumes": [
+ {
+ "name": "config",
+ "persistentVolumeClaim": {
+ "claimName": "hsaljoog-airflow"
+ }
+ }
+ ],
+ "from_env_config_map": [],
+ "from_env_secret": [],
+ "in_cluster": true,
+ "conn_id": "kubernetes_default",
+ "kube_config_file": null,
+ "cluster_context": null
+ }
+ }
+
+
+
+An alternative method, apart from using YAML or JSON files, is to directly
pass the ``template_spec`` field instead of application_file
+if you prefer not to employ a file for configuration.
+
+
+Reference
+^^^^^^^^^
+For further information, look at:
+
+* `Kubernetes Documentation <https://kubernetes.io/docs/home/>`__
+* `Spark-on-k8s-operator Documentation - User guide
<https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/master/docs/user-guide.md>`__
+* `Spark-on-k8s-operator Documentation - API
<https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/master/docs/api-docs.md>`__
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 49fbaf4c98..6e5eaa07fd 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -314,6 +314,7 @@ coverals
cp
cpu
cpus
+CRD
crd
createDisposition
CreateQueryOperator
@@ -552,6 +553,7 @@ enum
Env
env
envFrom
+EnvFromSource
EnvVar
envvar
eof
@@ -900,6 +902,7 @@ kubeclient
kubeconfig
Kubernetes
kubernetes
+KubernetesPodOperator
Kusto
kv
kwarg
@@ -939,6 +942,7 @@ loadBalancerIP
localExecutor
localexecutor
localhost
+LocalObjectReference
localstack
lodash
logfile
@@ -1145,6 +1149,7 @@ Pem
pem
performant
permalinks
+persistentVolume
personalizations
pformat
PgBouncer
@@ -1442,6 +1447,7 @@ Spark
sparkappinfo
sparkApplication
sparkcmd
+SparkKubernetesOperator
SparkPi
SparkR
sparkr
@@ -1639,6 +1645,7 @@ TZ
tz
tzinfo
UA
+UI
ui
uid
ukey
diff --git
a/tests/providers/cncf/kubernetes/operators/spark_application_template.yaml
b/tests/providers/cncf/kubernetes/operators/spark_application_template.yaml
new file mode 100644
index 0000000000..ea3fcea2e7
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/operators/spark_application_template.yaml
@@ -0,0 +1,147 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+---
+spark:
+ apiVersion: sparkoperator.k8s.io/v1beta2
+ kind: SparkApplication
+ metadata:
+ name: default_yaml_template
+ spec:
+ type: Python
+ image: 'gcr.io/spark-operator/spark:v2.4.5'
+ pythonVersion: "3"
+ mode: cluster
+ mainApplicationFile: 'local:///opt/test.py'
+ sparkVersion: 3.0.0
+ successfulRunHistoryLimit: 1
+ restartPolicy:
+ type: Never
+ imagePullPolicy: Always
+ hadoopConf: {}
+ imagePullSecrets: ''
+ dynamicAllocation:
+ enabled: false
+ initialExecutors: 1
+ minExecutors: 1
+ maxExecutors: 1
+ labels: {}
+ driver:
+ serviceAccount: default
+ container_resources:
+ gpu:
+ name: null
+ quantity: 0
+ cpu:
+ request: 1
+ limit: 1200m
+ memory:
+ limit: 512m
+ executor:
+ instances: 1
+ container_resources:
+ gpu:
+ name: null
+ quantity: 0
+ cpu:
+ request: 1
+ limit: null
+ memory:
+ limit: 512m
+kubernetes:
+ # example:
+ # env_vars:
+ # - name: TEST_NAME
+ # value: TEST_VALUE
+ env_vars: []
+
+ # example:
+ # env_from:
+ # - name: test
+ # valueFrom:
+ # secretKeyRef:
+ # name: mongo-secret
+ # key: mongo-password
+ env_from: []
+
+ # example:
https://kubernetes.io/docs/concepts/scheduling-eviction/assign-pod-node/
+ # affinity:
+ # nodeAffinity:
+ # requiredDuringSchedulingIgnoredDuringExecution:
+ # nodeSelectorTerms:
+ # - matchExpressions:
+ # - key: beta.kubernetes.io/instance-type
+ # operator: In
+ # values:
+ # - r5.xlarge
+ affinity:
+ nodeAffinity: {}
+ podAffinity: {}
+ podAntiAffinity: {}
+
+ # example:
+ # node_selector:
+ # karpenter.sh/provisioner-name: spark-group
+ node_selector: {}
+
+ # example:
https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/
+ # type: list
+ # tolerations:
+ # - key: "key1"
+ # operator: "Equal"
+ # value: "value1"
+ # effect: "NoSchedule"
+ tolerations: []
+
+ # example:
+ # config_map_mounts:
+ # snowflake-default: /mnt/tmp
+ config_map_mounts: {}
+
+ # example:
+ # volumeMounts:
+ # - name: config
+ # mountPath: /path
+ volume_mounts: []
+
+ # https://kubernetes.io/docs/concepts/storage/volumes/
+ # example:
+ # volumes:
+ # - name: config
+ # persistentVolumeClaim:
+ # claimName: claim-name
+ volumes: []
+
+ # read config map into an env variable
+ # example:
+ # from_env_config_map:
+ # - configmap_1
+ # - configmap_2
+ from_env_config_map: []
+
+ # load secret into an env variable
+ # example:
+ # from_env_secret:
+ # - secret_1
+ # - secret_2
+ from_env_secret: []
+
+ image_pull_secrets: ''
+ in_cluster: true
+ conn_id: kubernetes_default
+ kube_config_file: null
+ cluster_context: null
diff --git
a/tests/providers/cncf/kubernetes/operators/spark_application_test.json
b/tests/providers/cncf/kubernetes/operators/spark_application_test.json
new file mode 100644
index 0000000000..fefdb6bde5
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/operators/spark_application_test.json
@@ -0,0 +1,58 @@
+{
+ "apiVersion":"sparkoperator.k8s.io/v1beta2",
+ "kind":"SparkApplication",
+ "metadata":{
+ "name":"default_jsonss",
+ "namespace":"default"
+ },
+ "spec":{
+ "type":"Scala",
+ "mode":"cluster",
+ "image":"gcr.io/spark-operator/spark:v2.4.5",
+ "imagePullPolicy":"Always",
+ "mainClass":"org.apache.spark.examples.SparkPi",
+
"mainApplicationFile":"local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar",
+ "sparkVersion":"2.4.5",
+ "restartPolicy":{
+ "type":"Never"
+ },
+ "volumes":[
+ {
+ "name":"test-volume",
+ "hostPath":{
+ "path":"/tmp",
+ "type":"Directory"
+ }
+ }
+ ],
+ "driver":{
+ "cores":1,
+ "coreLimit":"1200m",
+ "memory":"512m",
+ "labels":{
+ "version":"2.4.5"
+ },
+ "serviceAccount":"spark",
+ "volumeMounts":[
+ {
+ "name":"test-volume",
+ "mountPath":"/tmp"
+ }
+ ]
+ },
+ "executor":{
+ "cores":1,
+ "instances":1,
+ "memory":"512m",
+ "labels":{
+ "version":"2.4.5"
+ },
+ "volumeMounts":[
+ {
+ "name":"test-volume",
+ "mountPath":"/tmp"
+ }
+ ]
+ }
+ }
+}
diff --git
a/tests/providers/cncf/kubernetes/operators/spark_application_test.yaml
b/tests/providers/cncf/kubernetes/operators/spark_application_test.yaml
new file mode 100644
index 0000000000..426e600585
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/operators/spark_application_test.yaml
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+---
+apiVersion: "sparkoperator.k8s.io/v1beta2"
+kind: SparkApplication
+metadata:
+ name: default_yaml
+ namespace: default
+spec:
+ type: Scala
+ mode: cluster
+ image: "gcr.io/spark-operator/spark:v2.4.5"
+ imagePullPolicy: Always
+ mainClass: org.apache.spark.examples.SparkPi
+ mainApplicationFile:
"local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar"
+ sparkVersion: "2.4.5"
+ restartPolicy:
+ type: Never
+ volumes:
+ - name: "test-volume"
+ hostPath:
+ path: "/tmp"
+ type: Directory
+ driver:
+ cores: 1
+ coreLimit: "1200m"
+ memory: "512m"
+ labels:
+ version: 2.4.5
+ serviceAccount: spark
+ volumeMounts:
+ - name: "test-volume"
+ mountPath: "/tmp"
+ executor:
+ cores: 1
+ instances: 1
+ memory: "512m"
+ labels:
+ version: 2.4.5
+ volumeMounts:
+ - name: "test-volume"
+ mountPath: "/tmp"
diff --git
a/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
b/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/operators/test_custom_object_launcher.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
index 76369bfd50..d900ef78b4 100644
--- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -1,3 +1,4 @@
+#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,21 +17,29 @@
# under the License.
from __future__ import annotations
-from datetime import datetime
-from unittest.mock import MagicMock, patch
+import copy
+import json
+from os.path import join
+from pathlib import Path
+from unittest import mock
+from unittest.mock import patch
-import pytest
-from dateutil import tz
+import pendulum
+import yaml
+from kubernetes.client import models as k8s
-from airflow.exceptions import AirflowException
+from airflow import DAG
+from airflow.models import Connection, DagRun, TaskInstance
from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import
SparkKubernetesOperator
+from airflow.utils import db, timezone
+from airflow.utils.types import DagRunType
@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
def test_spark_kubernetes_operator(mock_kubernetes_hook):
operator = SparkKubernetesOperator(
task_id="task_id",
- application_file="application_file",
+ application_file=join(Path(__file__).parent,
"spark_application_test.yaml"),
kubernetes_conn_id="kubernetes_conn_id",
in_cluster=True,
cluster_context="cluster_context",
@@ -45,7 +54,7 @@ def test_spark_kubernetes_operator(mock_kubernetes_hook):
def test_spark_kubernetes_operator_hook(mock_kubernetes_hook):
operator = SparkKubernetesOperator(
task_id="task_id",
- application_file="application_file",
+ application_file=join(Path(__file__).parent,
"spark_application_test.yaml"),
kubernetes_conn_id="kubernetes_conn_id",
in_cluster=True,
cluster_context="cluster_context",
@@ -60,146 +69,422 @@ def
test_spark_kubernetes_operator_hook(mock_kubernetes_hook):
)
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.Watch.stream")
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
-def test_execute_with_watch(mock_kubernetes_hook, mock_load_body_to_dict,
mock_stream):
- mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
-
- mock_kubernetes_hook.return_value.create_custom_object.return_value = {
- "metadata": {"name": "spark-app", "creationTimestamp":
"2022-01-01T00:00:00Z"}
- }
- mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
-
- object_mock = MagicMock()
- object_mock.reason = "SparkDriverRunning"
- object_mock.last_timestamp = datetime(2022, 1, 1, 23, 59, 59,
tzinfo=tz.tzutc())
- mock_stream.side_effect = [[{"object": object_mock}], []]
-
- op = SparkKubernetesOperator(task_id="task_id",
application_file="application_file", watch=True)
- operator_output = op.execute({})
-
-
mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
- group="sparkoperator.k8s.io",
- version="v1beta2",
- plural="sparkapplications",
- body={"metadata": {"name": "spark-app"}},
- namespace="default",
+TEST_K8S_DICT = {
+ "apiVersion": "sparkoperator.k8s.io/v1beta2",
+ "kind": "SparkApplication",
+ "metadata": {"name": "default_yaml_template", "namespace": "default"},
+ "spec": {
+ "driver": {
+ "coreLimit": "1200m",
+ "cores": 1,
+ "labels": {},
+ "memory": "365m",
+ "nodeSelector": {},
+ "serviceAccount": "default",
+ "volumeMounts": [],
+ "env": [],
+ "envFrom": [],
+ "tolerations": [],
+ "affinity": {"nodeAffinity": {}, "podAffinity": {},
"podAntiAffinity": {}},
+ },
+ "executor": {
+ "cores": 1,
+ "instances": 1,
+ "labels": {},
+ "env": [],
+ "envFrom": [],
+ "memory": "365m",
+ "nodeSelector": {},
+ "volumeMounts": [],
+ "tolerations": [],
+ "affinity": {"nodeAffinity": {}, "podAffinity": {},
"podAntiAffinity": {}},
+ },
+ "hadoopConf": {},
+ "dynamicAllocation": {"enabled": False, "initialExecutors": 1,
"maxExecutors": 1, "minExecutors": 1},
+ "image": "gcr.io/spark-operator/spark:v2.4.5",
+ "imagePullPolicy": "Always",
+ "mainApplicationFile": "local:///opt/test.py",
+ "mode": "cluster",
+ "restartPolicy": {"type": "Never"},
+ "sparkVersion": "3.0.0",
+ "successfulRunHistoryLimit": 1,
+ "pythonVersion": "3",
+ "type": "Python",
+ "imagePullSecrets": "",
+ "labels": {},
+ "volumes": [],
+ },
+}
+
+TEST_APPLICATION_DICT = {
+ "apiVersion": "sparkoperator.k8s.io/v1beta2",
+ "kind": "SparkApplication",
+ "metadata": {"name": "default_yaml", "namespace": "default"},
+ "spec": {
+ "driver": {
+ "coreLimit": "1200m",
+ "cores": 1,
+ "labels": {"version": "2.4.5"},
+ "memory": "512m",
+ "serviceAccount": "spark",
+ "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}],
+ },
+ "executor": {
+ "cores": 1,
+ "instances": 1,
+ "labels": {"version": "2.4.5"},
+ "memory": "512m",
+ "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}],
+ },
+ "image": "gcr.io/spark-operator/spark:v2.4.5",
+ "imagePullPolicy": "Always",
+ "mainApplicationFile":
"local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar",
+ "mainClass": "org.apache.spark.examples.SparkPi",
+ "mode": "cluster",
+ "restartPolicy": {"type": "Never"},
+ "sparkVersion": "2.4.5",
+ "type": "Scala",
+ "volumes": [{"hostPath": {"path": "/tmp", "type": "Directory"},
"name": "test-volume"}],
+ },
+}
+
+
+def create_context(task):
+ dag = DAG(dag_id="dag")
+ tzinfo = pendulum.timezone("Europe/Amsterdam")
+ execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+ dag_run = DagRun(
+ dag_id=dag.dag_id,
+ execution_date=execution_date,
+ run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
)
-
- assert mock_stream.call_count == 2
- mock_stream.assert_any_call(
- mock_kubernetes_hook.return_value.core_v1_client.list_namespaced_event,
- namespace="default",
- watch=True,
-
field_selector="involvedObject.kind=SparkApplication,involvedObject.name=spark-app",
- )
- mock_stream.assert_any_call(
-
mock_kubernetes_hook.return_value.core_v1_client.read_namespaced_pod_log,
- name="spark-app-driver",
- namespace="default",
- timestamps=True,
- )
-
- assert operator_output == {"metadata": {"name": "spark-app",
"creationTimestamp": "2022-01-01T00:00:00Z"}}
-
-
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.on_kill")
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.Watch.stream")
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
-def test_raise_exception_when_job_fails(
- mock_kubernetes_hook, mock_load_body_to_dict, mock_stream, mock_on_kill
-):
- mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
-
- mock_kubernetes_hook.return_value.create_custom_object.return_value = {
- "metadata": {"name": "spark-app", "creationTimestamp":
"2022-01-01T00:00:00Z"}
- }
- mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
-
- object_mock = MagicMock()
- object_mock.reason = "SparkApplicationFailed"
- object_mock.message = "spark-app submission failed"
- object_mock.last_timestamp = datetime(2022, 1, 1, 23, 59, 59,
tzinfo=tz.tzutc())
-
- mock_stream.side_effect = [[{"object": object_mock}], []]
- op = SparkKubernetesOperator(task_id="task_id",
application_file="application_file", watch=True)
- with pytest.raises(AirflowException, match="spark-app submission failed"):
- op.execute({})
-
- assert mock_on_kill.has_called_once()
-
-
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
-def test_execute_without_watch(mock_kubernetes_hook, mock_load_body_to_dict):
- mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
-
- mock_kubernetes_hook.return_value.create_custom_object.return_value = {
- "metadata": {"name": "spark-app", "creationTimestamp":
"2022-01-01T00:00:00Z"}
+ task_instance = TaskInstance(task=task)
+ task_instance.dag_run = dag_run
+ task_instance.dag_id = dag.dag_id
+ task_instance.xcom_push = mock.Mock()
+ return {
+ "dag": dag,
+ "run_id": dag_run.run_id,
+ "task": task,
+ "ti": task_instance,
+ "task_instance": task_instance,
}
- mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
-
- op = SparkKubernetesOperator(task_id="task_id",
application_file="application_file")
- operator_output = op.execute({})
-
-
mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
- group="sparkoperator.k8s.io",
- version="v1beta2",
- plural="sparkapplications",
- body={"metadata": {"name": "spark-app"}},
- namespace="default",
- )
- assert operator_output == {"metadata": {"name": "spark-app",
"creationTimestamp": "2022-01-01T00:00:00Z"}}
-
-
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
-def test_on_kill(mock_kubernetes_hook, mock_load_body_to_dict):
- mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
- mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
-
- op = SparkKubernetesOperator(task_id="task_id",
application_file="application_file")
- op.on_kill()
-
mock_kubernetes_hook.return_value.delete_custom_object.assert_called_once_with(
- group="sparkoperator.k8s.io",
- version="v1beta2",
- plural="sparkapplications",
- namespace="default",
- name="spark-app",
- )
-
-
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
-def test_execute_with_application_file_dict(mock_kubernetes_hook):
- op = SparkKubernetesOperator(task_id="task_id",
application_file={"metadata": {"name": "spark-app"}})
- mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
-
- op.execute({})
-
-
mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
- group="sparkoperator.k8s.io",
- version="v1beta2",
- plural="sparkapplications",
- body={"metadata": {"name": "spark-app"}},
- namespace="default",
- )
-
-
-@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
-def test_on_kill_with_application_file_dict(mock_kubernetes_hook):
- op = SparkKubernetesOperator(task_id="task_id",
application_file={"metadata": {"name": "spark-app"}})
- mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
-
- op.on_kill()
-
-
mock_kubernetes_hook.return_value.delete_custom_object.assert_called_once_with(
- group="sparkoperator.k8s.io",
- version="v1beta2",
- plural="sparkapplications",
- name="spark-app",
- namespace="default",
- )
+@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
+@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_start")
+@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.client")
+@patch(
+
"airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.create_job_name"
+) # , return_value='default')
+@patch("airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.cleanup")
+@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object_status")
+@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object")
+class TestSparkKubernetesOperator:
+ def setUp(self):
+ db.merge_conn(
+ Connection(conn_id="kubernetes_default_kube_config",
conn_type="kubernetes", extra=json.dumps({}))
+ )
+ db.merge_conn(
+ Connection(
+ conn_id="kubernetes_with_namespace",
+ conn_type="kubernetes",
+ extra=json.dumps({"extra__kubernetes__namespace":
"mock_namespace"}),
+ )
+ )
+ args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2,
1)}
+ self.dag = DAG("test_dag_id", default_args=args)
+
+ def execute_operator(self, task_name, mock_create_job_name, job_spec):
+ mock_create_job_name.return_value = task_name
+ op = SparkKubernetesOperator(
+ template_spec=job_spec,
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ task_id=task_name,
+ )
+ context = create_context(op)
+ op.execute(context)
+ return op
+
+ def test_create_application_from_yaml_json(
+ self,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_start,
+ mock_await_pod_completion,
+ ):
+ task_name = "default_yaml"
+ mock_create_job_name.return_value = task_name
+ op = SparkKubernetesOperator(
+ application_file=join(Path(__file__).parent,
"spark_application_test.yaml"),
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ task_id=task_name,
+ )
+ context = create_context(op)
+ op.execute(context)
+ TEST_APPLICATION_DICT["metadata"]["name"] = task_name
+ mock_create_namespaced_crd.assert_called_with(
+ body=TEST_APPLICATION_DICT,
+ group="sparkoperator.k8s.io",
+ namespace="default",
+ plural="sparkapplications",
+ version="v1beta2",
+ )
+
+ task_name = "default_json"
+ mock_create_job_name.return_value = task_name
+ op = SparkKubernetesOperator(
+ application_file=join(Path(__file__).parent,
"spark_application_test.json"),
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ task_id=task_name,
+ )
+ context = create_context(op)
+ op.execute(context)
+ TEST_APPLICATION_DICT["metadata"]["name"] = task_name
+ mock_create_namespaced_crd.assert_called_with(
+ body=TEST_APPLICATION_DICT,
+ group="sparkoperator.k8s.io",
+ namespace="default",
+ plural="sparkapplications",
+ version="v1beta2",
+ )
+
+ def test_new_template_from_yaml(
+ self,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_start,
+ mock_await_pod_completion,
+ ):
+ task_name = "default_yaml_template"
+ mock_create_job_name.return_value = task_name
+ op = SparkKubernetesOperator(
+ application_file=join(Path(__file__).parent,
"spark_application_template.yaml"),
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ task_id=task_name,
+ )
+ context = create_context(op)
+ op.execute(context)
+ TEST_K8S_DICT["metadata"]["name"] = task_name
+ mock_create_namespaced_crd.assert_called_with(
+ body=TEST_K8S_DICT,
+ group="sparkoperator.k8s.io",
+ namespace="default",
+ plural="sparkapplications",
+ version="v1beta2",
+ )
+
+ def test_template_spec(
+ self,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_start,
+ mock_await_pod_completion,
+ ):
+ task_name = "default_yaml_template"
+ job_spec = yaml.safe_load(open(join(Path(__file__).parent,
"spark_application_template.yaml")))
+ self.execute_operator(task_name, mock_create_job_name,
job_spec=job_spec)
+
+ TEST_K8S_DICT["metadata"]["name"] = task_name
+ mock_create_namespaced_crd.assert_called_with(
+ body=TEST_K8S_DICT,
+ group="sparkoperator.k8s.io",
+ namespace="default",
+ plural="sparkapplications",
+ version="v1beta2",
+ )
+
+ def test_env(
+ self,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_start,
+ mock_await_pod_completion,
+ ):
+ task_name = "default_env"
+ job_spec = yaml.safe_load(open(join(Path(__file__).parent,
"spark_application_template.yaml")))
+ # test env vars
+ job_spec["kubernetes"]["env_vars"] = {"TEST_ENV_1": "VALUE1"}
+
+ # test env from
+ env_from = [
+
k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name="env-direct-configmap")),
+
k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name="env-direct-secret")),
+ ]
+ job_spec["kubernetes"]["env_from"] = copy.deepcopy(env_from)
+
+ # test from_env_config_map
+ job_spec["kubernetes"]["from_env_config_map"] = ["env-from-configmap"]
+ job_spec["kubernetes"]["from_env_secret"] = ["env-from-secret"]
+
+ op = self.execute_operator(task_name, mock_create_job_name,
job_spec=job_spec)
+ assert op.launcher.body["spec"]["driver"]["env"] == [
+ k8s.V1EnvVar(name="TEST_ENV_1", value="VALUE1"),
+ ]
+ assert op.launcher.body["spec"]["executor"]["env"] == [
+ k8s.V1EnvVar(name="TEST_ENV_1", value="VALUE1"),
+ ]
+
+ env_from = env_from + [
+
k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name="env-from-configmap")),
+
k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name="env-from-secret")),
+ ]
+ assert op.launcher.body["spec"]["driver"]["envFrom"] == env_from
+ assert op.launcher.body["spec"]["executor"]["envFrom"] == env_from
+
+ def test_volume(
+ self,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_start,
+ mock_await_pod_completion,
+ ):
+ task_name = "default_volume"
+ job_spec = yaml.safe_load(open(join(Path(__file__).parent,
"spark_application_template.yaml")))
+ volumes = [
+ k8s.V1Volume(
+ name="test-pvc",
+
persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(claim_name="test-pvc"),
+ ),
+ k8s.V1Volume(
+ name="test-configmap-mount",
+
config_map=k8s.V1ConfigMapVolumeSource(name="test-configmap-mount"),
+ ),
+ ]
+ volume_mounts = [
+ k8s.V1VolumeMount(mount_path="/pvc-path", name="test-pvc"),
+ k8s.V1VolumeMount(mount_path="/configmap-path",
name="test-configmap-mount"),
+ ]
+ job_spec["kubernetes"]["volumes"] = copy.deepcopy(volumes)
+ job_spec["kubernetes"]["volume_mounts"] = copy.deepcopy(volume_mounts)
+ job_spec["kubernetes"]["config_map_mounts"] =
{"test-configmap-mounts-field": "/cm-path"}
+ op = self.execute_operator(task_name, mock_create_job_name,
job_spec=job_spec)
+
+ assert op.launcher.body["spec"]["volumes"] == volumes + [
+ k8s.V1Volume(
+ name="test-configmap-mounts-field",
+
config_map=k8s.V1ConfigMapVolumeSource(name="test-configmap-mounts-field"),
+ )
+ ]
+ volume_mounts = volume_mounts + [
+ k8s.V1VolumeMount(mount_path="/cm-path",
name="test-configmap-mounts-field")
+ ]
+ assert op.launcher.body["spec"]["driver"]["volumeMounts"] ==
volume_mounts
+ assert op.launcher.body["spec"]["executor"]["volumeMounts"] ==
volume_mounts
+
+ def test_pull_secret(
+ self,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_start,
+ mock_await_pod_completion,
+ ):
+ task_name = "test_pull_secret"
+ job_spec = yaml.safe_load(open(join(Path(__file__).parent,
"spark_application_template.yaml")))
+ job_spec["kubernetes"]["image_pull_secrets"] = "secret1,secret2"
+ op = self.execute_operator(task_name, mock_create_job_name,
job_spec=job_spec)
+
+ exp_secrets = [k8s.V1LocalObjectReference(name=secret) for secret in
["secret1", "secret2"]]
+ assert op.launcher.body["spec"]["imagePullSecrets"] == exp_secrets
+
+ def test_affinity(
+ self,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_start,
+ mock_await_pod_completion,
+ ):
+ task_name = "test_affinity"
+ job_spec = yaml.safe_load(open(join(Path(__file__).parent,
"spark_application_template.yaml")))
+ job_spec["kubernetes"]["affinity"] = k8s.V1Affinity(
+ node_affinity=k8s.V1NodeAffinity(
+
required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector(
+ node_selector_terms=[
+ k8s.V1NodeSelectorTerm(
+ match_expressions=[
+ k8s.V1NodeSelectorRequirement(
+ key="beta.kubernetes.io/instance-type",
+ operator="In",
+ values=["r5.xlarge"],
+ )
+ ]
+ )
+ ]
+ )
+ )
+ )
+
+ op = self.execute_operator(task_name, mock_create_job_name,
job_spec=job_spec)
+ affinity = k8s.V1Affinity(
+ node_affinity=k8s.V1NodeAffinity(
+
required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector(
+ node_selector_terms=[
+ k8s.V1NodeSelectorTerm(
+ match_expressions=[
+ k8s.V1NodeSelectorRequirement(
+ key="beta.kubernetes.io/instance-type",
+ operator="In",
+ values=["r5.xlarge"],
+ )
+ ]
+ )
+ ]
+ )
+ )
+ )
+ assert op.launcher.body["spec"]["driver"]["affinity"] == affinity
+ assert op.launcher.body["spec"]["executor"]["affinity"] == affinity
+
+ def test_toleration(
+ self,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_start,
+ mock_await_pod_completion,
+ ):
+ toleration = k8s.V1Toleration(
+ key="dedicated",
+ operator="Equal",
+ value="test",
+ effect="NoSchedule",
+ )
+ task_name = "test_tolerations"
+ job_spec = yaml.safe_load(open(join(Path(__file__).parent,
"spark_application_template.yaml")))
+ job_spec["kubernetes"]["tolerations"] = [toleration]
+ op = self.execute_operator(task_name, mock_create_job_name,
job_spec=job_spec)
+
+ assert op.launcher.body["spec"]["driver"]["tolerations"] ==
[toleration]
+ assert op.launcher.body["spec"]["executor"]["tolerations"] ==
[toleration]
diff --git a/tests/providers/cncf/kubernetes/resource_convert/__init__.py
b/tests/providers/cncf/kubernetes/resource_convert/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/resource_convert/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/cncf/kubernetes/resource_convert/test_configmap.py
b/tests/providers/cncf/kubernetes/resource_convert/test_configmap.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/resource_convert/test_configmap.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/tests/providers/cncf/kubernetes/resource_convert/test_env_variable.py
b/tests/providers/cncf/kubernetes/resource_convert/test_env_variable.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/resource_convert/test_env_variable.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/cncf/kubernetes/resource_convert/test_secret.py
b/tests/providers/cncf/kubernetes/resource_convert/test_secret.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/resource_convert/test_secret.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/system/providers/cncf/kubernetes/example_spark_kubernetes.py
b/tests/system/providers/cncf/kubernetes/example_spark_kubernetes.py
index 03fd0a34e5..00d2a2361f 100644
--- a/tests/system/providers/cncf/kubernetes/example_spark_kubernetes.py
+++ b/tests/system/providers/cncf/kubernetes/example_spark_kubernetes.py
@@ -27,7 +27,9 @@ https://github.com/GoogleCloudPlatform/spark-on-k8s-operator
from __future__ import annotations
import os
+import pathlib
from datetime import datetime, timedelta
+from os.path import join
# [START import_module]
# The DAG object; we'll need this to instantiate a DAG
@@ -55,10 +57,11 @@ with DAG(
catchup=False,
) as dag:
# [START SparkKubernetesOperator_DAG]
+ pi_example_path = pathlib.Path(__file__).parent.resolve()
t1 = SparkKubernetesOperator(
task_id="spark_pi_submit",
namespace="default",
- application_file="example_spark_kubernetes_spark_pi.yaml",
+ application_file=join(pi_example_path,
"example_spark_kubernetes_spark_pi.yaml"),
do_xcom_push=True,
dag=dag,
)
diff --git a/tests/system/providers/cncf/kubernetes/spark_job_template.yaml
b/tests/system/providers/cncf/kubernetes/spark_job_template.yaml
new file mode 100644
index 0000000000..6ab1b0082c
--- /dev/null
+++ b/tests/system/providers/cncf/kubernetes/spark_job_template.yaml
@@ -0,0 +1,149 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+---
+spark:
+ apiVersion: sparkoperator.k8s.io/v1beta2
+ version: v1beta2
+ kind: SparkApplication
+ apiGroup: sparkoperator.k8s.io
+ metadata:
+ namespace: default
+ spec:
+ type: Python
+ pythonVersion: "3"
+ mode: cluster
+ sparkVersion: 3.0.0
+ successfulRunHistoryLimit: 1
+ restartPolicy:
+ type: Never
+ imagePullPolicy: Always
+ hadoopConf: {}
+ imagePullSecrets: []
+ dynamicAllocation:
+ enabled: false
+ initialExecutors: 1
+ minExecutors: 1
+ maxExecutors: 1
+ labels: {}
+ driver:
+ serviceAccount: default
+ container_resources:
+ gpu:
+ name: null
+ quantity: 0
+ cpu:
+ request: null
+ limit: null
+ memory:
+ request: null
+ limit: null
+ executor:
+ instances: 1
+ container_resources:
+ gpu:
+ name: null
+ quantity: 0
+ cpu:
+ request: null
+ limit: null
+ memory:
+ request: null
+ limit: null
+kubernetes:
+ # example:
+ # env_vars:
+ # - name: TEST_NAME
+ # value: TEST_VALUE
+ env_vars: []
+
+ # example:
+ # env_from:
+ # - name: test
+ # valueFrom:
+ # secretKeyRef:
+ # name: mongo-secret
+ # key: mongo-password
+ env_from: []
+
+ # example:
+ # node_selector:
+ # karpenter.sh/provisioner-name: spark-group
+ node_selector: {}
+
+ # example:
https://kubernetes.io/docs/concepts/scheduling-eviction/assign-pod-node/
+ # affinity:
+ # nodeAffinity:
+ # requiredDuringSchedulingIgnoredDuringExecution:
+ # nodeSelectorTerms:
+ # - matchExpressions:
+ # - key: beta.kubernetes.io/instance-type
+ # operator: In
+ # values:
+ # - r5.xlarge
+ affinity:
+ nodeAffinity: {}
+ podAffinity: {}
+ podAntiAffinity: {}
+
+ # example:
https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/
+ # type: list
+ # tolerations:
+ # - key: "key1"
+ # operator: "Equal"
+ # value: "value1"
+ # effect: "NoSchedule"
+ tolerations: []
+
+ # example:
+ # config_map_mounts:
+ # snowflake-default: /mnt/tmp
+ config_map_mounts: {}
+
+ # example:
+ # volumeMounts:
+ # - name: config
+ # mountPath: /path
+ volume_mounts: []
+
+ # https://kubernetes.io/docs/concepts/storage/volumes/
+ # example:
+ # volumes:
+ # - name: config
+ # persistentVolumeClaim:
+ # claimName: claim-name
+ volumes: []
+
+ # read config map into an env variable
+ # example:
+ # from_env_config_map:
+ # - configmap_1
+ # - configmap_2
+ from_env_config_map: []
+
+ # load secret into an env variable
+ # example:
+ # from_env_secret:
+ # - secret_1
+ # - secret_2
+ from_env_secret: []
+
+ image_pull_secrets: ''
+ in_cluster: true
+ conn_id: kubernetes_default
+ kube_config_file: null
+ cluster_context: null