This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 12e9e2c  Databricks hook - retry on HTTP Status 429 as well (#21852)
12e9e2c is described below

commit 12e9e2c695f9ebb9d3dde9c0f7dfaa112654f0d6
Author: Alex Ott <alex...@gmail.com>
AuthorDate: Mon Mar 14 00:19:01 2022 +0100

    Databricks hook - retry on HTTP Status 429 as well (#21852)
    
    * Databricks hook - retry on HTTP Status 429 as well
    
    this fixes #21559
    
    * Reimplement retries using tenacity
    
    it's now uses exponential backoff by default
---
 airflow/providers/databricks/hooks/databricks.py   |   4 +-
 .../providers/databricks/hooks/databricks_base.py  | 196 +++++++++++----------
 .../providers/databricks/operators/databricks.py   |   8 +
 .../operators/run_now.rst                          |   2 +
 .../operators/submit_run.rst                       |   2 +
 .../providers/databricks/hooks/test_databricks.py  |  61 ++++---
 .../databricks/operators/test_databricks.py        |  45 ++++-
 7 files changed, 195 insertions(+), 123 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks.py 
b/airflow/providers/databricks/hooks/databricks.py
index 20255e3..cdab4d0 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -100,6 +100,7 @@ class DatabricksHook(BaseDatabricksHook):
         service outages.
     :param retry_delay: The number of seconds to wait between retries (it
         might be a floating point number).
+    :param retry_args: An optional dictionary with arguments passed to 
``tenacity.Retrying`` class.
     """
 
     hook_name = 'Databricks'
@@ -110,8 +111,9 @@ class DatabricksHook(BaseDatabricksHook):
         timeout_seconds: int = 180,
         retry_limit: int = 3,
         retry_delay: float = 1.0,
+        retry_args: Optional[Dict[Any, Any]] = None,
     ) -> None:
-        super().__init__(databricks_conn_id, timeout_seconds, retry_limit, 
retry_delay)
+        super().__init__(databricks_conn_id, timeout_seconds, retry_limit, 
retry_delay, retry_args)
 
     def run_now(self, json: dict) -> int:
         """
diff --git a/airflow/providers/databricks/hooks/databricks_base.py 
b/airflow/providers/databricks/hooks/databricks_base.py
index 1a0d639..f545f26 100644
--- a/airflow/providers/databricks/hooks/databricks_base.py
+++ b/airflow/providers/databricks/hooks/databricks_base.py
@@ -22,15 +22,16 @@ This hook enable the submitting and running of jobs to the 
Databricks platform.
 operators talk to the ``api/2.0/jobs/runs/submit``
 `endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_.
 """
+import copy
 import sys
 import time
-from time import sleep
 from typing import Any, Dict, Optional, Tuple
 from urllib.parse import urlparse
 
 import requests
 from requests import PreparedRequest, exceptions as requests_exceptions
 from requests.auth import AuthBase, HTTPBasicAuth
+from tenacity import RetryError, Retrying, retry_if_exception, 
stop_after_attempt, wait_exponential
 
 from airflow import __version__
 from airflow.exceptions import AirflowException
@@ -68,6 +69,7 @@ class BaseDatabricksHook(BaseHook):
         service outages.
     :param retry_delay: The number of seconds to wait between retries (it
         might be a floating point number).
+    :param retry_args: An optional dictionary with arguments passed to 
``tenacity.Retrying`` class.
     """
 
     conn_name_attr = 'databricks_conn_id'
@@ -89,17 +91,33 @@ class BaseDatabricksHook(BaseHook):
         timeout_seconds: int = 180,
         retry_limit: int = 3,
         retry_delay: float = 1.0,
+        retry_args: Optional[Dict[Any, Any]] = None,
     ) -> None:
         super().__init__()
         self.databricks_conn_id = databricks_conn_id
         self.timeout_seconds = timeout_seconds
         if retry_limit < 1:
-            raise ValueError('Retry limit must be greater than equal to 1')
+            raise ValueError('Retry limit must be greater than or equal to 1')
         self.retry_limit = retry_limit
         self.retry_delay = retry_delay
         self.aad_tokens: Dict[str, dict] = {}
         self.aad_timeout_seconds = 10
 
+        def my_after_func(retry_state):
+            self._log_request_error(retry_state.attempt_number, 
retry_state.outcome)
+
+        if retry_args:
+            self.retry_args = copy.copy(retry_args)
+            self.retry_args['retry'] = 
retry_if_exception(self._retryable_error)
+            self.retry_args['after'] = my_after_func
+        else:
+            self.retry_args = dict(
+                stop=stop_after_attempt(self.retry_limit),
+                wait=wait_exponential(min=self.retry_delay, max=(2 ** 
retry_limit)),
+                retry=retry_if_exception(self._retryable_error),
+                after=my_after_func,
+            )
+
     @cached_property
     def databricks_conn(self) -> Connection:
         return self.get_connection(self.databricks_conn_id)
@@ -143,6 +161,13 @@ class BaseDatabricksHook(BaseHook):
             # In this case, host = xx.cloud.databricks.com
             return host
 
+    def _get_retry_object(self) -> Retrying:
+        """
+        Instantiates a retry object
+        :return: instance of Retrying class
+        """
+        return Retrying(**self.retry_args)
+
     def _get_aad_token(self, resource: str) -> str:
         """
         Function to get AAD token for given resource. Supports managed 
identity or service principal auth
@@ -154,60 +179,59 @@ class BaseDatabricksHook(BaseHook):
             return aad_token['token']
 
         self.log.info('Existing AAD token is expired, or going to expire soon. 
Refreshing...')
-        attempt_num = 1
-        while True:
-            try:
-                if 
self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
-                    params = {
-                        "api-version": "2018-02-01",
-                        "resource": resource,
-                    }
-                    resp = requests.get(
-                        AZURE_METADATA_SERVICE_TOKEN_URL,
-                        params=params,
-                        headers={**USER_AGENT_HEADER, "Metadata": "true"},
-                        timeout=self.aad_timeout_seconds,
-                    )
-                else:
-                    tenant_id = 
self.databricks_conn.extra_dejson['azure_tenant_id']
-                    data = {
-                        "grant_type": "client_credentials",
-                        "client_id": self.databricks_conn.login,
-                        "resource": resource,
-                        "client_secret": self.databricks_conn.password,
-                    }
-                    azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
-                        "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
-                    )
-                    resp = requests.post(
-                        AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, 
tenant_id),
-                        data=data,
-                        headers={**USER_AGENT_HEADER, 'Content-Type': 
'application/x-www-form-urlencoded'},
-                        timeout=self.aad_timeout_seconds,
-                    )
-
-                resp.raise_for_status()
-                jsn = resp.json()
-                if 'access_token' not in jsn or jsn.get('token_type') != 
'Bearer' or 'expires_on' not in jsn:
-                    raise AirflowException(f"Can't get necessary data from AAD 
token: {jsn}")
-
-                token = jsn['access_token']
-                self.aad_tokens[resource] = {'token': token, 'expires_on': 
int(jsn["expires_on"])}
-
-                return token
-            except requests_exceptions.RequestException as e:
-                if not self._retryable_error(e):
-                    raise AirflowException(
-                        f'Response: {e.response.content}, Status Code: 
{e.response.status_code}'
-                    )
-
-                self._log_request_error(attempt_num, e.strerror)
-
-            if attempt_num == self.retry_limit:
-                raise AirflowException(f'API requests to Azure failed 
{self.retry_limit} times. Giving up.')
-
-            attempt_num += 1
-            sleep(self.retry_delay)
+        try:
+            for attempt in self._get_retry_object():
+                with attempt:
+                    if 
self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
+                        params = {
+                            "api-version": "2018-02-01",
+                            "resource": resource,
+                        }
+                        resp = requests.get(
+                            AZURE_METADATA_SERVICE_TOKEN_URL,
+                            params=params,
+                            headers={**USER_AGENT_HEADER, "Metadata": "true"},
+                            timeout=self.aad_timeout_seconds,
+                        )
+                    else:
+                        tenant_id = 
self.databricks_conn.extra_dejson['azure_tenant_id']
+                        data = {
+                            "grant_type": "client_credentials",
+                            "client_id": self.databricks_conn.login,
+                            "resource": resource,
+                            "client_secret": self.databricks_conn.password,
+                        }
+                        azure_ad_endpoint = 
self.databricks_conn.extra_dejson.get(
+                            "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
+                        )
+                        resp = requests.post(
+                            AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, 
tenant_id),
+                            data=data,
+                            headers={
+                                **USER_AGENT_HEADER,
+                                'Content-Type': 
'application/x-www-form-urlencoded',
+                            },
+                            timeout=self.aad_timeout_seconds,
+                        )
+
+                    resp.raise_for_status()
+                    jsn = resp.json()
+                    if (
+                        'access_token' not in jsn
+                        or jsn.get('token_type') != 'Bearer'
+                        or 'expires_on' not in jsn
+                    ):
+                        raise AirflowException(f"Can't get necessary data from 
AAD token: {jsn}")
+
+                    token = jsn['access_token']
+                    self.aad_tokens[resource] = {'token': token, 'expires_on': 
int(jsn["expires_on"])}
+                    break
+        except RetryError:
+            raise AirflowException(f'API requests to Azure failed 
{self.retry_limit} times. Giving up.')
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f'Response: {e.response.content}, Status 
Code: {e.response.status_code}')
+
+        return token
 
     def _get_aad_headers(self) -> dict:
         """
@@ -279,14 +303,6 @@ class BaseDatabricksHook(BaseHook):
 
         return None
 
-    @staticmethod
-    def _retryable_error(exception) -> bool:
-        return (
-            isinstance(exception, (requests_exceptions.ConnectionError, 
requests_exceptions.Timeout))
-            or exception.response is not None
-            and exception.response.status_code >= 500
-        )
-
     def _log_request_error(self, attempt_num: int, error: str) -> None:
         self.log.error('Attempt %s API Request to Databricks failed with 
reason: %s', attempt_num, error)
 
@@ -327,36 +343,32 @@ class BaseDatabricksHook(BaseHook):
         else:
             raise AirflowException('Unexpected HTTP Method: ' + method)
 
-        attempt_num = 1
-        while True:
-            try:
-                response = request_func(
-                    url,
-                    json=json if method in ('POST', 'PATCH') else None,
-                    params=json if method == 'GET' else None,
-                    auth=auth,
-                    headers=headers,
-                    timeout=self.timeout_seconds,
-                )
-                response.raise_for_status()
-                return response.json()
-            except requests_exceptions.RequestException as e:
-                if not self._retryable_error(e):
-                    # In this case, the user probably made a mistake.
-                    # Don't retry.
-                    raise AirflowException(
-                        f'Response: {e.response.content}, Status Code: 
{e.response.status_code}'
+        try:
+            for attempt in self._get_retry_object():
+                with attempt:
+                    response = request_func(
+                        url,
+                        json=json if method in ('POST', 'PATCH') else None,
+                        params=json if method == 'GET' else None,
+                        auth=auth,
+                        headers=headers,
+                        timeout=self.timeout_seconds,
                     )
+                    response.raise_for_status()
+                    return response.json()
+        except RetryError:
+            raise AirflowException(f'API requests to Databricks failed 
{self.retry_limit} times. Giving up.')
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f'Response: {e.response.content}, Status 
Code: {e.response.status_code}')
 
-                self._log_request_error(attempt_num, str(e))
-
-            if attempt_num == self.retry_limit:
-                raise AirflowException(
-                    f'API requests to Databricks failed {self.retry_limit} 
times. Giving up.'
-                )
-
-            attempt_num += 1
-            sleep(self.retry_delay)
+    @staticmethod
+    def _retryable_error(exception: BaseException) -> bool:
+        if not isinstance(exception, requests_exceptions.RequestException):
+            return False
+        return isinstance(exception, (requests_exceptions.ConnectionError, 
requests_exceptions.Timeout)) or (
+            exception.response is not None
+            and (exception.response.status_code >= 500 or 
exception.response.status_code == 429)
+        )
 
 
 class _TokenAuth(AuthBase):
diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index 818fb1d..ec0a9a0 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -245,6 +245,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
         unreachable. Its value must be greater than or equal to 1.
     :param databricks_retry_delay: Number of seconds to wait between retries 
(it
             might be a floating point number).
+    :param databricks_retry_args: An optional dictionary with arguments passed 
to ``tenacity.Retrying`` class.
     :param do_xcom_push: Whether we should push run_id and run_page_url to 
xcom.
     """
 
@@ -274,6 +275,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
         polling_period_seconds: int = 30,
         databricks_retry_limit: int = 3,
         databricks_retry_delay: int = 1,
+        databricks_retry_args: Optional[Dict[Any, Any]] = None,
         do_xcom_push: bool = False,
         idempotency_token: Optional[str] = None,
         access_control_list: Optional[List[Dict[str, str]]] = None,
@@ -287,6 +289,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
         self.polling_period_seconds = polling_period_seconds
         self.databricks_retry_limit = databricks_retry_limit
         self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
         self.wait_for_termination = wait_for_termination
         if tasks is not None:
             self.json['tasks'] = tasks
@@ -327,6 +330,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
             self.databricks_conn_id,
             retry_limit=self.databricks_retry_limit,
             retry_delay=self.databricks_retry_delay,
+            retry_args=self.databricks_retry_args,
         )
 
     def execute(self, context: 'Context'):
@@ -484,6 +488,7 @@ class DatabricksRunNowOperator(BaseOperator):
         this run. By default the operator will poll every 30 seconds.
     :param databricks_retry_limit: Amount of times retry if the Databricks 
backend is
         unreachable. Its value must be greater than or equal to 1.
+    :param databricks_retry_args: An optional dictionary with arguments passed 
to ``tenacity.Retrying`` class.
     :param do_xcom_push: Whether we should push run_id and run_page_url to 
xcom.
     """
 
@@ -508,6 +513,7 @@ class DatabricksRunNowOperator(BaseOperator):
         polling_period_seconds: int = 30,
         databricks_retry_limit: int = 3,
         databricks_retry_delay: int = 1,
+        databricks_retry_args: Optional[Dict[Any, Any]] = None,
         do_xcom_push: bool = False,
         wait_for_termination: bool = True,
         **kwargs,
@@ -519,6 +525,7 @@ class DatabricksRunNowOperator(BaseOperator):
         self.polling_period_seconds = polling_period_seconds
         self.databricks_retry_limit = databricks_retry_limit
         self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
         self.wait_for_termination = wait_for_termination
 
         if job_id is not None:
@@ -546,6 +553,7 @@ class DatabricksRunNowOperator(BaseOperator):
             self.databricks_conn_id,
             retry_limit=self.databricks_retry_limit,
             retry_delay=self.databricks_retry_delay,
+            retry_args=self.databricks_retry_args,
         )
 
     def execute(self, context: 'Context'):
diff --git a/docs/apache-airflow-providers-databricks/operators/run_now.rst 
b/docs/apache-airflow-providers-databricks/operators/run_now.rst
index e0c0a7b..62fb4fd 100644
--- a/docs/apache-airflow-providers-databricks/operators/run_now.rst
+++ b/docs/apache-airflow-providers-databricks/operators/run_now.rst
@@ -61,5 +61,7 @@ Note that there is exactly one named parameter for each top 
level parameter in t
      - amount of times retry if the Databricks backend is unreachable
    * - databricks_retry_delay: decimal
      - number of seconds to wait between retries
+   * - databricks_retry_args: dict
+     - An optional dictionary with arguments passed to ``tenacity.Retrying`` 
class.
    * - do_xcom_push: boolean
      - whether we should push run_id and run_page_url to xcom
diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst 
b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
index 579a305..497f72f 100644
--- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
+++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
@@ -70,6 +70,8 @@ one named parameter for each top level parameter in the 
``runs/submit`` endpoint
      - amount of times retry if the Databricks backend is unreachable
    * - databricks_retry_delay: decimal
      - number of seconds to wait between retries
+   * - databricks_retry_args: dict
+     - An optional dictionary with arguments passed to ``tenacity.Retrying`` 
class.
    * - do_xcom_push: boolean
      - whether we should push run_id and run_page_url to xcom
 
diff --git a/tests/providers/databricks/hooks/test_databricks.py 
b/tests/providers/databricks/hooks/test_databricks.py
index 0948641..d1adba0 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -24,6 +24,7 @@ import unittest
 from unittest import mock
 
 import pytest
+import tenacity
 from requests import exceptions as requests_exceptions
 from requests.auth import HTTPBasicAuth
 
@@ -50,6 +51,11 @@ CLUSTER_ID = 'cluster_id'
 RUN_ID = 1
 JOB_ID = 42
 JOB_NAME = 'job-name'
+DEFAULT_RETRY_NUMBER = 3
+DEFAULT_RETRY_ARGS = dict(
+    wait=tenacity.wait_none(),
+    stop=tenacity.stop_after_attempt(DEFAULT_RETRY_NUMBER),
+)
 HOST = 'xx.cloud.databricks.com'
 HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
 LOGIN = 'login'
@@ -228,6 +234,8 @@ class TestDatabricksHook(unittest.TestCase):
             DatabricksHook(retry_limit=0)
 
     def test_do_api_call_retries_with_retryable_error(self):
+        hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
         for exception in [
             requests_exceptions.ConnectionError,
             requests_exceptions.SSLError,
@@ -236,25 +244,41 @@ class TestDatabricksHook(unittest.TestCase):
             requests_exceptions.HTTPError,
         ]:
             with 
mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as 
mock_requests:
-                with mock.patch.object(self.hook.log, 'error') as mock_errors:
+                with mock.patch.object(hook.log, 'error') as mock_errors:
                     setup_mock_requests(mock_requests, exception)
 
                     with pytest.raises(AirflowException):
-                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+                        hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+                    assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
+
+    def test_do_api_call_retries_with_too_many_requests(self):
+        hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+        with 
mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as 
mock_requests:
+            with mock.patch.object(hook.log, 'error') as mock_errors:
+                setup_mock_requests(mock_requests, 
requests_exceptions.HTTPError, status_code=429)
+
+                with pytest.raises(AirflowException):
+                    hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
 
-                    assert mock_errors.call_count == self.hook.retry_limit
+                assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
 
     @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
     def test_do_api_call_does_not_retry_with_non_retryable_error(self, 
mock_requests):
+        hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
         setup_mock_requests(mock_requests, requests_exceptions.HTTPError, 
status_code=400)
 
-        with mock.patch.object(self.hook.log, 'error') as mock_errors:
+        with mock.patch.object(hook.log, 'error') as mock_errors:
             with pytest.raises(AirflowException):
-                self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+                hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
 
             mock_errors.assert_not_called()
 
     def test_do_api_call_succeeds_after_retrying(self):
+        hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
         for exception in [
             requests_exceptions.ConnectionError,
             requests_exceptions.SSLError,
@@ -263,20 +287,18 @@ class TestDatabricksHook(unittest.TestCase):
             requests_exceptions.HTTPError,
         ]:
             with 
mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as 
mock_requests:
-                with mock.patch.object(self.hook.log, 'error') as mock_errors:
+                with mock.patch.object(hook.log, 'error') as mock_errors:
                     setup_mock_requests(
                         mock_requests, exception, error_count=2, 
response_content={'run_id': '1'}
                     )
 
-                    response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+                    response = hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
 
                     assert mock_errors.call_count == 2
                     assert response == {'run_id': '1'}
 
-    @mock.patch('airflow.providers.databricks.hooks.databricks_base.sleep')
-    def test_do_api_call_waits_between_retries(self, mock_sleep):
-        retry_delay = 5
-        self.hook = DatabricksHook(retry_delay=retry_delay)
+    def test_do_api_call_custom_retry(self):
+        hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
 
         for exception in [
             requests_exceptions.ConnectionError,
@@ -286,16 +308,13 @@ class TestDatabricksHook(unittest.TestCase):
             requests_exceptions.HTTPError,
         ]:
             with 
mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as 
mock_requests:
-                with mock.patch.object(self.hook.log, 'error'):
-                    mock_sleep.reset_mock()
+                with mock.patch.object(hook.log, 'error') as mock_errors:
                     setup_mock_requests(mock_requests, exception)
 
                     with pytest.raises(AirflowException):
-                        self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+                        hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
 
-                    assert len(mock_sleep.mock_calls) == self.hook.retry_limit 
- 1
-                    calls = [mock.call(retry_delay), mock.call(retry_delay)]
-                    mock_sleep.assert_has_calls(calls)
+                    assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
 
     @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
     def test_do_api_call_patch(self, mock_requests):
@@ -796,7 +815,7 @@ class TestDatabricksHookAadToken(unittest.TestCase):
             }
         )
         session.commit()
-        self.hook = DatabricksHook()
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
 
     @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
     def test_submit_run(self, mock_requests):
@@ -838,7 +857,7 @@ class 
TestDatabricksHookAadTokenOtherClouds(unittest.TestCase):
             }
         )
         session.commit()
-        self.hook = DatabricksHook()
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
 
     @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
     def test_submit_run(self, mock_requests):
@@ -883,7 +902,7 @@ class 
TestDatabricksHookAadTokenSpOutside(unittest.TestCase):
             }
         )
         session.commit()
-        self.hook = DatabricksHook()
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
 
     @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
     def test_submit_run(self, mock_requests):
@@ -931,7 +950,7 @@ class 
TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase):
             }
         )
         session.commit()
-        self.hook = DatabricksHook()
+        self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
 
     @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
     def test_submit_run(self, mock_requests):
diff --git a/tests/providers/databricks/operators/test_databricks.py 
b/tests/providers/databricks/operators/test_databricks.py
index 0e93138..0d1bd09 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -224,7 +224,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
             {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': TASK_ID}
         )
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, 
retry_delay=op.databricks_retry_delay
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
         )
 
         db_mock.submit_run.assert_called_once_with(expected)
@@ -257,7 +260,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
             }
         )
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, 
retry_delay=op.databricks_retry_delay
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
         )
         db_mock.submit_run.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
@@ -297,7 +303,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
             {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': TASK_ID}
         )
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, 
retry_delay=op.databricks_retry_delay
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
         )
 
         db_mock.submit_run.assert_called_once_with(expected)
@@ -322,7 +331,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
             {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 
'run_name': TASK_ID}
         )
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, 
retry_delay=op.databricks_retry_delay
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
         )
 
         db_mock.submit_run.assert_called_once_with(expected)
@@ -445,7 +457,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
         )
 
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, 
retry_delay=op.databricks_retry_delay
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
         )
         db_mock.run_now.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
@@ -475,7 +490,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
             }
         )
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, 
retry_delay=op.databricks_retry_delay
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
         )
         db_mock.run_now.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
@@ -513,7 +531,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
             }
         )
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, 
retry_delay=op.databricks_retry_delay
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
         )
 
         db_mock.run_now.assert_called_once_with(expected)
@@ -540,7 +561,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
             }
         )
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, 
retry_delay=op.databricks_retry_delay
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
         )
 
         db_mock.run_now.assert_called_once_with(expected)
@@ -582,7 +606,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
         )
 
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, 
retry_delay=op.databricks_retry_delay
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
         )
         db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)
         db_mock.run_now.assert_called_once_with(expected)

Reply via email to