bolkedebruin commented on code in PR #22253:
URL: https://github.com/apache/airflow/pull/22253#discussion_r1403221124


##########
airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py:
##########
@@ -17,173 +17,271 @@
 # under the License.
 from __future__ import annotations
 
-import datetime
+import json
+import re
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any
 
-from kubernetes.client import ApiException
-from kubernetes.watch import Watch
+from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
 
 from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
+from airflow.providers.cncf.kubernetes import pod_generator
 from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook, 
_load_body_to_dict
+from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import 
CustomObjectLauncher
+from airflow.providers.cncf.kubernetes.operators.pod import 
KubernetesPodOperator
+from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN, 
PodGenerator
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
+from airflow.utils.helpers import prune_dict
 
 if TYPE_CHECKING:
-    from kubernetes.client.models import CoreV1EventList
+    import jinja2
 
     from airflow.utils.context import Context
 
 
-class SparkKubernetesOperator(BaseOperator):
+class SparkKubernetesOperator(KubernetesPodOperator):
     """
     Creates sparkApplication object in kubernetes cluster.
 
     .. seealso::
         For more detail about Spark Application Object have a look at the 
reference:
-        
https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.1.0-2.4.5/docs/api-docs.md#sparkapplication
+        
https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.3.3-3.1.1/docs/api-docs.md#sparkapplication
 
-    :param application_file: Defines Kubernetes 'custom_resource_definition' 
of 'sparkApplication' as either a
-        path to a '.yaml' file, '.json' file, YAML string or JSON string.
+    :param application_file: filepath to kubernetes custom_resource_definition 
of sparkApplication
+    :param kubernetes_conn_id: the connection to Kubernetes cluster
+    :param image: Docker image you wish to launch. Defaults to hub.docker.com,
+    :param code_path: path to the spark code in image,
     :param namespace: kubernetes namespace to put sparkApplication
-    :param kubernetes_conn_id: The :ref:`kubernetes connection id 
<howto/connection:kubernetes>`
-        for the to Kubernetes cluster.
-    :param api_group: kubernetes api group of sparkApplication
-    :param api_version: kubernetes api version of sparkApplication
-    :param watch: whether to watch the job status and logs or not
+    :param cluster_context: context of the cluster
+    :param application_file: yaml file if passed
+    :param get_logs: get the stdout of the container as logs of the tasks.
+    :param do_xcom_push: If True, the content of the file
+        /airflow/xcom/return.json in the container will also be pushed to an
+        XCom when the container completes.
+    :param success_run_history_limit: Number of past successful runs of the 
application to keep.
+    :param delete_on_termination: What to do when the pod reaches its final
+        state, or the execution is interrupted. If True (default), delete the
+        pod; if False, leave the pod.
+    :param startup_timeout_seconds: timeout in seconds to startup the pod.
+    :param log_events_on_failure: Log the pod's events if a failure occurs
+    :param reattach_on_restart: if the scheduler dies while the pod is 
running, reattach and monitor
     """
 
-    template_fields: Sequence[str] = ("application_file", "namespace")
-    template_ext: Sequence[str] = (".yaml", ".yml", ".json")
+    template_fields = ["application_file", "namespace", "template_spec"]
+    template_fields_renderers = {"template_spec": "py"}
+    template_ext = ("yaml", "yml", "json")
     ui_color = "#f4a460"
 
     def __init__(
         self,
         *,
-        application_file: str,
-        namespace: str | None = None,
+        image: str | None = None,
+        code_path: str | None = None,
+        namespace: str = "default",
+        name: str = "default",
+        application_file: str | None = None,
+        template_spec=None,
+        get_logs: bool = True,
+        do_xcom_push: bool = False,
+        success_run_history_limit: int = 1,
+        startup_timeout_seconds=600,
+        log_events_on_failure: bool = False,
+        reattach_on_restart: bool = True,
+        delete_on_termination: bool = True,
         kubernetes_conn_id: str = "kubernetes_default",
-        api_group: str = "sparkoperator.k8s.io",
-        api_version: str = "v1beta2",
-        in_cluster: bool | None = None,
-        cluster_context: str | None = None,
-        config_file: str | None = None,
-        watch: bool = False,
         **kwargs,
     ) -> None:
-        super().__init__(**kwargs)
-        self.namespace = namespace
-        self.kubernetes_conn_id = kubernetes_conn_id
-        self.api_group = api_group
-        self.api_version = api_version
-        self.plural = "sparkapplications"
+        if kwargs.get("xcom_push") is not None:
+            raise AirflowException("'xcom_push' was deprecated, use 
'do_xcom_push' instead")
+        super().__init__(name=name, **kwargs)
+        self.image = image
+        self.code_path = code_path
         self.application_file = application_file
-        self.in_cluster = in_cluster
-        self.cluster_context = cluster_context
-        self.config_file = config_file
-        self.watch = watch
+        self.template_spec = template_spec
+        self.name = self.create_job_name()
+        self.kubernetes_conn_id = kubernetes_conn_id
+        self.startup_timeout_seconds = startup_timeout_seconds
+        self.reattach_on_restart = reattach_on_restart
+        self.delete_on_termination = delete_on_termination
+        self.do_xcom_push = do_xcom_push
+        self.namespace = namespace
+        self.get_logs = get_logs
+        self.log_events_on_failure = log_events_on_failure
+        self.success_run_history_limit = success_run_history_limit
+        self.template_body = self.manage_template_specs()
+
+    def _render_nested_template_fields(
+        self,
+        content: Any,
+        context: Context,
+        jinja_env: jinja2.Environment,
+        seen_oids: set,
+    ) -> None:
+        if id(content) not in seen_oids and isinstance(content, k8s.V1EnvVar):
+            seen_oids.add(id(content))
+            self._do_render_template_fields(content, ("value", "name"), 
context, jinja_env, seen_oids)
+            return
+
+        super()._render_nested_template_fields(content, context, jinja_env, 
seen_oids)
+
+    def manage_template_specs(self):
+        if self.application_file:
+            template_body = _load_body_to_dict(open(self.application_file))
+        elif self.template_spec:
+            template_body = self.template_spec
+        else:
+            raise AirflowException("either application_file or template_spec 
should be passed")
+        if "spark" not in template_body:
+            template_body = {"spark": template_body}
+        return template_body
+
+    def create_job_name(self):
+        initial_name = 
PodGenerator.make_unique_pod_id(self.task_id)[:MAX_LABEL_LEN]
+        return re.sub(r"[^a-z0-9-]+", "-", initial_name.lower())
+
+    @staticmethod
+    def _get_pod_identifying_label_string(labels) -> str:
+        filtered_labels = {label_id: label for label_id, label in 
labels.items() if label_id != "try_number"}
+        return ",".join([label_id + "=" + label for label_id, label in 
sorted(filtered_labels.items())])
+
+    @staticmethod
+    def create_labels_for_pod(context: dict | None = None, include_try_number: 
bool = True) -> dict:
+        """
+        Generate labels for the pod to track the pod in case of Operator crash.
+
+        :param include_try_number: add try number to labels
+        :param context: task context provided by airflow DAG
+        :return: dict.
+        """
+        if not context:
+            return {}
+
+        ti = context["ti"]
+        run_id = context["run_id"]
+
+        labels = {
+            "dag_id": ti.dag_id,
+            "task_id": ti.task_id,
+            "run_id": run_id,
+            "spark_kubernetes_operator": "True",
+            # 'execution_date': context['ts'],
+            # 'try_number': context['ti'].try_number,
+        }
+
+        # If running on Airflow 2.3+:
+        map_index = getattr(ti, "map_index", -1)
+        if map_index >= 0:
+            labels["map_index"] = map_index
+
+        if include_try_number:
+            labels.update(try_number=ti.try_number)
+
+        # In the case of sub dags this is just useful
+        if context["dag"].is_subdag:
+            labels["parent_dag_id"] = context["dag"].parent_dag.dag_id
+        # Ensure that label is valid for Kube,
+        # and if not truncate/remove invalid chars and replace with short hash.
+        for label_id, label in labels.items():
+            safe_label = pod_generator.make_safe_label_value(str(label))
+            labels[label_id] = safe_label
+        return labels
+
+    @cached_property
+    def pod_manager(self) -> PodManager:
+        return PodManager(kube_client=self.client)
+
+    @staticmethod
+    def _try_numbers_match(context, pod) -> bool:
+        return pod.metadata.labels["try_number"] == context["ti"].try_number
+
+    def find_spark_job(self, context):
+        labels = self.create_labels_for_pod(context, include_try_number=False)
+        label_selector = self._get_pod_identifying_label_string(labels) + 
",spark-role=driver"
+        pod_list = self.client.list_namespaced_pod(self.namespace, 
label_selector=label_selector).items
+
+        pod = None
+        if len(pod_list) > 1:  # and self.reattach_on_restart:
+            raise AirflowException(f"More than one pod running with labels: 
{label_selector}")
+        elif len(pod_list) == 1:
+            pod = pod_list[0]
+            self.log.info(
+                "Found matching driver pod %s with labels %s", 
pod.metadata.name, pod.metadata.labels
+            )
+            self.log.info("`try_number` of task_instance: %s", 
context["ti"].try_number)
+            self.log.info("`try_number` of pod: %s", 
pod.metadata.labels["try_number"])
+        return pod
+
+    def get_or_create_spark_crd(self, launcher: CustomObjectLauncher, context) 
-> k8s.V1Pod:
+        if self.reattach_on_restart:
+            driver_pod = self.find_spark_job(context)
+            if driver_pod:
+                return driver_pod
+
+        driver_pod, spark_obj_spec = launcher.start_spark_job(
+            image=self.image, code_path=self.code_path, 
startup_timeout=self.startup_timeout_seconds
+        )
+        return driver_pod
+
+    def extract_xcom(self, pod):
+        """Retrieves xcom value and kills xcom sidecar container."""
+        result = self.pod_manager.extract_xcom(pod)
+        self.log.info("xcom result: \n%s", result)

Review Comment:
   I dont think this should be logged. It can also contain sensitive data. Also 
please do not format with '\n'



##########
airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py:
##########
@@ -0,0 +1,366 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Launches Custom object."""
+from __future__ import annotations
+
+import time
+from copy import deepcopy
+from datetime import datetime as dt
+from functools import cached_property
+
+import tenacity
+from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
+from kubernetes.client.rest import ApiException
+
+from airflow.exceptions import AirflowException
+from airflow.providers.cncf.kubernetes.resource_convert.configmap import (
+    convert_configmap,
+    convert_configmap_to_volume,
+)
+from airflow.providers.cncf.kubernetes.resource_convert.env_variable import 
convert_env_vars
+from airflow.providers.cncf.kubernetes.resource_convert.secret import (
+    convert_image_pull_secrets,
+    convert_secret,
+)
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+def should_retry_start_spark_job(exception: BaseException) -> bool:
+    """Check if an Exception indicates a transient error and warrants 
retrying."""
+    if isinstance(exception, ApiException):
+        return exception.status == 409
+    return False
+
+
+class SparkJobSpec:
+    """Spark job spec."""
+
+    def __init__(self, **entries):
+        self.__dict__.update(entries)
+        self.validate()
+        self.update_resources()
+
+    def validate(self):
+        if self.spec.get("dynamicAllocation", {}).get("enabled"):
+            if not all(
+                [
+                    self.spec["dynamicAllocation"]["initialExecutors"],
+                    self.spec["dynamicAllocation"]["minExecutors"],
+                    self.spec["dynamicAllocation"]["maxExecutors"],
+                ]
+            ):
+                raise AirflowException("Make sure initial/min/max value for 
dynamic allocation is passed")
+
+    def update_resources(self):
+        if self.spec["driver"].get("container_resources"):
+            spark_resources = SparkResources(
+                self.spec["driver"].pop("container_resources"),
+                self.spec["executor"].pop("container_resources"),
+            )
+            self.spec["driver"].update(spark_resources.resources["driver"])
+            self.spec["executor"].update(spark_resources.resources["executor"])
+
+
+class KubernetesSpec:
+    """Spark kubernetes spec."""
+
+    def __init__(self, **entries):
+        self.__dict__.update(entries)
+        self.set_attribute()
+
+    def set_attribute(self):
+        self.env_vars = convert_env_vars(self.env_vars) if self.env_vars else 
[]
+        self.image_pull_secrets = (
+            convert_image_pull_secrets(self.image_pull_secrets) if 
self.image_pull_secrets else []
+        )
+        if self.config_map_mounts:
+            vols, vols_mounts = 
convert_configmap_to_volume(self.config_map_mounts)
+            self.volumes.extend(vols)
+            self.volume_mounts.extend(vols_mounts)
+        if self.from_env_config_map:
+            self.env_from.extend([convert_configmap(c_name) for c_name in 
self.from_env_config_map])
+        if self.from_env_secret:
+            self.env_from.extend([convert_secret(c) for c in 
self.from_env_secret])
+
+
+class SparkResources:
+    """spark resources."""
+
+    def __init__(
+        self,
+        driver: dict = {None: None},
+        executor: dict = {None: None},
+    ):
+        self.default = {
+            "gpu": {"name": None, "quantity": 0},
+            "cpu": {"request": None, "limit": None},
+            "memory": {"request": None, "limit": None},
+        }
+        self.driver = deepcopy(self.default)
+        self.executor = deepcopy(self.default)
+        if driver:
+            self.driver.update(driver)
+        if executor:
+            self.executor.update(executor)
+        self.convert_resources()
+
+    @property
+    def resources(self):
+        """Return job resources."""
+        return {"driver": self.driver_resources, "executor": 
self.executor_resources}
+
+    @property
+    def driver_resources(self):
+        """Return resources to use."""
+        driver = {}
+        if self.driver["cpu"].get("request"):
+            driver["cores"] = self.driver["cpu"]["request"]
+        if self.driver["cpu"].get("limit"):
+            driver["coreLimit"] = self.driver["cpu"]["limit"]
+        if self.driver["memory"].get("limit"):
+            driver["memory"] = self.driver["memory"]["limit"]
+        if self.driver["gpu"].get("name") and 
self.driver["gpu"].get("quantity"):
+            driver["gpu"] = {"name": self.driver["gpu"]["name"], "quantity": 
self.driver["gpu"]["quantity"]}
+        return driver
+
+    @property
+    def executor_resources(self):
+        """Return resources to use."""
+        executor = {}
+        if self.executor["cpu"].get("request"):
+            executor["cores"] = self.executor["cpu"]["request"]
+        if self.executor["cpu"].get("limit"):
+            executor["coreLimit"] = self.executor["cpu"]["limit"]
+        if self.executor["memory"].get("limit"):
+            executor["memory"] = self.executor["memory"]["limit"]
+        if self.executor["gpu"].get("name") and 
self.executor["gpu"].get("quantity"):
+            executor["gpu"] = {
+                "name": self.executor["gpu"]["name"],
+                "quantity": self.executor["gpu"]["quantity"],
+            }
+        return executor
+
+    def convert_resources(self):
+        if isinstance(self.driver["memory"].get("limit"), str):
+            if "G" in self.driver["memory"]["limit"] or "Gi" in 
self.driver["memory"]["limit"]:
+                self.driver["memory"]["limit"] = 
float(self.driver["memory"]["limit"].rstrip("Gi G")) * 1024
+            elif "m" in self.driver["memory"]["limit"]:
+                self.driver["memory"]["limit"] = 
float(self.driver["memory"]["limit"].rstrip("m"))
+            # Adjusting the memory value as operator adds 40% to the given 
value
+            self.driver["memory"]["limit"] = 
str(int(self.driver["memory"]["limit"] / 1.4)) + "m"
+
+        if isinstance(self.executor["memory"].get("limit"), str):
+            if "G" in self.executor["memory"]["limit"] or "Gi" in 
self.executor["memory"]["limit"]:
+                self.executor["memory"]["limit"] = (
+                    float(self.executor["memory"]["limit"].rstrip("Gi G")) * 
1024
+                )
+            elif "m" in self.executor["memory"]["limit"]:
+                self.executor["memory"]["limit"] = 
float(self.executor["memory"]["limit"].rstrip("m"))
+            # Adjusting the memory value as operator adds 40% to the given 
value
+            self.executor["memory"]["limit"] = 
str(int(self.executor["memory"]["limit"] / 1.4)) + "m"
+
+        if self.driver["cpu"].get("request"):
+            self.driver["cpu"]["request"] = 
int(float(self.driver["cpu"]["request"]))
+        if self.driver["cpu"].get("limit"):
+            self.driver["cpu"]["limit"] = str(self.driver["cpu"]["limit"])
+        if self.executor["cpu"].get("request"):
+            self.executor["cpu"]["request"] = 
int(float(self.executor["cpu"]["request"]))
+        if self.executor["cpu"].get("limit"):
+            self.executor["cpu"]["limit"] = str(self.executor["cpu"]["limit"])
+
+        if self.driver["gpu"].get("quantity"):
+            self.driver["gpu"]["quantity"] = 
int(float(self.driver["gpu"]["quantity"]))
+        if self.executor["gpu"].get("quantity"):
+            self.executor["gpu"]["quantity"] = 
int(float(self.executor["gpu"]["quantity"]))
+
+
+class CustomObjectStatus:
+    """Status of the PODs."""
+
+    SUBMITTED = "SUBMITTED"
+    RUNNING = "RUNNING"
+    FAILED = "FAILED"
+    SUCCEEDED = "SUCCEEDED"
+
+
+class CustomObjectLauncher(LoggingMixin):
+    """Launches PODS."""
+
+    def __init__(
+        self,
+        name: str | None,
+        namespace: str | None,
+        kube_client: CoreV1Api,
+        custom_obj_api: CustomObjectsApi,
+        template_body: str | None = None,
+    ):
+        """
+        Creates custom object launcher(sparkapplications crd).
+
+                :param kube_client: kubernetes client.

Review Comment:
   This is wrongly aligned



##########
airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py:
##########
@@ -0,0 +1,366 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Launches Custom object."""
+from __future__ import annotations
+
+import time
+from copy import deepcopy
+from datetime import datetime as dt
+from functools import cached_property
+
+import tenacity
+from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
+from kubernetes.client.rest import ApiException
+
+from airflow.exceptions import AirflowException
+from airflow.providers.cncf.kubernetes.resource_convert.configmap import (
+    convert_configmap,
+    convert_configmap_to_volume,
+)
+from airflow.providers.cncf.kubernetes.resource_convert.env_variable import 
convert_env_vars
+from airflow.providers.cncf.kubernetes.resource_convert.secret import (
+    convert_image_pull_secrets,
+    convert_secret,
+)
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+def should_retry_start_spark_job(exception: BaseException) -> bool:
+    """Check if an Exception indicates a transient error and warrants 
retrying."""
+    if isinstance(exception, ApiException):
+        return exception.status == 409
+    return False
+
+
+class SparkJobSpec:
+    """Spark job spec."""
+
+    def __init__(self, **entries):
+        self.__dict__.update(entries)
+        self.validate()
+        self.update_resources()
+
+    def validate(self):
+        if self.spec.get("dynamicAllocation", {}).get("enabled"):
+            if not all(
+                [
+                    self.spec["dynamicAllocation"]["initialExecutors"],
+                    self.spec["dynamicAllocation"]["minExecutors"],
+                    self.spec["dynamicAllocation"]["maxExecutors"],
+                ]
+            ):
+                raise AirflowException("Make sure initial/min/max value for 
dynamic allocation is passed")
+
+    def update_resources(self):
+        if self.spec["driver"].get("container_resources"):
+            spark_resources = SparkResources(
+                self.spec["driver"].pop("container_resources"),
+                self.spec["executor"].pop("container_resources"),
+            )
+            self.spec["driver"].update(spark_resources.resources["driver"])
+            self.spec["executor"].update(spark_resources.resources["executor"])
+
+
+class KubernetesSpec:
+    """Spark kubernetes spec."""
+
+    def __init__(self, **entries):
+        self.__dict__.update(entries)
+        self.set_attribute()
+
+    def set_attribute(self):
+        self.env_vars = convert_env_vars(self.env_vars) if self.env_vars else 
[]
+        self.image_pull_secrets = (
+            convert_image_pull_secrets(self.image_pull_secrets) if 
self.image_pull_secrets else []
+        )
+        if self.config_map_mounts:
+            vols, vols_mounts = 
convert_configmap_to_volume(self.config_map_mounts)
+            self.volumes.extend(vols)
+            self.volume_mounts.extend(vols_mounts)
+        if self.from_env_config_map:
+            self.env_from.extend([convert_configmap(c_name) for c_name in 
self.from_env_config_map])
+        if self.from_env_secret:
+            self.env_from.extend([convert_secret(c) for c in 
self.from_env_secret])
+
+
+class SparkResources:
+    """spark resources."""
+
+    def __init__(
+        self,
+        driver: dict = {None: None},
+        executor: dict = {None: None},
+    ):
+        self.default = {
+            "gpu": {"name": None, "quantity": 0},
+            "cpu": {"request": None, "limit": None},
+            "memory": {"request": None, "limit": None},
+        }
+        self.driver = deepcopy(self.default)
+        self.executor = deepcopy(self.default)
+        if driver:
+            self.driver.update(driver)
+        if executor:
+            self.executor.update(executor)
+        self.convert_resources()
+
+    @property
+    def resources(self):
+        """Return job resources."""
+        return {"driver": self.driver_resources, "executor": 
self.executor_resources}
+
+    @property
+    def driver_resources(self):
+        """Return resources to use."""
+        driver = {}
+        if self.driver["cpu"].get("request"):
+            driver["cores"] = self.driver["cpu"]["request"]
+        if self.driver["cpu"].get("limit"):
+            driver["coreLimit"] = self.driver["cpu"]["limit"]
+        if self.driver["memory"].get("limit"):
+            driver["memory"] = self.driver["memory"]["limit"]
+        if self.driver["gpu"].get("name") and 
self.driver["gpu"].get("quantity"):
+            driver["gpu"] = {"name": self.driver["gpu"]["name"], "quantity": 
self.driver["gpu"]["quantity"]}
+        return driver
+
+    @property
+    def executor_resources(self):
+        """Return resources to use."""
+        executor = {}
+        if self.executor["cpu"].get("request"):
+            executor["cores"] = self.executor["cpu"]["request"]
+        if self.executor["cpu"].get("limit"):
+            executor["coreLimit"] = self.executor["cpu"]["limit"]
+        if self.executor["memory"].get("limit"):
+            executor["memory"] = self.executor["memory"]["limit"]
+        if self.executor["gpu"].get("name") and 
self.executor["gpu"].get("quantity"):
+            executor["gpu"] = {
+                "name": self.executor["gpu"]["name"],
+                "quantity": self.executor["gpu"]["quantity"],
+            }
+        return executor
+
+    def convert_resources(self):
+        if isinstance(self.driver["memory"].get("limit"), str):
+            if "G" in self.driver["memory"]["limit"] or "Gi" in 
self.driver["memory"]["limit"]:
+                self.driver["memory"]["limit"] = 
float(self.driver["memory"]["limit"].rstrip("Gi G")) * 1024
+            elif "m" in self.driver["memory"]["limit"]:
+                self.driver["memory"]["limit"] = 
float(self.driver["memory"]["limit"].rstrip("m"))
+            # Adjusting the memory value as operator adds 40% to the given 
value
+            self.driver["memory"]["limit"] = 
str(int(self.driver["memory"]["limit"] / 1.4)) + "m"
+
+        if isinstance(self.executor["memory"].get("limit"), str):
+            if "G" in self.executor["memory"]["limit"] or "Gi" in 
self.executor["memory"]["limit"]:
+                self.executor["memory"]["limit"] = (
+                    float(self.executor["memory"]["limit"].rstrip("Gi G")) * 
1024
+                )
+            elif "m" in self.executor["memory"]["limit"]:
+                self.executor["memory"]["limit"] = 
float(self.executor["memory"]["limit"].rstrip("m"))
+            # Adjusting the memory value as operator adds 40% to the given 
value
+            self.executor["memory"]["limit"] = 
str(int(self.executor["memory"]["limit"] / 1.4)) + "m"
+
+        if self.driver["cpu"].get("request"):
+            self.driver["cpu"]["request"] = 
int(float(self.driver["cpu"]["request"]))
+        if self.driver["cpu"].get("limit"):
+            self.driver["cpu"]["limit"] = str(self.driver["cpu"]["limit"])
+        if self.executor["cpu"].get("request"):
+            self.executor["cpu"]["request"] = 
int(float(self.executor["cpu"]["request"]))
+        if self.executor["cpu"].get("limit"):
+            self.executor["cpu"]["limit"] = str(self.executor["cpu"]["limit"])
+
+        if self.driver["gpu"].get("quantity"):
+            self.driver["gpu"]["quantity"] = 
int(float(self.driver["gpu"]["quantity"]))
+        if self.executor["gpu"].get("quantity"):
+            self.executor["gpu"]["quantity"] = 
int(float(self.executor["gpu"]["quantity"]))
+
+
+class CustomObjectStatus:
+    """Status of the PODs."""
+
+    SUBMITTED = "SUBMITTED"
+    RUNNING = "RUNNING"
+    FAILED = "FAILED"
+    SUCCEEDED = "SUCCEEDED"
+
+
+class CustomObjectLauncher(LoggingMixin):
+    """Launches PODS."""
+
+    def __init__(
+        self,
+        name: str | None,
+        namespace: str | None,
+        kube_client: CoreV1Api,
+        custom_obj_api: CustomObjectsApi,
+        template_body: str | None = None,
+    ):
+        """
+        Creates custom object launcher(sparkapplications crd).
+
+                :param kube_client: kubernetes client.
+        """
+        super().__init__()
+        self.name = name
+        self.namespace = namespace
+        self.template_body = template_body
+        self.body: dict = self.get_body()
+        self.kind = self.body["kind"]
+        self.plural = f"{self.kind.lower()}s"
+        if self.body.get("apiVersion"):
+            self.api_group, self.api_version = 
self.body["apiVersion"].split("/")
+        else:
+            self.api_group = self.body["apiGroup"]
+            self.api_version = self.body["version"]
+        self._client = kube_client
+        self.custom_obj_api = custom_obj_api
+        self.spark_obj_spec: dict = {}
+        self.pod_spec: k8s.V1Pod | None = None
+
+    @cached_property
+    def pod_manager(self) -> PodManager:
+        return PodManager(kube_client=self._client)
+
+    def get_body(self):
+        self.body: dict = SparkJobSpec(**self.template_body["spark"])
+        self.body.metadata = {"name": self.name, "namespace": self.namespace}
+        if self.template_body.get("kubernetes"):
+            k8s_spec: dict = KubernetesSpec(**self.template_body["kubernetes"])
+            self.body.spec["volumes"] = k8s_spec.volumes
+            if k8s_spec.image_pull_secrets:
+                self.body.spec["imagePullSecrets"] = 
k8s_spec.image_pull_secrets
+            for item in ["driver", "executor"]:
+                # Env List
+                self.body.spec[item]["env"] = k8s_spec.env_vars
+                self.body.spec[item]["envFrom"] = k8s_spec.env_from
+                # Volumes
+                self.body.spec[item]["volumeMounts"] = k8s_spec.volume_mounts
+                # Add affinity
+                self.body.spec[item]["affinity"] = k8s_spec.affinity
+                self.body.spec[item]["tolerations"] = k8s_spec.tolerations
+                # Labels
+                self.body.spec[item]["labels"] = self.body.spec["labels"]
+
+        return self.body.__dict__
+
+    @tenacity.retry(
+        stop=tenacity.stop_after_attempt(3),
+        wait=tenacity.wait_random_exponential(),
+        reraise=True,
+        retry=tenacity.retry_if_exception(should_retry_start_spark_job),
+    )
+    def start_spark_job(self, image=None, code_path=None, startup_timeout: int 
= 600):
+        """
+        Launches the pod synchronously and waits for completion.
+
+        :param image: image name
+        :param code_path: path to the .py file for python and jar file for 
scala
+        :param startup_timeout: Timeout for startup of the pod (if pod is 
pending for too long, fails task)
+        :return:
+        """
+        try:
+            if image:
+                self.body["spec"]["image"] = image
+            if code_path:
+                self.body["spec"]["mainApplicationFile"] = code_path
+            self.log.debug("Spark Job Creation Request Submitted")
+            self.spark_obj_spec = 
self.custom_obj_api.create_namespaced_custom_object(
+                group=self.api_group,
+                version=self.api_version,
+                namespace=self.namespace,
+                plural=self.plural,
+                body=self.body,
+            )
+            self.log.debug("Spark Job Creation Response: %s", 
self.spark_obj_spec)
+
+            # Wait for the driver pod to come alive
+            self.pod_spec = k8s.V1Pod(
+                metadata=k8s.V1ObjectMeta(
+                    labels=self.spark_obj_spec["spec"]["driver"]["labels"],
+                    name=self.spark_obj_spec["metadata"]["name"] + "-driver",
+                    namespace=self.namespace,
+                )
+            )
+            curr_time = dt.now()
+            while self.spark_job_not_running(self.spark_obj_spec):
+                self.log.warning(
+                    "Spark job submitted but not yet started. job_id: %s",
+                    self.spark_obj_spec["metadata"]["name"],
+                )
+                self.check_pod_start_failure()
+                delta = dt.now() - curr_time
+                if delta.total_seconds() >= startup_timeout:
+                    pod_status = 
self.pod_manager.read_pod(self.pod_spec).status.container_statuses
+                    raise AirflowException(f"Job took too long to start. pod 
status: {pod_status}")
+                time.sleep(10)
+        except Exception as e:
+            self.log.exception("Exception when attempting to create spark job")
+            raise e
+
+        return self.pod_spec, self.spark_obj_spec
+
+    def spark_job_not_running(self, spark_obj_spec):
+        """Tests if spark_obj_spec has not started."""
+        spark_job_info = 
self.custom_obj_api.get_namespaced_custom_object_status(
+            group=self.api_group,
+            version=self.api_version,
+            namespace=self.namespace,
+            name=spark_obj_spec["metadata"]["name"],
+            plural=self.plural,
+        )
+        driver_state = spark_job_info.get("status", 
{}).get("applicationState", {}).get("state", "SUBMITTED")
+        if driver_state == CustomObjectStatus.FAILED:
+            err = spark_job_info.get("status", {}).get("applicationState", 
{}).get("errorMessage", "N/A")
+            try:
+                self.pod_manager.fetch_container_logs(
+                    pod=self.pod_spec, container_name="spark-kubernetes-driver"
+                )
+            except Exception:
+                pass
+            raise AirflowException(f"Spark Job Failed.\nSparkJob Error 
stack:\n{err}")
+        return driver_state == CustomObjectStatus.SUBMITTED
+
+    def check_pod_start_failure(self):
+        try:
+            waiting_status = (
+                
self.pod_manager.read_pod(self.pod_spec).status.container_statuses[0].state.waiting
+            )
+            waiting_reason = waiting_status.reason
+            waiting_message = waiting_status.message
+        except Exception:
+            return
+        if waiting_reason != "ContainerCreating":
+            raise AirflowException(f"Spark Job Failed.\nStatus: 
{waiting_reason}\nError: {waiting_message}")

Review Comment:
   Please do not use extra formatting '\n', that makes parsing harder.



##########
airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py:
##########
@@ -0,0 +1,366 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Launches Custom object."""
+from __future__ import annotations
+
+import time
+from copy import deepcopy
+from datetime import datetime as dt
+from functools import cached_property
+
+import tenacity
+from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
+from kubernetes.client.rest import ApiException
+
+from airflow.exceptions import AirflowException
+from airflow.providers.cncf.kubernetes.resource_convert.configmap import (
+    convert_configmap,
+    convert_configmap_to_volume,
+)
+from airflow.providers.cncf.kubernetes.resource_convert.env_variable import 
convert_env_vars
+from airflow.providers.cncf.kubernetes.resource_convert.secret import (
+    convert_image_pull_secrets,
+    convert_secret,
+)
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+def should_retry_start_spark_job(exception: BaseException) -> bool:
+    """Check if an Exception indicates a transient error and warrants 
retrying."""
+    if isinstance(exception, ApiException):
+        return exception.status == 409
+    return False
+
+
+class SparkJobSpec:
+    """Spark job spec."""
+
+    def __init__(self, **entries):
+        self.__dict__.update(entries)
+        self.validate()
+        self.update_resources()
+
+    def validate(self):
+        if self.spec.get("dynamicAllocation", {}).get("enabled"):
+            if not all(
+                [
+                    self.spec["dynamicAllocation"]["initialExecutors"],
+                    self.spec["dynamicAllocation"]["minExecutors"],
+                    self.spec["dynamicAllocation"]["maxExecutors"],
+                ]
+            ):
+                raise AirflowException("Make sure initial/min/max value for 
dynamic allocation is passed")
+
+    def update_resources(self):
+        if self.spec["driver"].get("container_resources"):
+            spark_resources = SparkResources(
+                self.spec["driver"].pop("container_resources"),
+                self.spec["executor"].pop("container_resources"),
+            )
+            self.spec["driver"].update(spark_resources.resources["driver"])
+            self.spec["executor"].update(spark_resources.resources["executor"])
+
+
+class KubernetesSpec:
+    """Spark kubernetes spec."""
+
+    def __init__(self, **entries):
+        self.__dict__.update(entries)
+        self.set_attribute()
+
+    def set_attribute(self):
+        self.env_vars = convert_env_vars(self.env_vars) if self.env_vars else 
[]
+        self.image_pull_secrets = (
+            convert_image_pull_secrets(self.image_pull_secrets) if 
self.image_pull_secrets else []
+        )
+        if self.config_map_mounts:
+            vols, vols_mounts = 
convert_configmap_to_volume(self.config_map_mounts)
+            self.volumes.extend(vols)
+            self.volume_mounts.extend(vols_mounts)
+        if self.from_env_config_map:
+            self.env_from.extend([convert_configmap(c_name) for c_name in 
self.from_env_config_map])
+        if self.from_env_secret:
+            self.env_from.extend([convert_secret(c) for c in 
self.from_env_secret])
+
+
+class SparkResources:
+    """spark resources."""
+
+    def __init__(
+        self,
+        driver: dict = {None: None},
+        executor: dict = {None: None},
+    ):
+        self.default = {
+            "gpu": {"name": None, "quantity": 0},
+            "cpu": {"request": None, "limit": None},
+            "memory": {"request": None, "limit": None},
+        }
+        self.driver = deepcopy(self.default)
+        self.executor = deepcopy(self.default)
+        if driver:
+            self.driver.update(driver)
+        if executor:
+            self.executor.update(executor)
+        self.convert_resources()
+
+    @property
+    def resources(self):
+        """Return job resources."""
+        return {"driver": self.driver_resources, "executor": 
self.executor_resources}
+
+    @property
+    def driver_resources(self):
+        """Return resources to use."""
+        driver = {}
+        if self.driver["cpu"].get("request"):
+            driver["cores"] = self.driver["cpu"]["request"]
+        if self.driver["cpu"].get("limit"):
+            driver["coreLimit"] = self.driver["cpu"]["limit"]
+        if self.driver["memory"].get("limit"):
+            driver["memory"] = self.driver["memory"]["limit"]
+        if self.driver["gpu"].get("name") and 
self.driver["gpu"].get("quantity"):
+            driver["gpu"] = {"name": self.driver["gpu"]["name"], "quantity": 
self.driver["gpu"]["quantity"]}
+        return driver
+
+    @property
+    def executor_resources(self):
+        """Return resources to use."""
+        executor = {}
+        if self.executor["cpu"].get("request"):
+            executor["cores"] = self.executor["cpu"]["request"]
+        if self.executor["cpu"].get("limit"):
+            executor["coreLimit"] = self.executor["cpu"]["limit"]
+        if self.executor["memory"].get("limit"):
+            executor["memory"] = self.executor["memory"]["limit"]
+        if self.executor["gpu"].get("name") and 
self.executor["gpu"].get("quantity"):
+            executor["gpu"] = {
+                "name": self.executor["gpu"]["name"],
+                "quantity": self.executor["gpu"]["quantity"],
+            }
+        return executor
+
+    def convert_resources(self):
+        if isinstance(self.driver["memory"].get("limit"), str):
+            if "G" in self.driver["memory"]["limit"] or "Gi" in 
self.driver["memory"]["limit"]:
+                self.driver["memory"]["limit"] = 
float(self.driver["memory"]["limit"].rstrip("Gi G")) * 1024
+            elif "m" in self.driver["memory"]["limit"]:
+                self.driver["memory"]["limit"] = 
float(self.driver["memory"]["limit"].rstrip("m"))
+            # Adjusting the memory value as operator adds 40% to the given 
value
+            self.driver["memory"]["limit"] = 
str(int(self.driver["memory"]["limit"] / 1.4)) + "m"
+
+        if isinstance(self.executor["memory"].get("limit"), str):
+            if "G" in self.executor["memory"]["limit"] or "Gi" in 
self.executor["memory"]["limit"]:
+                self.executor["memory"]["limit"] = (
+                    float(self.executor["memory"]["limit"].rstrip("Gi G")) * 
1024
+                )
+            elif "m" in self.executor["memory"]["limit"]:
+                self.executor["memory"]["limit"] = 
float(self.executor["memory"]["limit"].rstrip("m"))
+            # Adjusting the memory value as operator adds 40% to the given 
value
+            self.executor["memory"]["limit"] = 
str(int(self.executor["memory"]["limit"] / 1.4)) + "m"
+
+        if self.driver["cpu"].get("request"):
+            self.driver["cpu"]["request"] = 
int(float(self.driver["cpu"]["request"]))
+        if self.driver["cpu"].get("limit"):
+            self.driver["cpu"]["limit"] = str(self.driver["cpu"]["limit"])
+        if self.executor["cpu"].get("request"):
+            self.executor["cpu"]["request"] = 
int(float(self.executor["cpu"]["request"]))
+        if self.executor["cpu"].get("limit"):
+            self.executor["cpu"]["limit"] = str(self.executor["cpu"]["limit"])
+
+        if self.driver["gpu"].get("quantity"):
+            self.driver["gpu"]["quantity"] = 
int(float(self.driver["gpu"]["quantity"]))
+        if self.executor["gpu"].get("quantity"):
+            self.executor["gpu"]["quantity"] = 
int(float(self.executor["gpu"]["quantity"]))
+
+
+class CustomObjectStatus:
+    """Status of the PODs."""
+
+    SUBMITTED = "SUBMITTED"
+    RUNNING = "RUNNING"
+    FAILED = "FAILED"
+    SUCCEEDED = "SUCCEEDED"
+
+
+class CustomObjectLauncher(LoggingMixin):
+    """Launches PODS."""
+
+    def __init__(
+        self,
+        name: str | None,
+        namespace: str | None,
+        kube_client: CoreV1Api,
+        custom_obj_api: CustomObjectsApi,
+        template_body: str | None = None,
+    ):
+        """
+        Creates custom object launcher(sparkapplications crd).
+
+                :param kube_client: kubernetes client.
+        """
+        super().__init__()
+        self.name = name
+        self.namespace = namespace
+        self.template_body = template_body
+        self.body: dict = self.get_body()
+        self.kind = self.body["kind"]
+        self.plural = f"{self.kind.lower()}s"
+        if self.body.get("apiVersion"):
+            self.api_group, self.api_version = 
self.body["apiVersion"].split("/")
+        else:
+            self.api_group = self.body["apiGroup"]
+            self.api_version = self.body["version"]
+        self._client = kube_client
+        self.custom_obj_api = custom_obj_api
+        self.spark_obj_spec: dict = {}
+        self.pod_spec: k8s.V1Pod | None = None
+
+    @cached_property
+    def pod_manager(self) -> PodManager:
+        return PodManager(kube_client=self._client)
+
+    def get_body(self):
+        self.body: dict = SparkJobSpec(**self.template_body["spark"])
+        self.body.metadata = {"name": self.name, "namespace": self.namespace}
+        if self.template_body.get("kubernetes"):
+            k8s_spec: dict = KubernetesSpec(**self.template_body["kubernetes"])
+            self.body.spec["volumes"] = k8s_spec.volumes
+            if k8s_spec.image_pull_secrets:
+                self.body.spec["imagePullSecrets"] = 
k8s_spec.image_pull_secrets
+            for item in ["driver", "executor"]:
+                # Env List
+                self.body.spec[item]["env"] = k8s_spec.env_vars
+                self.body.spec[item]["envFrom"] = k8s_spec.env_from
+                # Volumes
+                self.body.spec[item]["volumeMounts"] = k8s_spec.volume_mounts
+                # Add affinity
+                self.body.spec[item]["affinity"] = k8s_spec.affinity
+                self.body.spec[item]["tolerations"] = k8s_spec.tolerations
+                # Labels
+                self.body.spec[item]["labels"] = self.body.spec["labels"]
+
+        return self.body.__dict__
+
+    @tenacity.retry(
+        stop=tenacity.stop_after_attempt(3),
+        wait=tenacity.wait_random_exponential(),
+        reraise=True,
+        retry=tenacity.retry_if_exception(should_retry_start_spark_job),
+    )
+    def start_spark_job(self, image=None, code_path=None, startup_timeout: int 
= 600):
+        """
+        Launches the pod synchronously and waits for completion.
+
+        :param image: image name
+        :param code_path: path to the .py file for python and jar file for 
scala
+        :param startup_timeout: Timeout for startup of the pod (if pod is 
pending for too long, fails task)
+        :return:
+        """
+        try:
+            if image:
+                self.body["spec"]["image"] = image
+            if code_path:
+                self.body["spec"]["mainApplicationFile"] = code_path
+            self.log.debug("Spark Job Creation Request Submitted")
+            self.spark_obj_spec = 
self.custom_obj_api.create_namespaced_custom_object(
+                group=self.api_group,
+                version=self.api_version,
+                namespace=self.namespace,
+                plural=self.plural,
+                body=self.body,
+            )
+            self.log.debug("Spark Job Creation Response: %s", 
self.spark_obj_spec)
+
+            # Wait for the driver pod to come alive
+            self.pod_spec = k8s.V1Pod(
+                metadata=k8s.V1ObjectMeta(
+                    labels=self.spark_obj_spec["spec"]["driver"]["labels"],
+                    name=self.spark_obj_spec["metadata"]["name"] + "-driver",
+                    namespace=self.namespace,
+                )
+            )
+            curr_time = dt.now()
+            while self.spark_job_not_running(self.spark_obj_spec):
+                self.log.warning(
+                    "Spark job submitted but not yet started. job_id: %s",
+                    self.spark_obj_spec["metadata"]["name"],
+                )
+                self.check_pod_start_failure()
+                delta = dt.now() - curr_time
+                if delta.total_seconds() >= startup_timeout:
+                    pod_status = 
self.pod_manager.read_pod(self.pod_spec).status.container_statuses
+                    raise AirflowException(f"Job took too long to start. pod 
status: {pod_status}")
+                time.sleep(10)
+        except Exception as e:
+            self.log.exception("Exception when attempting to create spark job")
+            raise e
+
+        return self.pod_spec, self.spark_obj_spec
+
+    def spark_job_not_running(self, spark_obj_spec):
+        """Tests if spark_obj_spec has not started."""
+        spark_job_info = 
self.custom_obj_api.get_namespaced_custom_object_status(
+            group=self.api_group,
+            version=self.api_version,
+            namespace=self.namespace,
+            name=spark_obj_spec["metadata"]["name"],
+            plural=self.plural,
+        )
+        driver_state = spark_job_info.get("status", 
{}).get("applicationState", {}).get("state", "SUBMITTED")
+        if driver_state == CustomObjectStatus.FAILED:
+            err = spark_job_info.get("status", {}).get("applicationState", 
{}).get("errorMessage", "N/A")
+            try:
+                self.pod_manager.fetch_container_logs(
+                    pod=self.pod_spec, container_name="spark-kubernetes-driver"
+                )
+            except Exception:
+                pass
+            raise AirflowException(f"Spark Job Failed.\nSparkJob Error 
stack:\n{err}")

Review Comment:
   Please do not use extra formatting '\n', that makes parsing harder.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to