This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a7f450901dd Metadata service check handle 429 (#55462)
a7f450901dd is described below

commit a7f450901dd2fbde508d74f79fe5b2e7f3d6543d
Author: Xch1 <[email protected]>
AuthorDate: Wed Sep 17 06:58:21 2025 +0800

    Metadata service check handle 429 (#55462)
---
 .../providers/databricks/hooks/databricks_base.py  |  78 +++++++----
 .../unit/databricks/hooks/test_databricks_base.py  | 144 ++++++++++++++++++++-
 2 files changed, 194 insertions(+), 28 deletions(-)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
index 7ece90122e7..919e21c3287 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
@@ -121,6 +121,9 @@ class BaseDatabricksHook(BaseHook):
         self.oauth_tokens: dict[str, dict] = {}
         self.token_timeout_seconds = 10
         self.caller = caller
+        self._metadata_cache: dict[str, Any] = {}
+        self._metadata_expiry: float = 0
+        self._metadata_ttl: int = 300
 
         def my_after_func(retry_state):
             self._log_request_error(retry_state.attempt_number, 
retry_state.outcome)
@@ -515,43 +518,64 @@ class BaseDatabricksHook(BaseHook):
 
         return int(token[time_key]) > (int(time.time()) + 
TOKEN_REFRESH_LEAD_TIME)
 
-    @staticmethod
-    def _check_azure_metadata_service() -> None:
+    def _check_azure_metadata_service(self) -> None:
         """
-        Check for Azure Metadata Service.
+        Check for Azure Metadata Service (with caching).
 
         
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
         """
+        if self._metadata_cache and time.time() < self._metadata_expiry:
+            return
         try:
-            jsn = requests.get(
-                AZURE_METADATA_SERVICE_INSTANCE_URL,
-                params={"api-version": "2021-02-01"},
-                headers={"Metadata": "true"},
-                timeout=2,
-            ).json()
-            if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
-                raise AirflowException(
-                    f"Was able to fetch some metadata, but it doesn't look 
like Azure Metadata: {jsn}"
-                )
+            for attempt in self._get_retry_object():
+                with attempt:
+                    response = requests.get(
+                        AZURE_METADATA_SERVICE_INSTANCE_URL,
+                        params={"api-version": "2021-02-01"},
+                        headers={"Metadata": "true"},
+                        timeout=2,
+                    )
+                    response.raise_for_status()
+                    response_json = response.json()
+
+                    self._validate_azure_metadata_service(response_json)
+                    self._metadata_cache = response_json
+                    self._metadata_expiry = time.time() + self._metadata_ttl
+                    break
+        except RetryError:
+            raise ConnectionError(f"Failed to reach Azure Metadata Service 
after {self.retry_limit} retries.")
         except (requests_exceptions.RequestException, ValueError) as e:
-            raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
+            raise ConnectionError(f"Can't reach Azure Metadata Service: {e}")
 
     async def _a_check_azure_metadata_service(self):
         """Async version of `_check_azure_metadata_service()`."""
+        if self._metadata_cache and time.time() < self._metadata_expiry:
+            return
         try:
-            async with self._session.get(
-                url=AZURE_METADATA_SERVICE_INSTANCE_URL,
-                params={"api-version": "2021-02-01"},
-                headers={"Metadata": "true"},
-                timeout=2,
-            ) as resp:
-                jsn = await resp.json()
-            if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
-                raise AirflowException(
-                    f"Was able to fetch some metadata, but it doesn't look 
like Azure Metadata: {jsn}"
-                )
-        except (requests_exceptions.RequestException, ValueError) as e:
-            raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    async with self._session.get(
+                        url=AZURE_METADATA_SERVICE_INSTANCE_URL,
+                        params={"api-version": "2021-02-01"},
+                        headers={"Metadata": "true"},
+                        timeout=2,
+                    ) as resp:
+                        resp.raise_for_status()
+                        response_json = await resp.json()
+                    self._validate_azure_metadata_service(response_json)
+                    self._metadata_cache = response_json
+                    self._metadata_expiry = time.time() + self._metadata_ttl
+                    break
+        except RetryError:
+            raise ConnectionError(f"Failed to reach Azure Metadata Service 
after {self.retry_limit} retries.")
+        except (aiohttp.ClientError, ValueError) as e:
+            raise ConnectionError(f"Can't reach Azure Metadata Service: {e}")
+
+    def _validate_azure_metadata_service(self, response_json: dict) -> None:
+        if "compute" not in response_json or "azEnvironment" not in 
response_json["compute"]:
+            raise ValueError(
+                f"Was able to fetch some metadata, but it doesn't look like 
Azure Metadata: {response_json}"
+            )
 
     def _get_token(self, raise_error: bool = False) -> str | None:
         if "token" in self.databricks_conn.extra_dejson:
diff --git 
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py 
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
index 2d6df615dca..34b952f7c84 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
@@ -26,7 +26,7 @@ import time_machine
 from aiohttp.client_exceptions import ClientConnectorError
 from requests import exceptions as requests_exceptions
 from requests.auth import HTTPBasicAuth
-from tenacity import Future, RetryError
+from tenacity import AsyncRetrying, Future, RetryError, retry_if_exception, 
stop_after_attempt, wait_fixed
 
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
@@ -768,3 +768,145 @@ class TestBaseDatabricksHook:
         exception.response = mock_response
         hook = BaseDatabricksHook()
         assert hook._get_error_code(exception) == "INVALID_REQUEST"
+
+    @mock.patch("requests.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    def test_check_azure_metadata_service_normal(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+        mock_get.return_value.json.return_value = mock_response
+
+        hook._check_azure_metadata_service()
+
+        assert hook._metadata_cache == mock_response
+        assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl
+
+    @mock.patch("requests.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    def test_check_azure_metadata_service_cached(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+        hook._metadata_cache = mock_response
+        hook._metadata_expiry = travel_time + 1000
+
+        hook._check_azure_metadata_service()
+        mock_get.assert_not_called()
+
+    @mock.patch("requests.get")
+    def test_check_azure_metadata_service_http_error(self, mock_get):
+        hook = BaseDatabricksHook()
+        mock_get.side_effect = requests_exceptions.RequestException("Fail")
+
+        with pytest.raises(ConnectionError, match="Can't reach Azure Metadata 
Service"):
+            hook._check_azure_metadata_service()
+        assert hook._metadata_cache == {}
+        assert hook._metadata_expiry == 0
+
+    @mock.patch("requests.get")
+    def test_check_azure_metadata_service_retry_error(self, mock_get):
+        hook = BaseDatabricksHook()
+
+        resp_429 = mock.Mock()
+        resp_429.status_code = 429
+        resp_429.content = b"Too many requests"
+        http_error = requests_exceptions.HTTPError(response=resp_429)
+        mock_get.side_effect = http_error
+
+        with pytest.raises(ConnectionError, match="Failed to reach Azure 
Metadata Service after 3 retries."):
+            hook._check_azure_metadata_service()
+        assert mock_get.call_count == 3
+
+    @pytest.mark.asyncio
+    @mock.patch("aiohttp.ClientSession.get")
+    async def test_a_check_azure_metadata_service_normal(self, mock_get):
+        hook = BaseDatabricksHook()
+
+        async_mock = mock.AsyncMock()
+        async_mock.__aenter__.return_value = async_mock
+        async_mock.__aexit__.return_value = None
+        async_mock.json.return_value = {"compute": {"azEnvironment": 
"AzurePublicCloud"}}
+
+        mock_get.return_value = async_mock
+
+        async with aiohttp.ClientSession() as session:
+            hook._session = session
+            mock_attempt = mock.Mock()
+            mock_attempt.__enter__ = mock.Mock(return_value=None)
+            mock_attempt.__exit__ = mock.Mock(return_value=None)
+
+            async def mock_retry_generator():
+                yield mock_attempt
+
+            hook._a_get_retry_object = 
mock.Mock(return_value=mock_retry_generator())
+            await hook._a_check_azure_metadata_service()
+
+            assert hook._metadata_cache["compute"]["azEnvironment"] == 
"AzurePublicCloud"
+            assert hook._metadata_expiry > 0
+
+    @pytest.mark.asyncio
+    @mock.patch("aiohttp.ClientSession.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    async def test_a_check_azure_metadata_service_cached(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = {"compute": {"azEnvironment": 
"AzurePublicCloud"}}
+        hook._metadata_expiry = travel_time + 1000
+
+        async with aiohttp.ClientSession() as session:
+            hook._session = session
+            await hook._a_check_azure_metadata_service()
+            mock_get.assert_not_called()
+
+    @pytest.mark.asyncio
+    @mock.patch("aiohttp.ClientSession.get")
+    async def test_a_check_azure_metadata_service_http_error(self, mock_get):
+        hook = BaseDatabricksHook()
+
+        async_mock = mock.AsyncMock()
+        async_mock.__aenter__.side_effect = aiohttp.ClientError("Fail")
+        async_mock.__aexit__.return_value = None
+        mock_get.return_value = async_mock
+
+        async with aiohttp.ClientSession() as session:
+            hook._session = session
+            mock_attempt = mock.Mock()
+            mock_attempt.__enter__ = mock.Mock(return_value=None)
+            mock_attempt.__exit__ = mock.Mock(return_value=None)
+
+            async def mock_retry_generator():
+                yield mock_attempt
+
+            hook._a_get_retry_object = 
mock.Mock(return_value=mock_retry_generator())
+
+            with pytest.raises(ConnectionError, match="Can't reach Azure 
Metadata Service"):
+                await hook._a_check_azure_metadata_service()
+            assert hook._metadata_cache == {}
+            assert hook._metadata_expiry == 0
+
+    @pytest.mark.asyncio
+    @mock.patch("aiohttp.ClientSession.get")
+    async def test_a_check_azure_metadata_service_retry_error(self, mock_get):
+        hook = BaseDatabricksHook()
+
+        mock_get.side_effect = aiohttp.ClientResponseError(
+            request_info=mock.Mock(), history=(), status=429, message="429 Too 
Many Requests"
+        )
+
+        async with aiohttp.ClientSession() as session:
+            hook._session = session
+
+            hook._a_get_retry_object = lambda: AsyncRetrying(
+                stop=stop_after_attempt(hook.retry_limit),
+                wait=wait_fixed(0),
+                retry=retry_if_exception(hook._retryable_error),
+            )
+
+            hook._validate_azure_metadata_service = mock.Mock()
+
+            with pytest.raises(
+                ConnectionError, match="Failed to reach Azure Metadata Service 
after 3 retries."
+            ):
+                await hook._a_check_azure_metadata_service()
+            assert mock_get.call_count == 3

Reply via email to