This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6150d28323 Add Deferrable Databricks operators (#19736)
6150d28323 is described below
commit 6150d283234b48f86362fd4da856e282dd91ebb4
Author: Eugene Karimov <[email protected]>
AuthorDate: Sun May 22 16:22:49 2022 +0200
Add Deferrable Databricks operators (#19736)
---
airflow/providers/databricks/hooks/databricks.py | 29 ++
.../providers/databricks/hooks/databricks_base.py | 237 +++++++++++++-
.../providers/databricks/operators/databricks.py | 106 +++++--
airflow/providers/databricks/triggers/__init__.py | 17 +
.../providers/databricks/triggers/databricks.py | 77 +++++
airflow/providers/databricks/utils/__init__.py | 16 +
airflow/providers/databricks/utils/databricks.py | 69 ++++
.../operators/run_now.rst | 7 +
.../operators/submit_run.rst | 7 +
setup.py | 2 +
.../providers/databricks/hooks/test_databricks.py | 352 ++++++++++++++++++++-
.../databricks/operators/test_databricks.py | 239 +++++++++++---
tests/providers/databricks/triggers/__init__.py | 17 +
.../databricks/triggers/test_databricks.py | 153 +++++++++
tests/providers/databricks/utils/__init__.py | 16 +
tests/providers/databricks/utils/databricks.py | 62 ++++
16 files changed, 1313 insertions(+), 93 deletions(-)
diff --git a/airflow/providers/databricks/hooks/databricks.py
b/airflow/providers/databricks/hooks/databricks.py
index 7911660412..400bbe8955 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -25,6 +25,7 @@ operators talk to the
or the ``api/2.1/jobs/runs/submit``
`endpoint
<https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit>`_.
"""
+import json
from typing import Any, Dict, List, Optional
from requests import exceptions as requests_exceptions
@@ -92,6 +93,13 @@ class RunState:
def __repr__(self) -> str:
return str(self.__dict__)
+ def to_json(self) -> str:
+ return json.dumps(self.__dict__)
+
+ @classmethod
+ def from_json(cls, data: str) -> 'RunState':
+ return RunState(**json.loads(data))
+
class DatabricksHook(BaseDatabricksHook):
"""
@@ -198,6 +206,16 @@ class DatabricksHook(BaseDatabricksHook):
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response['run_page_url']
+ async def a_get_run_page_url(self, run_id: int) -> str:
+ """
+ Async version of `get_run_page_url()`.
+ :param run_id: id of the run
+ :return: URL of the run page
+ """
+ json = {'run_id': run_id}
+ response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
+ return response['run_page_url']
+
def get_job_id(self, run_id: int) -> int:
"""
Retrieves job_id from run_id.
@@ -229,6 +247,17 @@ class DatabricksHook(BaseDatabricksHook):
state = response['state']
return RunState(**state)
+ async def a_get_run_state(self, run_id: int) -> RunState:
+ """
+ Async version of `get_run_state()`.
+ :param run_id: id of the run
+ :return: state of the run
+ """
+ json = {'run_id': run_id}
+ response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
+ state = response['state']
+ return RunState(**state)
+
def get_run_state_str(self, run_id: int) -> str:
"""
Return the string representation of RunState.
diff --git a/airflow/providers/databricks/hooks/databricks_base.py
b/airflow/providers/databricks/hooks/databricks_base.py
index 6e0f1b44d8..5b18dad930 100644
--- a/airflow/providers/databricks/hooks/databricks_base.py
+++ b/airflow/providers/databricks/hooks/databricks_base.py
@@ -28,11 +28,19 @@ import time
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse
+import aiohttp
import requests
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
from requests.exceptions import JSONDecodeError
-from tenacity import RetryError, Retrying, retry_if_exception,
stop_after_attempt, wait_exponential
+from tenacity import (
+ AsyncRetrying,
+ RetryError,
+ Retrying,
+ retry_if_exception,
+ stop_after_attempt,
+ wait_exponential,
+)
from airflow import __version__
from airflow.exceptions import AirflowException
@@ -135,6 +143,14 @@ class BaseDatabricksHook(BaseHook):
return host
+ async def __aenter__(self):
+ self._session = aiohttp.ClientSession()
+ return self
+
+ async def __aexit__(self, *err):
+ await self._session.close()
+ self._session = None
+
@staticmethod
def _parse_host(host: str) -> str:
"""
@@ -169,6 +185,13 @@ class BaseDatabricksHook(BaseHook):
"""
return Retrying(**self.retry_args)
+ def _a_get_retry_object(self) -> AsyncRetrying:
+ """
+ Instantiates an async retry object
+ :return: instance of AsyncRetrying class
+ """
+ return AsyncRetrying(**self.retry_args)
+
def _get_aad_token(self, resource: str) -> str:
"""
Function to get AAD token for given resource. Supports managed
identity or service principal auth
@@ -234,6 +257,72 @@ class BaseDatabricksHook(BaseHook):
return token
+ async def _a_get_aad_token(self, resource: str) -> str:
+ """
+ Async version of `_get_aad_token()`.
+ :param resource: resource to issue token to
+ :return: AAD token, or raise an exception
+ """
+ aad_token = self.aad_tokens.get(resource)
+ if aad_token and self._is_aad_token_valid(aad_token):
+ return aad_token['token']
+
+ self.log.info('Existing AAD token is expired, or going to expire soon.
Refreshing...')
+ try:
+ async for attempt in self._a_get_retry_object():
+ with attempt:
+ if
self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
+ params = {
+ "api-version": "2018-02-01",
+ "resource": resource,
+ }
+ async with self._session.get(
+ url=AZURE_METADATA_SERVICE_TOKEN_URL,
+ params=params,
+ headers={**USER_AGENT_HEADER, "Metadata": "true"},
+ timeout=self.aad_timeout_seconds,
+ ) as resp:
+ resp.raise_for_status()
+ jsn = await resp.json()
+ else:
+ tenant_id =
self.databricks_conn.extra_dejson['azure_tenant_id']
+ data = {
+ "grant_type": "client_credentials",
+ "client_id": self.databricks_conn.login,
+ "resource": resource,
+ "client_secret": self.databricks_conn.password,
+ }
+ azure_ad_endpoint =
self.databricks_conn.extra_dejson.get(
+ "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
+ )
+ async with self._session.post(
+
url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
+ data=data,
+ headers={
+ **USER_AGENT_HEADER,
+ 'Content-Type':
'application/x-www-form-urlencoded',
+ },
+ timeout=self.aad_timeout_seconds,
+ ) as resp:
+ resp.raise_for_status()
+ jsn = await resp.json()
+ if (
+ 'access_token' not in jsn
+ or jsn.get('token_type') != 'Bearer'
+ or 'expires_on' not in jsn
+ ):
+ raise AirflowException(f"Can't get necessary data from
AAD token: {jsn}")
+
+ token = jsn['access_token']
+ self.aad_tokens[resource] = {'token': token, 'expires_on':
int(jsn["expires_on"])}
+ break
+ except RetryError:
+ raise AirflowException(f'API requests to Azure failed
{self.retry_limit} times. Giving up.')
+ except aiohttp.ClientResponseError as err:
+ raise AirflowException(f'Response: {err.message}, Status Code:
{err.status}')
+
+ return token
+
def _get_aad_headers(self) -> dict:
"""
Fills AAD headers if necessary (SPN is outside of the workspace)
@@ -248,6 +337,20 @@ class BaseDatabricksHook(BaseHook):
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
return headers
+ async def _a_get_aad_headers(self) -> dict:
+ """
+ Async version of `_get_aad_headers()`.
+ :return: dictionary with filled AAD headers
+ """
+ headers = {}
+ if 'azure_resource_id' in self.databricks_conn.extra_dejson:
+ mgmt_token = await self._a_get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
+ headers['X-Databricks-Azure-Workspace-Resource-Id'] =
self.databricks_conn.extra_dejson[
+ 'azure_resource_id'
+ ]
+ headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
+ return headers
+
@staticmethod
def _is_aad_token_valid(aad_token: dict) -> bool:
"""
@@ -281,6 +384,23 @@ class BaseDatabricksHook(BaseHook):
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
+ async def _a_check_azure_metadata_service(self):
+ """Async version of `_check_azure_metadata_service()`."""
+ try:
+ async with self._session.get(
+ url=AZURE_METADATA_SERVICE_INSTANCE_URL,
+ params={"api-version": "2021-02-01"},
+ headers={"Metadata": "true"},
+ timeout=2,
+ ) as resp:
+ jsn = await resp.json()
+ if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
+ raise AirflowException(
+ f"Was able to fetch some metadata, but it doesn't look
like Azure Metadata: {jsn}"
+ )
+ except (requests_exceptions.RequestException, ValueError) as e:
+ raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
+
def _get_token(self, raise_error: bool = False) -> Optional[str]:
if 'token' in self.databricks_conn.extra_dejson:
self.log.info(
@@ -304,6 +424,29 @@ class BaseDatabricksHook(BaseHook):
return None
+ async def _a_get_token(self, raise_error: bool = False) -> Optional[str]:
+ if 'token' in self.databricks_conn.extra_dejson:
+ self.log.info(
+ 'Using token auth. For security reasons, please set token in
Password field instead of extra'
+ )
+ return self.databricks_conn.extra_dejson["token"]
+ elif not self.databricks_conn.login and self.databricks_conn.password:
+ self.log.info('Using token auth.')
+ return self.databricks_conn.password
+ elif 'azure_tenant_id' in self.databricks_conn.extra_dejson:
+ if self.databricks_conn.login == "" or
self.databricks_conn.password == "":
+ raise AirflowException("Azure SPN credentials aren't provided")
+ self.log.info('Using AAD Token for SPN.')
+ return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
+ elif
self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
+ self.log.info('Using AAD Token for managed identity.')
+ await self._a_check_azure_metadata_service()
+ return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
+ elif raise_error:
+ raise AirflowException('Token authentication isn\'t configured')
+
+ return None
+
def _log_request_error(self, attempt_num: int, error: str) -> None:
self.log.error('Attempt %s API Request to Databricks failed with
reason: %s', attempt_num, error)
@@ -374,6 +517,55 @@ class BaseDatabricksHook(BaseHook):
else:
raise e
+ async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json:
Optional[Dict[str, Any]] = None):
+ """
+ Async version of `_do_api_call()`.
+ :param endpoint_info: Tuple of method and endpoint
+ :param json: Parameters for this API call.
+ :return: If the api call returns a OK status code,
+ this function returns the response in JSON. Otherwise, throw an
AirflowException.
+ """
+ method, endpoint = endpoint_info
+
+ url = f'https://{self.host}/{endpoint}'
+
+ aad_headers = await self._a_get_aad_headers()
+ headers = {**USER_AGENT_HEADER.copy(), **aad_headers}
+
+ auth: aiohttp.BasicAuth
+ token = await self._a_get_token()
+ if token:
+ auth = BearerAuth(token)
+ else:
+ self.log.info('Using basic auth.')
+ auth = aiohttp.BasicAuth(self.databricks_conn.login,
self.databricks_conn.password)
+
+ request_func: Any
+ if method == 'GET':
+ request_func = self._session.get
+ elif method == 'POST':
+ request_func = self._session.post
+ elif method == 'PATCH':
+ request_func = self._session.patch
+ else:
+ raise AirflowException('Unexpected HTTP Method: ' + method)
+ try:
+ async for attempt in self._a_get_retry_object():
+ with attempt:
+ async with request_func(
+ url,
+ json=json,
+ auth=auth,
+ headers={**headers, **USER_AGENT_HEADER},
+ timeout=self.timeout_seconds,
+ ) as response:
+ response.raise_for_status()
+ return await response.json()
+ except RetryError:
+ raise AirflowException(f'API requests to Databricks failed
{self.retry_limit} times. Giving up.')
+ except aiohttp.ClientResponseError as err:
+ raise AirflowException(f'Response: {err.message}, Status Code:
{err.status}')
+
@staticmethod
def _get_error_code(exception: BaseException) -> str:
if isinstance(exception, requests_exceptions.HTTPError):
@@ -387,19 +579,25 @@ class BaseDatabricksHook(BaseHook):
@staticmethod
def _retryable_error(exception: BaseException) -> bool:
- if not isinstance(exception, requests_exceptions.RequestException):
- return False
- return isinstance(exception, (requests_exceptions.ConnectionError,
requests_exceptions.Timeout)) or (
- exception.response is not None
- and (
- exception.response.status_code >= 500
- or exception.response.status_code == 429
- or (
- exception.response.status_code == 400
- and BaseDatabricksHook._get_error_code(exception) ==
'COULD_NOT_ACQUIRE_LOCK'
+ if isinstance(exception, requests_exceptions.RequestException):
+ if isinstance(exception, (requests_exceptions.ConnectionError,
requests_exceptions.Timeout)) or (
+ exception.response is not None
+ and (
+ exception.response.status_code >= 500
+ or exception.response.status_code == 429
+ or (
+ exception.response.status_code == 400
+ and BaseDatabricksHook._get_error_code(exception) ==
'COULD_NOT_ACQUIRE_LOCK'
+ )
)
- )
- )
+ ):
+ return True
+
+ if isinstance(exception, aiohttp.ClientResponseError):
+ if exception.status >= 500 or exception.status == 429:
+ return True
+
+ return False
class _TokenAuth(AuthBase):
@@ -414,3 +612,16 @@ class _TokenAuth(AuthBase):
def __call__(self, r: PreparedRequest) -> PreparedRequest:
r.headers['Authorization'] = 'Bearer ' + self.token
return r
+
+
+class BearerAuth(aiohttp.BasicAuth):
+ """aiohttp only ships BasicAuth, for Bearer auth we need a subclass of
BasicAuth."""
+
+ def __new__(cls, token: str) -> 'BearerAuth':
+ return super().__new__(cls, token) # type: ignore
+
+ def __init__(self, token: str) -> None:
+ self.token = token
+
+ def encode(self) -> str:
+ return f'Bearer {self.token}'
diff --git a/airflow/providers/databricks/operators/databricks.py
b/airflow/providers/databricks/operators/databricks.py
index 577e59a2c2..8af4474b13 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -19,51 +19,24 @@
"""This module contains Databricks operators."""
import time
+from logging import Logger
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
-from airflow.providers.databricks.hooks.databricks import DatabricksHook
+from airflow.providers.databricks.hooks.databricks import DatabricksHook,
RunState
+from airflow.providers.databricks.triggers.databricks import
DatabricksExecutionTrigger
+from airflow.providers.databricks.utils.databricks import deep_string_coerce,
validate_trigger_event
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.context import Context
+DEFER_METHOD_NAME = 'execute_complete'
XCOM_RUN_ID_KEY = 'run_id'
XCOM_RUN_PAGE_URL_KEY = 'run_page_url'
-def _deep_string_coerce(content, json_path: str = 'json') -> Union[str, list,
dict]:
- """
- Coerces content or all values of content if it is a dict to a string. The
- function will throw if content contains non-string or non-numeric types.
-
- The reason why we have this function is because the ``self.json`` field
must be a
- dict with only string values. This is because ``render_template`` will fail
- for numerical values.
- """
- coerce = _deep_string_coerce
- if isinstance(content, str):
- return content
- elif isinstance(
- content,
- (
- int,
- float,
- ),
- ):
- # Databricks can tolerate either numeric or string types in the API
backend.
- return str(content)
- elif isinstance(content, (list, tuple)):
- return [coerce(e, f'{json_path}[{i}]') for i, e in enumerate(content)]
- elif isinstance(content, dict):
- return {k: coerce(v, f'{json_path}[{k}]') for k, v in
list(content.items())}
- else:
- param_type = type(content)
- msg = f'Type {param_type} used for parameter {json_path} is not a
number or a string'
- raise AirflowException(msg)
-
-
def _handle_databricks_operator_execution(operator, hook, log, context) ->
None:
"""
Handles the Airflow + Databricks lifecycle logic for a Databricks operator
@@ -103,6 +76,47 @@ def _handle_databricks_operator_execution(operator, hook,
log, context) -> None:
log.info('View run status, Spark UI, and logs at %s', run_page_url)
+def _handle_deferrable_databricks_operator_execution(operator, hook, log,
context) -> None:
+ """
+ Handles the Airflow + Databricks lifecycle logic for deferrable Databricks
operators
+
+ :param operator: Databricks async operator being handled
+ :param context: Airflow context
+ """
+ if operator.do_xcom_push and context is not None:
+ context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
+ log.info(f'Run submitted with run_id: {operator.run_id}')
+
+ run_page_url = hook.get_run_page_url(operator.run_id)
+ if operator.do_xcom_push and context is not None:
+ context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
+ log.info(f'View run status, Spark UI, and logs at {run_page_url}')
+
+ if operator.wait_for_termination:
+ operator.defer(
+ trigger=DatabricksExecutionTrigger(
+ run_id=operator.run_id,
+ databricks_conn_id=operator.databricks_conn_id,
+ polling_period_seconds=operator.polling_period_seconds,
+ ),
+ method_name=DEFER_METHOD_NAME,
+ )
+
+
+def _handle_deferrable_databricks_operator_completion(event: dict, log:
Logger) -> None:
+ validate_trigger_event(event)
+ run_state = RunState.from_json(event['run_state'])
+ run_page_url = event['run_page_url']
+ log.info(f'View run status, Spark UI, and logs at {run_page_url}')
+
+ if run_state.is_successful:
+ log.info('Job run completed successfully.')
+ return
+ else:
+ error_message = f'Job run failed with terminal state: {run_state}'
+ raise AirflowException(error_message)
+
+
class DatabricksJobRunLink(BaseOperatorLink):
"""Constructs a link to monitor a Databricks Job Run."""
@@ -364,7 +378,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
if git_source is not None:
self.json['git_source'] = git_source
- self.json = _deep_string_coerce(self.json)
+ self.json = deep_string_coerce(self.json)
# This variable will be used in case our task gets killed.
self.run_id: Optional[int] = None
self.do_xcom_push = do_xcom_push
@@ -393,6 +407,18 @@ class DatabricksSubmitRunOperator(BaseOperator):
self.log.error('Error: Task: %s with invalid run_id was requested
to be cancelled.', self.task_id)
+class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
+ """Deferrable version of ``DatabricksSubmitRunOperator``"""
+
+ def execute(self, context):
+ hook = self._get_hook()
+ self.run_id = hook.submit_run(self.json)
+ _handle_deferrable_databricks_operator_execution(self, hook, self.log,
context)
+
+ def execute_complete(self, context: Optional[dict], event: dict):
+ _handle_deferrable_databricks_operator_completion(event, self.log)
+
+
class DatabricksRunNowOperator(BaseOperator):
"""
Runs an existing Spark job run to Databricks using the
@@ -604,7 +630,7 @@ class DatabricksRunNowOperator(BaseOperator):
if idempotency_token is not None:
self.json['idempotency_token'] = idempotency_token
- self.json = _deep_string_coerce(self.json)
+ self.json = deep_string_coerce(self.json)
# This variable will be used in case our task gets killed.
self.run_id: Optional[int] = None
self.do_xcom_push = do_xcom_push
@@ -637,3 +663,15 @@ class DatabricksRunNowOperator(BaseOperator):
)
else:
self.log.error('Error: Task: %s with invalid run_id was requested
to be cancelled.', self.task_id)
+
+
+class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
+ """Deferrable version of ``DatabricksRunNowOperator``"""
+
+ def execute(self, context):
+ hook = self._get_hook()
+ self.run_id = hook.run_now(self.json)
+ _handle_deferrable_databricks_operator_execution(self, hook, self.log,
context)
+
+ def execute_complete(self, context: Optional[dict], event: dict):
+ _handle_deferrable_databricks_operator_completion(event, self.log)
diff --git a/airflow/providers/databricks/triggers/__init__.py
b/airflow/providers/databricks/triggers/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/airflow/providers/databricks/triggers/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/databricks/triggers/databricks.py
b/airflow/providers/databricks/triggers/databricks.py
new file mode 100644
index 0000000000..5f50f5aff2
--- /dev/null
+++ b/airflow/providers/databricks/triggers/databricks.py
@@ -0,0 +1,77 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import asyncio
+import logging
+from typing import Any, Dict, Tuple
+
+from airflow.providers.databricks.hooks.databricks import DatabricksHook
+
+try:
+ from airflow.triggers.base import BaseTrigger, TriggerEvent
+except ImportError:
+ logging.getLogger(__name__).warning(
+ 'Deferrable Operators only work starting Airflow 2.2',
+ exc_info=True,
+ )
+ BaseTrigger = object # type: ignore
+ TriggerEvent = None # type: ignore
+
+
+class DatabricksExecutionTrigger(BaseTrigger):
+ """
+ The trigger handles the logic of async communication with DataBricks API.
+
+ :param run_id: id of the run
+ :param databricks_conn_id: Reference to the :ref:`Databricks connection
<howto/connection:databricks>`.
+ :param polling_period_seconds: Controls the rate of the poll for the
result of this run.
+ By default, the trigger will poll every 30 seconds.
+ """
+
+ def __init__(self, run_id: int, databricks_conn_id: str,
polling_period_seconds: int = 30) -> None:
+ super().__init__()
+ self.run_id = run_id
+ self.databricks_conn_id = databricks_conn_id
+ self.polling_period_seconds = polling_period_seconds
+ self.hook = DatabricksHook(databricks_conn_id)
+
+ def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ return (
+
'airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger',
+ {
+ 'run_id': self.run_id,
+ 'databricks_conn_id': self.databricks_conn_id,
+ 'polling_period_seconds': self.polling_period_seconds,
+ },
+ )
+
+ async def run(self):
+ async with self.hook:
+ run_page_url = await self.hook.a_get_run_page_url(self.run_id)
+ while True:
+ run_state = await self.hook.a_get_run_state(self.run_id)
+ if run_state.is_terminal:
+ yield TriggerEvent(
+ {
+ 'run_id': self.run_id,
+ 'run_state': run_state.to_json(),
+ 'run_page_url': run_page_url,
+ }
+ )
+ break
+ else:
+ await asyncio.sleep(self.polling_period_seconds)
diff --git a/airflow/providers/databricks/utils/__init__.py
b/airflow/providers/databricks/utils/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/databricks/utils/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/databricks/utils/databricks.py
b/airflow/providers/databricks/utils/databricks.py
new file mode 100644
index 0000000000..96935d8063
--- /dev/null
+++ b/airflow/providers/databricks/utils/databricks.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Union
+
+from airflow.exceptions import AirflowException
+from airflow.providers.databricks.hooks.databricks import RunState
+
+
+def deep_string_coerce(content, json_path: str = 'json') -> Union[str, list,
dict]:
+ """
+ Coerces content or all values of content if it is a dict to a string. The
+ function will throw if content contains non-string or non-numeric types.
+ The reason why we have this function is because the ``self.json`` field
must be a
+ dict with only string values. This is because ``render_template`` will fail
+ for numerical values.
+ """
+ coerce = deep_string_coerce
+ if isinstance(content, str):
+ return content
+ elif isinstance(
+ content,
+ (
+ int,
+ float,
+ ),
+ ):
+ # Databricks can tolerate either numeric or string types in the API
backend.
+ return str(content)
+ elif isinstance(content, (list, tuple)):
+ return [coerce(e, f'{json_path}[{i}]') for i, e in enumerate(content)]
+ elif isinstance(content, dict):
+ return {k: coerce(v, f'{json_path}[{k}]') for k, v in
list(content.items())}
+ else:
+ param_type = type(content)
+ msg = f'Type {param_type} used for parameter {json_path} is not a
number or a string'
+ raise AirflowException(msg)
+
+
+def validate_trigger_event(event: dict):
+ """
+ Validates correctness of the event
+ received from
:class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger`
+ """
+ keys_to_check = ['run_id', 'run_page_url', 'run_state']
+ for key in keys_to_check:
+ if key not in event:
+ raise AirflowException(f'Could not find `{key}` in the event:
{event}')
+
+ try:
+ RunState.from_json(event['run_state'])
+ except Exception:
+ raise AirflowException(f'Run state returned by the Trigger is
incorrect: {event["run_state"]}')
diff --git a/docs/apache-airflow-providers-databricks/operators/run_now.rst
b/docs/apache-airflow-providers-databricks/operators/run_now.rst
index 8b2e6010ee..a4b00d9005 100644
--- a/docs/apache-airflow-providers-databricks/operators/run_now.rst
+++ b/docs/apache-airflow-providers-databricks/operators/run_now.rst
@@ -45,3 +45,10 @@ All other parameters are optional and described in
documentation for ``Databrick
* ``python_named_parameters``
* ``jar_params``
* ``spark_submit_params``
+
+DatabricksRunNowDeferrableOperator
+==================================
+
+Deferrable version of the
:class:`~airflow.providers.databricks.operators.DatabricksRunNowOperator`
operator.
+
+It allows to utilize Airflow workers more effectively using `new functionality
introduced in Airflow 2.2.0
<https://airflow.apache.org/docs/apache-airflow/2.2.0/concepts/deferring.html#triggering-deferral>`_
diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
index da71194da7..81f9dfd32f 100644
--- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
+++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
@@ -75,3 +75,10 @@ You can also use named parameters to initialize the operator
and run the job.
:language: python
:start-after: [START howto_operator_databricks_named]
:end-before: [END howto_operator_databricks_named]
+
+DatabricksSubmitRunDeferrableOperator
+=====================================
+
+Deferrable version of the
:class:`~airflow.providers.databricks.operators.DatabricksSubmitRunOperator`
operator.
+
+It allows to utilize Airflow workers more effectively using `new functionality
introduced in Airflow 2.2.0
<https://airflow.apache.org/docs/apache-airflow/2.2.0/concepts/deferring.html#triggering-deferral>`_
diff --git a/setup.py b/setup.py
index 742ba57aa5..e7d5951694 100644
--- a/setup.py
+++ b/setup.py
@@ -268,6 +268,7 @@ dask = [
databricks = [
'requests>=2.26.0, <3',
'databricks-sql-connector>=2.0.0, <3.0.0',
+ 'aiohttp>=3.6.3, <4',
]
datadog = [
'datadog>=0.14.0',
@@ -609,6 +610,7 @@ mypy_dependencies = [
# Dependencies needed for development only
devel_only = [
+ 'asynctest~=0.13',
'aws_xray_sdk',
'beautifulsoup4>=4.7.1',
'black',
diff --git a/tests/providers/databricks/hooks/test_databricks.py
b/tests/providers/databricks/hooks/test_databricks.py
index 5a93ed7b42..2997984f7b 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -19,10 +19,11 @@
import itertools
import json
+import sys
import time
import unittest
-from unittest import mock
+import aiohttp
import pytest
import tenacity
from requests import exceptions as requests_exceptions
@@ -31,7 +32,12 @@ from requests.auth import HTTPBasicAuth
from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.models import Connection
-from airflow.providers.databricks.hooks.databricks import SUBMIT_RUN_ENDPOINT,
DatabricksHook, RunState
+from airflow.providers.databricks.hooks.databricks import (
+ GET_RUN_ENDPOINT,
+ SUBMIT_RUN_ENDPOINT,
+ DatabricksHook,
+ RunState,
+)
from airflow.providers.databricks.hooks.databricks_base import (
AZURE_DEFAULT_AD_ENDPOINT,
AZURE_MANAGEMENT_ENDPOINT,
@@ -39,9 +45,17 @@ from airflow.providers.databricks.hooks.databricks_base
import (
AZURE_TOKEN_SERVICE_URL,
DEFAULT_DATABRICKS_SCOPE,
TOKEN_REFRESH_LEAD_TIME,
+ BearerAuth,
)
from airflow.utils.session import provide_session
+if sys.version_info < (3, 8):
+ from asynctest import mock
+ from asynctest.mock import CoroutineMock as AsyncMock
+else:
+ from unittest import mock
+ from unittest.mock import AsyncMock
+
TASK_ID = 'databricks-operator'
DEFAULT_CONN_ID = 'databricks_default'
NOTEBOOK_TASK = {'notebook_path': '/test'}
@@ -172,6 +186,7 @@ def list_jobs_endpoint(host):
def create_valid_response_mock(content):
response = mock.MagicMock()
response.json.return_value = content
+ response.__aenter__.return_value.json = AsyncMock(return_value=content)
return response
@@ -785,6 +800,18 @@ class TestRunState(unittest.TestCase):
run_state = RunState('TERMINATED', 'SUCCESS', '')
assert run_state.is_successful
+ def test_to_json(self):
+ run_state = RunState('TERMINATED', 'SUCCESS', '')
+ expected = json.dumps(
+ {'life_cycle_state': 'TERMINATED', 'result_state': 'SUCCESS',
'state_message': ''}
+ )
+ assert expected == run_state.to_json()
+
+ def test_from_json(self):
+ state = {'life_cycle_state': 'TERMINATED', 'result_state': 'SUCCESS',
'state_message': ''}
+ expected = RunState('TERMINATED', 'SUCCESS', '')
+ assert expected == RunState.from_json(json.dumps(state))
+
def create_aad_token_for_resource(resource: str) -> dict:
return {
@@ -976,3 +1003,324 @@ class
TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase):
args = mock_requests.post.call_args
kwargs = args[1]
assert kwargs['auth'].token == TOKEN
+
+
+class TestDatabricksHookAsyncMethods:
+ """
+ Tests for async functionality of DatabricksHook.
+ """
+
+ @provide_session
+ def setup_method(self, method, session=None):
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.host = HOST
+ conn.login = LOGIN
+ conn.password = PASSWORD
+ conn.extra = None
+ session.commit()
+
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+ @pytest.mark.asyncio
+ async def test_init_async_session(self):
+ async with self.hook:
+ assert isinstance(self.hook._session, aiohttp.ClientSession)
+ assert self.hook._session is None
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+ async def test_do_api_call_retries_with_retryable_error(self, mock_get):
+ mock_get.side_effect = aiohttp.ClientResponseError(None, None,
status=500)
+ with mock.patch.object(self.hook.log, 'error') as mock_errors:
+ async with self.hook:
+ with pytest.raises(AirflowException):
+ await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
+ assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+ async def test_do_api_call_does_not_retry_with_non_retryable_error(self,
mock_get):
+ mock_get.side_effect = aiohttp.ClientResponseError(None, None,
status=400)
+ with mock.patch.object(self.hook.log, 'error') as mock_errors:
+ async with self.hook:
+ with pytest.raises(AirflowException):
+ await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
+ mock_errors.assert_not_called()
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+ async def test_do_api_call_succeeds_after_retrying(self, mock_get):
+ mock_get.side_effect = [
+ aiohttp.ClientResponseError(None, None, status=500),
+ create_valid_response_mock({'run_id': '1'}),
+ ]
+ with mock.patch.object(self.hook.log, 'error') as mock_errors:
+ async with self.hook:
+ response = await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
+ assert mock_errors.call_count == 1
+ assert response == {'run_id': '1'}
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+ async def test_do_api_call_waits_between_retries(self, mock_get):
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+ mock_get.side_effect = aiohttp.ClientResponseError(None, None,
status=500)
+ with mock.patch.object(self.hook.log, 'error') as mock_errors:
+ async with self.hook:
+ with pytest.raises(AirflowException):
+ await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
+ assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.patch')
+ async def test_do_api_call_patch(self, mock_patch):
+ mock_patch.return_value.__aenter__.return_value.json = AsyncMock(
+ return_value={'cluster_name': 'new_name'}
+ )
+ data = {'cluster_name': 'new_name'}
+ async with self.hook:
+ patched_cluster_name = await self.hook._a_do_api_call(('PATCH',
'api/2.1/jobs/runs/submit'), data)
+
+ assert patched_cluster_name['cluster_name'] == 'new_name'
+ mock_patch.assert_called_once_with(
+ submit_run_endpoint(HOST),
+ json={'cluster_name': 'new_name'},
+ auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds,
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+ async def test_get_run_page_url(self, mock_get):
+ mock_get.return_value.__aenter__.return_value.json =
AsyncMock(return_value=GET_RUN_RESPONSE)
+ async with self.hook:
+ run_page_url = await self.hook.a_get_run_page_url(RUN_ID)
+
+ assert run_page_url == RUN_PAGE_URL
+ mock_get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds,
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+ async def test_get_run_state(self, mock_get):
+ mock_get.return_value.__aenter__.return_value.json =
AsyncMock(return_value=GET_RUN_RESPONSE)
+
+ async with self.hook:
+ run_state = await self.hook.a_get_run_state(RUN_ID)
+
+ assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE,
STATE_MESSAGE)
+ mock_get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds,
+ )
+
+
+class TestDatabricksHookAsyncAadToken:
+ """
+ Tests for DatabricksHook using async methods when
+ auth is done with AAD token for SP as user inside workspace.
+ """
+
+ @provide_session
+ def setup_method(self, method, session=None):
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.login = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d'
+ conn.password = 'secret'
+ conn.extra = json.dumps(
+ {
+ 'host': HOST,
+ 'azure_tenant_id': '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d',
+ }
+ )
+ session.commit()
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post')
+ async def test_get_run_state(self, mock_post, mock_get):
+ mock_post.return_value.__aenter__.return_value.json = AsyncMock(
+
return_value=create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)
+ )
+ mock_get.return_value.__aenter__.return_value.json =
AsyncMock(return_value=GET_RUN_RESPONSE)
+
+ async with self.hook:
+ run_state = await self.hook.a_get_run_state(RUN_ID)
+
+ assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE,
STATE_MESSAGE)
+ mock_get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=BearerAuth(TOKEN),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds,
+ )
+
+
+class TestDatabricksHookAsyncAadTokenOtherClouds:
+ """
+ Tests for DatabricksHook using async methodswhen auth is done with AAD
token
+ for SP as user inside workspace and using non-global Azure cloud (China,
GovCloud, Germany)
+ """
+
+ @provide_session
+ def setup_method(self, method, session=None):
+ self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d'
+ self.ad_endpoint = 'https://login.microsoftonline.de'
+ self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d'
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.login = self.client_id
+ conn.password = 'secret'
+ conn.extra = json.dumps(
+ {
+ 'host': HOST,
+ 'azure_tenant_id': self.tenant_id,
+ 'azure_ad_endpoint': self.ad_endpoint,
+ }
+ )
+ session.commit()
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post')
+ async def test_get_run_state(self, mock_post, mock_get):
+ mock_post.return_value.__aenter__.return_value.json = AsyncMock(
+
return_value=create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)
+ )
+ mock_get.return_value.__aenter__.return_value.json =
AsyncMock(return_value=GET_RUN_RESPONSE)
+
+ async with self.hook:
+ run_state = await self.hook.a_get_run_state(RUN_ID)
+
+ assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE,
STATE_MESSAGE)
+
+ ad_call_args = mock_post.call_args_list[0]
+ assert ad_call_args[1]['url'] ==
AZURE_TOKEN_SERVICE_URL.format(self.ad_endpoint, self.tenant_id)
+ assert ad_call_args[1]['data']['client_id'] == self.client_id
+ assert ad_call_args[1]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE
+
+ mock_get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=BearerAuth(TOKEN),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds,
+ )
+
+
+class TestDatabricksHookAsyncAadTokenSpOutside:
+ """
+ Tests for DatabricksHook using async methods when auth is done with AAD
token for SP outside of workspace.
+ """
+
+ @provide_session
+ def setup_method(self, method, session=None):
+ self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d'
+ self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d'
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.login = self.client_id
+ conn.password = 'secret'
+ conn.host = HOST
+ conn.extra = json.dumps(
+ {
+ 'azure_resource_id': '/Some/resource',
+ 'azure_tenant_id': self.tenant_id,
+ }
+ )
+ session.commit()
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post')
+ async def test_get_run_state(self, mock_post, mock_get):
+ mock_post.return_value.__aenter__.return_value.json.side_effect =
AsyncMock(
+ side_effect=[
+ create_aad_token_for_resource(AZURE_MANAGEMENT_ENDPOINT),
+ create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE),
+ ]
+ )
+ mock_get.return_value.__aenter__.return_value.json =
AsyncMock(return_value=GET_RUN_RESPONSE)
+
+ async with self.hook:
+ run_state = await self.hook.a_get_run_state(RUN_ID)
+
+ assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE,
STATE_MESSAGE)
+
+ ad_call_args = mock_post.call_args_list[0]
+ assert ad_call_args[1]['url'] == AZURE_TOKEN_SERVICE_URL.format(
+ AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id
+ )
+ assert ad_call_args[1]['data']['client_id'] == self.client_id
+ assert ad_call_args[1]['data']['resource'] == AZURE_MANAGEMENT_ENDPOINT
+
+ ad_call_args = mock_post.call_args_list[1]
+ assert ad_call_args[1]['url'] == AZURE_TOKEN_SERVICE_URL.format(
+ AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id
+ )
+ assert ad_call_args[1]['data']['client_id'] == self.client_id
+ assert ad_call_args[1]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE
+
+ mock_get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=BearerAuth(TOKEN),
+ headers={
+ **USER_AGENT_HEADER,
+ 'X-Databricks-Azure-Workspace-Resource-Id': '/Some/resource',
+ 'X-Databricks-Azure-SP-Management-Token': TOKEN,
+ },
+ timeout=self.hook.timeout_seconds,
+ )
+
+
+class TestDatabricksHookAsyncAadTokenManagedIdentity:
+ """
+ Tests for DatabricksHook using async methods when
+ auth is done with AAD leveraging Managed Identity authentication
+ """
+
+ @provide_session
+ def setup_method(self, method, session=None):
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.host = HOST
+ conn.extra = json.dumps(
+ {
+ 'use_azure_managed_identity': True,
+ }
+ )
+ session.commit()
+ session.commit()
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+ async def test_get_run_state(self, mock_get):
+ mock_get.return_value.__aenter__.return_value.json.side_effect =
AsyncMock(
+ side_effect=[
+ {'compute': {'azEnvironment': 'AZUREPUBLICCLOUD'}},
+ create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE),
+ GET_RUN_RESPONSE,
+ ]
+ )
+
+ async with self.hook:
+ run_state = await self.hook.a_get_run_state(RUN_ID)
+
+ assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE,
STATE_MESSAGE)
+
+ ad_call_args = mock_get.call_args_list[0]
+ assert ad_call_args[1]['url'] == AZURE_METADATA_SERVICE_INSTANCE_URL
+ assert ad_call_args[1]['params']['api-version'] > '2018-02-01'
+ assert ad_call_args[1]['headers']['Metadata'] == 'true'
diff --git a/tests/providers/databricks/operators/test_databricks.py
b/tests/providers/databricks/operators/test_databricks.py
index 99782234e0..895beb18e5 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -22,14 +22,17 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import DAG
from airflow.providers.databricks.hooks.databricks import RunState
-from airflow.providers.databricks.operators import databricks as
databricks_operator
from airflow.providers.databricks.operators.databricks import (
+ DatabricksRunNowDeferrableOperator,
DatabricksRunNowOperator,
+ DatabricksSubmitRunDeferrableOperator,
DatabricksSubmitRunOperator,
)
+from airflow.providers.databricks.triggers.databricks import
DatabricksExecutionTrigger
+from airflow.providers.databricks.utils import databricks as utils
DATE = '2017-04-20'
TASK_ID = 'databricks-operator'
@@ -46,6 +49,7 @@ NEW_CLUSTER = {'spark_version': '2.0.x-scala2.10',
'node_type_id': 'development-
EXISTING_CLUSTER_ID = 'existing-cluster-id'
RUN_NAME = 'run-name'
RUN_ID = 1
+RUN_PAGE_URL = 'run-page-url'
JOB_ID = "42"
JOB_NAME = "job-name"
NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider":
"1457570074236"}
@@ -56,26 +60,6 @@ PYTHON_PARAMS = ["john doe", "35"]
SPARK_SUBMIT_PARAMS = ["--class", "org.apache.spark.examples.SparkPi"]
-class TestDatabricksOperatorSharedFunctions(unittest.TestCase):
- def test_deep_string_coerce(self):
- test_json = {
- 'test_int': 1,
- 'test_float': 1.0,
- 'test_dict': {'key': 'value'},
- 'test_list': [1, 1.0, 'a', 'b'],
- 'test_tuple': (1, 1.0, 'a', 'b'),
- }
-
- expected = {
- 'test_int': '1',
- 'test_float': '1.0',
- 'test_dict': {'key': 'value'},
- 'test_list': ['1', '1.0', 'a', 'b'],
- 'test_tuple': ['1', '1.0', 'a', 'b'],
- }
- assert databricks_operator._deep_string_coerce(test_json) == expected
-
-
class TestDatabricksSubmitRunOperator(unittest.TestCase):
def test_init_with_notebook_task_named_parameters(self):
"""
@@ -84,7 +68,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER,
notebook_task=NOTEBOOK_TASK
)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': TASK_ID}
)
@@ -97,7 +81,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER,
spark_python_task=SPARK_PYTHON_TASK
)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'spark_python_task':
SPARK_PYTHON_TASK, 'run_name': TASK_ID}
)
@@ -110,7 +94,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER,
spark_submit_task=SPARK_SUBMIT_TASK
)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'spark_submit_task':
SPARK_SUBMIT_TASK, 'run_name': TASK_ID}
)
@@ -122,7 +106,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
"""
json = {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': TASK_ID}
)
assert expected == op.json
@@ -130,7 +114,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
def test_init_with_tasks(self):
tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task":
NOTEBOOK_TASK}]
op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks)
- expected = databricks_operator._deep_string_coerce({'run_name':
TASK_ID, "tasks": tasks})
+ expected = utils.deep_string_coerce({'run_name': TASK_ID, "tasks":
tasks})
assert expected == op.json
def test_init_with_specified_run_name(self):
@@ -139,7 +123,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
"""
json = {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': RUN_NAME}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': RUN_NAME}
)
assert expected == op.json
@@ -151,7 +135,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
pipeline_task = {"pipeline_id": "test-dlt"}
json = {'new_cluster': NEW_CLUSTER, 'run_name': RUN_NAME,
"pipeline_task": pipeline_task}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{'new_cluster': NEW_CLUSTER, "pipeline_task": pipeline_task,
'run_name': RUN_NAME}
)
assert expected == op.json
@@ -168,7 +152,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
'notebook_task': NOTEBOOK_TASK,
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json,
new_cluster=override_new_cluster)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'new_cluster': override_new_cluster,
'notebook_task': NOTEBOOK_TASK,
@@ -185,7 +169,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
dag = DAG('test', start_date=datetime.now())
op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json)
op.render_template_fields(context={'ds': DATE})
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'new_cluster': NEW_CLUSTER,
'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK,
@@ -238,7 +222,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
op.execute(None)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -270,7 +254,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
with pytest.raises(AirflowException):
op.execute(None)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'new_cluster': NEW_CLUSTER,
'notebook_task': NOTEBOOK_TASK,
@@ -317,7 +301,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
op.execute(None)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -345,7 +329,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
op.execute(None)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -360,13 +344,98 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
db_mock.get_run_state.assert_not_called()
+class TestDatabricksSubmitRunDeferrableOperator(unittest.TestCase):
+
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+ def test_execute_task_deferred(self, db_mock_class):
+ """
+ Test the execute function in case where the run is successful.
+ """
+ run = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS',
'')
+
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(None)
+ self.assertTrue(isinstance(exc.value.trigger,
DatabricksExecutionTrigger))
+ self.assertEqual(exc.value.method_name, 'execute_complete')
+
+ expected = utils.deep_string_coerce(
+ {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': TASK_ID}
+ )
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ )
+
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ self.assertEqual(RUN_ID, op.run_id)
+
+ def test_execute_complete_success(self):
+ """
+ Test `execute_complete` function in case the Trigger has returned a
successful completion event.
+ """
+ run = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ event = {
+ 'run_id': RUN_ID,
+ 'run_page_url': RUN_PAGE_URL,
+ 'run_state': RunState('TERMINATED', 'SUCCESS', '').to_json(),
+ }
+
+ op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+ self.assertIsNone(op.execute_complete(context=None, event=event))
+
+
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+ def test_execute_complete_failure(self, db_mock_class):
+ """
+ Test `execute_complete` function in case the Trigger has returned a
failure completion event.
+ """
+ run_state_failed = RunState('TERMINATED', 'FAILED', '')
+ run = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ event = {
+ 'run_id': RUN_ID,
+ 'run_page_url': RUN_PAGE_URL,
+ 'run_state': run_state_failed.to_json(),
+ }
+
+ op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+ with pytest.raises(AirflowException):
+ op.execute_complete(context=None, event=event)
+
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run_state.return_value = run_state_failed
+
+ with pytest.raises(AirflowException, match=f'Job run failed with
terminal state: {run_state_failed}'):
+ op.execute_complete(context=None, event=event)
+
+ def test_execute_complete_incorrect_event_validation_failure(self):
+ event = {}
+ op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID)
+ with pytest.raises(AirflowException):
+ op.execute_complete(context=None, event=event)
+
+
class TestDatabricksRunNowOperator(unittest.TestCase):
def test_init_with_named_parameters(self):
"""
Test the initializer with the named parameters.
"""
op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID)
- expected = databricks_operator._deep_string_coerce({'job_id': 42})
+ expected = utils.deep_string_coerce({'job_id': 42})
assert expected == op.json
@@ -383,7 +452,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
}
op = DatabricksRunNowOperator(task_id=TASK_ID, json=json)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'notebook_params': NOTEBOOK_PARAMS,
'jar_params': JAR_PARAMS,
@@ -415,7 +484,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
spark_submit_params=SPARK_SUBMIT_PARAMS,
)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'notebook_params': override_notebook_params,
'jar_params': override_jar_params,
@@ -433,7 +502,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
dag = DAG('test', start_date=datetime.now())
op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID,
json=json)
op.render_template_fields(context={'ds': DATE})
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'notebook_params': NOTEBOOK_PARAMS,
'jar_params': RENDERED_TEMPLATED_JAR_PARAMS,
@@ -465,7 +534,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
op.execute(None)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'notebook_params': NOTEBOOK_PARAMS,
'notebook_task': NOTEBOOK_TASK,
@@ -499,7 +568,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
with pytest.raises(AirflowException):
op.execute(None)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'notebook_params': NOTEBOOK_PARAMS,
'notebook_task': NOTEBOOK_TASK,
@@ -540,7 +609,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
op.execute(None)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'notebook_params': NOTEBOOK_PARAMS,
'notebook_task': NOTEBOOK_TASK,
@@ -570,7 +639,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
op.execute(None)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'notebook_params': NOTEBOOK_PARAMS,
'notebook_task': NOTEBOOK_TASK,
@@ -614,7 +683,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
op.execute(None)
- expected = databricks_operator._deep_string_coerce(
+ expected = utils.deep_string_coerce(
{
'notebook_params': NOTEBOOK_PARAMS,
'notebook_task': NOTEBOOK_TASK,
@@ -647,3 +716,85 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
op.execute(None)
db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)
+
+
+class TestDatabricksRunNowDeferrableOperator(unittest.TestCase):
+
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+ def test_execute_task_deferred(self, db_mock_class):
+ """
+ Test the execute function in case where the run is successful.
+ """
+ run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task':
NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+ op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID,
job_id=JOB_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.run_now.return_value = 1
+ db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS',
'')
+
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(None)
+ self.assertTrue(isinstance(exc.value.trigger,
DatabricksExecutionTrigger))
+ self.assertEqual(exc.value.method_name, 'execute_complete')
+
+ expected = utils.deep_string_coerce(
+ {
+ 'notebook_params': NOTEBOOK_PARAMS,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'jar_params': JAR_PARAMS,
+ 'job_id': JOB_ID,
+ }
+ )
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ )
+
+ db_mock.run_now.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ self.assertEqual(RUN_ID, op.run_id)
+
+ def test_execute_complete_success(self):
+ """
+ Test `execute_complete` function in case the Trigger has returned a
successful completion event.
+ """
+ run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task':
NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+ event = {
+ 'run_id': RUN_ID,
+ 'run_page_url': RUN_PAGE_URL,
+ 'run_state': RunState('TERMINATED', 'SUCCESS', '').to_json(),
+ }
+
+ op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID,
job_id=JOB_ID, json=run)
+ self.assertIsNone(op.execute_complete(context=None, event=event))
+
+
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+ def test_execute_complete_failure(self, db_mock_class):
+ """
+ Test `execute_complete` function in case the Trigger has returned a
failure completion event.
+ """
+ run_state_failed = RunState('TERMINATED', 'FAILED', '')
+ run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task':
NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+ event = {
+ 'run_id': RUN_ID,
+ 'run_page_url': RUN_PAGE_URL,
+ 'run_state': run_state_failed.to_json(),
+ }
+
+ op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID,
job_id=JOB_ID, json=run)
+ with pytest.raises(AirflowException):
+ op.execute_complete(context=None, event=event)
+
+ db_mock = db_mock_class.return_value
+ db_mock.run_now.return_value = 1
+ db_mock.get_run_state.return_value = run_state_failed
+
+ with pytest.raises(AirflowException, match=f'Job run failed with
terminal state: {run_state_failed}'):
+ op.execute_complete(context=None, event=event)
+
+ def test_execute_complete_incorrect_event_validation_failure(self):
+ event = {}
+ op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID)
+ with pytest.raises(AirflowException):
+ op.execute_complete(context=None, event=event)
diff --git a/tests/providers/databricks/triggers/__init__.py
b/tests/providers/databricks/triggers/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/tests/providers/databricks/triggers/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/databricks/triggers/test_databricks.py
b/tests/providers/databricks/triggers/test_databricks.py
new file mode 100644
index 0000000000..cecbed1138
--- /dev/null
+++ b/tests/providers/databricks/triggers/test_databricks.py
@@ -0,0 +1,153 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import sys
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.databricks.hooks.databricks import RunState
+from airflow.providers.databricks.triggers.databricks import
DatabricksExecutionTrigger
+from airflow.triggers.base import TriggerEvent
+from airflow.utils.session import provide_session
+
+if sys.version_info < (3, 8):
+ from asynctest import mock
+else:
+ from unittest import mock
+
+DEFAULT_CONN_ID = 'databricks_default'
+HOST = 'xx.cloud.databricks.com'
+LOGIN = 'login'
+PASSWORD = 'password'
+POLLING_INTERVAL_SECONDS = 30
+RETRY_DELAY = 10
+RETRY_LIMIT = 3
+RUN_ID = 1
+JOB_ID = 42
+RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
+
+RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED',
'SKIPPED', 'INTERNAL_ERROR']
+
+LIFE_CYCLE_STATE_PENDING = 'PENDING'
+LIFE_CYCLE_STATE_TERMINATED = 'TERMINATED'
+
+STATE_MESSAGE = 'Waiting for cluster'
+
+GET_RUN_RESPONSE_PENDING = {
+ 'job_id': JOB_ID,
+ 'run_page_url': RUN_PAGE_URL,
+ 'state': {
+ 'life_cycle_state': LIFE_CYCLE_STATE_PENDING,
+ 'state_message': STATE_MESSAGE,
+ 'result_state': None,
+ },
+}
+GET_RUN_RESPONSE_TERMINATED = {
+ 'job_id': JOB_ID,
+ 'run_page_url': RUN_PAGE_URL,
+ 'state': {
+ 'life_cycle_state': LIFE_CYCLE_STATE_TERMINATED,
+ 'state_message': None,
+ 'result_state': 'SUCCESS',
+ },
+}
+
+
+class TestDatabricksExecutionTrigger:
+ @provide_session
+ def setup_method(self, method, session=None):
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.host = HOST
+ conn.login = LOGIN
+ conn.password = PASSWORD
+ conn.extra = None
+ session.commit()
+
+ self.trigger = DatabricksExecutionTrigger(
+ run_id=RUN_ID,
+ databricks_conn_id=DEFAULT_CONN_ID,
+ polling_period_seconds=POLLING_INTERVAL_SECONDS,
+ )
+
+ def test_serialize(self):
+ assert self.trigger.serialize() == (
+
'airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger',
+ {
+ 'run_id': RUN_ID,
+ 'databricks_conn_id': DEFAULT_CONN_ID,
+ 'polling_period_seconds': POLLING_INTERVAL_SECONDS,
+ },
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url')
+
@mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state')
+ async def test_run_return_success(self, mock_get_run_state,
mock_get_run_page_url):
+ mock_get_run_page_url.return_value = RUN_PAGE_URL
+ mock_get_run_state.return_value = RunState(
+ life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
+ state_message='',
+ result_state='SUCCESS',
+ )
+
+ trigger_event = self.trigger.run()
+ async for event in trigger_event:
+ assert event == TriggerEvent(
+ {
+ 'run_id': RUN_ID,
+ 'run_state': RunState(
+ life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
state_message='', result_state='SUCCESS'
+ ).to_json(),
+ 'run_page_url': RUN_PAGE_URL,
+ }
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch('airflow.providers.databricks.triggers.databricks.asyncio.sleep')
+
@mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url')
+
@mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state')
+ async def test_sleep_between_retries(self, mock_get_run_state,
mock_get_run_page_url, mock_sleep):
+ mock_get_run_page_url.return_value = RUN_PAGE_URL
+ mock_get_run_state.side_effect = [
+ RunState(
+ life_cycle_state=LIFE_CYCLE_STATE_PENDING,
+ state_message='',
+ result_state='',
+ ),
+ RunState(
+ life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
+ state_message='',
+ result_state='SUCCESS',
+ ),
+ ]
+
+ trigger_event = self.trigger.run()
+ async for event in trigger_event:
+ assert event == TriggerEvent(
+ {
+ 'run_id': RUN_ID,
+ 'run_state': RunState(
+ life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
state_message='', result_state='SUCCESS'
+ ).to_json(),
+ 'run_page_url': RUN_PAGE_URL,
+ }
+ )
+ mock_sleep.assert_called_once()
+ mock_sleep.assert_called_with(POLLING_INTERVAL_SECONDS)
diff --git a/tests/providers/databricks/utils/__init__.py
b/tests/providers/databricks/utils/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/databricks/utils/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/databricks/utils/databricks.py
b/tests/providers/databricks/utils/databricks.py
new file mode 100644
index 0000000000..d450a19ceb
--- /dev/null
+++ b/tests/providers/databricks/utils/databricks.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.databricks.hooks.databricks import RunState
+from airflow.providers.databricks.utils.databricks import deep_string_coerce,
validate_trigger_event
+
+RUN_ID = 1
+RUN_PAGE_URL = 'run-page-url'
+
+
+class TestDatabricksOperatorSharedFunctions(unittest.TestCase):
+ def test_deep_string_coerce(self):
+ test_json = {
+ 'test_int': 1,
+ 'test_float': 1.0,
+ 'test_dict': {'key': 'value'},
+ 'test_list': [1, 1.0, 'a', 'b'],
+ 'test_tuple': (1, 1.0, 'a', 'b'),
+ }
+
+ expected = {
+ 'test_int': '1',
+ 'test_float': '1.0',
+ 'test_dict': {'key': 'value'},
+ 'test_list': ['1', '1.0', 'a', 'b'],
+ 'test_tuple': ['1', '1.0', 'a', 'b'],
+ }
+ assert deep_string_coerce(test_json) == expected
+
+ def test_validate_trigger_event_success(self):
+ event = {
+ 'run_id': RUN_ID,
+ 'run_page_url': RUN_PAGE_URL,
+ 'run_state': RunState('TERMINATED', 'SUCCESS', '').to_json(),
+ }
+ self.assertIsNone(validate_trigger_event(event))
+
+ def test_validate_trigger_event_failure(self):
+ event = {}
+ with pytest.raises(AirflowException):
+ validate_trigger_event(event)