Repository: incubator-airflow Updated Branches: refs/heads/master 5a6f18f1c -> 53ca50845
[AIRFLOW-1028] Databricks Operator for Airflow Add DatabricksSubmitRun Operator In this PR, we contribute a DatabricksSubmitRun operator and a Databricks hook. This operator enables easy integration of Airflow with Databricks. In addition to the operator, we have created a databricks_default connection, an example_dag using this DatabricksSubmitRunOperator, and matching documentation. Closes #2202 from andrewmchen/databricks-operator- squashed Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/53ca5084 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/53ca5084 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/53ca5084 Branch: refs/heads/master Commit: 53ca5084561fd5c13996609f2eda6baf717249b5 Parents: 5a6f18f Author: Andrew Chen <[email protected]> Authored: Thu Apr 6 08:30:01 2017 -0700 Committer: Arthur Wiedmer <[email protected]> Committed: Thu Apr 6 08:30:33 2017 -0700 ---------------------------------------------------------------------- .../example_dags/example_databricks_operator.py | 82 +++++++ airflow/contrib/hooks/databricks_hook.py | 202 +++++++++++++++++ .../contrib/operators/databricks_operator.py | 211 +++++++++++++++++ airflow/exceptions.py | 2 +- airflow/models.py | 1 + airflow/utils/db.py | 4 + docs/code.rst | 1 + docs/integration.rst | 13 ++ setup.py | 2 + tests/contrib/hooks/databricks_hook.py | 226 +++++++++++++++++++ tests/contrib/operators/databricks_operator.py | 185 +++++++++++++++ 11 files changed, 928 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/contrib/example_dags/example_databricks_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/example_dags/example_databricks_operator.py b/airflow/contrib/example_dags/example_databricks_operator.py new file mode 100644 index 0000000..abf6844 --- /dev/null +++ b/airflow/contrib/example_dags/example_databricks_operator.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import airflow + +from airflow import DAG +from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator + +# This is an example DAG which uses the DatabricksSubmitRunOperator. +# In this example, we create two tasks which execute sequentially. +# The first task is to run a notebook at the workspace path "/test" +# and the second task is to run a JAR uploaded to DBFS. Both, +# tasks use new clusters. +# +# Because we have set a downstream dependency on the notebook task, +# the spark jar task will NOT run until the notebook task completes +# successfully. +# +# The definition of a succesful run is if the run has a result_state of "SUCCESS". +# For more information about the state of a run refer to +# https://docs.databricks.com/api/latest/jobs.html#runstate + +args = { + 'owner': 'airflow', + 'email': ['[email protected]'], + 'depends_on_past': False, + 'start_date': airflow.utils.dates.days_ago(2) +} + +dag = DAG( + dag_id='example_databricks_operator', default_args=args, + schedule_interval='@daily') + +new_cluster = { + 'spark_version': '2.1.0-db3-scala2.11', + 'node_type_id': 'r3.xlarge', + 'aws_attributes': { + 'availability': 'ON_DEMAND' + }, + 'num_workers': 8 +} + +notebook_task_params = { + 'new_cluster': new_cluster, + 'notebook_task': { + 'notebook_path': '/Users/[email protected]/PrepareData', + }, +} +# Example of using the JSON parameter to initialize the operator. +notebook_task = DatabricksSubmitRunOperator( + task_id='notebook_task', + dag=dag, + json=notebook_task_params) + +# Example of using the named parameters of DatabricksSubmitRunOperator +# to initialize the operator. +spark_jar_task = DatabricksSubmitRunOperator( + task_id='spark_jar_task', + dag=dag, + new_cluster=new_cluster, + spark_jar_task={ + 'main_class_name': 'com.example.ProcessData' + }, + libraries=[ + { + 'jar': 'dbfs:/lib/etl-0.1.jar' + } + ] +) + +notebook_task.set_downstream(spark_jar_task) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/contrib/hooks/databricks_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py new file mode 100644 index 0000000..0cd5d0f --- /dev/null +++ b/airflow/contrib/hooks/databricks_hook.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import requests + +from airflow import __version__ +from airflow.exceptions import AirflowException +from airflow.hooks.base_hook import BaseHook +from requests import exceptions as requests_exceptions + + +try: + from urllib import parse as urlparse +except ImportError: + import urlparse + + +SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit') +GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get') +CANCEL_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/cancel') +USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)} + + +class DatabricksHook(BaseHook): + """ + Interact with Databricks. + """ + def __init__( + self, + databricks_conn_id='databricks_default', + timeout_seconds=180, + retry_limit=3): + """ + :param databricks_conn_id: The name of the databricks connection to use. + :type databricks_conn_id: string + :param timeout_seconds: The amount of time in seconds the requests library + will wait before timing-out. + :type timeout_seconds: int + :param retry_limit: The number of times to retry the connection in case of + service outages. + :type retry_limit: int + """ + self.databricks_conn_id = databricks_conn_id + self.databricks_conn = self.get_connection(databricks_conn_id) + self.timeout_seconds = timeout_seconds + assert retry_limit >= 1, 'Retry limit must be greater than equal to 1' + self.retry_limit = retry_limit + + def _parse_host(self, host): + """ + The purpose of this function is to be robust to improper connections + settings provided by users, specifically in the host field. + + + For example -- when users supply ``https://xx.cloud.databricks.com`` as the + host, we must strip out the protocol to get the host. + >>> h = DatabricksHook() + >>> assert h._parse_host('https://xx.cloud.databricks.com') == \ + 'xx.cloud.databricks.com' + + In the case where users supply the correct ``xx.cloud.databricks.com`` as the + host, this function is a no-op. + >>> assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com' + """ + urlparse_host = urlparse.urlparse(host).hostname + if urlparse_host: + # In this case, host = https://xx.cloud.databricks.com + return urlparse_host + else: + # In this case, host = xx.cloud.databricks.com + return host + + def _do_api_call(self, endpoint_info, json): + """ + Utility function to perform an API call with retries + :param endpoint_info: Tuple of method and endpoint + :type endpoint_info: (string, string) + :param json: Parameters for this API call. + :type json: dict + :return: If the api call returns a OK status code, + this function returns the response in JSON. Otherwise, + we throw an AirflowException. + :rtype: dict + """ + method, endpoint = endpoint_info + url = 'https://{host}/{endpoint}'.format( + host=self._parse_host(self.databricks_conn.host), + endpoint=endpoint) + auth = (self.databricks_conn.login, self.databricks_conn.password) + if method == 'GET': + request_func = requests.get + elif method == 'POST': + request_func = requests.post + else: + raise AirflowException('Unexpected HTTP Method: ' + method) + + for attempt_num in range(1, self.retry_limit+1): + try: + response = request_func( + url, + json=json, + auth=auth, + headers=USER_AGENT_HEADER, + timeout=self.timeout_seconds) + if response.status_code == requests.codes.ok: + return response.json() + else: + # In this case, the user probably made a mistake. + # Don't retry. + raise AirflowException('Response: {0}, Status Code: {1}'.format( + response.content, response.status_code)) + except (requests_exceptions.ConnectionError, + requests_exceptions.Timeout) as e: + logging.error(('Attempt {0} API Request to Databricks failed ' + + 'with reason: {1}').format(attempt_num, e)) + raise AirflowException(('API requests to Databricks failed {} times. ' + + 'Giving up.').format(self.retry_limit)) + + def submit_run(self, json): + """ + Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint. + + :param json: The data used in the body of the request to the ``submit`` endpoint. + :type json: dict + :return: the run_id as a string + :rtype: string + """ + response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json) + return response['run_id'] + + def get_run_page_url(self, run_id): + json = {'run_id': run_id} + response = self._do_api_call(GET_RUN_ENDPOINT, json) + return response['run_page_url'] + + def get_run_state(self, run_id): + json = {'run_id': run_id} + response = self._do_api_call(GET_RUN_ENDPOINT, json) + state = response['state'] + life_cycle_state = state['life_cycle_state'] + # result_state may not be in the state if not terminal + result_state = state.get('result_state', None) + state_message = state['state_message'] + return RunState(life_cycle_state, result_state, state_message) + + def cancel_run(self, run_id): + json = {'run_id': run_id} + self._do_api_call(CANCEL_RUN_ENDPOINT, json) + + +RUN_LIFE_CYCLE_STATES = [ + 'PENDING', + 'RUNNING', + 'TERMINATING', + 'TERMINATED', + 'SKIPPED', + 'INTERNAL_ERROR' +] + + +class RunState: + """ + Utility class for the run state concept of Databricks runs. + """ + def __init__(self, life_cycle_state, result_state, state_message): + self.life_cycle_state = life_cycle_state + self.result_state = result_state + self.state_message = state_message + + @property + def is_terminal(self): + if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES: + raise AirflowException(('Unexpected life cycle state: {}: If the state has ' + 'been introduced recently, please check the Databricks user ' + 'guide for troubleshooting information').format( + self.life_cycle_state)) + return self.life_cycle_state in ('TERMINATED', 'SKIPPED', 'INTERNAL_ERROR') + + @property + def is_successful(self): + return self.result_state == 'SUCCESS' + + def __eq__(self, other): + return self.life_cycle_state == other.life_cycle_state and \ + self.result_state == other.result_state and \ + self.state_message == other.state_message + + def __repr__(self): + return str(self.__dict__) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/contrib/operators/databricks_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py new file mode 100644 index 0000000..46b1659 --- /dev/null +++ b/airflow/contrib/operators/databricks_operator.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import time + +from airflow.exceptions import AirflowException +from airflow.contrib.hooks.databricks_hook import DatabricksHook +from airflow.models import BaseOperator + +LINE_BREAK = ('-' * 80) + + +class DatabricksSubmitRunOperator(BaseOperator): + """ + Submits an Spark job run to Databricks using the + `api/2.0/jobs/runs/submit + <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_ + API endpoint. + + There are two ways to instantiate this operator. + + In the first way, you can take the JSON payload that you typically use + to call the ``api/2.0/jobs/runs/submit`` endpoint and pass it directly + to our ``DatabricksSubmitRunOperator`` through the ``json`` parameter. + For example :: + json = { + 'new_cluster': { + 'spark_version': '2.1.0-db3-scala2.11', + 'num_workers': 2 + }, + 'notebook_task': { + 'notebook_path': '/Users/[email protected]/PrepareData', + }, + } + notebook_run = DatabricksSubmitRunOperator(task_id='notebook_run', json=json) + + Another way to accomplish the same thing is to use the named parameters + of the ``DatabricksSubmitRunOperator`` directly. Note that there is exactly + one named parameter for each top level parameter in the ``runs/submit`` + endpoint. In this method, your code would look like this: :: + new_cluster = { + 'spark_version': '2.1.0-db3-scala2.11', + 'num_workers': 2 + } + notebook_task = { + 'notebook_path': '/Users/[email protected]/PrepareData', + } + notebook_run = DatabricksSubmitRunOperator( + task_id='notebook_run', + new_cluster=new_cluster, + notebook_task=notebook_task) + + In the case where both the json parameter **AND** the named parameters + are provided, they will be merged together. If there are conflicts during the merge, + the named parameters will take precedence and override the top level ``json`` keys. + + Currently the named parameters that ``DatabricksSubmitRunOperator`` supports are + - ``spark_jar_task`` + - ``notebook_task`` + - ``new_cluster`` + - ``existing_cluster_id`` + - ``libraries`` + - ``run_name`` + - ``timeout_seconds`` + + :param json: A JSON object containing API parameters which will be passed + directly to the ``api/2.0/jobs/runs/submit`` endpoint. The other named parameters + (i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will + be merged with this json dictionary if they are provided. + If there are conflicts during the merge, the named parameters will + take precedence and override the top level json keys. + https://docs.databricks.com/api/latest/jobs.html#runs-submit + :type json: dict + :param spark_jar_task: The main class and parameters for the JAR task. Note that + the actual JAR is specified in the ``libraries``. + *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified. + https://docs.databricks.com/api/latest/jobs.html#jobssparkjartask + :type spark_jar_task: dict + :param notebook_task: The notebook path and parameters for the notebook task. + *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified. + https://docs.databricks.com/api/latest/jobs.html#jobsnotebooktask + :type notebook_task: dict + :param new_cluster: Specs for a new cluster on which this task will be run. + *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified. + https://docs.databricks.com/api/latest/jobs.html#jobsclusterspecnewcluster + :type new_cluster: dict + :param existing_cluster_id: ID for existing cluster on which to run this task. + *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified. + :type existing_cluster_id: string + :param libraries: Libraries which this run will use. + https://docs.databricks.com/api/latest/libraries.html#managedlibrarieslibrary + :type libraries: list of dicts + :param run_name: The run name used for this task. + By default this will be set to the Airflow ``task_id``. This ``task_id`` is a + required parameter of the superclass ``BaseOperator``. + :type run_name: string + :param timeout_seconds: The timeout for this run. By default a value of 0 is used + which means to have no timeout. + :type timeout_seconds: int32 + :param databricks_conn_id: The name of the Airflow connection to use. + By default and in the common case this will be ``databricks_default``. + :type databricks_conn_id: string + :param polling_period_seconds: Controls the rate which we poll for the result of + this run. By default the operator will poll every 30 seconds. + :type polling_period_seconds: int + :param databricks_retry_limit: Amount of times retry if the Databricks backend is + unreachable. Its value must be greater than or equal to 1. + :type databricks_retry_limit: int + """ + # Databricks brand color (blue) under white text + ui_color = '#1CB1C2' + ui_fgcolor = '#fff' + + def __init__( + self, + json=None, + spark_jar_task=None, + notebook_task=None, + new_cluster=None, + existing_cluster_id=None, + libraries=None, + run_name=None, + timeout_seconds=None, + databricks_conn_id='databricks_default', + polling_period_seconds=30, + databricks_retry_limit=3, + **kwargs): + """ + Creates a new ``DatabricksSubmitRunOperator``. + """ + super(DatabricksSubmitRunOperator, self).__init__(**kwargs) + self.json = json or {} + self.databricks_conn_id = databricks_conn_id + self.polling_period_seconds = polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + if spark_jar_task is not None: + self.json['spark_jar_task'] = spark_jar_task + if notebook_task is not None: + self.json['notebook_task'] = notebook_task + if new_cluster is not None: + self.json['new_cluster'] = new_cluster + if existing_cluster_id is not None: + self.json['existing_cluster_id'] = existing_cluster_id + if libraries is not None: + self.json['libraries'] = libraries + if run_name is not None: + self.json['run_name'] = run_name + if timeout_seconds is not None: + self.json['timeout_seconds'] = timeout_seconds + if 'run_name' not in self.json: + self.json['run_name'] = run_name or kwargs['task_id'] + + # This variable will be used in case our task gets killed. + self.run_id = None + + def _log_run_page_url(self, url): + logging.info('View run status, Spark UI, and logs at {}'.format(url)) + + def get_hook(self): + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit) + + def execute(self, context): + hook = self.get_hook() + self.run_id = hook.submit_run(self.json) + run_page_url = hook.get_run_page_url(self.run_id) + logging.info(LINE_BREAK) + logging.info('Run submitted with run_id: {}'.format(self.run_id)) + self._log_run_page_url(run_page_url) + logging.info(LINE_BREAK) + while True: + run_state = hook.get_run_state(self.run_id) + if run_state.is_terminal: + if run_state.is_successful: + logging.info('{} completed successfully.'.format( + self.task_id)) + self._log_run_page_url(run_page_url) + return + else: + error_message = '{t} failed with terminal state: {s}'.format( + t=self.task_id, + s=run_state) + raise AirflowException(error_message) + else: + logging.info('{t} in run state: {s}'.format(t=self.task_id, + s=run_state)) + self._log_run_page_url(run_page_url) + logging.info('Sleeping for {} seconds.'.format( + self.polling_period_seconds)) + time.sleep(self.polling_period_seconds) + + def on_kill(self): + hook = self.get_hook() + hook.cancel_run(self.run_id) + logging.info('Task: {t} with run_id: {r} was requested to be cancelled.'.format( + t=self.task_id, + r=self.run_id)) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/exceptions.py ---------------------------------------------------------------------- diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 2231208..90d3e22 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -22,7 +22,7 @@ class AirflowException(Exception): class AirflowConfigException(AirflowException): pass - + class AirflowSensorTimeout(AirflowException): pass http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index 95e2255..42b621d 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -543,6 +543,7 @@ class Connection(Base): ('jira', 'JIRA',), ('redis', 'Redis',), ('wasb', 'Azure Blob Storage'), + ('databricks', 'Databricks',), ] def __init__( http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/utils/db.py ---------------------------------------------------------------------- diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 7da9217..54254f6 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -249,6 +249,10 @@ def initdb(): ] } ''')) + merge_conn( + models.Connection( + conn_id='databricks_default', conn_type='databricks', + host='localhost')) # Known event types KET = models.KnownEventType http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/docs/code.rst ---------------------------------------------------------------------- diff --git a/docs/code.rst b/docs/code.rst index 683e85f..c31061c 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -97,6 +97,7 @@ Community-contributed Operators .. autoclass:: airflow.contrib.operators.bigquery_operator.BigQueryOperator .. autoclass:: airflow.contrib.operators.bigquery_to_gcs.BigQueryToCloudStorageOperator +.. autoclass:: airflow.contrib.operators.databricks_operator.DatabricksSubmitRunOperator .. autoclass:: airflow.contrib.operators.ecs_operator.ECSOperator .. autoclass:: airflow.contrib.operators.file_to_wasb.FileToWasbOperator .. autoclass:: airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/docs/integration.rst ---------------------------------------------------------------------- diff --git a/docs/integration.rst b/docs/integration.rst index 4a6b676..a6c9d7c 100644 --- a/docs/integration.rst +++ b/docs/integration.rst @@ -61,6 +61,19 @@ AWS: Amazon Webservices --- +.. _Databricks: + +Databricks +-------------------------- +`Databricks <https://databricks.com/>`_ has contributed an Airflow operator which enables +submitting runs to the Databricks platform. Internally the operator talks to the +``api/2.0/jobs/runs/submit`` `endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_. + +DatabricksSubmitRunOperator +'''''''''''''''''''''''''''' + +.. autoclass:: airflow.contrib.operators.databricks_operator.DatabricksSubmitRunOperator + .. _GCP: GCP: Google Cloud Platform http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/setup.py ---------------------------------------------------------------------- diff --git a/setup.py b/setup.py index ea60dca..6691208 100644 --- a/setup.py +++ b/setup.py @@ -116,6 +116,7 @@ crypto = ['cryptography>=0.9.3'] dask = [ 'distributed>=1.15.2, <2' ] +databricks = ['requests>=2.5.1, <3'] datadog = ['datadog>=0.14.0'] doc = [ 'sphinx>=1.2.3', @@ -244,6 +245,7 @@ def do_setup(): 'cloudant': cloudant, 'crypto': crypto, 'dask': dask, + 'databricks': databricks, 'datadog': datadog, 'devel': devel_minreq, 'devel_hadoop': devel_hadoop, http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/tests/contrib/hooks/databricks_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/databricks_hook.py b/tests/contrib/hooks/databricks_hook.py new file mode 100644 index 0000000..6c789f9 --- /dev/null +++ b/tests/contrib/hooks/databricks_hook.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from airflow import __version__ +from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT +from airflow.exceptions import AirflowException +from airflow.models import Connection +from airflow.utils import db +from requests import exceptions as requests_exceptions + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +TASK_ID = 'databricks-operator' +DEFAULT_CONN_ID = 'databricks_default' +NOTEBOOK_TASK = { + 'notebook_path': '/test' +} +NEW_CLUSTER = { + 'spark_version': '2.0.x-scala2.10', + 'node_type_id': 'r3.xlarge', + 'num_workers': 1 +} +RUN_ID = 1 +HOST = 'xx.cloud.databricks.com' +HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com' +LOGIN = 'login' +PASSWORD = 'password' +USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)} +RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1' +LIFE_CYCLE_STATE = 'PENDING' +STATE_MESSAGE = 'Waiting for cluster' +GET_RUN_RESPONSE = { + 'run_page_url': RUN_PAGE_URL, + 'state': { + 'life_cycle_state': LIFE_CYCLE_STATE, + 'state_message': STATE_MESSAGE + } +} +RESULT_STATE = None + + +def submit_run_endpoint(host): + """ + Utility function to generate the submit run endpoint given the host. + """ + return 'https://{}/api/2.0/jobs/runs/submit'.format(host) + + +def get_run_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/jobs/runs/get'.format(host) + +def cancel_run_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/jobs/runs/cancel'.format(host) + +class DatabricksHookTest(unittest.TestCase): + """ + Tests for DatabricksHook. + """ + @db.provide_session + def setUp(self, session=None): + conn = session.query(Connection) \ + .filter(Connection.conn_id == DEFAULT_CONN_ID) \ + .first() + conn.host = HOST + conn.login = LOGIN + conn.password = PASSWORD + session.commit() + + self.hook = DatabricksHook() + + def test_parse_host_with_proper_host(self): + host = self.hook._parse_host(HOST) + self.assertEquals(host, HOST) + + def test_parse_host_with_scheme(self): + host = self.hook._parse_host(HOST_WITH_SCHEME) + self.assertEquals(host, HOST) + + def test_init_bad_retry_limit(self): + with self.assertRaises(AssertionError): + DatabricksHook(retry_limit = 0) + + @mock.patch('airflow.contrib.hooks.databricks_hook.logging') + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_do_api_call_with_error_retry(self, mock_requests, mock_logging): + for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]: + mock_requests.reset_mock() + mock_logging.reset_mock() + mock_requests.post.side_effect = exception() + + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_do_api_call_with_bad_status_code(self, mock_requests): + mock_requests.codes.ok = 200 + status_code_mock = mock.PropertyMock(return_value=500) + type(mock_requests.post.return_value).status_code = status_code_mock + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_submit_run(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {'run_id': '1'} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + json = { + 'notebook_task': NOTEBOOK_TASK, + 'new_cluster': NEW_CLUSTER + } + run_id = self.hook.submit_run(json) + + self.assertEquals(run_id, '1') + mock_requests.post.assert_called_once_with( + submit_run_endpoint(HOST), + json={ + 'notebook_task': NOTEBOOK_TASK, + 'new_cluster': NEW_CLUSTER, + }, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_get_run_page_url(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.get.return_value).status_code = status_code_mock + + run_page_url = self.hook.get_run_page_url(RUN_ID) + + self.assertEquals(run_page_url, RUN_PAGE_URL) + mock_requests.get.assert_called_once_with( + get_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_get_run_state(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.get.return_value).status_code = status_code_mock + + run_state = self.hook.get_run_state(RUN_ID) + + self.assertEquals(run_state, RunState( + LIFE_CYCLE_STATE, + RESULT_STATE, + STATE_MESSAGE)) + mock_requests.get.assert_called_once_with( + get_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_cancel_run(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.cancel_run(RUN_ID) + + mock_requests.post.assert_called_once_with( + cancel_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + +class RunStateTest(unittest.TestCase): + def test_is_terminal_true(self): + terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] + for state in terminal_states: + run_state = RunState(state, '', '') + self.assertTrue(run_state.is_terminal) + + def test_is_terminal_false(self): + non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING'] + for state in non_terminal_states: + run_state = RunState(state, '', '') + self.assertFalse(run_state.is_terminal) + + def test_is_terminal_with_nonexistent_life_cycle_state(self): + run_state = RunState('blah', '', '') + with self.assertRaises(AirflowException): + run_state.is_terminal + + def test_is_successful(self): + run_state = RunState('TERMINATED', 'SUCCESS', '') + self.assertTrue(run_state.is_successful) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/tests/contrib/operators/databricks_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/databricks_operator.py b/tests/contrib/operators/databricks_operator.py new file mode 100644 index 0000000..aab47fa --- /dev/null +++ b/tests/contrib/operators/databricks_operator.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from airflow.contrib.hooks.databricks_hook import RunState +from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator +from airflow.exceptions import AirflowException + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +TASK_ID = 'databricks-operator' +DEFAULT_CONN_ID = 'databricks_default' +NOTEBOOK_TASK = { + 'notebook_path': '/test' +} +SPARK_JAR_TASK = { + 'main_class_name': 'com.databricks.Test' +} +NEW_CLUSTER = { + 'spark_version': '2.0.x-scala2.10', + 'node_type_id': 'development-node', + 'num_workers': 1 +} +EXISTING_CLUSTER_ID = 'existing-cluster-id' +RUN_NAME = 'run-name' +RUN_ID = 1 + + +class DatabricksSubmitRunOperatorTest(unittest.TestCase): + def test_init_with_named_parameters(self): + """ + Test the initializer with the named parameters. + """ + op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK) + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID + } + self.assertDictEqual(expected, op.json) + + def test_init_with_json(self): + """ + Test the initializer with json data. + """ + json = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID + } + self.assertDictEqual(expected, op.json) + + def test_init_with_specified_run_name(self): + """ + Test the initializer with a specified run_name. + """ + json = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': RUN_NAME + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': RUN_NAME + } + self.assertDictEqual(expected, op.json) + + def test_init_with_merging(self): + """ + Test the initializer when json and other named parameters are both + provided. The named parameters should override top level keys in the + json dict. + """ + override_new_cluster = {'workers': 999} + json = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster) + expected = { + 'new_cluster': override_new_cluster, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID, + } + self.assertDictEqual(expected, op.json) + + @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') + def test_exec_success(self, db_mock_class): + """ + Test the execute function in case where the run is successful. + """ + run = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = 1 + db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + + op.execute(None) + + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID + } + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit) + db_mock.submit_run.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run_state.assert_called_once_with(RUN_ID) + self.assertEquals(RUN_ID, op.run_id) + + @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') + def test_exec_failure(self, db_mock_class): + """ + Test the execute function in case where the run failed. + """ + run = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = 1 + db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '') + + with self.assertRaises(AirflowException): + op.execute(None) + + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID, + } + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit) + db_mock.submit_run.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run_state.assert_called_once_with(RUN_ID) + self.assertEquals(RUN_ID, op.run_id) + + @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') + def test_on_kill(self, db_mock_class): + run = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) + db_mock = db_mock_class.return_value + op.run_id = RUN_ID + + op.on_kill() + + db_mock.cancel_run.assert_called_once_with(RUN_ID) +
