kaxil commented on a change in pull request #6090: [AIRFLOW-5470] Add Apache 
Livy REST operator
URL: https://github.com/apache/airflow/pull/6090#discussion_r371115807
 
 

 ##########
 File path: airflow/providers/apache/livy/hooks/livy_hook.py
 ##########
 @@ -0,0 +1,379 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Apache Livy hook.
+"""
+
+import json
+import re
+from enum import Enum
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.http_hook import HttpHook
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+class BatchState(Enum):
+    """
+    Batch session states
+    """
+    NOT_STARTED = 'not_started'
+    STARTING = 'starting'
+    RUNNING = 'running'
+    IDLE = 'idle'
+    BUSY = 'busy'
+    SHUTTING_DOWN = 'shutting_down'
+    ERROR = 'error'
+    DEAD = 'dead'
+    KILLED = 'killed'
+    SUCCESS = 'success'
+
+
+class LivyHook(HttpHook, LoggingMixin):
+    """
+    Hook for Apache Livy through the REST API.
+
+    :param livy_conn_id: reference to a pre-defined Livy Connection.
+    :type livy_conn_id: str
+
+    .. seealso::
+        For more details refer to the Apache Livy API reference:
+        https://livy.apache.org/docs/latest/rest-api.html
+    """
+
+    TERMINAL_STATES = {
+        BatchState.SUCCESS,
+        BatchState.DEAD,
+        BatchState.KILLED,
+        BatchState.ERROR,
+    }
+
+    _def_headers = {
+        'Content-Type': 'application/json',
+        'Accept': 'application/json'
+    }
+
+    def __init__(self, livy_conn_id='livy_default'):
+        super(LivyHook, self).__init__(http_conn_id=livy_conn_id)
+
+    def get_conn(self, headers=None):
+        """
+        Returns http session for use with requests
+
+        :param headers: additional headers to be passed through as a dictionary
+        :type headers: dict
+        :return: requests session
+        :rtype: requests.Session
+        """
+        tmp_headers = self._def_headers.copy()  # setting default headers
+        if headers:
+            tmp_headers.update(headers)
+        return super().get_conn(tmp_headers)
+
+    def run_method(self, method='GET', endpoint=None, data=None, headers=None, 
extra_options=None):
+        """
+        Wrapper for HttpHook, allows to change method on the same HttpHook
+
+        :param method: http method
+        :type method: str
+        :param endpoint: endpoint
+        :type endpoint: str
+        :param data: request payload
+        :type data: dict
+        :param headers: headers
+        :type headers: dict
+        :param extra_options: extra options
+        :type extra_options: dict
+        :return: http response
+        :rtype: requests.Response
+        """
+        if method not in ('GET', 'POST', 'PUT', 'DELETE', 'HEAD'):
+            raise AirflowException("Invalid http method '{}'".format(method))
+
+        back_method = self.method
+        self.method = method
+        try:
+            result = self.run(endpoint, data, headers, extra_options)
+        finally:
+            self.method = back_method
+        return result
+
+    def post_batch(self, *args, **kwargs):
+        """
+        Perform request to submit batch
+
+        :return: batch session id
+        :rtype: int
+        """
+        batch_submit_body = json.dumps(self.build_post_batch_body(*args, 
**kwargs))
+
+        if self.base_url is None:
+            # need to init self.base_url
+            self.get_conn()
+        self.log.info("Submitting job {} to {}".format(batch_submit_body, 
self.base_url))
+
+        response = self.run_method(
+            method='POST',
+            endpoint='/batches',
+            data=batch_submit_body
+        )
+        self.log.debug("Got response: {}".format(response.text))
+
+        if response.status_code != 201:
+            raise AirflowException("Could not submit batch. Status code: 
{}".format(response.status_code))
+
+        batch_id = self._parse_post_response(response.json())
+        if batch_id is None:
+            raise AirflowException("Unable to parse a batch session id")
+        self.log.info("Batch submitted with session id: {}".format(batch_id))
+
+        return batch_id
+
+    def get_batch(self, session_id):
+        """
+        Fetch info about the specified batch
+
+        :param session_id: identifier of the batch sessions
+        :type session_id: int
+        :return: response body
+        :rtype: dict
+        """
+        self._validate_session_id(session_id)
+
+        self.log.debug("Fetching info for batch session {}".format(session_id))
+        response = self.run_method(endpoint='/batches/{}'.format(session_id))
+
+        if response.status_code != 200:
+            self.log.warning("Got status code {} for session 
{}".format(response.status_code, session_id))
+            raise AirflowException("Unable to fetch batch with id: 
{}".format(session_id))
+
+        return response.json()
+
+    def get_batch_state(self, session_id):
+        """
+        Fetch the state of the specified batch
+
+        :param session_id: identifier of the batch sessions
+        :type session_id: Union[int, str]
+        :return: batch state
+        :rtype: str
+        """
+        self._validate_session_id(session_id)
+
+        self.log.debug("Fetching info for batch session {}".format(session_id))
+        response = 
self.run_method(endpoint='/batches/{}/state'.format(session_id))
+
+        if response.status_code != 200:
+            self.log.warning("Got status code {} for session 
{}".format(response.status_code, session_id))
+            raise AirflowException("Unable to fetch state for batch id: 
{}".format(session_id))
+
+        jresp = response.json()
+        if 'state' not in jresp:
+            raise AirflowException("Unable to get state for batch with id: 
{}".format(session_id))
+        return BatchState(jresp['state'])
+
+    def delete_batch(self, session_id):
+        """
+        Delete the specified batch
+
+        :param session_id: identifier of the batch sessions
+        :type session_id: int
+        :return: response body
+        :rtype: dict
+        """
+        self._validate_session_id(session_id)
+
+        self.log.info("Deleting batch session {}".format(session_id))
+        response = self.run_method(
+            method='DELETE',
+            endpoint='/batches/{}'.format(session_id)
+        )
+
+        if response.status_code != 200:
+            self.log.warning("Got status code {} for session 
{}".format(response.status_code, session_id))
+            raise AirflowException("Could not kill the batch with session id: 
{}".format(session_id))
+
+        return response.json()
+
+    @staticmethod
+    def _validate_session_id(session_id):
+        """
+        Validate session id is a int
+
+        :param session_id: session id
+        :type session_id: Union[int, str]
+        """
+        try:
+            int(session_id)
+        except (TypeError, ValueError):
+            raise AirflowException("'session_id' must represent an integer")
+
+    @staticmethod
+    def _parse_post_response(response):
+        """
+        Parse batch response for batch id
+
+        :param response: response body
+        :type response: dict
+        :return: session id
+        :rtype: str
+        """
+        return response.get('id')
+
+    @staticmethod
+    def build_post_batch_body(
+        file,
+        args=None,
+        class_name=None,
+        jars=None,
+        py_files=None,
+        files=None,
+        archives=None,
+        name=None,
+        driver_memory=None,
+        driver_cores=None,
+        executor_memory=None,
+        executor_cores=None,
+        num_executors=None,
+        queue=None,
+        proxy_user=None,
+        conf=None
+    ):
+        """
+        Build the post batch request body.
+        For more information about the format refer to
+        .. seealso:: https://livy.apache.org/docs/latest/rest-api.html
+
+        :param file: Path of the file containing the application to execute 
(required).
+        :type file: str
+        :param proxy_user: User to impersonate when running the job.
+        :type proxy_user: str
+        :param class_name: Application Java/Spark main class string.
+        :type class_name: str
+        :param args: Command line arguments for the application s.
+        :type args: list
+        :param jars: jars to be used in this sessions.
+        :type jars: list
+        :param py_files: Python files to be used in this session.
+        :type py_files: list
+        :param files: files to be used in this session.
+        :type files: list
+        :param driver_memory: Amount of memory to use for the driver process  
string.
+        :type driver_memory: str
+        :param driver_cores: Number of cores to use for the driver process int.
+        :type driver_cores: str
+        :param executor_memory: Amount of memory to use per executor process  
string.
+        :type executor_memory: str
+        :param executor_cores: Number of cores to use for each executor  int.
+        :type executor_cores: str
+        :param num_executors: Number of executors to launch for this session  
int.
+        :type num_executors: str
+        :param archives: Archives to be used in this session.
+        :type archives: list
+        :param queue: The name of the YARN queue to which submitted string.
+        :type queue: str
+        :param name: The name of this session  string.
+        :type name: str
+        :param conf: Spark configuration properties.
+        :type conf: dict
+        :return: request body
+        :rtype: dict
+        """
+        # pylint: disable-msg=too-many-arguments
+
+        body = {'file': file}
+
+        if proxy_user:
+            body['proxyUser'] = proxy_user
+        if class_name:
+            body['className'] = class_name
+        if args and LivyHook._validate_list_of_stringables(args):
+            body['args'] = [str(val) for val in args]
+        if jars and LivyHook._validate_list_of_stringables(jars):
+            body['jars'] = jars
+        if py_files and LivyHook._validate_list_of_stringables(py_files):
+            body['pyFiles'] = py_files
+        if files and LivyHook._validate_list_of_stringables(files):
+            body['files'] = files
+        if driver_memory and 
LivyHook._validate_list_of_stringables(driver_memory):
+            body['driverMemory'] = driver_memory
+        if driver_cores:
+            body['driverCores'] = driver_cores
+        if executor_memory and LivyHook._validate_size_format(executor_memory):
+            body['executorMemory'] = executor_memory
+        if executor_cores:
+            body['executorCores'] = executor_cores
+        if num_executors:
+            body['numExecutors'] = num_executors
+        if archives and LivyHook._validate_size_format(archives):
+            body['archives'] = archives
+        if queue:
+            body['queue'] = queue
+        if name:
+            body['name'] = name
+        if conf and LivyHook._validate_extra_conf(conf):
+            body['conf'] = conf
+
+        return body
+
+    @staticmethod
+    def _validate_size_format(size):
+        """
+        Validate size format.
+
+        :param size: size value
+        :type size: str
+        :return: true if valid format
+        :rtype: bool
+        """
+        if size and not (isinstance(size, str) and re.match(r'^\d+[kmgt]b?$', 
size, re.IGNORECASE)):
+            raise AirflowException("Invalid java size format for 
string'{}'".format(size))
 
 Review comment:
   ```suggestion
               raise ValueError("Invalid java size format for 
string'{}'".format(size))
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to