This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6150d28323 Add Deferrable Databricks operators (#19736)
6150d28323 is described below

commit 6150d283234b48f86362fd4da856e282dd91ebb4
Author: Eugene Karimov <[email protected]>
AuthorDate: Sun May 22 16:22:49 2022 +0200

    Add Deferrable Databricks operators (#19736)
---
 airflow/providers/databricks/hooks/databricks.py   |  29 ++
 .../providers/databricks/hooks/databricks_base.py  | 237 +++++++++++++-
 .../providers/databricks/operators/databricks.py   | 106 +++++--
 airflow/providers/databricks/triggers/__init__.py  |  17 +
 .../providers/databricks/triggers/databricks.py    |  77 +++++
 airflow/providers/databricks/utils/__init__.py     |  16 +
 airflow/providers/databricks/utils/databricks.py   |  69 ++++
 .../operators/run_now.rst                          |   7 +
 .../operators/submit_run.rst                       |   7 +
 setup.py                                           |   2 +
 .../providers/databricks/hooks/test_databricks.py  | 352 ++++++++++++++++++++-
 .../databricks/operators/test_databricks.py        | 239 +++++++++++---
 tests/providers/databricks/triggers/__init__.py    |  17 +
 .../databricks/triggers/test_databricks.py         | 153 +++++++++
 tests/providers/databricks/utils/__init__.py       |  16 +
 tests/providers/databricks/utils/databricks.py     |  62 ++++
 16 files changed, 1313 insertions(+), 93 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks.py 
b/airflow/providers/databricks/hooks/databricks.py
index 7911660412..400bbe8955 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -25,6 +25,7 @@ operators talk to the
 or the ``api/2.1/jobs/runs/submit``
 `endpoint 
<https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit>`_.
 """
+import json
 from typing import Any, Dict, List, Optional
 
 from requests import exceptions as requests_exceptions
@@ -92,6 +93,13 @@ class RunState:
     def __repr__(self) -> str:
         return str(self.__dict__)
 
+    def to_json(self) -> str:
+        return json.dumps(self.__dict__)
+
+    @classmethod
+    def from_json(cls, data: str) -> 'RunState':
+        return RunState(**json.loads(data))
+
 
 class DatabricksHook(BaseDatabricksHook):
     """
@@ -198,6 +206,16 @@ class DatabricksHook(BaseDatabricksHook):
         response = self._do_api_call(GET_RUN_ENDPOINT, json)
         return response['run_page_url']
 
+    async def a_get_run_page_url(self, run_id: int) -> str:
+        """
+        Async version of `get_run_page_url()`.
+        :param run_id: id of the run
+        :return: URL of the run page
+        """
+        json = {'run_id': run_id}
+        response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
+        return response['run_page_url']
+
     def get_job_id(self, run_id: int) -> int:
         """
         Retrieves job_id from run_id.
@@ -229,6 +247,17 @@ class DatabricksHook(BaseDatabricksHook):
         state = response['state']
         return RunState(**state)
 
+    async def a_get_run_state(self, run_id: int) -> RunState:
+        """
+        Async version of `get_run_state()`.
+        :param run_id: id of the run
+        :return: state of the run
+        """
+        json = {'run_id': run_id}
+        response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
+        state = response['state']
+        return RunState(**state)
+
     def get_run_state_str(self, run_id: int) -> str:
         """
         Return the string representation of RunState.
diff --git a/airflow/providers/databricks/hooks/databricks_base.py 
b/airflow/providers/databricks/hooks/databricks_base.py
index 6e0f1b44d8..5b18dad930 100644
--- a/airflow/providers/databricks/hooks/databricks_base.py
+++ b/airflow/providers/databricks/hooks/databricks_base.py
@@ -28,11 +28,19 @@ import time
 from typing import Any, Dict, Optional, Tuple
 from urllib.parse import urlparse
 
+import aiohttp
 import requests
 from requests import PreparedRequest, exceptions as requests_exceptions
 from requests.auth import AuthBase, HTTPBasicAuth
 from requests.exceptions import JSONDecodeError
-from tenacity import RetryError, Retrying, retry_if_exception, 
stop_after_attempt, wait_exponential
+from tenacity import (
+    AsyncRetrying,
+    RetryError,
+    Retrying,
+    retry_if_exception,
+    stop_after_attempt,
+    wait_exponential,
+)
 
 from airflow import __version__
 from airflow.exceptions import AirflowException
@@ -135,6 +143,14 @@ class BaseDatabricksHook(BaseHook):
 
         return host
 
+    async def __aenter__(self):
+        self._session = aiohttp.ClientSession()
+        return self
+
+    async def __aexit__(self, *err):
+        await self._session.close()
+        self._session = None
+
     @staticmethod
     def _parse_host(host: str) -> str:
         """
@@ -169,6 +185,13 @@ class BaseDatabricksHook(BaseHook):
         """
         return Retrying(**self.retry_args)
 
+    def _a_get_retry_object(self) -> AsyncRetrying:
+        """
+        Instantiates an async retry object
+        :return: instance of AsyncRetrying class
+        """
+        return AsyncRetrying(**self.retry_args)
+
     def _get_aad_token(self, resource: str) -> str:
         """
         Function to get AAD token for given resource. Supports managed 
identity or service principal auth
@@ -234,6 +257,72 @@ class BaseDatabricksHook(BaseHook):
 
         return token
 
+    async def _a_get_aad_token(self, resource: str) -> str:
+        """
+        Async version of `_get_aad_token()`.
+        :param resource: resource to issue token to
+        :return: AAD token, or raise an exception
+        """
+        aad_token = self.aad_tokens.get(resource)
+        if aad_token and self._is_aad_token_valid(aad_token):
+            return aad_token['token']
+
+        self.log.info('Existing AAD token is expired, or going to expire soon. 
Refreshing...')
+        try:
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    if 
self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
+                        params = {
+                            "api-version": "2018-02-01",
+                            "resource": resource,
+                        }
+                        async with self._session.get(
+                            url=AZURE_METADATA_SERVICE_TOKEN_URL,
+                            params=params,
+                            headers={**USER_AGENT_HEADER, "Metadata": "true"},
+                            timeout=self.aad_timeout_seconds,
+                        ) as resp:
+                            resp.raise_for_status()
+                            jsn = await resp.json()
+                    else:
+                        tenant_id = 
self.databricks_conn.extra_dejson['azure_tenant_id']
+                        data = {
+                            "grant_type": "client_credentials",
+                            "client_id": self.databricks_conn.login,
+                            "resource": resource,
+                            "client_secret": self.databricks_conn.password,
+                        }
+                        azure_ad_endpoint = 
self.databricks_conn.extra_dejson.get(
+                            "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
+                        )
+                        async with self._session.post(
+                            
url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
+                            data=data,
+                            headers={
+                                **USER_AGENT_HEADER,
+                                'Content-Type': 
'application/x-www-form-urlencoded',
+                            },
+                            timeout=self.aad_timeout_seconds,
+                        ) as resp:
+                            resp.raise_for_status()
+                            jsn = await resp.json()
+                    if (
+                        'access_token' not in jsn
+                        or jsn.get('token_type') != 'Bearer'
+                        or 'expires_on' not in jsn
+                    ):
+                        raise AirflowException(f"Can't get necessary data from 
AAD token: {jsn}")
+
+                    token = jsn['access_token']
+                    self.aad_tokens[resource] = {'token': token, 'expires_on': 
int(jsn["expires_on"])}
+                    break
+        except RetryError:
+            raise AirflowException(f'API requests to Azure failed 
{self.retry_limit} times. Giving up.')
+        except aiohttp.ClientResponseError as err:
+            raise AirflowException(f'Response: {err.message}, Status Code: 
{err.status}')
+
+        return token
+
     def _get_aad_headers(self) -> dict:
         """
         Fills AAD headers if necessary (SPN is outside of the workspace)
@@ -248,6 +337,20 @@ class BaseDatabricksHook(BaseHook):
             headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
         return headers
 
+    async def _a_get_aad_headers(self) -> dict:
+        """
+        Async version of `_get_aad_headers()`.
+        :return: dictionary with filled AAD headers
+        """
+        headers = {}
+        if 'azure_resource_id' in self.databricks_conn.extra_dejson:
+            mgmt_token = await self._a_get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
+            headers['X-Databricks-Azure-Workspace-Resource-Id'] = 
self.databricks_conn.extra_dejson[
+                'azure_resource_id'
+            ]
+            headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
+        return headers
+
     @staticmethod
     def _is_aad_token_valid(aad_token: dict) -> bool:
         """
@@ -281,6 +384,23 @@ class BaseDatabricksHook(BaseHook):
         except (requests_exceptions.RequestException, ValueError) as e:
             raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
 
+    async def _a_check_azure_metadata_service(self):
+        """Async version of `_check_azure_metadata_service()`."""
+        try:
+            async with self._session.get(
+                url=AZURE_METADATA_SERVICE_INSTANCE_URL,
+                params={"api-version": "2021-02-01"},
+                headers={"Metadata": "true"},
+                timeout=2,
+            ) as resp:
+                jsn = await resp.json()
+            if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
+                raise AirflowException(
+                    f"Was able to fetch some metadata, but it doesn't look 
like Azure Metadata: {jsn}"
+                )
+        except (requests_exceptions.RequestException, ValueError) as e:
+            raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
+
     def _get_token(self, raise_error: bool = False) -> Optional[str]:
         if 'token' in self.databricks_conn.extra_dejson:
             self.log.info(
@@ -304,6 +424,29 @@ class BaseDatabricksHook(BaseHook):
 
         return None
 
+    async def _a_get_token(self, raise_error: bool = False) -> Optional[str]:
+        if 'token' in self.databricks_conn.extra_dejson:
+            self.log.info(
+                'Using token auth. For security reasons, please set token in 
Password field instead of extra'
+            )
+            return self.databricks_conn.extra_dejson["token"]
+        elif not self.databricks_conn.login and self.databricks_conn.password:
+            self.log.info('Using token auth.')
+            return self.databricks_conn.password
+        elif 'azure_tenant_id' in self.databricks_conn.extra_dejson:
+            if self.databricks_conn.login == "" or 
self.databricks_conn.password == "":
+                raise AirflowException("Azure SPN credentials aren't provided")
+            self.log.info('Using AAD Token for SPN.')
+            return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
+        elif 
self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
+            self.log.info('Using AAD Token for managed identity.')
+            await self._a_check_azure_metadata_service()
+            return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
+        elif raise_error:
+            raise AirflowException('Token authentication isn\'t configured')
+
+        return None
+
     def _log_request_error(self, attempt_num: int, error: str) -> None:
         self.log.error('Attempt %s API Request to Databricks failed with 
reason: %s', attempt_num, error)
 
@@ -374,6 +517,55 @@ class BaseDatabricksHook(BaseHook):
             else:
                 raise e
 
+    async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: 
Optional[Dict[str, Any]] = None):
+        """
+        Async version of `_do_api_call()`.
+        :param endpoint_info: Tuple of method and endpoint
+        :param json: Parameters for this API call.
+        :return: If the api call returns a OK status code,
+            this function returns the response in JSON. Otherwise, throw an 
AirflowException.
+        """
+        method, endpoint = endpoint_info
+
+        url = f'https://{self.host}/{endpoint}'
+
+        aad_headers = await self._a_get_aad_headers()
+        headers = {**USER_AGENT_HEADER.copy(), **aad_headers}
+
+        auth: aiohttp.BasicAuth
+        token = await self._a_get_token()
+        if token:
+            auth = BearerAuth(token)
+        else:
+            self.log.info('Using basic auth.')
+            auth = aiohttp.BasicAuth(self.databricks_conn.login, 
self.databricks_conn.password)
+
+        request_func: Any
+        if method == 'GET':
+            request_func = self._session.get
+        elif method == 'POST':
+            request_func = self._session.post
+        elif method == 'PATCH':
+            request_func = self._session.patch
+        else:
+            raise AirflowException('Unexpected HTTP Method: ' + method)
+        try:
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    async with request_func(
+                        url,
+                        json=json,
+                        auth=auth,
+                        headers={**headers, **USER_AGENT_HEADER},
+                        timeout=self.timeout_seconds,
+                    ) as response:
+                        response.raise_for_status()
+                        return await response.json()
+        except RetryError:
+            raise AirflowException(f'API requests to Databricks failed 
{self.retry_limit} times. Giving up.')
+        except aiohttp.ClientResponseError as err:
+            raise AirflowException(f'Response: {err.message}, Status Code: 
{err.status}')
+
     @staticmethod
     def _get_error_code(exception: BaseException) -> str:
         if isinstance(exception, requests_exceptions.HTTPError):
@@ -387,19 +579,25 @@ class BaseDatabricksHook(BaseHook):
 
     @staticmethod
     def _retryable_error(exception: BaseException) -> bool:
-        if not isinstance(exception, requests_exceptions.RequestException):
-            return False
-        return isinstance(exception, (requests_exceptions.ConnectionError, 
requests_exceptions.Timeout)) or (
-            exception.response is not None
-            and (
-                exception.response.status_code >= 500
-                or exception.response.status_code == 429
-                or (
-                    exception.response.status_code == 400
-                    and BaseDatabricksHook._get_error_code(exception) == 
'COULD_NOT_ACQUIRE_LOCK'
+        if isinstance(exception, requests_exceptions.RequestException):
+            if isinstance(exception, (requests_exceptions.ConnectionError, 
requests_exceptions.Timeout)) or (
+                exception.response is not None
+                and (
+                    exception.response.status_code >= 500
+                    or exception.response.status_code == 429
+                    or (
+                        exception.response.status_code == 400
+                        and BaseDatabricksHook._get_error_code(exception) == 
'COULD_NOT_ACQUIRE_LOCK'
+                    )
                 )
-            )
-        )
+            ):
+                return True
+
+        if isinstance(exception, aiohttp.ClientResponseError):
+            if exception.status >= 500 or exception.status == 429:
+                return True
+
+        return False
 
 
 class _TokenAuth(AuthBase):
@@ -414,3 +612,16 @@ class _TokenAuth(AuthBase):
     def __call__(self, r: PreparedRequest) -> PreparedRequest:
         r.headers['Authorization'] = 'Bearer ' + self.token
         return r
+
+
+class BearerAuth(aiohttp.BasicAuth):
+    """aiohttp only ships BasicAuth, for Bearer auth we need a subclass of 
BasicAuth."""
+
+    def __new__(cls, token: str) -> 'BearerAuth':
+        return super().__new__(cls, token)  # type: ignore
+
+    def __init__(self, token: str) -> None:
+        self.token = token
+
+    def encode(self) -> str:
+        return f'Bearer {self.token}'
diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index 577e59a2c2..8af4474b13 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -19,51 +19,24 @@
 """This module contains Databricks operators."""
 
 import time
+from logging import Logger
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator, BaseOperatorLink, XCom
-from airflow.providers.databricks.hooks.databricks import DatabricksHook
+from airflow.providers.databricks.hooks.databricks import DatabricksHook, 
RunState
+from airflow.providers.databricks.triggers.databricks import 
DatabricksExecutionTrigger
+from airflow.providers.databricks.utils.databricks import deep_string_coerce, 
validate_trigger_event
 
 if TYPE_CHECKING:
     from airflow.models.taskinstance import TaskInstanceKey
     from airflow.utils.context import Context
 
+DEFER_METHOD_NAME = 'execute_complete'
 XCOM_RUN_ID_KEY = 'run_id'
 XCOM_RUN_PAGE_URL_KEY = 'run_page_url'
 
 
-def _deep_string_coerce(content, json_path: str = 'json') -> Union[str, list, 
dict]:
-    """
-    Coerces content or all values of content if it is a dict to a string. The
-    function will throw if content contains non-string or non-numeric types.
-
-    The reason why we have this function is because the ``self.json`` field 
must be a
-    dict with only string values. This is because ``render_template`` will fail
-    for numerical values.
-    """
-    coerce = _deep_string_coerce
-    if isinstance(content, str):
-        return content
-    elif isinstance(
-        content,
-        (
-            int,
-            float,
-        ),
-    ):
-        # Databricks can tolerate either numeric or string types in the API 
backend.
-        return str(content)
-    elif isinstance(content, (list, tuple)):
-        return [coerce(e, f'{json_path}[{i}]') for i, e in enumerate(content)]
-    elif isinstance(content, dict):
-        return {k: coerce(v, f'{json_path}[{k}]') for k, v in 
list(content.items())}
-    else:
-        param_type = type(content)
-        msg = f'Type {param_type} used for parameter {json_path} is not a 
number or a string'
-        raise AirflowException(msg)
-
-
 def _handle_databricks_operator_execution(operator, hook, log, context) -> 
None:
     """
     Handles the Airflow + Databricks lifecycle logic for a Databricks operator
@@ -103,6 +76,47 @@ def _handle_databricks_operator_execution(operator, hook, 
log, context) -> None:
         log.info('View run status, Spark UI, and logs at %s', run_page_url)
 
 
+def _handle_deferrable_databricks_operator_execution(operator, hook, log, 
context) -> None:
+    """
+    Handles the Airflow + Databricks lifecycle logic for deferrable Databricks 
operators
+
+    :param operator: Databricks async operator being handled
+    :param context: Airflow context
+    """
+    if operator.do_xcom_push and context is not None:
+        context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
+    log.info(f'Run submitted with run_id: {operator.run_id}')
+
+    run_page_url = hook.get_run_page_url(operator.run_id)
+    if operator.do_xcom_push and context is not None:
+        context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
+    log.info(f'View run status, Spark UI, and logs at {run_page_url}')
+
+    if operator.wait_for_termination:
+        operator.defer(
+            trigger=DatabricksExecutionTrigger(
+                run_id=operator.run_id,
+                databricks_conn_id=operator.databricks_conn_id,
+                polling_period_seconds=operator.polling_period_seconds,
+            ),
+            method_name=DEFER_METHOD_NAME,
+        )
+
+
+def _handle_deferrable_databricks_operator_completion(event: dict, log: 
Logger) -> None:
+    validate_trigger_event(event)
+    run_state = RunState.from_json(event['run_state'])
+    run_page_url = event['run_page_url']
+    log.info(f'View run status, Spark UI, and logs at {run_page_url}')
+
+    if run_state.is_successful:
+        log.info('Job run completed successfully.')
+        return
+    else:
+        error_message = f'Job run failed with terminal state: {run_state}'
+        raise AirflowException(error_message)
+
+
 class DatabricksJobRunLink(BaseOperatorLink):
     """Constructs a link to monitor a Databricks Job Run."""
 
@@ -364,7 +378,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
         if git_source is not None:
             self.json['git_source'] = git_source
 
-        self.json = _deep_string_coerce(self.json)
+        self.json = deep_string_coerce(self.json)
         # This variable will be used in case our task gets killed.
         self.run_id: Optional[int] = None
         self.do_xcom_push = do_xcom_push
@@ -393,6 +407,18 @@ class DatabricksSubmitRunOperator(BaseOperator):
             self.log.error('Error: Task: %s with invalid run_id was requested 
to be cancelled.', self.task_id)
 
 
+class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
+    """Deferrable version of ``DatabricksSubmitRunOperator``"""
+
+    def execute(self, context):
+        hook = self._get_hook()
+        self.run_id = hook.submit_run(self.json)
+        _handle_deferrable_databricks_operator_execution(self, hook, self.log, 
context)
+
+    def execute_complete(self, context: Optional[dict], event: dict):
+        _handle_deferrable_databricks_operator_completion(event, self.log)
+
+
 class DatabricksRunNowOperator(BaseOperator):
     """
     Runs an existing Spark job run to Databricks using the
@@ -604,7 +630,7 @@ class DatabricksRunNowOperator(BaseOperator):
         if idempotency_token is not None:
             self.json['idempotency_token'] = idempotency_token
 
-        self.json = _deep_string_coerce(self.json)
+        self.json = deep_string_coerce(self.json)
         # This variable will be used in case our task gets killed.
         self.run_id: Optional[int] = None
         self.do_xcom_push = do_xcom_push
@@ -637,3 +663,15 @@ class DatabricksRunNowOperator(BaseOperator):
             )
         else:
             self.log.error('Error: Task: %s with invalid run_id was requested 
to be cancelled.', self.task_id)
+
+
+class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
+    """Deferrable version of ``DatabricksRunNowOperator``"""
+
+    def execute(self, context):
+        hook = self._get_hook()
+        self.run_id = hook.run_now(self.json)
+        _handle_deferrable_databricks_operator_execution(self, hook, self.log, 
context)
+
+    def execute_complete(self, context: Optional[dict], event: dict):
+        _handle_deferrable_databricks_operator_completion(event, self.log)
diff --git a/airflow/providers/databricks/triggers/__init__.py 
b/airflow/providers/databricks/triggers/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/airflow/providers/databricks/triggers/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/databricks/triggers/databricks.py 
b/airflow/providers/databricks/triggers/databricks.py
new file mode 100644
index 0000000000..5f50f5aff2
--- /dev/null
+++ b/airflow/providers/databricks/triggers/databricks.py
@@ -0,0 +1,77 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import asyncio
+import logging
+from typing import Any, Dict, Tuple
+
+from airflow.providers.databricks.hooks.databricks import DatabricksHook
+
+try:
+    from airflow.triggers.base import BaseTrigger, TriggerEvent
+except ImportError:
+    logging.getLogger(__name__).warning(
+        'Deferrable Operators only work starting Airflow 2.2',
+        exc_info=True,
+    )
+    BaseTrigger = object  # type: ignore
+    TriggerEvent = None  # type: ignore
+
+
+class DatabricksExecutionTrigger(BaseTrigger):
+    """
+    The trigger handles the logic of async communication with DataBricks API.
+
+    :param run_id: id of the run
+    :param databricks_conn_id: Reference to the :ref:`Databricks connection 
<howto/connection:databricks>`.
+    :param polling_period_seconds: Controls the rate of the poll for the 
result of this run.
+        By default, the trigger will poll every 30 seconds.
+    """
+
+    def __init__(self, run_id: int, databricks_conn_id: str, 
polling_period_seconds: int = 30) -> None:
+        super().__init__()
+        self.run_id = run_id
+        self.databricks_conn_id = databricks_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.hook = DatabricksHook(databricks_conn_id)
+
+    def serialize(self) -> Tuple[str, Dict[str, Any]]:
+        return (
+            
'airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger',
+            {
+                'run_id': self.run_id,
+                'databricks_conn_id': self.databricks_conn_id,
+                'polling_period_seconds': self.polling_period_seconds,
+            },
+        )
+
+    async def run(self):
+        async with self.hook:
+            run_page_url = await self.hook.a_get_run_page_url(self.run_id)
+            while True:
+                run_state = await self.hook.a_get_run_state(self.run_id)
+                if run_state.is_terminal:
+                    yield TriggerEvent(
+                        {
+                            'run_id': self.run_id,
+                            'run_state': run_state.to_json(),
+                            'run_page_url': run_page_url,
+                        }
+                    )
+                    break
+                else:
+                    await asyncio.sleep(self.polling_period_seconds)
diff --git a/airflow/providers/databricks/utils/__init__.py 
b/airflow/providers/databricks/utils/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/databricks/utils/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/databricks/utils/databricks.py 
b/airflow/providers/databricks/utils/databricks.py
new file mode 100644
index 0000000000..96935d8063
--- /dev/null
+++ b/airflow/providers/databricks/utils/databricks.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Union
+
+from airflow.exceptions import AirflowException
+from airflow.providers.databricks.hooks.databricks import RunState
+
+
+def deep_string_coerce(content, json_path: str = 'json') -> Union[str, list, 
dict]:
+    """
+    Coerces content or all values of content if it is a dict to a string. The
+    function will throw if content contains non-string or non-numeric types.
+    The reason why we have this function is because the ``self.json`` field 
must be a
+    dict with only string values. This is because ``render_template`` will fail
+    for numerical values.
+    """
+    coerce = deep_string_coerce
+    if isinstance(content, str):
+        return content
+    elif isinstance(
+        content,
+        (
+            int,
+            float,
+        ),
+    ):
+        # Databricks can tolerate either numeric or string types in the API 
backend.
+        return str(content)
+    elif isinstance(content, (list, tuple)):
+        return [coerce(e, f'{json_path}[{i}]') for i, e in enumerate(content)]
+    elif isinstance(content, dict):
+        return {k: coerce(v, f'{json_path}[{k}]') for k, v in 
list(content.items())}
+    else:
+        param_type = type(content)
+        msg = f'Type {param_type} used for parameter {json_path} is not a 
number or a string'
+        raise AirflowException(msg)
+
+
+def validate_trigger_event(event: dict):
+    """
+    Validates correctness of the event
+    received from 
:class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger`
+    """
+    keys_to_check = ['run_id', 'run_page_url', 'run_state']
+    for key in keys_to_check:
+        if key not in event:
+            raise AirflowException(f'Could not find `{key}` in the event: 
{event}')
+
+    try:
+        RunState.from_json(event['run_state'])
+    except Exception:
+        raise AirflowException(f'Run state returned by the Trigger is 
incorrect: {event["run_state"]}')
diff --git a/docs/apache-airflow-providers-databricks/operators/run_now.rst 
b/docs/apache-airflow-providers-databricks/operators/run_now.rst
index 8b2e6010ee..a4b00d9005 100644
--- a/docs/apache-airflow-providers-databricks/operators/run_now.rst
+++ b/docs/apache-airflow-providers-databricks/operators/run_now.rst
@@ -45,3 +45,10 @@ All other parameters are optional and described in 
documentation for ``Databrick
 * ``python_named_parameters``
 * ``jar_params``
 * ``spark_submit_params``
+
+DatabricksRunNowDeferrableOperator
+==================================
+
+Deferrable version of the 
:class:`~airflow.providers.databricks.operators.DatabricksRunNowOperator` 
operator.
+
+It allows to utilize Airflow workers more effectively using `new functionality 
introduced in Airflow 2.2.0 
<https://airflow.apache.org/docs/apache-airflow/2.2.0/concepts/deferring.html#triggering-deferral>`_
diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst 
b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
index da71194da7..81f9dfd32f 100644
--- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
+++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
@@ -75,3 +75,10 @@ You can also use named parameters to initialize the operator 
and run the job.
     :language: python
     :start-after: [START howto_operator_databricks_named]
     :end-before: [END howto_operator_databricks_named]
+
+DatabricksSubmitRunDeferrableOperator
+=====================================
+
+Deferrable version of the 
:class:`~airflow.providers.databricks.operators.DatabricksSubmitRunOperator` 
operator.
+
+It allows to utilize Airflow workers more effectively using `new functionality 
introduced in Airflow 2.2.0 
<https://airflow.apache.org/docs/apache-airflow/2.2.0/concepts/deferring.html#triggering-deferral>`_
diff --git a/setup.py b/setup.py
index 742ba57aa5..e7d5951694 100644
--- a/setup.py
+++ b/setup.py
@@ -268,6 +268,7 @@ dask = [
 databricks = [
     'requests>=2.26.0, <3',
     'databricks-sql-connector>=2.0.0, <3.0.0',
+    'aiohttp>=3.6.3, <4',
 ]
 datadog = [
     'datadog>=0.14.0',
@@ -609,6 +610,7 @@ mypy_dependencies = [
 
 # Dependencies needed for development only
 devel_only = [
+    'asynctest~=0.13',
     'aws_xray_sdk',
     'beautifulsoup4>=4.7.1',
     'black',
diff --git a/tests/providers/databricks/hooks/test_databricks.py 
b/tests/providers/databricks/hooks/test_databricks.py
index 5a93ed7b42..2997984f7b 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -19,10 +19,11 @@
 
 import itertools
 import json
+import sys
 import time
 import unittest
-from unittest import mock
 
+import aiohttp
 import pytest
 import tenacity
 from requests import exceptions as requests_exceptions
@@ -31,7 +32,12 @@ from requests.auth import HTTPBasicAuth
 from airflow import __version__
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
-from airflow.providers.databricks.hooks.databricks import SUBMIT_RUN_ENDPOINT, 
DatabricksHook, RunState
+from airflow.providers.databricks.hooks.databricks import (
+    GET_RUN_ENDPOINT,
+    SUBMIT_RUN_ENDPOINT,
+    DatabricksHook,
+    RunState,
+)
 from airflow.providers.databricks.hooks.databricks_base import (
     AZURE_DEFAULT_AD_ENDPOINT,
     AZURE_MANAGEMENT_ENDPOINT,
@@ -39,9 +45,17 @@ from airflow.providers.databricks.hooks.databricks_base 
import (
     AZURE_TOKEN_SERVICE_URL,
     DEFAULT_DATABRICKS_SCOPE,
     TOKEN_REFRESH_LEAD_TIME,
+    BearerAuth,
 )
 from airflow.utils.session import provide_session
 
+if sys.version_info < (3, 8):
+    from asynctest import mock
+    from asynctest.mock import CoroutineMock as AsyncMock
+else:
+    from unittest import mock
+    from unittest.mock import AsyncMock
+
 TASK_ID = 'databricks-operator'
 DEFAULT_CONN_ID = 'databricks_default'
 NOTEBOOK_TASK = {'notebook_path': '/test'}
@@ -172,6 +186,7 @@ def list_jobs_endpoint(host):
 def create_valid_response_mock(content):
     response = mock.MagicMock()
     response.json.return_value = content
+    response.__aenter__.return_value.json = AsyncMock(return_value=content)
     return response
 
 
@@ -785,6 +800,18 @@ class TestRunState(unittest.TestCase):
         run_state = RunState('TERMINATED', 'SUCCESS', '')
         assert run_state.is_successful
 
+    def test_to_json(self):
+        run_state = RunState('TERMINATED', 'SUCCESS', '')
+        expected = json.dumps(
+            {'life_cycle_state': 'TERMINATED', 'result_state': 'SUCCESS', 
'state_message': ''}
+        )
+        assert expected == run_state.to_json()
+
+    def test_from_json(self):
+        state = {'life_cycle_state': 'TERMINATED', 'result_state': 'SUCCESS', 
'state_message': ''}
+        expected = RunState('TERMINATED', 'SUCCESS', '')
+        assert expected == RunState.from_json(json.dumps(state))
+
 
 def create_aad_token_for_resource(resource: str) -> dict:
     return {
@@ -976,3 +1003,324 @@ class 
TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase):
         args = mock_requests.post.call_args
         kwargs = args[1]
         assert kwargs['auth'].token == TOKEN
+
+
+class TestDatabricksHookAsyncMethods:
+    """
+    Tests for async functionality of DatabricksHook.
+    """
+
+    @provide_session
+    def setup_method(self, method, session=None):
+        conn = session.query(Connection).filter(Connection.conn_id == 
DEFAULT_CONN_ID).first()
+        conn.host = HOST
+        conn.login = LOGIN
+        conn.password = PASSWORD
+        conn.extra = None
+        session.commit()
+
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+    @pytest.mark.asyncio
+    async def test_init_async_session(self):
+        async with self.hook:
+            assert isinstance(self.hook._session, aiohttp.ClientSession)
+        assert self.hook._session is None
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    async def test_do_api_call_retries_with_retryable_error(self, mock_get):
+        mock_get.side_effect = aiohttp.ClientResponseError(None, None, 
status=500)
+        with mock.patch.object(self.hook.log, 'error') as mock_errors:
+            async with self.hook:
+                with pytest.raises(AirflowException):
+                    await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
+                assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    async def test_do_api_call_does_not_retry_with_non_retryable_error(self, 
mock_get):
+        mock_get.side_effect = aiohttp.ClientResponseError(None, None, 
status=400)
+        with mock.patch.object(self.hook.log, 'error') as mock_errors:
+            async with self.hook:
+                with pytest.raises(AirflowException):
+                    await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
+                mock_errors.assert_not_called()
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    async def test_do_api_call_succeeds_after_retrying(self, mock_get):
+        mock_get.side_effect = [
+            aiohttp.ClientResponseError(None, None, status=500),
+            create_valid_response_mock({'run_id': '1'}),
+        ]
+        with mock.patch.object(self.hook.log, 'error') as mock_errors:
+            async with self.hook:
+                response = await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
+                assert mock_errors.call_count == 1
+                assert response == {'run_id': '1'}
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    async def test_do_api_call_waits_between_retries(self, mock_get):
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+        mock_get.side_effect = aiohttp.ClientResponseError(None, None, 
status=500)
+        with mock.patch.object(self.hook.log, 'error') as mock_errors:
+            async with self.hook:
+                with pytest.raises(AirflowException):
+                    await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
+                assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.patch')
+    async def test_do_api_call_patch(self, mock_patch):
+        mock_patch.return_value.__aenter__.return_value.json = AsyncMock(
+            return_value={'cluster_name': 'new_name'}
+        )
+        data = {'cluster_name': 'new_name'}
+        async with self.hook:
+            patched_cluster_name = await self.hook._a_do_api_call(('PATCH', 
'api/2.1/jobs/runs/submit'), data)
+
+        assert patched_cluster_name['cluster_name'] == 'new_name'
+        mock_patch.assert_called_once_with(
+            submit_run_endpoint(HOST),
+            json={'cluster_name': 'new_name'},
+            auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds,
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    async def test_get_run_page_url(self, mock_get):
+        mock_get.return_value.__aenter__.return_value.json = 
AsyncMock(return_value=GET_RUN_RESPONSE)
+        async with self.hook:
+            run_page_url = await self.hook.a_get_run_page_url(RUN_ID)
+
+        assert run_page_url == RUN_PAGE_URL
+        mock_get.assert_called_once_with(
+            get_run_endpoint(HOST),
+            json={'run_id': RUN_ID},
+            auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds,
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    async def test_get_run_state(self, mock_get):
+        mock_get.return_value.__aenter__.return_value.json = 
AsyncMock(return_value=GET_RUN_RESPONSE)
+
+        async with self.hook:
+            run_state = await self.hook.a_get_run_state(RUN_ID)
+
+        assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, 
STATE_MESSAGE)
+        mock_get.assert_called_once_with(
+            get_run_endpoint(HOST),
+            json={'run_id': RUN_ID},
+            auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds,
+        )
+
+
+class TestDatabricksHookAsyncAadToken:
+    """
+    Tests for DatabricksHook using async methods when
+    auth is done with AAD token for SP as user inside workspace.
+    """
+
+    @provide_session
+    def setup_method(self, method, session=None):
+        conn = session.query(Connection).filter(Connection.conn_id == 
DEFAULT_CONN_ID).first()
+        conn.login = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d'
+        conn.password = 'secret'
+        conn.extra = json.dumps(
+            {
+                'host': HOST,
+                'azure_tenant_id': '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d',
+            }
+        )
+        session.commit()
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post')
+    async def test_get_run_state(self, mock_post, mock_get):
+        mock_post.return_value.__aenter__.return_value.json = AsyncMock(
+            
return_value=create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)
+        )
+        mock_get.return_value.__aenter__.return_value.json = 
AsyncMock(return_value=GET_RUN_RESPONSE)
+
+        async with self.hook:
+            run_state = await self.hook.a_get_run_state(RUN_ID)
+
+        assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, 
STATE_MESSAGE)
+        mock_get.assert_called_once_with(
+            get_run_endpoint(HOST),
+            json={'run_id': RUN_ID},
+            auth=BearerAuth(TOKEN),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds,
+        )
+
+
+class TestDatabricksHookAsyncAadTokenOtherClouds:
+    """
+    Tests for DatabricksHook using async methodswhen auth is done with AAD 
token
+    for SP as user inside workspace and using non-global Azure cloud (China, 
GovCloud, Germany)
+    """
+
+    @provide_session
+    def setup_method(self, method, session=None):
+        self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d'
+        self.ad_endpoint = 'https://login.microsoftonline.de'
+        self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d'
+        conn = session.query(Connection).filter(Connection.conn_id == 
DEFAULT_CONN_ID).first()
+        conn.login = self.client_id
+        conn.password = 'secret'
+        conn.extra = json.dumps(
+            {
+                'host': HOST,
+                'azure_tenant_id': self.tenant_id,
+                'azure_ad_endpoint': self.ad_endpoint,
+            }
+        )
+        session.commit()
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post')
+    async def test_get_run_state(self, mock_post, mock_get):
+        mock_post.return_value.__aenter__.return_value.json = AsyncMock(
+            
return_value=create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)
+        )
+        mock_get.return_value.__aenter__.return_value.json = 
AsyncMock(return_value=GET_RUN_RESPONSE)
+
+        async with self.hook:
+            run_state = await self.hook.a_get_run_state(RUN_ID)
+
+        assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, 
STATE_MESSAGE)
+
+        ad_call_args = mock_post.call_args_list[0]
+        assert ad_call_args[1]['url'] == 
AZURE_TOKEN_SERVICE_URL.format(self.ad_endpoint, self.tenant_id)
+        assert ad_call_args[1]['data']['client_id'] == self.client_id
+        assert ad_call_args[1]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE
+
+        mock_get.assert_called_once_with(
+            get_run_endpoint(HOST),
+            json={'run_id': RUN_ID},
+            auth=BearerAuth(TOKEN),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds,
+        )
+
+
+class TestDatabricksHookAsyncAadTokenSpOutside:
+    """
+    Tests for DatabricksHook using async methods when auth is done with AAD 
token for SP outside of workspace.
+    """
+
+    @provide_session
+    def setup_method(self, method, session=None):
+        self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d'
+        self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d'
+        conn = session.query(Connection).filter(Connection.conn_id == 
DEFAULT_CONN_ID).first()
+        conn.login = self.client_id
+        conn.password = 'secret'
+        conn.host = HOST
+        conn.extra = json.dumps(
+            {
+                'azure_resource_id': '/Some/resource',
+                'azure_tenant_id': self.tenant_id,
+            }
+        )
+        session.commit()
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post')
+    async def test_get_run_state(self, mock_post, mock_get):
+        mock_post.return_value.__aenter__.return_value.json.side_effect = 
AsyncMock(
+            side_effect=[
+                create_aad_token_for_resource(AZURE_MANAGEMENT_ENDPOINT),
+                create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE),
+            ]
+        )
+        mock_get.return_value.__aenter__.return_value.json = 
AsyncMock(return_value=GET_RUN_RESPONSE)
+
+        async with self.hook:
+            run_state = await self.hook.a_get_run_state(RUN_ID)
+
+        assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, 
STATE_MESSAGE)
+
+        ad_call_args = mock_post.call_args_list[0]
+        assert ad_call_args[1]['url'] == AZURE_TOKEN_SERVICE_URL.format(
+            AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id
+        )
+        assert ad_call_args[1]['data']['client_id'] == self.client_id
+        assert ad_call_args[1]['data']['resource'] == AZURE_MANAGEMENT_ENDPOINT
+
+        ad_call_args = mock_post.call_args_list[1]
+        assert ad_call_args[1]['url'] == AZURE_TOKEN_SERVICE_URL.format(
+            AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id
+        )
+        assert ad_call_args[1]['data']['client_id'] == self.client_id
+        assert ad_call_args[1]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE
+
+        mock_get.assert_called_once_with(
+            get_run_endpoint(HOST),
+            json={'run_id': RUN_ID},
+            auth=BearerAuth(TOKEN),
+            headers={
+                **USER_AGENT_HEADER,
+                'X-Databricks-Azure-Workspace-Resource-Id': '/Some/resource',
+                'X-Databricks-Azure-SP-Management-Token': TOKEN,
+            },
+            timeout=self.hook.timeout_seconds,
+        )
+
+
+class TestDatabricksHookAsyncAadTokenManagedIdentity:
+    """
+    Tests for DatabricksHook using async methods when
+    auth is done with AAD leveraging Managed Identity authentication
+    """
+
+    @provide_session
+    def setup_method(self, method, session=None):
+        conn = session.query(Connection).filter(Connection.conn_id == 
DEFAULT_CONN_ID).first()
+        conn.host = HOST
+        conn.extra = json.dumps(
+            {
+                'use_azure_managed_identity': True,
+            }
+        )
+        session.commit()
+        session.commit()
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get')
+    async def test_get_run_state(self, mock_get):
+        mock_get.return_value.__aenter__.return_value.json.side_effect = 
AsyncMock(
+            side_effect=[
+                {'compute': {'azEnvironment': 'AZUREPUBLICCLOUD'}},
+                create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE),
+                GET_RUN_RESPONSE,
+            ]
+        )
+
+        async with self.hook:
+            run_state = await self.hook.a_get_run_state(RUN_ID)
+
+        assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, 
STATE_MESSAGE)
+
+        ad_call_args = mock_get.call_args_list[0]
+        assert ad_call_args[1]['url'] == AZURE_METADATA_SERVICE_INSTANCE_URL
+        assert ad_call_args[1]['params']['api-version'] > '2018-02-01'
+        assert ad_call_args[1]['headers']['Metadata'] == 'true'
diff --git a/tests/providers/databricks/operators/test_databricks.py 
b/tests/providers/databricks/operators/test_databricks.py
index 99782234e0..895beb18e5 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -22,14 +22,17 @@ from unittest import mock
 
 import pytest
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.models import DAG
 from airflow.providers.databricks.hooks.databricks import RunState
-from airflow.providers.databricks.operators import databricks as 
databricks_operator
 from airflow.providers.databricks.operators.databricks import (
+    DatabricksRunNowDeferrableOperator,
     DatabricksRunNowOperator,
+    DatabricksSubmitRunDeferrableOperator,
     DatabricksSubmitRunOperator,
 )
+from airflow.providers.databricks.triggers.databricks import 
DatabricksExecutionTrigger
+from airflow.providers.databricks.utils import databricks as utils
 
 DATE = '2017-04-20'
 TASK_ID = 'databricks-operator'
@@ -46,6 +49,7 @@ NEW_CLUSTER = {'spark_version': '2.0.x-scala2.10', 
'node_type_id': 'development-
 EXISTING_CLUSTER_ID = 'existing-cluster-id'
 RUN_NAME = 'run-name'
 RUN_ID = 1
+RUN_PAGE_URL = 'run-page-url'
 JOB_ID = "42"
 JOB_NAME = "job-name"
 NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": 
"1457570074236"}
@@ -56,26 +60,6 @@ PYTHON_PARAMS = ["john doe", "35"]
 SPARK_SUBMIT_PARAMS = ["--class", "org.apache.spark.examples.SparkPi"]
 
 
-class TestDatabricksOperatorSharedFunctions(unittest.TestCase):
-    def test_deep_string_coerce(self):
-        test_json = {
-            'test_int': 1,
-            'test_float': 1.0,
-            'test_dict': {'key': 'value'},
-            'test_list': [1, 1.0, 'a', 'b'],
-            'test_tuple': (1, 1.0, 'a', 'b'),
-        }
-
-        expected = {
-            'test_int': '1',
-            'test_float': '1.0',
-            'test_dict': {'key': 'value'},
-            'test_list': ['1', '1.0', 'a', 'b'],
-            'test_tuple': ['1', '1.0', 'a', 'b'],
-        }
-        assert databricks_operator._deep_string_coerce(test_json) == expected
-
-
 class TestDatabricksSubmitRunOperator(unittest.TestCase):
     def test_init_with_notebook_task_named_parameters(self):
         """
@@ -84,7 +68,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
         op = DatabricksSubmitRunOperator(
             task_id=TASK_ID, new_cluster=NEW_CLUSTER, 
notebook_task=NOTEBOOK_TASK
         )
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': TASK_ID}
         )
 
@@ -97,7 +81,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
         op = DatabricksSubmitRunOperator(
             task_id=TASK_ID, new_cluster=NEW_CLUSTER, 
spark_python_task=SPARK_PYTHON_TASK
         )
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {'new_cluster': NEW_CLUSTER, 'spark_python_task': 
SPARK_PYTHON_TASK, 'run_name': TASK_ID}
         )
 
@@ -110,7 +94,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
         op = DatabricksSubmitRunOperator(
             task_id=TASK_ID, new_cluster=NEW_CLUSTER, 
spark_submit_task=SPARK_SUBMIT_TASK
         )
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {'new_cluster': NEW_CLUSTER, 'spark_submit_task': 
SPARK_SUBMIT_TASK, 'run_name': TASK_ID}
         )
 
@@ -122,7 +106,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
         """
         json = {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK}
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': TASK_ID}
         )
         assert expected == op.json
@@ -130,7 +114,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
     def test_init_with_tasks(self):
         tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task": 
NOTEBOOK_TASK}]
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks)
-        expected = databricks_operator._deep_string_coerce({'run_name': 
TASK_ID, "tasks": tasks})
+        expected = utils.deep_string_coerce({'run_name': TASK_ID, "tasks": 
tasks})
         assert expected == op.json
 
     def test_init_with_specified_run_name(self):
@@ -139,7 +123,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
         """
         json = {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': RUN_NAME}
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': RUN_NAME}
         )
         assert expected == op.json
@@ -151,7 +135,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
         pipeline_task = {"pipeline_id": "test-dlt"}
         json = {'new_cluster': NEW_CLUSTER, 'run_name': RUN_NAME, 
"pipeline_task": pipeline_task}
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {'new_cluster': NEW_CLUSTER, "pipeline_task": pipeline_task, 
'run_name': RUN_NAME}
         )
         assert expected == op.json
@@ -168,7 +152,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
             'notebook_task': NOTEBOOK_TASK,
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, 
new_cluster=override_new_cluster)
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'new_cluster': override_new_cluster,
                 'notebook_task': NOTEBOOK_TASK,
@@ -185,7 +169,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
         dag = DAG('test', start_date=datetime.now())
         op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json)
         op.render_template_fields(context={'ds': DATE})
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'new_cluster': NEW_CLUSTER,
                 'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK,
@@ -238,7 +222,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
 
         op.execute(None)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': TASK_ID}
         )
         db_mock_class.assert_called_once_with(
@@ -270,7 +254,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
         with pytest.raises(AirflowException):
             op.execute(None)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'new_cluster': NEW_CLUSTER,
                 'notebook_task': NOTEBOOK_TASK,
@@ -317,7 +301,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
 
         op.execute(None)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': TASK_ID}
         )
         db_mock_class.assert_called_once_with(
@@ -345,7 +329,7 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
 
         op.execute(None)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': TASK_ID}
         )
         db_mock_class.assert_called_once_with(
@@ -360,13 +344,98 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
         db_mock.get_run_state.assert_not_called()
 
 
+class TestDatabricksSubmitRunDeferrableOperator(unittest.TestCase):
+    
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+    def test_execute_task_deferred(self, db_mock_class):
+        """
+        Test the execute function in case where the run is successful.
+        """
+        run = {
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
+        }
+        op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.submit_run.return_value = 1
+        db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', 
'')
+
+        with pytest.raises(TaskDeferred) as exc:
+            op.execute(None)
+        self.assertTrue(isinstance(exc.value.trigger, 
DatabricksExecutionTrigger))
+        self.assertEqual(exc.value.method_name, 'execute_complete')
+
+        expected = utils.deep_string_coerce(
+            {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': TASK_ID}
+        )
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+        )
+
+        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        self.assertEqual(RUN_ID, op.run_id)
+
+    def test_execute_complete_success(self):
+        """
+        Test `execute_complete` function in case the Trigger has returned a 
successful completion event.
+        """
+        run = {
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
+        }
+        event = {
+            'run_id': RUN_ID,
+            'run_page_url': RUN_PAGE_URL,
+            'run_state': RunState('TERMINATED', 'SUCCESS', '').to_json(),
+        }
+
+        op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+        self.assertIsNone(op.execute_complete(context=None, event=event))
+
+    
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+    def test_execute_complete_failure(self, db_mock_class):
+        """
+        Test `execute_complete` function in case the Trigger has returned a 
failure completion event.
+        """
+        run_state_failed = RunState('TERMINATED', 'FAILED', '')
+        run = {
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
+        }
+        event = {
+            'run_id': RUN_ID,
+            'run_page_url': RUN_PAGE_URL,
+            'run_state': run_state_failed.to_json(),
+        }
+
+        op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+        with pytest.raises(AirflowException):
+            op.execute_complete(context=None, event=event)
+
+        db_mock = db_mock_class.return_value
+        db_mock.submit_run.return_value = 1
+        db_mock.get_run_state.return_value = run_state_failed
+
+        with pytest.raises(AirflowException, match=f'Job run failed with 
terminal state: {run_state_failed}'):
+            op.execute_complete(context=None, event=event)
+
+    def test_execute_complete_incorrect_event_validation_failure(self):
+        event = {}
+        op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID)
+        with pytest.raises(AirflowException):
+            op.execute_complete(context=None, event=event)
+
+
 class TestDatabricksRunNowOperator(unittest.TestCase):
     def test_init_with_named_parameters(self):
         """
         Test the initializer with the named parameters.
         """
         op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID)
-        expected = databricks_operator._deep_string_coerce({'job_id': 42})
+        expected = utils.deep_string_coerce({'job_id': 42})
 
         assert expected == op.json
 
@@ -383,7 +452,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
         }
         op = DatabricksRunNowOperator(task_id=TASK_ID, json=json)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'notebook_params': NOTEBOOK_PARAMS,
                 'jar_params': JAR_PARAMS,
@@ -415,7 +484,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
             spark_submit_params=SPARK_SUBMIT_PARAMS,
         )
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'notebook_params': override_notebook_params,
                 'jar_params': override_jar_params,
@@ -433,7 +502,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
         dag = DAG('test', start_date=datetime.now())
         op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, 
json=json)
         op.render_template_fields(context={'ds': DATE})
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'notebook_params': NOTEBOOK_PARAMS,
                 'jar_params': RENDERED_TEMPLATED_JAR_PARAMS,
@@ -465,7 +534,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
 
         op.execute(None)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'notebook_params': NOTEBOOK_PARAMS,
                 'notebook_task': NOTEBOOK_TASK,
@@ -499,7 +568,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
         with pytest.raises(AirflowException):
             op.execute(None)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'notebook_params': NOTEBOOK_PARAMS,
                 'notebook_task': NOTEBOOK_TASK,
@@ -540,7 +609,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
 
         op.execute(None)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'notebook_params': NOTEBOOK_PARAMS,
                 'notebook_task': NOTEBOOK_TASK,
@@ -570,7 +639,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
 
         op.execute(None)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'notebook_params': NOTEBOOK_PARAMS,
                 'notebook_task': NOTEBOOK_TASK,
@@ -614,7 +683,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
 
         op.execute(None)
 
-        expected = databricks_operator._deep_string_coerce(
+        expected = utils.deep_string_coerce(
             {
                 'notebook_params': NOTEBOOK_PARAMS,
                 'notebook_task': NOTEBOOK_TASK,
@@ -647,3 +716,85 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
             op.execute(None)
 
         db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)
+
+
+class TestDatabricksRunNowDeferrableOperator(unittest.TestCase):
+    
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+    def test_execute_task_deferred(self, db_mock_class):
+        """
+        Test the execute function in case where the run is successful.
+        """
+        run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': 
NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+        op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, 
job_id=JOB_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.run_now.return_value = 1
+        db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', 
'')
+
+        with pytest.raises(TaskDeferred) as exc:
+            op.execute(None)
+        self.assertTrue(isinstance(exc.value.trigger, 
DatabricksExecutionTrigger))
+        self.assertEqual(exc.value.method_name, 'execute_complete')
+
+        expected = utils.deep_string_coerce(
+            {
+                'notebook_params': NOTEBOOK_PARAMS,
+                'notebook_task': NOTEBOOK_TASK,
+                'jar_params': JAR_PARAMS,
+                'job_id': JOB_ID,
+            }
+        )
+
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+        )
+
+        db_mock.run_now.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        self.assertEqual(RUN_ID, op.run_id)
+
+    def test_execute_complete_success(self):
+        """
+        Test `execute_complete` function in case the Trigger has returned a 
successful completion event.
+        """
+        run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': 
NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+        event = {
+            'run_id': RUN_ID,
+            'run_page_url': RUN_PAGE_URL,
+            'run_state': RunState('TERMINATED', 'SUCCESS', '').to_json(),
+        }
+
+        op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, 
job_id=JOB_ID, json=run)
+        self.assertIsNone(op.execute_complete(context=None, event=event))
+
+    
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+    def test_execute_complete_failure(self, db_mock_class):
+        """
+        Test `execute_complete` function in case the Trigger has returned a 
failure completion event.
+        """
+        run_state_failed = RunState('TERMINATED', 'FAILED', '')
+        run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': 
NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+        event = {
+            'run_id': RUN_ID,
+            'run_page_url': RUN_PAGE_URL,
+            'run_state': run_state_failed.to_json(),
+        }
+
+        op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, 
job_id=JOB_ID, json=run)
+        with pytest.raises(AirflowException):
+            op.execute_complete(context=None, event=event)
+
+        db_mock = db_mock_class.return_value
+        db_mock.run_now.return_value = 1
+        db_mock.get_run_state.return_value = run_state_failed
+
+        with pytest.raises(AirflowException, match=f'Job run failed with 
terminal state: {run_state_failed}'):
+            op.execute_complete(context=None, event=event)
+
+    def test_execute_complete_incorrect_event_validation_failure(self):
+        event = {}
+        op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID)
+        with pytest.raises(AirflowException):
+            op.execute_complete(context=None, event=event)
diff --git a/tests/providers/databricks/triggers/__init__.py 
b/tests/providers/databricks/triggers/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/tests/providers/databricks/triggers/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/databricks/triggers/test_databricks.py 
b/tests/providers/databricks/triggers/test_databricks.py
new file mode 100644
index 0000000000..cecbed1138
--- /dev/null
+++ b/tests/providers/databricks/triggers/test_databricks.py
@@ -0,0 +1,153 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import sys
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.databricks.hooks.databricks import RunState
+from airflow.providers.databricks.triggers.databricks import 
DatabricksExecutionTrigger
+from airflow.triggers.base import TriggerEvent
+from airflow.utils.session import provide_session
+
+if sys.version_info < (3, 8):
+    from asynctest import mock
+else:
+    from unittest import mock
+
+DEFAULT_CONN_ID = 'databricks_default'
+HOST = 'xx.cloud.databricks.com'
+LOGIN = 'login'
+PASSWORD = 'password'
+POLLING_INTERVAL_SECONDS = 30
+RETRY_DELAY = 10
+RETRY_LIMIT = 3
+RUN_ID = 1
+JOB_ID = 42
+RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
+
+RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 
'SKIPPED', 'INTERNAL_ERROR']
+
+LIFE_CYCLE_STATE_PENDING = 'PENDING'
+LIFE_CYCLE_STATE_TERMINATED = 'TERMINATED'
+
+STATE_MESSAGE = 'Waiting for cluster'
+
+GET_RUN_RESPONSE_PENDING = {
+    'job_id': JOB_ID,
+    'run_page_url': RUN_PAGE_URL,
+    'state': {
+        'life_cycle_state': LIFE_CYCLE_STATE_PENDING,
+        'state_message': STATE_MESSAGE,
+        'result_state': None,
+    },
+}
+GET_RUN_RESPONSE_TERMINATED = {
+    'job_id': JOB_ID,
+    'run_page_url': RUN_PAGE_URL,
+    'state': {
+        'life_cycle_state': LIFE_CYCLE_STATE_TERMINATED,
+        'state_message': None,
+        'result_state': 'SUCCESS',
+    },
+}
+
+
+class TestDatabricksExecutionTrigger:
+    @provide_session
+    def setup_method(self, method, session=None):
+        conn = session.query(Connection).filter(Connection.conn_id == 
DEFAULT_CONN_ID).first()
+        conn.host = HOST
+        conn.login = LOGIN
+        conn.password = PASSWORD
+        conn.extra = None
+        session.commit()
+
+        self.trigger = DatabricksExecutionTrigger(
+            run_id=RUN_ID,
+            databricks_conn_id=DEFAULT_CONN_ID,
+            polling_period_seconds=POLLING_INTERVAL_SECONDS,
+        )
+
+    def test_serialize(self):
+        assert self.trigger.serialize() == (
+            
'airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger',
+            {
+                'run_id': RUN_ID,
+                'databricks_conn_id': DEFAULT_CONN_ID,
+                'polling_period_seconds': POLLING_INTERVAL_SECONDS,
+            },
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url')
+    
@mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state')
+    async def test_run_return_success(self, mock_get_run_state, 
mock_get_run_page_url):
+        mock_get_run_page_url.return_value = RUN_PAGE_URL
+        mock_get_run_state.return_value = RunState(
+            life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
+            state_message='',
+            result_state='SUCCESS',
+        )
+
+        trigger_event = self.trigger.run()
+        async for event in trigger_event:
+            assert event == TriggerEvent(
+                {
+                    'run_id': RUN_ID,
+                    'run_state': RunState(
+                        life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, 
state_message='', result_state='SUCCESS'
+                    ).to_json(),
+                    'run_page_url': RUN_PAGE_URL,
+                }
+            )
+
+    @pytest.mark.asyncio
+    
@mock.patch('airflow.providers.databricks.triggers.databricks.asyncio.sleep')
+    
@mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url')
+    
@mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state')
+    async def test_sleep_between_retries(self, mock_get_run_state, 
mock_get_run_page_url, mock_sleep):
+        mock_get_run_page_url.return_value = RUN_PAGE_URL
+        mock_get_run_state.side_effect = [
+            RunState(
+                life_cycle_state=LIFE_CYCLE_STATE_PENDING,
+                state_message='',
+                result_state='',
+            ),
+            RunState(
+                life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
+                state_message='',
+                result_state='SUCCESS',
+            ),
+        ]
+
+        trigger_event = self.trigger.run()
+        async for event in trigger_event:
+            assert event == TriggerEvent(
+                {
+                    'run_id': RUN_ID,
+                    'run_state': RunState(
+                        life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, 
state_message='', result_state='SUCCESS'
+                    ).to_json(),
+                    'run_page_url': RUN_PAGE_URL,
+                }
+            )
+            mock_sleep.assert_called_once()
+            mock_sleep.assert_called_with(POLLING_INTERVAL_SECONDS)
diff --git a/tests/providers/databricks/utils/__init__.py 
b/tests/providers/databricks/utils/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/databricks/utils/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/databricks/utils/databricks.py 
b/tests/providers/databricks/utils/databricks.py
new file mode 100644
index 0000000000..d450a19ceb
--- /dev/null
+++ b/tests/providers/databricks/utils/databricks.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.databricks.hooks.databricks import RunState
+from airflow.providers.databricks.utils.databricks import deep_string_coerce, 
validate_trigger_event
+
+RUN_ID = 1
+RUN_PAGE_URL = 'run-page-url'
+
+
+class TestDatabricksOperatorSharedFunctions(unittest.TestCase):
+    def test_deep_string_coerce(self):
+        test_json = {
+            'test_int': 1,
+            'test_float': 1.0,
+            'test_dict': {'key': 'value'},
+            'test_list': [1, 1.0, 'a', 'b'],
+            'test_tuple': (1, 1.0, 'a', 'b'),
+        }
+
+        expected = {
+            'test_int': '1',
+            'test_float': '1.0',
+            'test_dict': {'key': 'value'},
+            'test_list': ['1', '1.0', 'a', 'b'],
+            'test_tuple': ['1', '1.0', 'a', 'b'],
+        }
+        assert deep_string_coerce(test_json) == expected
+
+    def test_validate_trigger_event_success(self):
+        event = {
+            'run_id': RUN_ID,
+            'run_page_url': RUN_PAGE_URL,
+            'run_state': RunState('TERMINATED', 'SUCCESS', '').to_json(),
+        }
+        self.assertIsNone(validate_trigger_event(event))
+
+    def test_validate_trigger_event_failure(self):
+        event = {}
+        with pytest.raises(AirflowException):
+            validate_trigger_event(event)

Reply via email to