This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 203c044a638 Added retry logic for Snowflake OAuth token requests 
(#61796)
203c044a638 is described below

commit 203c044a6380bf23b0fdca35f39b8f6b06e8652b
Author: SameerMesiah97 <[email protected]>
AuthorDate: Sun Feb 15 23:45:05 2026 +0000

    Added retry logic for Snowflake OAuth token requests (#61796)
    
    Introduced retry handling for OAuth token acquisition in SnowflakeHook 
using tenacity. Extracted the HTTP call into _request_oauth_token and added 
retry classification via _is_retryable_oauth_error. Retries apply only to 
connection errors and HTTP 5xx responses, while HTTP 4xx errors fail fast.
    
    Updated unit tests to cover retry behavior, non-retryable errors, and retry 
exhaustion. Updated the get_oauth_token docstring to reflect retry semantics.
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../airflow/providers/snowflake/hooks/snowflake.py |  74 +++++++++++---
 .../tests/unit/snowflake/hooks/test_snowflake.py   | 106 +++++++++++++++++++++
 2 files changed, 166 insertions(+), 14 deletions(-)

diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index 96725472493..d077b072330 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -29,9 +29,11 @@ from typing import TYPE_CHECKING, Any, TypeVar, overload
 from urllib.parse import urlparse
 
 import requests
+import tenacity
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
 from requests.auth import HTTPBasicAuth
+from requests.exceptions import ConnectionError, HTTPError, Timeout
 from snowflake import connector
 from snowflake.connector import DictCursor, SnowflakeConnection, util_text
 from snowflake.sqlalchemy import URL
@@ -65,6 +67,22 @@ def _try_to_boolean(value: Any):
     return value
 
 
+def _is_retryable_oauth_error(exception: BaseException) -> bool:
+    """Return True if exception is retryable for OAuth token request."""
+    if isinstance(exception, (ConnectionError, Timeout)):
+        return True
+
+    # Retry only on server-side HTTP errors (5xx).
+    # Client-side errors (4xx) indicate misconfiguration or invalid credentials
+    # and should fail fast without retrying.
+    if isinstance(exception, HTTPError):
+        response = exception.response
+        if response is not None and 500 <= response.status_code < 600:
+            return True
+
+    return False
+
+
 class SnowflakeHook(DbApiHook):
     """
     A client to interact with Snowflake.
@@ -239,7 +257,11 @@ class SnowflakeHook(DbApiHook):
         token_endpoint: str | None = None,
         grant_type: str = "refresh_token",
     ) -> str:
-        """Generate temporary OAuth access token using refresh token in 
connection details."""
+        """
+        Generate temporary OAuth access token using refresh token in 
connection details.
+
+        Transient network and server-side errors are retried automatically.
+        """
         if conn_config is None:
             conn_config = self._get_static_conn_params
 
@@ -503,22 +525,13 @@ class SnowflakeHook(DbApiHook):
         else:
             raise ValueError(f"Unknown grant_type: {grant_type}")
 
-        response = requests.post(
-            url,
+        response = self._request_oauth_token(
+            url=url,
             data=data,
-            headers={
-                "Content-Type": "application/x-www-form-urlencoded",
-            },
-            auth=HTTPBasicAuth(conn_config["client_id"], 
conn_config["client_secret"]),  # type: ignore[arg-type]
-            timeout=OAUTH_REQUEST_TIMEOUT,
+            client_id=conn_config["client_id"],
+            client_secret=conn_config["client_secret"],
         )
 
-        try:
-            response.raise_for_status()
-        except requests.exceptions.HTTPError as e:  # pragma: no cover
-            msg = f"Response: {e.response.content.decode()} Status Code: 
{e.response.status_code}"
-            raise AirflowException(msg)
-
         token = response.json()["access_token"]
         expires_in = int(response.json()["expires_in"])
 
@@ -531,6 +544,39 @@ class SnowflakeHook(DbApiHook):
 
         return token
 
+    @tenacity.retry(
+        stop=tenacity.stop_after_attempt(3),
+        wait=tenacity.wait_exponential(multiplier=1, min=0, max=10),
+        retry=tenacity.retry_if_exception(_is_retryable_oauth_error),
+        reraise=True,
+    )
+    def _request_oauth_token(
+        self,
+        *,
+        url: str,
+        data: dict[str, Any],
+        client_id: str,
+        client_secret: str,
+    ):
+        """
+        Execute a single OAuth token request.
+
+        Performs one HTTP call and raises ``HTTPError`` for 4xx and 5xx 
responses.
+        Retry behavior is handled by the caller.
+        """
+        response = requests.post(
+            url,
+            data=data,
+            headers={"Content-Type": "application/x-www-form-urlencoded"},
+            auth=HTTPBasicAuth(client_id, client_secret),
+            timeout=OAUTH_REQUEST_TIMEOUT,
+        )
+
+        # Raise HTTPError for non-success responses so retry logic can decide
+        # whether the failure is retryable.
+        response.raise_for_status()
+        return response
+
     def get_uri(self) -> str:
         """Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
         conn_params = self._get_conn_params()
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py 
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index c28fd895d12..da10a1efd72 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -30,6 +30,7 @@ import pytest
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
 from cryptography.hazmat.primitives.asymmetric import rsa
+from requests.exceptions import ConnectionError, HTTPError
 
 from airflow.models import Connection
 from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
@@ -1131,6 +1132,111 @@ class TestPytestSnowflakeHook:
             timeout=30,
         )
 
+    @mock.patch("airflow.providers.snowflake.hooks.snowflake.timezone.utcnow")
+    @mock.patch("requests.post")
+    def test_get_oauth_token_retries_and_succeeds(self, requests_post, 
mock_timezone_utcnow):
+
+        # Freeze time to prevent access token expiration.
+        t0 = datetime(2025, 1, 1, 12, 0, tzinfo=timezone.utc)
+        mock_timezone_utcnow.side_effect = [t0, t0]
+
+        requests_post.side_effect = [
+            ConnectionError("temporary network error"),
+            Mock(
+                status_code=200,
+                json=lambda: {"access_token": "retry_token", "expires_in": 
600},
+                raise_for_status=lambda: None,
+            ),
+        ]
+
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "login": "client_id",
+            "password": "client_secret",
+            "extra": {
+                "account": "airflow",
+                "authenticator": "oauth",
+                "grant_type": "refresh_token",
+                "refresh_token": "secret_token",
+            },
+        }
+
+        with mock.patch.dict(
+            "os.environ",
+            {"AIRFLOW_CONN_TEST_CONN": 
Connection(**connection_kwargs).get_uri()},
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            token = hook.get_oauth_token()
+
+        # Should retry.
+        assert token == "retry_token"
+        assert requests_post.call_count == 2
+
+    @mock.patch("airflow.providers.snowflake.hooks.snowflake.requests.post")
+    def test_get_oauth_token_does_not_retry_on_client_error(self, 
requests_post):
+
+        response = Mock(status_code=401)
+        http_error = HTTPError(response=response)
+
+        mock_response = Mock()
+        mock_response.raise_for_status.side_effect = http_error
+
+        requests_post.return_value = mock_response
+
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "login": "client_id",
+            "password": "client_secret",
+            "extra": {
+                "account": "airflow",
+                "authenticator": "oauth",
+                "grant_type": "refresh_token",
+                "refresh_token": "secret_token",
+            },
+        }
+
+        with mock.patch.dict(
+            "os.environ",
+            {"AIRFLOW_CONN_TEST_CONN": 
Connection(**connection_kwargs).get_uri()},
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+
+            with pytest.raises(HTTPError):
+                hook.get_oauth_token()
+
+        # Should not retry.
+        assert requests_post.call_count == 1
+
+    @mock.patch("airflow.providers.snowflake.hooks.snowflake.requests.post")
+    def test_get_oauth_token_fails_after_max_retries(self, requests_post):
+
+        # Always fail with retryable error.
+        requests_post.side_effect = ConnectionError("persistent network 
failure")
+
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "login": "client_id",
+            "password": "client_secret",
+            "extra": {
+                "account": "airflow",
+                "authenticator": "oauth",
+                "grant_type": "refresh_token",
+                "refresh_token": "secret_token",
+            },
+        }
+
+        with mock.patch.dict(
+            "os.environ",
+            {"AIRFLOW_CONN_TEST_CONN": 
Connection(**connection_kwargs).get_uri()},
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+
+            with pytest.raises(ConnectionError):
+                hook.get_oauth_token()
+
+        # Stop after the third attempt.
+        assert requests_post.call_count == 3
+
     def test_get_azure_oauth_token(self, mocker):
         """Test get_azure_oauth_token method gets token from provided 
connection id"""
         azure_conn_id = "azure_test_conn"

Reply via email to