dstandish commented on code in PR #22253:
URL: https://github.com/apache/airflow/pull/22253#discussion_r997240324


##########
airflow/kubernetes/custom_object_launcher.py:
##########
@@ -0,0 +1,354 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Launches Custom object"""
+import sys
+import time
+from copy import deepcopy
+from datetime import datetime as dt
+from typing import Optional
+
+import tenacity
+import yaml
+from kubernetes import client, watch
+from kubernetes.client import models as k8s
+from kubernetes.client.rest import ApiException
+
+from airflow.exceptions import AirflowException
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+if sys.version_info >= (3, 8):
+    from functools import cached_property
+else:
+    from cached_property import cached_property
+
+
+def should_retry_start_spark_job(exception: BaseException) -> bool:
+    """Check if an Exception indicates a transient error and warrants 
retrying"""
+    if isinstance(exception, ApiException):
+        return exception.status == 409
+    return False
+
+
+class SparkResources:
+    """spark resources
+    :param request_memory: requested memory
+    :param request_cpu: requested CPU number
+    :param request_ephemeral_storage: requested ephemeral storage
+    :param limit_memory: limit for memory usage
+    :param limit_cpu: Limit for CPU used
+    :param limit_gpu: Limits for GPU used
+    :param limit_ephemeral_storage: Limit for ephemeral storage
+    """
+
+    def __init__(
+        self,
+        **kwargs,
+    ):
+        self.driver_request_cpu = kwargs.get('driver_request_cpu')
+        self.driver_limit_cpu = kwargs.get('driver_limit_cpu')
+        self.driver_limit_memory = kwargs.get('driver_limit_memory')
+        self.executor_request_cpu = kwargs.get('executor_request_cpu')
+        self.executor_limit_cpu = kwargs.get('executor_limit_cpu')
+        self.executor_limit_memory = kwargs.get('executor_limit_memory')
+        self.driver_gpu_name = kwargs.get('driver_gpu_name')
+        self.driver_gpu_quantity = kwargs.get('driver_gpu_quantity')
+        self.executor_gpu_name = kwargs.get('executor_gpu_name')
+        self.executor_gpu_quantity = kwargs.get('executor_gpu_quantity')
+        self.convert_resources()
+
+    @property
+    def resources(self):
+        """Return job resources"""
+        return {'driver': self.driver_resources, 'executor': 
self.executor_resources}
+
+    @property
+    def driver_resources(self):
+        """Return resources to use"""
+        driver = {}
+        if self.driver_request_cpu:
+            driver['cores'] = self.driver_request_cpu
+        if self.driver_limit_cpu:
+            driver['coreLimit'] = self.driver_limit_cpu
+        if self.driver_limit_memory:
+            driver['memory'] = self.driver_limit_memory
+        if self.driver_gpu_name and self.driver_gpu_quantity:
+            driver['gpu'] = {'name': self.driver_gpu_name, 'quantity': 
self.driver_gpu_quantity}
+        return driver
+
+    @property
+    def executor_resources(self):
+        """Return resources to use"""
+        executor = {}
+        if self.executor_request_cpu:
+            executor['cores'] = self.executor_request_cpu
+        if self.executor_limit_cpu:
+            executor['coreLimit'] = self.executor_limit_cpu
+        if self.executor_limit_memory:
+            executor['memory'] = self.executor_limit_memory
+        if self.executor_gpu_name and self.executor_gpu_quantity:
+            executor['gpu'] = {'name': self.executor_gpu_name, 'quantity': 
self.executor_gpu_quantity}
+        return executor
+
+    def convert_resources(self):
+        if isinstance(self.driver_limit_memory, str):
+            if 'G' in self.driver_limit_memory or 'Gi' in 
self.driver_limit_memory:
+                self.driver_limit_memory = 
float(self.driver_limit_memory.rstrip('Gi G')) * 1024
+            elif 'm' in self.driver_limit_memory:
+                self.driver_limit_memory = 
float(self.driver_limit_memory.rstrip('m'))
+            # Adjusting the memory value as operator adds 40% to the given 
value
+            self.driver_limit_memory = str(int(self.driver_limit_memory / 
1.4)) + 'm'
+
+        if isinstance(self.executor_limit_memory, str):
+            if 'G' in self.executor_limit_memory or 'Gi' in 
self.executor_limit_memory:
+                self.executor_limit_memory = 
float(self.executor_limit_memory.rstrip('Gi G')) * 1024
+            elif 'm' in self.executor_limit_memory:
+                self.executor_limit_memory = 
float(self.executor_limit_memory.rstrip('m'))
+            # Adjusting the memory value as operator adds 40% to the given 
value
+            self.executor_limit_memory = str(int(self.executor_limit_memory / 
1.4)) + 'm'
+
+        if self.driver_request_cpu:
+            self.driver_request_cpu = int(float(self.driver_request_cpu))
+        if self.driver_limit_cpu:
+            self.driver_limit_cpu = str(self.driver_limit_cpu)
+        if self.executor_request_cpu:
+            self.executor_request_cpu = int(float(self.executor_request_cpu))
+        if self.executor_limit_cpu:
+            self.executor_limit_cpu = str(self.executor_limit_cpu)
+
+        if self.driver_gpu_quantity:
+            self.driver_gpu_quantity = int(float(self.driver_gpu_quantity))
+        if self.executor_gpu_quantity:
+            self.executor_gpu_quantity = int(float(self.executor_gpu_quantity))
+
+
+class CustomObjectStatus:
+    """Status of the PODs"""
+
+    SUBMITTED = 'SUBMITTED'
+    RUNNING = 'RUNNING'
+    FAILED = 'FAILED'
+    SUCCEEDED = 'SUCCEEDED'
+    INITIAL_SPEC = {
+        'metadata': {},
+        'spec': {
+            'dynamicAllocation': {'enabled': False},
+            'driver': {},
+            'executor': {},
+        },
+    }
+
+
+class CustomObjectLauncher(LoggingMixin):
+    """Launches PODS"""
+
+    def __init__(
+        self,
+        kube_client: client.CoreV1Api,
+        custom_obj_api: client.CustomObjectsApi,
+        namespace: str = 'default',
+        api_group: str = 'sparkoperator.k8s.io',
+        api_version: str = 'v1beta2',
+        plural: str = 'sparkapplications',
+        kind: str = 'SparkApplication',
+        extract_xcom: bool = False,
+        application_file: Optional[str] = None,
+    ):
+        """
+        Creates the launcher.
+
+        :param kube_client: kubernetes client
+        :param extract_xcom: whether we should extract xcom
+        """
+        super().__init__()
+        self.namespace = namespace
+        self.api_group = api_group
+        self.api_version = api_version
+        self.plural = plural
+        self.kind = kind
+        self._client = kube_client
+        self.custom_obj_api = custom_obj_api
+        self._watch = watch.Watch()
+        self.extract_xcom = extract_xcom
+        self.spark_obj_spec: dict = {}
+        self.pod_spec: dict = {}
+        self.body: dict = {}
+        self.application_file = application_file
+
+    @cached_property
+    def pod_manager(self) -> PodManager:
+        return PodManager(kube_client=self._client)
+
+    @staticmethod
+    def _load_body(file):
+        # try:
+        #     base_body = yaml.safe_load(file)
+        # except Exception:
+        try:
+            with open(file) as data:
+                base_body = yaml.safe_load(data)
+        except yaml.YAMLError as e:
+            raise AirflowException(f"Exception when loading resource 
definition: {e}\n")
+        return base_body
+
+    def set_body(self, **kwargs):
+        if self.application_file:
+            self.body = self._load_body(self.application_file)
+        else:
+            self.body = self.get_body(
+                f'{self.api_group}/{self.api_version}', self.kind, 
CustomObjectStatus.INITIAL_SPEC, **kwargs
+            )
+
+    @tenacity.retry(
+        stop=tenacity.stop_after_attempt(3),
+        wait=tenacity.wait_random_exponential(),
+        reraise=True,
+        retry=tenacity.retry_if_exception(should_retry_start_spark_job),
+    )
+    def start_spark_job(self, startup_timeout: int = 600):
+        """
+        Launches the pod synchronously and waits for completion.
+
+        :param startup_timeout: Timeout for startup of the pod (if pod is 
pending for too long, fails task)
+        :return:
+        """
+        try:
+            self.log.debug('Spark Job Creation Request Submitted')
+            self.spark_obj_spec = 
self.custom_obj_api.create_namespaced_custom_object(
+                group=self.api_group,
+                version=self.api_version,
+                namespace=self.namespace,
+                plural=self.plural,
+                body=self.body,
+            )
+            self.log.debug('Spark Job Creation Response: %s', 
self.spark_obj_spec)
+
+            # Wait for the driver pod to come alive
+            self.pod_spec = k8s.V1Pod(
+                metadata=k8s.V1ObjectMeta(
+                    labels=self.spark_obj_spec['spec']['driver']['labels'],
+                    name=self.spark_obj_spec['metadata']['name'] + '-driver',
+                    namespace=self.namespace,
+                )
+            )
+            curr_time = dt.now()
+            while self.spark_job_not_running(self.spark_obj_spec):
+                self.log.warning(
+                    "Spark job submitted but not yet started: %s", 
self.spark_obj_spec['metadata']['name']
+                )
+                delta = dt.now() - curr_time
+                if delta.total_seconds() >= startup_timeout:
+                    pod_status = 
self.pod_manager.read_pod(self.pod_spec).status.container_statuses
+                    raise AirflowException(f"Job took too long to start. pod 
status: {pod_status}")
+                time.sleep(2)
+        except Exception as e:
+            self.log.exception('Exception when attempting to create spark job: 
%s', self.body)
+            raise e
+
+        return self.pod_spec, self.spark_obj_spec
+
+    def spark_job_not_running(self, spark_obj_spec):
+        """Tests if spark_obj_spec has not started"""
+        spark_job_info = 
self.custom_obj_api.get_namespaced_custom_object_status(
+            group=self.api_group,
+            version=self.api_version,
+            namespace=self.namespace,
+            name=spark_obj_spec['metadata']['name'],
+            plural=self.plural,
+        )
+        driver_state = spark_job_info.get('status', 
{}).get('applicationState', {}).get('state', 'SUBMITTED')
+        if driver_state == CustomObjectStatus.FAILED:
+            err = spark_job_info.get('status', {}).get('applicationState', 
{}).get('errorMessage')
+            raise AirflowException(f"Spark Job Failed. Error stack:\n{err}")
+        return driver_state == CustomObjectStatus.SUBMITTED
+
+    def delete_spark_job(self, spark_job_name=None):
+        """Deletes spark job"""
+        spark_job_name = spark_job_name or self.spark_obj_spec.get('metadata', 
{}).get('name')
+        if not spark_job_name:
+            self.log.warning("Spark job not found: %s", spark_job_name)
+            return
+        try:
+            v1 = client.CustomObjectsApi()
+            v1.delete_namespaced_custom_object(
+                group=self.api_group,
+                version=self.api_version,
+                namespace=self.namespace,
+                plural=self.plural,
+                name=spark_job_name,
+            )
+        except ApiException as e:
+            # If the pod is already deleted
+            if e.status != 404:
+                raise
+
+    @staticmethod
+    def get_body(api_version, kind, initial_template, **kwargs):

Review Comment:
   e.g. "re the whole point" ... what about this...
   
   > Similar to the KubernetesOperator, we have added the logic to wait for a 
job after submission, manage error handling, retrieve logs from the driver pod, 
ability to delete a spark job. It also supports out-of-the-box Kubernetes 
functionalities such as handling volumes, config maps, secrets, etc.
   
   You know, it might be easier to review / less controversial to split this 
into 2 PRs...
   
   one: "add the retry handling etc etc a la KPO" 
   the other: "simplify the interface".  
   
   To me, it doesn't really seem like it _does_  simplify the interface rather, 
it seems more complicated but.... certainly if viewed separately from the other 
changes it would be easier to evaluate



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to