This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 12e9e2c Databricks hook - retry on HTTP Status 429 as well (#21852)
12e9e2c is described below
commit 12e9e2c695f9ebb9d3dde9c0f7dfaa112654f0d6
Author: Alex Ott <[email protected]>
AuthorDate: Mon Mar 14 00:19:01 2022 +0100
Databricks hook - retry on HTTP Status 429 as well (#21852)
* Databricks hook - retry on HTTP Status 429 as well
this fixes #21559
* Reimplement retries using tenacity
it's now uses exponential backoff by default
---
airflow/providers/databricks/hooks/databricks.py | 4 +-
.../providers/databricks/hooks/databricks_base.py | 196 +++++++++++----------
.../providers/databricks/operators/databricks.py | 8 +
.../operators/run_now.rst | 2 +
.../operators/submit_run.rst | 2 +
.../providers/databricks/hooks/test_databricks.py | 61 ++++---
.../databricks/operators/test_databricks.py | 45 ++++-
7 files changed, 195 insertions(+), 123 deletions(-)
diff --git a/airflow/providers/databricks/hooks/databricks.py
b/airflow/providers/databricks/hooks/databricks.py
index 20255e3..cdab4d0 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -100,6 +100,7 @@ class DatabricksHook(BaseDatabricksHook):
service outages.
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
+ :param retry_args: An optional dictionary with arguments passed to
``tenacity.Retrying`` class.
"""
hook_name = 'Databricks'
@@ -110,8 +111,9 @@ class DatabricksHook(BaseDatabricksHook):
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
+ retry_args: Optional[Dict[Any, Any]] = None,
) -> None:
- super().__init__(databricks_conn_id, timeout_seconds, retry_limit,
retry_delay)
+ super().__init__(databricks_conn_id, timeout_seconds, retry_limit,
retry_delay, retry_args)
def run_now(self, json: dict) -> int:
"""
diff --git a/airflow/providers/databricks/hooks/databricks_base.py
b/airflow/providers/databricks/hooks/databricks_base.py
index 1a0d639..f545f26 100644
--- a/airflow/providers/databricks/hooks/databricks_base.py
+++ b/airflow/providers/databricks/hooks/databricks_base.py
@@ -22,15 +22,16 @@ This hook enable the submitting and running of jobs to the
Databricks platform.
operators talk to the ``api/2.0/jobs/runs/submit``
`endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_.
"""
+import copy
import sys
import time
-from time import sleep
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse
import requests
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
+from tenacity import RetryError, Retrying, retry_if_exception,
stop_after_attempt, wait_exponential
from airflow import __version__
from airflow.exceptions import AirflowException
@@ -68,6 +69,7 @@ class BaseDatabricksHook(BaseHook):
service outages.
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
+ :param retry_args: An optional dictionary with arguments passed to
``tenacity.Retrying`` class.
"""
conn_name_attr = 'databricks_conn_id'
@@ -89,17 +91,33 @@ class BaseDatabricksHook(BaseHook):
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
+ retry_args: Optional[Dict[Any, Any]] = None,
) -> None:
super().__init__()
self.databricks_conn_id = databricks_conn_id
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
- raise ValueError('Retry limit must be greater than equal to 1')
+ raise ValueError('Retry limit must be greater than or equal to 1')
self.retry_limit = retry_limit
self.retry_delay = retry_delay
self.aad_tokens: Dict[str, dict] = {}
self.aad_timeout_seconds = 10
+ def my_after_func(retry_state):
+ self._log_request_error(retry_state.attempt_number,
retry_state.outcome)
+
+ if retry_args:
+ self.retry_args = copy.copy(retry_args)
+ self.retry_args['retry'] =
retry_if_exception(self._retryable_error)
+ self.retry_args['after'] = my_after_func
+ else:
+ self.retry_args = dict(
+ stop=stop_after_attempt(self.retry_limit),
+ wait=wait_exponential(min=self.retry_delay, max=(2 **
retry_limit)),
+ retry=retry_if_exception(self._retryable_error),
+ after=my_after_func,
+ )
+
@cached_property
def databricks_conn(self) -> Connection:
return self.get_connection(self.databricks_conn_id)
@@ -143,6 +161,13 @@ class BaseDatabricksHook(BaseHook):
# In this case, host = xx.cloud.databricks.com
return host
+ def _get_retry_object(self) -> Retrying:
+ """
+ Instantiates a retry object
+ :return: instance of Retrying class
+ """
+ return Retrying(**self.retry_args)
+
def _get_aad_token(self, resource: str) -> str:
"""
Function to get AAD token for given resource. Supports managed
identity or service principal auth
@@ -154,60 +179,59 @@ class BaseDatabricksHook(BaseHook):
return aad_token['token']
self.log.info('Existing AAD token is expired, or going to expire soon.
Refreshing...')
- attempt_num = 1
- while True:
- try:
- if
self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
- params = {
- "api-version": "2018-02-01",
- "resource": resource,
- }
- resp = requests.get(
- AZURE_METADATA_SERVICE_TOKEN_URL,
- params=params,
- headers={**USER_AGENT_HEADER, "Metadata": "true"},
- timeout=self.aad_timeout_seconds,
- )
- else:
- tenant_id =
self.databricks_conn.extra_dejson['azure_tenant_id']
- data = {
- "grant_type": "client_credentials",
- "client_id": self.databricks_conn.login,
- "resource": resource,
- "client_secret": self.databricks_conn.password,
- }
- azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
- "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
- )
- resp = requests.post(
- AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint,
tenant_id),
- data=data,
- headers={**USER_AGENT_HEADER, 'Content-Type':
'application/x-www-form-urlencoded'},
- timeout=self.aad_timeout_seconds,
- )
-
- resp.raise_for_status()
- jsn = resp.json()
- if 'access_token' not in jsn or jsn.get('token_type') !=
'Bearer' or 'expires_on' not in jsn:
- raise AirflowException(f"Can't get necessary data from AAD
token: {jsn}")
-
- token = jsn['access_token']
- self.aad_tokens[resource] = {'token': token, 'expires_on':
int(jsn["expires_on"])}
-
- return token
- except requests_exceptions.RequestException as e:
- if not self._retryable_error(e):
- raise AirflowException(
- f'Response: {e.response.content}, Status Code:
{e.response.status_code}'
- )
-
- self._log_request_error(attempt_num, e.strerror)
-
- if attempt_num == self.retry_limit:
- raise AirflowException(f'API requests to Azure failed
{self.retry_limit} times. Giving up.')
-
- attempt_num += 1
- sleep(self.retry_delay)
+ try:
+ for attempt in self._get_retry_object():
+ with attempt:
+ if
self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
+ params = {
+ "api-version": "2018-02-01",
+ "resource": resource,
+ }
+ resp = requests.get(
+ AZURE_METADATA_SERVICE_TOKEN_URL,
+ params=params,
+ headers={**USER_AGENT_HEADER, "Metadata": "true"},
+ timeout=self.aad_timeout_seconds,
+ )
+ else:
+ tenant_id =
self.databricks_conn.extra_dejson['azure_tenant_id']
+ data = {
+ "grant_type": "client_credentials",
+ "client_id": self.databricks_conn.login,
+ "resource": resource,
+ "client_secret": self.databricks_conn.password,
+ }
+ azure_ad_endpoint =
self.databricks_conn.extra_dejson.get(
+ "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
+ )
+ resp = requests.post(
+ AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint,
tenant_id),
+ data=data,
+ headers={
+ **USER_AGENT_HEADER,
+ 'Content-Type':
'application/x-www-form-urlencoded',
+ },
+ timeout=self.aad_timeout_seconds,
+ )
+
+ resp.raise_for_status()
+ jsn = resp.json()
+ if (
+ 'access_token' not in jsn
+ or jsn.get('token_type') != 'Bearer'
+ or 'expires_on' not in jsn
+ ):
+ raise AirflowException(f"Can't get necessary data from
AAD token: {jsn}")
+
+ token = jsn['access_token']
+ self.aad_tokens[resource] = {'token': token, 'expires_on':
int(jsn["expires_on"])}
+ break
+ except RetryError:
+ raise AirflowException(f'API requests to Azure failed
{self.retry_limit} times. Giving up.')
+ except requests_exceptions.HTTPError as e:
+ raise AirflowException(f'Response: {e.response.content}, Status
Code: {e.response.status_code}')
+
+ return token
def _get_aad_headers(self) -> dict:
"""
@@ -279,14 +303,6 @@ class BaseDatabricksHook(BaseHook):
return None
- @staticmethod
- def _retryable_error(exception) -> bool:
- return (
- isinstance(exception, (requests_exceptions.ConnectionError,
requests_exceptions.Timeout))
- or exception.response is not None
- and exception.response.status_code >= 500
- )
-
def _log_request_error(self, attempt_num: int, error: str) -> None:
self.log.error('Attempt %s API Request to Databricks failed with
reason: %s', attempt_num, error)
@@ -327,36 +343,32 @@ class BaseDatabricksHook(BaseHook):
else:
raise AirflowException('Unexpected HTTP Method: ' + method)
- attempt_num = 1
- while True:
- try:
- response = request_func(
- url,
- json=json if method in ('POST', 'PATCH') else None,
- params=json if method == 'GET' else None,
- auth=auth,
- headers=headers,
- timeout=self.timeout_seconds,
- )
- response.raise_for_status()
- return response.json()
- except requests_exceptions.RequestException as e:
- if not self._retryable_error(e):
- # In this case, the user probably made a mistake.
- # Don't retry.
- raise AirflowException(
- f'Response: {e.response.content}, Status Code:
{e.response.status_code}'
+ try:
+ for attempt in self._get_retry_object():
+ with attempt:
+ response = request_func(
+ url,
+ json=json if method in ('POST', 'PATCH') else None,
+ params=json if method == 'GET' else None,
+ auth=auth,
+ headers=headers,
+ timeout=self.timeout_seconds,
)
+ response.raise_for_status()
+ return response.json()
+ except RetryError:
+ raise AirflowException(f'API requests to Databricks failed
{self.retry_limit} times. Giving up.')
+ except requests_exceptions.HTTPError as e:
+ raise AirflowException(f'Response: {e.response.content}, Status
Code: {e.response.status_code}')
- self._log_request_error(attempt_num, str(e))
-
- if attempt_num == self.retry_limit:
- raise AirflowException(
- f'API requests to Databricks failed {self.retry_limit}
times. Giving up.'
- )
-
- attempt_num += 1
- sleep(self.retry_delay)
+ @staticmethod
+ def _retryable_error(exception: BaseException) -> bool:
+ if not isinstance(exception, requests_exceptions.RequestException):
+ return False
+ return isinstance(exception, (requests_exceptions.ConnectionError,
requests_exceptions.Timeout)) or (
+ exception.response is not None
+ and (exception.response.status_code >= 500 or
exception.response.status_code == 429)
+ )
class _TokenAuth(AuthBase):
diff --git a/airflow/providers/databricks/operators/databricks.py
b/airflow/providers/databricks/operators/databricks.py
index 818fb1d..ec0a9a0 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -245,6 +245,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
unreachable. Its value must be greater than or equal to 1.
:param databricks_retry_delay: Number of seconds to wait between retries
(it
might be a floating point number).
+ :param databricks_retry_args: An optional dictionary with arguments passed
to ``tenacity.Retrying`` class.
:param do_xcom_push: Whether we should push run_id and run_page_url to
xcom.
"""
@@ -274,6 +275,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
+ databricks_retry_args: Optional[Dict[Any, Any]] = None,
do_xcom_push: bool = False,
idempotency_token: Optional[str] = None,
access_control_list: Optional[List[Dict[str, str]]] = None,
@@ -287,6 +289,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
+ self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
if tasks is not None:
self.json['tasks'] = tasks
@@ -327,6 +330,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
+ retry_args=self.databricks_retry_args,
)
def execute(self, context: 'Context'):
@@ -484,6 +488,7 @@ class DatabricksRunNowOperator(BaseOperator):
this run. By default the operator will poll every 30 seconds.
:param databricks_retry_limit: Amount of times retry if the Databricks
backend is
unreachable. Its value must be greater than or equal to 1.
+ :param databricks_retry_args: An optional dictionary with arguments passed
to ``tenacity.Retrying`` class.
:param do_xcom_push: Whether we should push run_id and run_page_url to
xcom.
"""
@@ -508,6 +513,7 @@ class DatabricksRunNowOperator(BaseOperator):
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
+ databricks_retry_args: Optional[Dict[Any, Any]] = None,
do_xcom_push: bool = False,
wait_for_termination: bool = True,
**kwargs,
@@ -519,6 +525,7 @@ class DatabricksRunNowOperator(BaseOperator):
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
+ self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
if job_id is not None:
@@ -546,6 +553,7 @@ class DatabricksRunNowOperator(BaseOperator):
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
+ retry_args=self.databricks_retry_args,
)
def execute(self, context: 'Context'):
diff --git a/docs/apache-airflow-providers-databricks/operators/run_now.rst
b/docs/apache-airflow-providers-databricks/operators/run_now.rst
index e0c0a7b..62fb4fd 100644
--- a/docs/apache-airflow-providers-databricks/operators/run_now.rst
+++ b/docs/apache-airflow-providers-databricks/operators/run_now.rst
@@ -61,5 +61,7 @@ Note that there is exactly one named parameter for each top
level parameter in t
- amount of times retry if the Databricks backend is unreachable
* - databricks_retry_delay: decimal
- number of seconds to wait between retries
+ * - databricks_retry_args: dict
+ - An optional dictionary with arguments passed to ``tenacity.Retrying``
class.
* - do_xcom_push: boolean
- whether we should push run_id and run_page_url to xcom
diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
index 579a305..497f72f 100644
--- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
+++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
@@ -70,6 +70,8 @@ one named parameter for each top level parameter in the
``runs/submit`` endpoint
- amount of times retry if the Databricks backend is unreachable
* - databricks_retry_delay: decimal
- number of seconds to wait between retries
+ * - databricks_retry_args: dict
+ - An optional dictionary with arguments passed to ``tenacity.Retrying``
class.
* - do_xcom_push: boolean
- whether we should push run_id and run_page_url to xcom
diff --git a/tests/providers/databricks/hooks/test_databricks.py
b/tests/providers/databricks/hooks/test_databricks.py
index 0948641..d1adba0 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -24,6 +24,7 @@ import unittest
from unittest import mock
import pytest
+import tenacity
from requests import exceptions as requests_exceptions
from requests.auth import HTTPBasicAuth
@@ -50,6 +51,11 @@ CLUSTER_ID = 'cluster_id'
RUN_ID = 1
JOB_ID = 42
JOB_NAME = 'job-name'
+DEFAULT_RETRY_NUMBER = 3
+DEFAULT_RETRY_ARGS = dict(
+ wait=tenacity.wait_none(),
+ stop=tenacity.stop_after_attempt(DEFAULT_RETRY_NUMBER),
+)
HOST = 'xx.cloud.databricks.com'
HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
LOGIN = 'login'
@@ -228,6 +234,8 @@ class TestDatabricksHook(unittest.TestCase):
DatabricksHook(retry_limit=0)
def test_do_api_call_retries_with_retryable_error(self):
+ hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
for exception in [
requests_exceptions.ConnectionError,
requests_exceptions.SSLError,
@@ -236,25 +244,41 @@ class TestDatabricksHook(unittest.TestCase):
requests_exceptions.HTTPError,
]:
with
mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as
mock_requests:
- with mock.patch.object(self.hook.log, 'error') as mock_errors:
+ with mock.patch.object(hook.log, 'error') as mock_errors:
setup_mock_requests(mock_requests, exception)
with pytest.raises(AirflowException):
- self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+ hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+ assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
+
+ def test_do_api_call_retries_with_too_many_requests(self):
+ hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+ with
mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as
mock_requests:
+ with mock.patch.object(hook.log, 'error') as mock_errors:
+ setup_mock_requests(mock_requests,
requests_exceptions.HTTPError, status_code=429)
+
+ with pytest.raises(AirflowException):
+ hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
- assert mock_errors.call_count == self.hook.retry_limit
+ assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
@mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
def test_do_api_call_does_not_retry_with_non_retryable_error(self,
mock_requests):
+ hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
setup_mock_requests(mock_requests, requests_exceptions.HTTPError,
status_code=400)
- with mock.patch.object(self.hook.log, 'error') as mock_errors:
+ with mock.patch.object(hook.log, 'error') as mock_errors:
with pytest.raises(AirflowException):
- self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+ hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
mock_errors.assert_not_called()
def test_do_api_call_succeeds_after_retrying(self):
+ hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
for exception in [
requests_exceptions.ConnectionError,
requests_exceptions.SSLError,
@@ -263,20 +287,18 @@ class TestDatabricksHook(unittest.TestCase):
requests_exceptions.HTTPError,
]:
with
mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as
mock_requests:
- with mock.patch.object(self.hook.log, 'error') as mock_errors:
+ with mock.patch.object(hook.log, 'error') as mock_errors:
setup_mock_requests(
mock_requests, exception, error_count=2,
response_content={'run_id': '1'}
)
- response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+ response = hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
assert mock_errors.call_count == 2
assert response == {'run_id': '1'}
- @mock.patch('airflow.providers.databricks.hooks.databricks_base.sleep')
- def test_do_api_call_waits_between_retries(self, mock_sleep):
- retry_delay = 5
- self.hook = DatabricksHook(retry_delay=retry_delay)
+ def test_do_api_call_custom_retry(self):
+ hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
for exception in [
requests_exceptions.ConnectionError,
@@ -286,16 +308,13 @@ class TestDatabricksHook(unittest.TestCase):
requests_exceptions.HTTPError,
]:
with
mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as
mock_requests:
- with mock.patch.object(self.hook.log, 'error'):
- mock_sleep.reset_mock()
+ with mock.patch.object(hook.log, 'error') as mock_errors:
setup_mock_requests(mock_requests, exception)
with pytest.raises(AirflowException):
- self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+ hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
- assert len(mock_sleep.mock_calls) == self.hook.retry_limit
- 1
- calls = [mock.call(retry_delay), mock.call(retry_delay)]
- mock_sleep.assert_has_calls(calls)
+ assert mock_errors.call_count == DEFAULT_RETRY_NUMBER
@mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
def test_do_api_call_patch(self, mock_requests):
@@ -796,7 +815,7 @@ class TestDatabricksHookAadToken(unittest.TestCase):
}
)
session.commit()
- self.hook = DatabricksHook()
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
@mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
def test_submit_run(self, mock_requests):
@@ -838,7 +857,7 @@ class
TestDatabricksHookAadTokenOtherClouds(unittest.TestCase):
}
)
session.commit()
- self.hook = DatabricksHook()
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
@mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
def test_submit_run(self, mock_requests):
@@ -883,7 +902,7 @@ class
TestDatabricksHookAadTokenSpOutside(unittest.TestCase):
}
)
session.commit()
- self.hook = DatabricksHook()
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
@mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
def test_submit_run(self, mock_requests):
@@ -931,7 +950,7 @@ class
TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase):
}
)
session.commit()
- self.hook = DatabricksHook()
+ self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
@mock.patch('airflow.providers.databricks.hooks.databricks_base.requests')
def test_submit_run(self, mock_requests):
diff --git a/tests/providers/databricks/operators/test_databricks.py
b/tests/providers/databricks/operators/test_databricks.py
index 0e93138..0d1bd09 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -224,7 +224,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': TASK_ID}
)
db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
)
db_mock.submit_run.assert_called_once_with(expected)
@@ -257,7 +260,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
}
)
db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
)
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
@@ -297,7 +303,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': TASK_ID}
)
db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
)
db_mock.submit_run.assert_called_once_with(expected)
@@ -322,7 +331,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase):
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK,
'run_name': TASK_ID}
)
db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
)
db_mock.submit_run.assert_called_once_with(expected)
@@ -445,7 +457,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
)
db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
@@ -475,7 +490,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
}
)
db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
@@ -513,7 +531,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
}
)
db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
)
db_mock.run_now.assert_called_once_with(expected)
@@ -540,7 +561,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
}
)
db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
)
db_mock.run_now.assert_called_once_with(expected)
@@ -582,7 +606,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
)
db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
)
db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)
db_mock.run_now.assert_called_once_with(expected)