This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 59084fd1f4 fix: add retry logic in case of google auth refresh 
credential error (#38961)
59084fd1f4 is described below

commit 59084fd1f4c200986433f9ff60b28cd6f8a0bcc1
Author: Sebastian Daum <[email protected]>
AuthorDate: Sun Apr 21 08:19:24 2024 +0200

    fix: add retry logic in case of google auth refresh credential error 
(#38961)
---
 airflow/providers/google/cloud/hooks/bigquery.py   |  1 +
 .../providers/google/common/hooks/base_google.py   | 48 +++++++++++++++++++---
 .../providers/google/cloud/hooks/test_bigquery.py  | 32 +++++++++++++++
 .../google/common/hooks/test_base_google.py        | 44 +++++++++++++++++++-
 4 files changed, 118 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/bigquery.py 
b/airflow/providers/google/cloud/hooks/bigquery.py
index 7482be89fb..f8c0fd1a5d 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -1580,6 +1580,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
                 time.sleep(5)
 
     @GoogleBaseHook.fallback_to_default_project_id
+    @GoogleBaseHook.refresh_credentials_retry()
     def get_job(
         self,
         job_id: str,
diff --git a/airflow/providers/google/common/hooks/base_google.py 
b/airflow/providers/google/common/hooks/base_google.py
index 0fe5d16aae..5800f8e44c 100644
--- a/airflow/providers/google/common/hooks/base_google.py
+++ b/airflow/providers/google/common/hooks/base_google.py
@@ -114,6 +114,19 @@ def is_operation_in_progress_exception(exception: 
Exception) -> bool:
     return False
 
 
+def is_refresh_credentials_exception(exception: Exception) -> bool:
+    """
+    Handle refresh credentials exceptions.
+
+    Some calls return 502 (server error) in case a new token cannot be 
obtained.
+
+    * Google BigQuery
+    """
+    if isinstance(exception, RefreshError):
+        return "Unable to acquire impersonated credentials" in str(exception)
+    return False
+
+
 class retry_if_temporary_quota(tenacity.retry_if_exception):
     """Retries if there was an exception for exceeding the temporary quote 
limit."""
 
@@ -122,12 +135,19 @@ class 
retry_if_temporary_quota(tenacity.retry_if_exception):
 
 
 class retry_if_operation_in_progress(tenacity.retry_if_exception):
-    """Retries if there was an exception for exceeding the temporary quote 
limit."""
+    """Retries if there was an exception in case of operation in progress."""
 
     def __init__(self):
         super().__init__(is_operation_in_progress_exception)
 
 
+class retry_if_temporary_refresh_credentials(tenacity.retry_if_exception):
+    """Retries if there was an exception for refreshing credentials."""
+
+    def __init__(self):
+        super().__init__(is_refresh_credentials_exception)
+
+
 # A fake project_id to use in functions decorated by 
fallback_to_default_project_id
 # This allows the 'project_id' argument to be of type str instead of str | 
None,
 # making it easier to type hint the function body without dealing with the None
@@ -426,7 +446,7 @@ class GoogleBaseHook(BaseHook):
     def quota_retry(*args, **kwargs) -> Callable:
         """Provide a mechanism to repeat requests in response to exceeding a 
temporary quota limit."""
 
-        def decorator(fun: Callable):
+        def decorator(func: Callable):
             default_kwargs = {
                 "wait": tenacity.wait_exponential(multiplier=1, max=100),
                 "retry": retry_if_temporary_quota(),
@@ -434,7 +454,7 @@ class GoogleBaseHook(BaseHook):
                 "after": tenacity.after_log(log, logging.DEBUG),
             }
             default_kwargs.update(**kwargs)
-            return tenacity.retry(*args, **default_kwargs)(fun)
+            return tenacity.retry(*args, **default_kwargs)(func)
 
         return decorator
 
@@ -442,7 +462,7 @@ class GoogleBaseHook(BaseHook):
     def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]:
         """Provide a mechanism to repeat requests in response to operation in 
progress (HTTP 409) limit."""
 
-        def decorator(fun: T):
+        def decorator(func: T):
             default_kwargs = {
                 "wait": tenacity.wait_exponential(multiplier=1, max=300),
                 "retry": retry_if_operation_in_progress(),
@@ -450,7 +470,25 @@ class GoogleBaseHook(BaseHook):
                 "after": tenacity.after_log(log, logging.DEBUG),
             }
             default_kwargs.update(**kwargs)
-            return cast(T, tenacity.retry(*args, **default_kwargs)(fun))
+            return cast(T, tenacity.retry(*args, **default_kwargs)(func))
+
+        return decorator
+
+    @staticmethod
+    def refresh_credentials_retry(*args, **kwargs) -> Callable[[T], T]:
+        """Provide a mechanism to repeat requests in response to a temporary 
refresh credential issue."""
+
+        def decorator(func: T):
+            default_kwargs = {
+                "wait": tenacity.wait_exponential(multiplier=1, max=5),
+                "stop": tenacity.stop_after_attempt(3),
+                "retry": retry_if_temporary_refresh_credentials(),
+                "reraise": True,
+                "before": tenacity.before_log(log, logging.DEBUG),
+                "after": tenacity.after_log(log, logging.DEBUG),
+            }
+            default_kwargs.update(**kwargs)
+            return cast(T, tenacity.retry(*args, **default_kwargs)(func))
 
         return decorator
 
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py 
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 37096b0ff3..c63dc581a9 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -26,6 +26,7 @@ import google.auth
 import pytest
 from gcloud.aio.bigquery import Job, Table as Table_async
 from google.api_core import page_iterator
+from google.auth.exceptions import RefreshError
 from google.cloud.bigquery import DEFAULT_RETRY, DatasetReference, Table, 
TableReference
 from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem
 from google.cloud.bigquery.table import _EmptyRowIterator
@@ -598,6 +599,37 @@ class TestBigQueryHookMethods(_BigQueryBaseTestClass):
         mock_client.return_value.get_job.assert_called_once_with(job_id=JOB_ID)
         
mock_client.return_value.get_job.return_value.done.assert_called_once_with(retry=DEFAULT_RETRY)
 
+    @mock.patch("tenacity.nap.time.sleep", mock.MagicMock())
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
+    def test_get_job_credentials_refresh_error(self, mock_client):
+        error = "Unable to acquire impersonated credentials"
+        response_body = "<!DOCTYPE html>\n<html lang=en>\n  <meta 
charset=utf-8>\n"
+        mock_job = mock.MagicMock(
+            job_id="123456_hash",
+            error_result=False,
+            state="PENDING",
+            done=lambda: False,
+        )
+        mock_client.return_value.get_job.side_effect = [RefreshError(error, 
response_body), mock_job]
+
+        job = self.hook.get_job(job_id=JOB_ID, location=LOCATION, 
project_id=PROJECT_ID)
+        mock_client.assert_any_call(location=LOCATION, project_id=PROJECT_ID)
+        assert mock_client.call_count == 2
+        assert job == mock_job
+
+    @pytest.mark.parametrize(
+        "error",
+        [
+            RefreshError("Other error", "test body"),
+            ValueError(),
+        ],
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
+    def test_get_job_credentials_error(self, mock_client, error):
+        mock_client.return_value.get_job.side_effect = error
+        with pytest.raises(type(error)):
+            self.hook.get_job(job_id=JOB_ID, location=LOCATION, 
project_id=PROJECT_ID)
+
     
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.poll_job_complete")
     @mock.patch("logging.Logger.info")
     def test_cancel_query_jobs_to_cancel(
diff --git a/tests/providers/google/common/hooks/test_base_google.py 
b/tests/providers/google/common/hooks/test_base_google.py
index ab2e26f59b..1a91f0742d 100644
--- a/tests/providers/google/common/hooks/test_base_google.py
+++ b/tests/providers/google/common/hooks/test_base_google.py
@@ -30,14 +30,14 @@ import google.auth.compute_engine
 import pytest
 import tenacity
 from google.auth.environment_vars import CREDENTIALS
-from google.auth.exceptions import GoogleAuthError
+from google.auth.exceptions import GoogleAuthError, RefreshError
 from google.cloud.exceptions import Forbidden
 
 from airflow import version
 from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.utils.credentials_provider import 
_DEFAULT_SCOPES
 from airflow.providers.google.common.hooks import base_google as hook
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, 
is_refresh_credentials_exception
 from tests.providers.google.cloud.utils.base_gcp_mock import 
mock_base_gcp_hook_default_project_id
 
 default_creds_available = True
@@ -98,6 +98,46 @@ class TestQuotaRetry:
             )
 
 
+class TestRefreshCredentialsRetry:
+    @pytest.mark.parametrize(
+        "exc, retryable",
+        [
+            (RefreshError("Other error", "test body"), False),
+            (RefreshError("Unable to acquire impersonated credentials", "test 
body"), True),
+            (ValueError(), False),
+        ],
+    )
+    def test_is_refresh_credentials_exception(self, exc, retryable):
+        assert is_refresh_credentials_exception(exc) is retryable
+
+    def test_do_nothing_on_non_error(self):
+        @hook.GoogleBaseHook.refresh_credentials_retry()
+        def func():
+            return 42
+
+        assert func() == 42
+
+    def test_raise_non_refresh_error(self):
+        @hook.GoogleBaseHook.refresh_credentials_retry()
+        def func():
+            raise ValueError()
+
+        with pytest.raises(ValueError):
+            func()
+
+    @mock.patch("tenacity.nap.time.sleep", mock.MagicMock())
+    def test_retry_on_refresh_error(self):
+        func_return = mock.Mock(
+            side_effect=[RefreshError("Unable to acquire impersonated 
credentials", "test body"), 42]
+        )
+
+        @hook.GoogleBaseHook.refresh_credentials_retry()
+        def func():
+            return func_return()
+
+        assert func() == 42
+
+
 class FallbackToDefaultProjectIdFixtureClass:
     def __init__(self, project_id):
         self.mock = mock.Mock()

Reply via email to