This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a7f450901dd Metadata service check handle 429 (#55462)
a7f450901dd is described below
commit a7f450901dd2fbde508d74f79fe5b2e7f3d6543d
Author: Xch1 <[email protected]>
AuthorDate: Wed Sep 17 06:58:21 2025 +0800
Metadata service check handle 429 (#55462)
---
.../providers/databricks/hooks/databricks_base.py | 78 +++++++----
.../unit/databricks/hooks/test_databricks_base.py | 144 ++++++++++++++++++++-
2 files changed, 194 insertions(+), 28 deletions(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
index 7ece90122e7..919e21c3287 100644
---
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
+++
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
@@ -121,6 +121,9 @@ class BaseDatabricksHook(BaseHook):
self.oauth_tokens: dict[str, dict] = {}
self.token_timeout_seconds = 10
self.caller = caller
+ self._metadata_cache: dict[str, Any] = {}
+ self._metadata_expiry: float = 0
+ self._metadata_ttl: int = 300
def my_after_func(retry_state):
self._log_request_error(retry_state.attempt_number,
retry_state.outcome)
@@ -515,43 +518,64 @@ class BaseDatabricksHook(BaseHook):
return int(token[time_key]) > (int(time.time()) +
TOKEN_REFRESH_LEAD_TIME)
- @staticmethod
- def _check_azure_metadata_service() -> None:
+ def _check_azure_metadata_service(self) -> None:
"""
- Check for Azure Metadata Service.
+ Check for Azure Metadata Service (with caching).
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
"""
+ if self._metadata_cache and time.time() < self._metadata_expiry:
+ return
try:
- jsn = requests.get(
- AZURE_METADATA_SERVICE_INSTANCE_URL,
- params={"api-version": "2021-02-01"},
- headers={"Metadata": "true"},
- timeout=2,
- ).json()
- if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
- raise AirflowException(
- f"Was able to fetch some metadata, but it doesn't look
like Azure Metadata: {jsn}"
- )
+ for attempt in self._get_retry_object():
+ with attempt:
+ response = requests.get(
+ AZURE_METADATA_SERVICE_INSTANCE_URL,
+ params={"api-version": "2021-02-01"},
+ headers={"Metadata": "true"},
+ timeout=2,
+ )
+ response.raise_for_status()
+ response_json = response.json()
+
+ self._validate_azure_metadata_service(response_json)
+ self._metadata_cache = response_json
+ self._metadata_expiry = time.time() + self._metadata_ttl
+ break
+ except RetryError:
+ raise ConnectionError(f"Failed to reach Azure Metadata Service
after {self.retry_limit} retries.")
except (requests_exceptions.RequestException, ValueError) as e:
- raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
+ raise ConnectionError(f"Can't reach Azure Metadata Service: {e}")
async def _a_check_azure_metadata_service(self):
"""Async version of `_check_azure_metadata_service()`."""
+ if self._metadata_cache and time.time() < self._metadata_expiry:
+ return
try:
- async with self._session.get(
- url=AZURE_METADATA_SERVICE_INSTANCE_URL,
- params={"api-version": "2021-02-01"},
- headers={"Metadata": "true"},
- timeout=2,
- ) as resp:
- jsn = await resp.json()
- if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
- raise AirflowException(
- f"Was able to fetch some metadata, but it doesn't look
like Azure Metadata: {jsn}"
- )
- except (requests_exceptions.RequestException, ValueError) as e:
- raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
+ async for attempt in self._a_get_retry_object():
+ with attempt:
+ async with self._session.get(
+ url=AZURE_METADATA_SERVICE_INSTANCE_URL,
+ params={"api-version": "2021-02-01"},
+ headers={"Metadata": "true"},
+ timeout=2,
+ ) as resp:
+ resp.raise_for_status()
+ response_json = await resp.json()
+ self._validate_azure_metadata_service(response_json)
+ self._metadata_cache = response_json
+ self._metadata_expiry = time.time() + self._metadata_ttl
+ break
+ except RetryError:
+ raise ConnectionError(f"Failed to reach Azure Metadata Service
after {self.retry_limit} retries.")
+ except (aiohttp.ClientError, ValueError) as e:
+ raise ConnectionError(f"Can't reach Azure Metadata Service: {e}")
+
+ def _validate_azure_metadata_service(self, response_json: dict) -> None:
+ if "compute" not in response_json or "azEnvironment" not in
response_json["compute"]:
+ raise ValueError(
+ f"Was able to fetch some metadata, but it doesn't look like
Azure Metadata: {response_json}"
+ )
def _get_token(self, raise_error: bool = False) -> str | None:
if "token" in self.databricks_conn.extra_dejson:
diff --git
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
index 2d6df615dca..34b952f7c84 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
@@ -26,7 +26,7 @@ import time_machine
from aiohttp.client_exceptions import ClientConnectorError
from requests import exceptions as requests_exceptions
from requests.auth import HTTPBasicAuth
-from tenacity import Future, RetryError
+from tenacity import AsyncRetrying, Future, RetryError, retry_if_exception,
stop_after_attempt, wait_fixed
from airflow.exceptions import AirflowException
from airflow.models import Connection
@@ -768,3 +768,145 @@ class TestBaseDatabricksHook:
exception.response = mock_response
hook = BaseDatabricksHook()
assert hook._get_error_code(exception) == "INVALID_REQUEST"
+
+ @mock.patch("requests.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ def test_check_azure_metadata_service_normal(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+ mock_get.return_value.json.return_value = mock_response
+
+ hook._check_azure_metadata_service()
+
+ assert hook._metadata_cache == mock_response
+ assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl
+
+ @mock.patch("requests.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ def test_check_azure_metadata_service_cached(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+ hook._metadata_cache = mock_response
+ hook._metadata_expiry = travel_time + 1000
+
+ hook._check_azure_metadata_service()
+ mock_get.assert_not_called()
+
+ @mock.patch("requests.get")
+ def test_check_azure_metadata_service_http_error(self, mock_get):
+ hook = BaseDatabricksHook()
+ mock_get.side_effect = requests_exceptions.RequestException("Fail")
+
+ with pytest.raises(ConnectionError, match="Can't reach Azure Metadata
Service"):
+ hook._check_azure_metadata_service()
+ assert hook._metadata_cache == {}
+ assert hook._metadata_expiry == 0
+
+ @mock.patch("requests.get")
+ def test_check_azure_metadata_service_retry_error(self, mock_get):
+ hook = BaseDatabricksHook()
+
+ resp_429 = mock.Mock()
+ resp_429.status_code = 429
+ resp_429.content = b"Too many requests"
+ http_error = requests_exceptions.HTTPError(response=resp_429)
+ mock_get.side_effect = http_error
+
+ with pytest.raises(ConnectionError, match="Failed to reach Azure
Metadata Service after 3 retries."):
+ hook._check_azure_metadata_service()
+ assert mock_get.call_count == 3
+
+ @pytest.mark.asyncio
+ @mock.patch("aiohttp.ClientSession.get")
+ async def test_a_check_azure_metadata_service_normal(self, mock_get):
+ hook = BaseDatabricksHook()
+
+ async_mock = mock.AsyncMock()
+ async_mock.__aenter__.return_value = async_mock
+ async_mock.__aexit__.return_value = None
+ async_mock.json.return_value = {"compute": {"azEnvironment":
"AzurePublicCloud"}}
+
+ mock_get.return_value = async_mock
+
+ async with aiohttp.ClientSession() as session:
+ hook._session = session
+ mock_attempt = mock.Mock()
+ mock_attempt.__enter__ = mock.Mock(return_value=None)
+ mock_attempt.__exit__ = mock.Mock(return_value=None)
+
+ async def mock_retry_generator():
+ yield mock_attempt
+
+ hook._a_get_retry_object =
mock.Mock(return_value=mock_retry_generator())
+ await hook._a_check_azure_metadata_service()
+
+ assert hook._metadata_cache["compute"]["azEnvironment"] ==
"AzurePublicCloud"
+ assert hook._metadata_expiry > 0
+
+ @pytest.mark.asyncio
+ @mock.patch("aiohttp.ClientSession.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ async def test_a_check_azure_metadata_service_cached(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = {"compute": {"azEnvironment":
"AzurePublicCloud"}}
+ hook._metadata_expiry = travel_time + 1000
+
+ async with aiohttp.ClientSession() as session:
+ hook._session = session
+ await hook._a_check_azure_metadata_service()
+ mock_get.assert_not_called()
+
+ @pytest.mark.asyncio
+ @mock.patch("aiohttp.ClientSession.get")
+ async def test_a_check_azure_metadata_service_http_error(self, mock_get):
+ hook = BaseDatabricksHook()
+
+ async_mock = mock.AsyncMock()
+ async_mock.__aenter__.side_effect = aiohttp.ClientError("Fail")
+ async_mock.__aexit__.return_value = None
+ mock_get.return_value = async_mock
+
+ async with aiohttp.ClientSession() as session:
+ hook._session = session
+ mock_attempt = mock.Mock()
+ mock_attempt.__enter__ = mock.Mock(return_value=None)
+ mock_attempt.__exit__ = mock.Mock(return_value=None)
+
+ async def mock_retry_generator():
+ yield mock_attempt
+
+ hook._a_get_retry_object =
mock.Mock(return_value=mock_retry_generator())
+
+ with pytest.raises(ConnectionError, match="Can't reach Azure
Metadata Service"):
+ await hook._a_check_azure_metadata_service()
+ assert hook._metadata_cache == {}
+ assert hook._metadata_expiry == 0
+
+ @pytest.mark.asyncio
+ @mock.patch("aiohttp.ClientSession.get")
+ async def test_a_check_azure_metadata_service_retry_error(self, mock_get):
+ hook = BaseDatabricksHook()
+
+ mock_get.side_effect = aiohttp.ClientResponseError(
+ request_info=mock.Mock(), history=(), status=429, message="429 Too
Many Requests"
+ )
+
+ async with aiohttp.ClientSession() as session:
+ hook._session = session
+
+ hook._a_get_retry_object = lambda: AsyncRetrying(
+ stop=stop_after_attempt(hook.retry_limit),
+ wait=wait_fixed(0),
+ retry=retry_if_exception(hook._retryable_error),
+ )
+
+ hook._validate_azure_metadata_service = mock.Mock()
+
+ with pytest.raises(
+ ConnectionError, match="Failed to reach Azure Metadata Service
after 3 retries."
+ ):
+ await hook._a_check_azure_metadata_service()
+ assert mock_get.call_count == 3