jason810496 commented on code in PR #55462:
URL: https://github.com/apache/airflow/pull/55462#discussion_r2337256584
##########
providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py:
##########
@@ -768,3 +768,151 @@ def
test_get_error_code_with_http_error_and_valid_error_code(self):
exception.response = mock_response
hook = BaseDatabricksHook()
assert hook._get_error_code(exception) == "INVALID_REQUEST"
+
+ @mock.patch("requests.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ def test_check_azure_metadata_service_normal(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+ mock_get.return_value.json.return_value = mock_response
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+
+ hook._check_azure_metadata_service()
+
+ assert hook._metadata_cache == mock_response
+ assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl
+
+ @mock.patch("requests.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ def test_check_azure_metadata_service_cached(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+ hook._metadata_cache = mock_response
+ hook._metadata_expiry = travel_time + 1000
+
+ hook._check_azure_metadata_service()
+ mock_get.assert_not_called()
+
+ @mock.patch("requests.get")
+ def test_check_azure_metadata_service_http_error(self, mock_get):
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+ mock_get.side_effect = requests_exceptions.RequestException("Fail")
+
+ with pytest.raises(AirflowException, match="Can't reach Azure Metadata
Service"):
+ hook._check_azure_metadata_service()
+
+ @mock.patch("requests.get")
+ def test_check_azure_metadata_service_retry_error(self, mock_get):
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+
+ resp_429 = mock.Mock()
+ resp_429.status_code = 429
+ resp_429.content = b"Too many requests"
+ http_error = requests_exceptions.HTTPError(response=resp_429)
+ mock_get.side_effect = http_error
+
+ with pytest.raises(AirflowException, match="Failed to reach Azure
Metadata Service after"):
+ hook._check_azure_metadata_service()
Review Comment:
Would it be better to validate how many time does `mock_get` be called after
calling `_check_azure_metadata_service` ?
##########
providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py:
##########
@@ -768,3 +768,151 @@ def
test_get_error_code_with_http_error_and_valid_error_code(self):
exception.response = mock_response
hook = BaseDatabricksHook()
assert hook._get_error_code(exception) == "INVALID_REQUEST"
+
+ @mock.patch("requests.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ def test_check_azure_metadata_service_normal(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+ mock_get.return_value.json.return_value = mock_response
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+
+ hook._check_azure_metadata_service()
+
+ assert hook._metadata_cache == mock_response
+ assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl
+
+ @mock.patch("requests.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ def test_check_azure_metadata_service_cached(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+ hook._metadata_cache = mock_response
+ hook._metadata_expiry = travel_time + 1000
+
+ hook._check_azure_metadata_service()
+ mock_get.assert_not_called()
+
+ @mock.patch("requests.get")
+ def test_check_azure_metadata_service_http_error(self, mock_get):
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+ mock_get.side_effect = requests_exceptions.RequestException("Fail")
+
+ with pytest.raises(AirflowException, match="Can't reach Azure Metadata
Service"):
+ hook._check_azure_metadata_service()
+
+ @mock.patch("requests.get")
+ def test_check_azure_metadata_service_retry_error(self, mock_get):
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
Review Comment:
It seem we can have a factory method for constructing the
`BaseDatabricksHook` with `_metadata_cache` and `_metadata_expiry` properties.
##########
providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py:
##########
@@ -768,3 +768,151 @@ def
test_get_error_code_with_http_error_and_valid_error_code(self):
exception.response = mock_response
hook = BaseDatabricksHook()
assert hook._get_error_code(exception) == "INVALID_REQUEST"
+
+ @mock.patch("requests.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ def test_check_azure_metadata_service_normal(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+ mock_get.return_value.json.return_value = mock_response
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+
+ hook._check_azure_metadata_service()
+
+ assert hook._metadata_cache == mock_response
+ assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl
+
+ @mock.patch("requests.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ def test_check_azure_metadata_service_cached(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+ hook._metadata_cache = mock_response
+ hook._metadata_expiry = travel_time + 1000
+
+ hook._check_azure_metadata_service()
+ mock_get.assert_not_called()
+
+ @mock.patch("requests.get")
+ def test_check_azure_metadata_service_http_error(self, mock_get):
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+ mock_get.side_effect = requests_exceptions.RequestException("Fail")
+
+ with pytest.raises(AirflowException, match="Can't reach Azure Metadata
Service"):
+ hook._check_azure_metadata_service()
+
+ @mock.patch("requests.get")
+ def test_check_azure_metadata_service_retry_error(self, mock_get):
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+
+ resp_429 = mock.Mock()
+ resp_429.status_code = 429
+ resp_429.content = b"Too many requests"
+ http_error = requests_exceptions.HTTPError(response=resp_429)
+ mock_get.side_effect = http_error
+
+ with pytest.raises(AirflowException, match="Failed to reach Azure
Metadata Service after"):
+ hook._check_azure_metadata_service()
+
+ @pytest.mark.asyncio
+ @mock.patch("aiohttp.ClientSession.get")
+ async def test_a_check_azure_metadata_service_normal(self, mock_get):
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+
+ async_mock = mock.AsyncMock()
+ async_mock.__aenter__.return_value.json = mock.AsyncMock(
+ return_value={"compute": {"azEnvironment": "AzurePublicCloud"}}
+ )
+ mock_get.return_value = async_mock
+
+ async with aiohttp.ClientSession() as session:
+ hook._session = session
+ mock_attempt = mock.Mock()
+ mock_attempt.__enter__ = mock.Mock(return_value=None)
+ mock_attempt.__exit__ = mock.Mock(return_value=None)
+
+ async def mock_retry_generator():
+ yield mock_attempt
+
+ hook._a_get_retry_object =
mock.Mock(return_value=mock_retry_generator())
+ await hook._a_check_azure_metadata_service()
+
+ assert hook._metadata_cache["compute"]["azEnvironment"] ==
"AzurePublicCloud"
+ assert hook._metadata_expiry > 0
+
+ @pytest.mark.asyncio
+ @mock.patch("aiohttp.ClientSession.get")
+ @time_machine.travel("2025-07-12 12:00:00")
+ async def test_a_check_azure_metadata_service_cached(self, mock_get):
+ travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = {"compute": {"azEnvironment":
"AzurePublicCloud"}}
+ hook._metadata_expiry = travel_time + 1000
+
+ async with aiohttp.ClientSession() as session:
+ hook._session = session
+ await hook._a_check_azure_metadata_service()
+ mock_get.assert_not_called()
+
+ @pytest.mark.asyncio
+ @mock.patch("aiohttp.ClientSession.get")
+ async def test_a_check_azure_metadata_service_http_error(self, mock_get):
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+
+ async_mock = mock.AsyncMock()
+ async_mock.__aenter__.side_effect = aiohttp.ClientError("Fail")
+ async_mock.__aexit__.return_value = None
+ mock_get.return_value = async_mock
+
+ async with aiohttp.ClientSession() as session:
+ hook._session = session
+ mock_attempt = mock.Mock()
+ mock_attempt.__enter__ = mock.Mock(return_value=None)
+ mock_attempt.__exit__ = mock.Mock(return_value=None)
+
+ async def mock_retry_generator():
+ yield mock_attempt
+
+ hook._a_get_retry_object =
mock.Mock(return_value=mock_retry_generator())
+
+ with pytest.raises(AirflowException, match="Can't reach Azure
Metadata Service"):
+ await hook._a_check_azure_metadata_service()
+ assert hook._metadata_cache is None
+ assert hook._metadata_expiry == 0
+
+ @pytest.mark.asyncio
+ @mock.patch("aiohttp.ClientSession.get")
+ async def test_a_check_azure_metadata_service_retry_error(self, mock_get):
+ hook = BaseDatabricksHook()
+ hook._metadata_cache = None
+ hook._metadata_expiry = 0
+
+ mock_get.side_effect = aiohttp.ClientResponseError(
+ request_info=mock.Mock(), history=(), status=429, message="429 Too
Many Requests"
+ )
+
+ async with aiohttp.ClientSession() as session:
+ hook._session = session
+
+ hook._a_get_retry_object = lambda: AsyncRetrying(
+ stop=stop_after_attempt(hook.retry_limit),
+ wait=wait_fixed(0),
+ retry=retry_if_exception(hook._retryable_error),
+ )
+
+ hook._validate_azure_metadata_service = mock.Mock()
+
+ with pytest.raises(AirflowException, match="Failed to reach Azure
Metadata Service after"):
Review Comment:
It seems we can match with the full exception name with default retry count.
##########
providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -515,44 +518,69 @@ def _is_oauth_token_valid(token: dict,
time_key="expires_on") -> bool:
return int(token[time_key]) > (int(time.time()) +
TOKEN_REFRESH_LEAD_TIME)
- @staticmethod
- def _check_azure_metadata_service() -> None:
+ def _check_azure_metadata_service(self) -> None:
"""
- Check for Azure Metadata Service.
+ Check for Azure Metadata Service (with caching).
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
"""
+ if self._metadata_cache and time.time() < self._metadata_expiry:
+ return
try:
- jsn = requests.get(
- AZURE_METADATA_SERVICE_INSTANCE_URL,
- params={"api-version": "2021-02-01"},
- headers={"Metadata": "true"},
- timeout=2,
- ).json()
- if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
- raise AirflowException(
- f"Was able to fetch some metadata, but it doesn't look
like Azure Metadata: {jsn}"
- )
+ for attempt in self._get_retry_object():
+ with attempt:
+ response = requests.get(
+ AZURE_METADATA_SERVICE_INSTANCE_URL,
+ params={"api-version": "2021-02-01"},
+ headers={"Metadata": "true"},
+ timeout=2,
+ )
+ response.raise_for_status()
+ jsn = response.json()
+
+ self._metadata_cache = jsn
+ self._metadata_expiry = time.time() + self._metadata_ttl
+ self._validate_azure_metadata_service(jsn)
+ break
+ except RetryError:
+ raise AirflowException(
+ f"Failed to reach Azure Metadata Service after
{self.retry_limit} retries."
+ )
Review Comment:
It seems the [[LAZY CONSENSUS] Avoid adding new direct `AirflowException`
usage](https://lists.apache.org/thread/t8bnhyqy77kq4fk7fj3fmjd5wo9kv6w0) _will_
be reached soon.
Perhaps we could add new exception for this case in advance.
##########
providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -515,44 +518,69 @@ def _is_oauth_token_valid(token: dict,
time_key="expires_on") -> bool:
return int(token[time_key]) > (int(time.time()) +
TOKEN_REFRESH_LEAD_TIME)
- @staticmethod
- def _check_azure_metadata_service() -> None:
+ def _check_azure_metadata_service(self) -> None:
"""
- Check for Azure Metadata Service.
+ Check for Azure Metadata Service (with caching).
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
"""
+ if self._metadata_cache and time.time() < self._metadata_expiry:
+ return
try:
- jsn = requests.get(
- AZURE_METADATA_SERVICE_INSTANCE_URL,
- params={"api-version": "2021-02-01"},
- headers={"Metadata": "true"},
- timeout=2,
- ).json()
- if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
- raise AirflowException(
- f"Was able to fetch some metadata, but it doesn't look
like Azure Metadata: {jsn}"
- )
+ for attempt in self._get_retry_object():
+ with attempt:
+ response = requests.get(
+ AZURE_METADATA_SERVICE_INSTANCE_URL,
+ params={"api-version": "2021-02-01"},
+ headers={"Metadata": "true"},
+ timeout=2,
+ )
+ response.raise_for_status()
+ jsn = response.json()
+
+ self._metadata_cache = jsn
+ self._metadata_expiry = time.time() + self._metadata_ttl
+ self._validate_azure_metadata_service(jsn)
Review Comment:
Would it be better to validated the response before storing them?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]