This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 59084fd1f4 fix: add retry logic in case of google auth refresh
credential error (#38961)
59084fd1f4 is described below
commit 59084fd1f4c200986433f9ff60b28cd6f8a0bcc1
Author: Sebastian Daum <[email protected]>
AuthorDate: Sun Apr 21 08:19:24 2024 +0200
fix: add retry logic in case of google auth refresh credential error
(#38961)
---
airflow/providers/google/cloud/hooks/bigquery.py | 1 +
.../providers/google/common/hooks/base_google.py | 48 +++++++++++++++++++---
.../providers/google/cloud/hooks/test_bigquery.py | 32 +++++++++++++++
.../google/common/hooks/test_base_google.py | 44 +++++++++++++++++++-
4 files changed, 118 insertions(+), 7 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index 7482be89fb..f8c0fd1a5d 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -1580,6 +1580,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
time.sleep(5)
@GoogleBaseHook.fallback_to_default_project_id
+ @GoogleBaseHook.refresh_credentials_retry()
def get_job(
self,
job_id: str,
diff --git a/airflow/providers/google/common/hooks/base_google.py
b/airflow/providers/google/common/hooks/base_google.py
index 0fe5d16aae..5800f8e44c 100644
--- a/airflow/providers/google/common/hooks/base_google.py
+++ b/airflow/providers/google/common/hooks/base_google.py
@@ -114,6 +114,19 @@ def is_operation_in_progress_exception(exception:
Exception) -> bool:
return False
+def is_refresh_credentials_exception(exception: Exception) -> bool:
+ """
+ Handle refresh credentials exceptions.
+
+ Some calls return 502 (server error) in case a new token cannot be
obtained.
+
+ * Google BigQuery
+ """
+ if isinstance(exception, RefreshError):
+ return "Unable to acquire impersonated credentials" in str(exception)
+ return False
+
+
class retry_if_temporary_quota(tenacity.retry_if_exception):
"""Retries if there was an exception for exceeding the temporary quote
limit."""
@@ -122,12 +135,19 @@ class
retry_if_temporary_quota(tenacity.retry_if_exception):
class retry_if_operation_in_progress(tenacity.retry_if_exception):
- """Retries if there was an exception for exceeding the temporary quote
limit."""
+ """Retries if there was an exception in case of operation in progress."""
def __init__(self):
super().__init__(is_operation_in_progress_exception)
+class retry_if_temporary_refresh_credentials(tenacity.retry_if_exception):
+ """Retries if there was an exception for refreshing credentials."""
+
+ def __init__(self):
+ super().__init__(is_refresh_credentials_exception)
+
+
# A fake project_id to use in functions decorated by
fallback_to_default_project_id
# This allows the 'project_id' argument to be of type str instead of str |
None,
# making it easier to type hint the function body without dealing with the None
@@ -426,7 +446,7 @@ class GoogleBaseHook(BaseHook):
def quota_retry(*args, **kwargs) -> Callable:
"""Provide a mechanism to repeat requests in response to exceeding a
temporary quota limit."""
- def decorator(fun: Callable):
+ def decorator(func: Callable):
default_kwargs = {
"wait": tenacity.wait_exponential(multiplier=1, max=100),
"retry": retry_if_temporary_quota(),
@@ -434,7 +454,7 @@ class GoogleBaseHook(BaseHook):
"after": tenacity.after_log(log, logging.DEBUG),
}
default_kwargs.update(**kwargs)
- return tenacity.retry(*args, **default_kwargs)(fun)
+ return tenacity.retry(*args, **default_kwargs)(func)
return decorator
@@ -442,7 +462,7 @@ class GoogleBaseHook(BaseHook):
def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]:
"""Provide a mechanism to repeat requests in response to operation in
progress (HTTP 409) limit."""
- def decorator(fun: T):
+ def decorator(func: T):
default_kwargs = {
"wait": tenacity.wait_exponential(multiplier=1, max=300),
"retry": retry_if_operation_in_progress(),
@@ -450,7 +470,25 @@ class GoogleBaseHook(BaseHook):
"after": tenacity.after_log(log, logging.DEBUG),
}
default_kwargs.update(**kwargs)
- return cast(T, tenacity.retry(*args, **default_kwargs)(fun))
+ return cast(T, tenacity.retry(*args, **default_kwargs)(func))
+
+ return decorator
+
+ @staticmethod
+ def refresh_credentials_retry(*args, **kwargs) -> Callable[[T], T]:
+ """Provide a mechanism to repeat requests in response to a temporary
refresh credential issue."""
+
+ def decorator(func: T):
+ default_kwargs = {
+ "wait": tenacity.wait_exponential(multiplier=1, max=5),
+ "stop": tenacity.stop_after_attempt(3),
+ "retry": retry_if_temporary_refresh_credentials(),
+ "reraise": True,
+ "before": tenacity.before_log(log, logging.DEBUG),
+ "after": tenacity.after_log(log, logging.DEBUG),
+ }
+ default_kwargs.update(**kwargs)
+ return cast(T, tenacity.retry(*args, **default_kwargs)(func))
return decorator
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 37096b0ff3..c63dc581a9 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -26,6 +26,7 @@ import google.auth
import pytest
from gcloud.aio.bigquery import Job, Table as Table_async
from google.api_core import page_iterator
+from google.auth.exceptions import RefreshError
from google.cloud.bigquery import DEFAULT_RETRY, DatasetReference, Table,
TableReference
from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem
from google.cloud.bigquery.table import _EmptyRowIterator
@@ -598,6 +599,37 @@ class TestBigQueryHookMethods(_BigQueryBaseTestClass):
mock_client.return_value.get_job.assert_called_once_with(job_id=JOB_ID)
mock_client.return_value.get_job.return_value.done.assert_called_once_with(retry=DEFAULT_RETRY)
+ @mock.patch("tenacity.nap.time.sleep", mock.MagicMock())
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
+ def test_get_job_credentials_refresh_error(self, mock_client):
+ error = "Unable to acquire impersonated credentials"
+ response_body = "<!DOCTYPE html>\n<html lang=en>\n <meta
charset=utf-8>\n"
+ mock_job = mock.MagicMock(
+ job_id="123456_hash",
+ error_result=False,
+ state="PENDING",
+ done=lambda: False,
+ )
+ mock_client.return_value.get_job.side_effect = [RefreshError(error,
response_body), mock_job]
+
+ job = self.hook.get_job(job_id=JOB_ID, location=LOCATION,
project_id=PROJECT_ID)
+ mock_client.assert_any_call(location=LOCATION, project_id=PROJECT_ID)
+ assert mock_client.call_count == 2
+ assert job == mock_job
+
+ @pytest.mark.parametrize(
+ "error",
+ [
+ RefreshError("Other error", "test body"),
+ ValueError(),
+ ],
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
+ def test_get_job_credentials_error(self, mock_client, error):
+ mock_client.return_value.get_job.side_effect = error
+ with pytest.raises(type(error)):
+ self.hook.get_job(job_id=JOB_ID, location=LOCATION,
project_id=PROJECT_ID)
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.poll_job_complete")
@mock.patch("logging.Logger.info")
def test_cancel_query_jobs_to_cancel(
diff --git a/tests/providers/google/common/hooks/test_base_google.py
b/tests/providers/google/common/hooks/test_base_google.py
index ab2e26f59b..1a91f0742d 100644
--- a/tests/providers/google/common/hooks/test_base_google.py
+++ b/tests/providers/google/common/hooks/test_base_google.py
@@ -30,14 +30,14 @@ import google.auth.compute_engine
import pytest
import tenacity
from google.auth.environment_vars import CREDENTIALS
-from google.auth.exceptions import GoogleAuthError
+from google.auth.exceptions import GoogleAuthError, RefreshError
from google.cloud.exceptions import Forbidden
from airflow import version
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.utils.credentials_provider import
_DEFAULT_SCOPES
from airflow.providers.google.common.hooks import base_google as hook
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import GoogleBaseHook,
is_refresh_credentials_exception
from tests.providers.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
default_creds_available = True
@@ -98,6 +98,46 @@ class TestQuotaRetry:
)
+class TestRefreshCredentialsRetry:
+ @pytest.mark.parametrize(
+ "exc, retryable",
+ [
+ (RefreshError("Other error", "test body"), False),
+ (RefreshError("Unable to acquire impersonated credentials", "test
body"), True),
+ (ValueError(), False),
+ ],
+ )
+ def test_is_refresh_credentials_exception(self, exc, retryable):
+ assert is_refresh_credentials_exception(exc) is retryable
+
+ def test_do_nothing_on_non_error(self):
+ @hook.GoogleBaseHook.refresh_credentials_retry()
+ def func():
+ return 42
+
+ assert func() == 42
+
+ def test_raise_non_refresh_error(self):
+ @hook.GoogleBaseHook.refresh_credentials_retry()
+ def func():
+ raise ValueError()
+
+ with pytest.raises(ValueError):
+ func()
+
+ @mock.patch("tenacity.nap.time.sleep", mock.MagicMock())
+ def test_retry_on_refresh_error(self):
+ func_return = mock.Mock(
+ side_effect=[RefreshError("Unable to acquire impersonated
credentials", "test body"), 42]
+ )
+
+ @hook.GoogleBaseHook.refresh_credentials_retry()
+ def func():
+ return func_return()
+
+ assert func() == 42
+
+
class FallbackToDefaultProjectIdFixtureClass:
def __init__(self, project_id):
self.mock = mock.Mock()