This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 12e9e2c Databricks hook - retry on HTTP Status 429 as well (#21852) 12e9e2c is described below commit 12e9e2c695f9ebb9d3dde9c0f7dfaa112654f0d6 Author: Alex Ott <alex...@gmail.com> AuthorDate: Mon Mar 14 00:19:01 2022 +0100 Databricks hook - retry on HTTP Status 429 as well (#21852) * Databricks hook - retry on HTTP Status 429 as well this fixes #21559 * Reimplement retries using tenacity it's now uses exponential backoff by default --- airflow/providers/databricks/hooks/databricks.py | 4 +- .../providers/databricks/hooks/databricks_base.py | 196 +++++++++++---------- .../providers/databricks/operators/databricks.py | 8 + .../operators/run_now.rst | 2 + .../operators/submit_run.rst | 2 + .../providers/databricks/hooks/test_databricks.py | 61 ++++--- .../databricks/operators/test_databricks.py | 45 ++++- 7 files changed, 195 insertions(+), 123 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 20255e3..cdab4d0 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -100,6 +100,7 @@ class DatabricksHook(BaseDatabricksHook): service outages. :param retry_delay: The number of seconds to wait between retries (it might be a floating point number). + :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. """ hook_name = 'Databricks' @@ -110,8 +111,9 @@ class DatabricksHook(BaseDatabricksHook): timeout_seconds: int = 180, retry_limit: int = 3, retry_delay: float = 1.0, + retry_args: Optional[Dict[Any, Any]] = None, ) -> None: - super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay) + super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args) def run_now(self, json: dict) -> int: """ diff --git a/airflow/providers/databricks/hooks/databricks_base.py b/airflow/providers/databricks/hooks/databricks_base.py index 1a0d639..f545f26 100644 --- a/airflow/providers/databricks/hooks/databricks_base.py +++ b/airflow/providers/databricks/hooks/databricks_base.py @@ -22,15 +22,16 @@ This hook enable the submitting and running of jobs to the Databricks platform. operators talk to the ``api/2.0/jobs/runs/submit`` `endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_. """ +import copy import sys import time -from time import sleep from typing import Any, Dict, Optional, Tuple from urllib.parse import urlparse import requests from requests import PreparedRequest, exceptions as requests_exceptions from requests.auth import AuthBase, HTTPBasicAuth +from tenacity import RetryError, Retrying, retry_if_exception, stop_after_attempt, wait_exponential from airflow import __version__ from airflow.exceptions import AirflowException @@ -68,6 +69,7 @@ class BaseDatabricksHook(BaseHook): service outages. :param retry_delay: The number of seconds to wait between retries (it might be a floating point number). + :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. """ conn_name_attr = 'databricks_conn_id' @@ -89,17 +91,33 @@ class BaseDatabricksHook(BaseHook): timeout_seconds: int = 180, retry_limit: int = 3, retry_delay: float = 1.0, + retry_args: Optional[Dict[Any, Any]] = None, ) -> None: super().__init__() self.databricks_conn_id = databricks_conn_id self.timeout_seconds = timeout_seconds if retry_limit < 1: - raise ValueError('Retry limit must be greater than equal to 1') + raise ValueError('Retry limit must be greater than or equal to 1') self.retry_limit = retry_limit self.retry_delay = retry_delay self.aad_tokens: Dict[str, dict] = {} self.aad_timeout_seconds = 10 + def my_after_func(retry_state): + self._log_request_error(retry_state.attempt_number, retry_state.outcome) + + if retry_args: + self.retry_args = copy.copy(retry_args) + self.retry_args['retry'] = retry_if_exception(self._retryable_error) + self.retry_args['after'] = my_after_func + else: + self.retry_args = dict( + stop=stop_after_attempt(self.retry_limit), + wait=wait_exponential(min=self.retry_delay, max=(2 ** retry_limit)), + retry=retry_if_exception(self._retryable_error), + after=my_after_func, + ) + @cached_property def databricks_conn(self) -> Connection: return self.get_connection(self.databricks_conn_id) @@ -143,6 +161,13 @@ class BaseDatabricksHook(BaseHook): # In this case, host = xx.cloud.databricks.com return host + def _get_retry_object(self) -> Retrying: + """ + Instantiates a retry object + :return: instance of Retrying class + """ + return Retrying(**self.retry_args) + def _get_aad_token(self, resource: str) -> str: """ Function to get AAD token for given resource. Supports managed identity or service principal auth @@ -154,60 +179,59 @@ class BaseDatabricksHook(BaseHook): return aad_token['token'] self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...') - attempt_num = 1 - while True: - try: - if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): - params = { - "api-version": "2018-02-01", - "resource": resource, - } - resp = requests.get( - AZURE_METADATA_SERVICE_TOKEN_URL, - params=params, - headers={**USER_AGENT_HEADER, "Metadata": "true"}, - timeout=self.aad_timeout_seconds, - ) - else: - tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id'] - data = { - "grant_type": "client_credentials", - "client_id": self.databricks_conn.login, - "resource": resource, - "client_secret": self.databricks_conn.password, - } - azure_ad_endpoint = self.databricks_conn.extra_dejson.get( - "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT - ) - resp = requests.post( - AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id), - data=data, - headers={**USER_AGENT_HEADER, 'Content-Type': 'application/x-www-form-urlencoded'}, - timeout=self.aad_timeout_seconds, - ) - - resp.raise_for_status() - jsn = resp.json() - if 'access_token' not in jsn or jsn.get('token_type') != 'Bearer' or 'expires_on' not in jsn: - raise AirflowException(f"Can't get necessary data from AAD token: {jsn}") - - token = jsn['access_token'] - self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])} - - return token - except requests_exceptions.RequestException as e: - if not self._retryable_error(e): - raise AirflowException( - f'Response: {e.response.content}, Status Code: {e.response.status_code}' - ) - - self._log_request_error(attempt_num, e.strerror) - - if attempt_num == self.retry_limit: - raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.') - - attempt_num += 1 - sleep(self.retry_delay) + try: + for attempt in self._get_retry_object(): + with attempt: + if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): + params = { + "api-version": "2018-02-01", + "resource": resource, + } + resp = requests.get( + AZURE_METADATA_SERVICE_TOKEN_URL, + params=params, + headers={**USER_AGENT_HEADER, "Metadata": "true"}, + timeout=self.aad_timeout_seconds, + ) + else: + tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id'] + data = { + "grant_type": "client_credentials", + "client_id": self.databricks_conn.login, + "resource": resource, + "client_secret": self.databricks_conn.password, + } + azure_ad_endpoint = self.databricks_conn.extra_dejson.get( + "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT + ) + resp = requests.post( + AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id), + data=data, + headers={ + **USER_AGENT_HEADER, + 'Content-Type': 'application/x-www-form-urlencoded', + }, + timeout=self.aad_timeout_seconds, + ) + + resp.raise_for_status() + jsn = resp.json() + if ( + 'access_token' not in jsn + or jsn.get('token_type') != 'Bearer' + or 'expires_on' not in jsn + ): + raise AirflowException(f"Can't get necessary data from AAD token: {jsn}") + + token = jsn['access_token'] + self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])} + break + except RetryError: + raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.') + except requests_exceptions.HTTPError as e: + raise AirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}') + + return token def _get_aad_headers(self) -> dict: """ @@ -279,14 +303,6 @@ class BaseDatabricksHook(BaseHook): return None - @staticmethod - def _retryable_error(exception) -> bool: - return ( - isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) - or exception.response is not None - and exception.response.status_code >= 500 - ) - def _log_request_error(self, attempt_num: int, error: str) -> None: self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error) @@ -327,36 +343,32 @@ class BaseDatabricksHook(BaseHook): else: raise AirflowException('Unexpected HTTP Method: ' + method) - attempt_num = 1 - while True: - try: - response = request_func( - url, - json=json if method in ('POST', 'PATCH') else None, - params=json if method == 'GET' else None, - auth=auth, - headers=headers, - timeout=self.timeout_seconds, - ) - response.raise_for_status() - return response.json() - except requests_exceptions.RequestException as e: - if not self._retryable_error(e): - # In this case, the user probably made a mistake. - # Don't retry. - raise AirflowException( - f'Response: {e.response.content}, Status Code: {e.response.status_code}' + try: + for attempt in self._get_retry_object(): + with attempt: + response = request_func( + url, + json=json if method in ('POST', 'PATCH') else None, + params=json if method == 'GET' else None, + auth=auth, + headers=headers, + timeout=self.timeout_seconds, ) + response.raise_for_status() + return response.json() + except RetryError: + raise AirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.') + except requests_exceptions.HTTPError as e: + raise AirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}') - self._log_request_error(attempt_num, str(e)) - - if attempt_num == self.retry_limit: - raise AirflowException( - f'API requests to Databricks failed {self.retry_limit} times. Giving up.' - ) - - attempt_num += 1 - sleep(self.retry_delay) + @staticmethod + def _retryable_error(exception: BaseException) -> bool: + if not isinstance(exception, requests_exceptions.RequestException): + return False + return isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or ( + exception.response is not None + and (exception.response.status_code >= 500 or exception.response.status_code == 429) + ) class _TokenAuth(AuthBase): diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 818fb1d..ec0a9a0 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -245,6 +245,7 @@ class DatabricksSubmitRunOperator(BaseOperator): unreachable. Its value must be greater than or equal to 1. :param databricks_retry_delay: Number of seconds to wait between retries (it might be a floating point number). + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. """ @@ -274,6 +275,7 @@ class DatabricksSubmitRunOperator(BaseOperator): polling_period_seconds: int = 30, databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, + databricks_retry_args: Optional[Dict[Any, Any]] = None, do_xcom_push: bool = False, idempotency_token: Optional[str] = None, access_control_list: Optional[List[Dict[str, str]]] = None, @@ -287,6 +289,7 @@ class DatabricksSubmitRunOperator(BaseOperator): self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination if tasks is not None: self.json['tasks'] = tasks @@ -327,6 +330,7 @@ class DatabricksSubmitRunOperator(BaseOperator): self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, ) def execute(self, context: 'Context'): @@ -484,6 +488,7 @@ class DatabricksRunNowOperator(BaseOperator): this run. By default the operator will poll every 30 seconds. :param databricks_retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. """ @@ -508,6 +513,7 @@ class DatabricksRunNowOperator(BaseOperator): polling_period_seconds: int = 30, databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, + databricks_retry_args: Optional[Dict[Any, Any]] = None, do_xcom_push: bool = False, wait_for_termination: bool = True, **kwargs, @@ -519,6 +525,7 @@ class DatabricksRunNowOperator(BaseOperator): self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination if job_id is not None: @@ -546,6 +553,7 @@ class DatabricksRunNowOperator(BaseOperator): self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, ) def execute(self, context: 'Context'): diff --git a/docs/apache-airflow-providers-databricks/operators/run_now.rst b/docs/apache-airflow-providers-databricks/operators/run_now.rst index e0c0a7b..62fb4fd 100644 --- a/docs/apache-airflow-providers-databricks/operators/run_now.rst +++ b/docs/apache-airflow-providers-databricks/operators/run_now.rst @@ -61,5 +61,7 @@ Note that there is exactly one named parameter for each top level parameter in t - amount of times retry if the Databricks backend is unreachable * - databricks_retry_delay: decimal - number of seconds to wait between retries + * - databricks_retry_args: dict + - An optional dictionary with arguments passed to ``tenacity.Retrying`` class. * - do_xcom_push: boolean - whether we should push run_id and run_page_url to xcom diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst b/docs/apache-airflow-providers-databricks/operators/submit_run.rst index 579a305..497f72f 100644 --- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst +++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst @@ -70,6 +70,8 @@ one named parameter for each top level parameter in the ``runs/submit`` endpoint - amount of times retry if the Databricks backend is unreachable * - databricks_retry_delay: decimal - number of seconds to wait between retries + * - databricks_retry_args: dict + - An optional dictionary with arguments passed to ``tenacity.Retrying`` class. * - do_xcom_push: boolean - whether we should push run_id and run_page_url to xcom diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 0948641..d1adba0 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -24,6 +24,7 @@ import unittest from unittest import mock import pytest +import tenacity from requests import exceptions as requests_exceptions from requests.auth import HTTPBasicAuth @@ -50,6 +51,11 @@ CLUSTER_ID = 'cluster_id' RUN_ID = 1 JOB_ID = 42 JOB_NAME = 'job-name' +DEFAULT_RETRY_NUMBER = 3 +DEFAULT_RETRY_ARGS = dict( + wait=tenacity.wait_none(), + stop=tenacity.stop_after_attempt(DEFAULT_RETRY_NUMBER), +) HOST = 'xx.cloud.databricks.com' HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com' LOGIN = 'login' @@ -228,6 +234,8 @@ class TestDatabricksHook(unittest.TestCase): DatabricksHook(retry_limit=0) def test_do_api_call_retries_with_retryable_error(self): + hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + for exception in [ requests_exceptions.ConnectionError, requests_exceptions.SSLError, @@ -236,25 +244,41 @@ class TestDatabricksHook(unittest.TestCase): requests_exceptions.HTTPError, ]: with mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as mock_requests: - with mock.patch.object(self.hook.log, 'error') as mock_errors: + with mock.patch.object(hook.log, 'error') as mock_errors: setup_mock_requests(mock_requests, exception) with pytest.raises(AirflowException): - self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + assert mock_errors.call_count == DEFAULT_RETRY_NUMBER + + def test_do_api_call_retries_with_too_many_requests(self): + hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + + with mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as mock_requests: + with mock.patch.object(hook.log, 'error') as mock_errors: + setup_mock_requests(mock_requests, requests_exceptions.HTTPError, status_code=429) + + with pytest.raises(AirflowException): + hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) - assert mock_errors.call_count == self.hook.retry_limit + assert mock_errors.call_count == DEFAULT_RETRY_NUMBER @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests): + hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + setup_mock_requests(mock_requests, requests_exceptions.HTTPError, status_code=400) - with mock.patch.object(self.hook.log, 'error') as mock_errors: + with mock.patch.object(hook.log, 'error') as mock_errors: with pytest.raises(AirflowException): - self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) mock_errors.assert_not_called() def test_do_api_call_succeeds_after_retrying(self): + hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + for exception in [ requests_exceptions.ConnectionError, requests_exceptions.SSLError, @@ -263,20 +287,18 @@ class TestDatabricksHook(unittest.TestCase): requests_exceptions.HTTPError, ]: with mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as mock_requests: - with mock.patch.object(self.hook.log, 'error') as mock_errors: + with mock.patch.object(hook.log, 'error') as mock_errors: setup_mock_requests( mock_requests, exception, error_count=2, response_content={'run_id': '1'} ) - response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + response = hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) assert mock_errors.call_count == 2 assert response == {'run_id': '1'} - @mock.patch('airflow.providers.databricks.hooks.databricks_base.sleep') - def test_do_api_call_waits_between_retries(self, mock_sleep): - retry_delay = 5 - self.hook = DatabricksHook(retry_delay=retry_delay) + def test_do_api_call_custom_retry(self): + hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) for exception in [ requests_exceptions.ConnectionError, @@ -286,16 +308,13 @@ class TestDatabricksHook(unittest.TestCase): requests_exceptions.HTTPError, ]: with mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') as mock_requests: - with mock.patch.object(self.hook.log, 'error'): - mock_sleep.reset_mock() + with mock.patch.object(hook.log, 'error') as mock_errors: setup_mock_requests(mock_requests, exception) with pytest.raises(AirflowException): - self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) - assert len(mock_sleep.mock_calls) == self.hook.retry_limit - 1 - calls = [mock.call(retry_delay), mock.call(retry_delay)] - mock_sleep.assert_has_calls(calls) + assert mock_errors.call_count == DEFAULT_RETRY_NUMBER @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') def test_do_api_call_patch(self, mock_requests): @@ -796,7 +815,7 @@ class TestDatabricksHookAadToken(unittest.TestCase): } ) session.commit() - self.hook = DatabricksHook() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') def test_submit_run(self, mock_requests): @@ -838,7 +857,7 @@ class TestDatabricksHookAadTokenOtherClouds(unittest.TestCase): } ) session.commit() - self.hook = DatabricksHook() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') def test_submit_run(self, mock_requests): @@ -883,7 +902,7 @@ class TestDatabricksHookAadTokenSpOutside(unittest.TestCase): } ) session.commit() - self.hook = DatabricksHook() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') def test_submit_run(self, mock_requests): @@ -931,7 +950,7 @@ class TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase): } ) session.commit() - self.hook = DatabricksHook() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @mock.patch('airflow.providers.databricks.hooks.databricks_base.requests') def test_submit_run(self, mock_requests): diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 0e93138..0d1bd09 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -224,7 +224,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase): {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, ) db_mock.submit_run.assert_called_once_with(expected) @@ -257,7 +260,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase): } ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, ) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) @@ -297,7 +303,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase): {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, ) db_mock.submit_run.assert_called_once_with(expected) @@ -322,7 +331,10 @@ class TestDatabricksSubmitRunOperator(unittest.TestCase): {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, ) db_mock.submit_run.assert_called_once_with(expected) @@ -445,7 +457,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase): ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, ) db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) @@ -475,7 +490,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase): } ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, ) db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) @@ -513,7 +531,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase): } ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, ) db_mock.run_now.assert_called_once_with(expected) @@ -540,7 +561,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase): } ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, ) db_mock.run_now.assert_called_once_with(expected) @@ -582,7 +606,10 @@ class TestDatabricksRunNowOperator(unittest.TestCase): ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, ) db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME) db_mock.run_now.assert_called_once_with(expected)