jason810496 commented on code in PR #55462:
URL: https://github.com/apache/airflow/pull/55462#discussion_r2337256584


##########
providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py:
##########
@@ -768,3 +768,151 @@ def 
test_get_error_code_with_http_error_and_valid_error_code(self):
         exception.response = mock_response
         hook = BaseDatabricksHook()
         assert hook._get_error_code(exception) == "INVALID_REQUEST"
+
+    @mock.patch("requests.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    def test_check_azure_metadata_service_normal(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+        mock_get.return_value.json.return_value = mock_response
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+
+        hook._check_azure_metadata_service()
+
+        assert hook._metadata_cache == mock_response
+        assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl
+
+    @mock.patch("requests.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    def test_check_azure_metadata_service_cached(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+        hook._metadata_cache = mock_response
+        hook._metadata_expiry = travel_time + 1000
+
+        hook._check_azure_metadata_service()
+        mock_get.assert_not_called()
+
+    @mock.patch("requests.get")
+    def test_check_azure_metadata_service_http_error(self, mock_get):
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+        mock_get.side_effect = requests_exceptions.RequestException("Fail")
+
+        with pytest.raises(AirflowException, match="Can't reach Azure Metadata 
Service"):
+            hook._check_azure_metadata_service()
+
+    @mock.patch("requests.get")
+    def test_check_azure_metadata_service_retry_error(self, mock_get):
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+
+        resp_429 = mock.Mock()
+        resp_429.status_code = 429
+        resp_429.content = b"Too many requests"
+        http_error = requests_exceptions.HTTPError(response=resp_429)
+        mock_get.side_effect = http_error
+
+        with pytest.raises(AirflowException, match="Failed to reach Azure 
Metadata Service after"):
+            hook._check_azure_metadata_service()

Review Comment:
   Would it be better to validate how many time does `mock_get` be called after 
calling `_check_azure_metadata_service` ?



##########
providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py:
##########
@@ -768,3 +768,151 @@ def 
test_get_error_code_with_http_error_and_valid_error_code(self):
         exception.response = mock_response
         hook = BaseDatabricksHook()
         assert hook._get_error_code(exception) == "INVALID_REQUEST"
+
+    @mock.patch("requests.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    def test_check_azure_metadata_service_normal(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+        mock_get.return_value.json.return_value = mock_response
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+
+        hook._check_azure_metadata_service()
+
+        assert hook._metadata_cache == mock_response
+        assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl
+
+    @mock.patch("requests.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    def test_check_azure_metadata_service_cached(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+        hook._metadata_cache = mock_response
+        hook._metadata_expiry = travel_time + 1000
+
+        hook._check_azure_metadata_service()
+        mock_get.assert_not_called()
+
+    @mock.patch("requests.get")
+    def test_check_azure_metadata_service_http_error(self, mock_get):
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+        mock_get.side_effect = requests_exceptions.RequestException("Fail")
+
+        with pytest.raises(AirflowException, match="Can't reach Azure Metadata 
Service"):
+            hook._check_azure_metadata_service()
+
+    @mock.patch("requests.get")
+    def test_check_azure_metadata_service_retry_error(self, mock_get):
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0

Review Comment:
   It seem we can have a factory method for constructing the 
`BaseDatabricksHook` with `_metadata_cache` and `_metadata_expiry` properties.



##########
providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py:
##########
@@ -768,3 +768,151 @@ def 
test_get_error_code_with_http_error_and_valid_error_code(self):
         exception.response = mock_response
         hook = BaseDatabricksHook()
         assert hook._get_error_code(exception) == "INVALID_REQUEST"
+
+    @mock.patch("requests.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    def test_check_azure_metadata_service_normal(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+        mock_get.return_value.json.return_value = mock_response
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+
+        hook._check_azure_metadata_service()
+
+        assert hook._metadata_cache == mock_response
+        assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl
+
+    @mock.patch("requests.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    def test_check_azure_metadata_service_cached(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
+        hook._metadata_cache = mock_response
+        hook._metadata_expiry = travel_time + 1000
+
+        hook._check_azure_metadata_service()
+        mock_get.assert_not_called()
+
+    @mock.patch("requests.get")
+    def test_check_azure_metadata_service_http_error(self, mock_get):
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+        mock_get.side_effect = requests_exceptions.RequestException("Fail")
+
+        with pytest.raises(AirflowException, match="Can't reach Azure Metadata 
Service"):
+            hook._check_azure_metadata_service()
+
+    @mock.patch("requests.get")
+    def test_check_azure_metadata_service_retry_error(self, mock_get):
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+
+        resp_429 = mock.Mock()
+        resp_429.status_code = 429
+        resp_429.content = b"Too many requests"
+        http_error = requests_exceptions.HTTPError(response=resp_429)
+        mock_get.side_effect = http_error
+
+        with pytest.raises(AirflowException, match="Failed to reach Azure 
Metadata Service after"):
+            hook._check_azure_metadata_service()
+
+    @pytest.mark.asyncio
+    @mock.patch("aiohttp.ClientSession.get")
+    async def test_a_check_azure_metadata_service_normal(self, mock_get):
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+
+        async_mock = mock.AsyncMock()
+        async_mock.__aenter__.return_value.json = mock.AsyncMock(
+            return_value={"compute": {"azEnvironment": "AzurePublicCloud"}}
+        )
+        mock_get.return_value = async_mock
+
+        async with aiohttp.ClientSession() as session:
+            hook._session = session
+            mock_attempt = mock.Mock()
+            mock_attempt.__enter__ = mock.Mock(return_value=None)
+            mock_attempt.__exit__ = mock.Mock(return_value=None)
+
+            async def mock_retry_generator():
+                yield mock_attempt
+
+            hook._a_get_retry_object = 
mock.Mock(return_value=mock_retry_generator())
+            await hook._a_check_azure_metadata_service()
+
+            assert hook._metadata_cache["compute"]["azEnvironment"] == 
"AzurePublicCloud"
+            assert hook._metadata_expiry > 0
+
+    @pytest.mark.asyncio
+    @mock.patch("aiohttp.ClientSession.get")
+    @time_machine.travel("2025-07-12 12:00:00")
+    async def test_a_check_azure_metadata_service_cached(self, mock_get):
+        travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = {"compute": {"azEnvironment": 
"AzurePublicCloud"}}
+        hook._metadata_expiry = travel_time + 1000
+
+        async with aiohttp.ClientSession() as session:
+            hook._session = session
+            await hook._a_check_azure_metadata_service()
+            mock_get.assert_not_called()
+
+    @pytest.mark.asyncio
+    @mock.patch("aiohttp.ClientSession.get")
+    async def test_a_check_azure_metadata_service_http_error(self, mock_get):
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+
+        async_mock = mock.AsyncMock()
+        async_mock.__aenter__.side_effect = aiohttp.ClientError("Fail")
+        async_mock.__aexit__.return_value = None
+        mock_get.return_value = async_mock
+
+        async with aiohttp.ClientSession() as session:
+            hook._session = session
+            mock_attempt = mock.Mock()
+            mock_attempt.__enter__ = mock.Mock(return_value=None)
+            mock_attempt.__exit__ = mock.Mock(return_value=None)
+
+            async def mock_retry_generator():
+                yield mock_attempt
+
+            hook._a_get_retry_object = 
mock.Mock(return_value=mock_retry_generator())
+
+            with pytest.raises(AirflowException, match="Can't reach Azure 
Metadata Service"):
+                await hook._a_check_azure_metadata_service()
+            assert hook._metadata_cache is None
+            assert hook._metadata_expiry == 0
+
+    @pytest.mark.asyncio
+    @mock.patch("aiohttp.ClientSession.get")
+    async def test_a_check_azure_metadata_service_retry_error(self, mock_get):
+        hook = BaseDatabricksHook()
+        hook._metadata_cache = None
+        hook._metadata_expiry = 0
+
+        mock_get.side_effect = aiohttp.ClientResponseError(
+            request_info=mock.Mock(), history=(), status=429, message="429 Too 
Many Requests"
+        )
+
+        async with aiohttp.ClientSession() as session:
+            hook._session = session
+
+            hook._a_get_retry_object = lambda: AsyncRetrying(
+                stop=stop_after_attempt(hook.retry_limit),
+                wait=wait_fixed(0),
+                retry=retry_if_exception(hook._retryable_error),
+            )
+
+            hook._validate_azure_metadata_service = mock.Mock()
+
+            with pytest.raises(AirflowException, match="Failed to reach Azure 
Metadata Service after"):

Review Comment:
   It seems we can match with the full exception name with default retry count.



##########
providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -515,44 +518,69 @@ def _is_oauth_token_valid(token: dict, 
time_key="expires_on") -> bool:
 
         return int(token[time_key]) > (int(time.time()) + 
TOKEN_REFRESH_LEAD_TIME)
 
-    @staticmethod
-    def _check_azure_metadata_service() -> None:
+    def _check_azure_metadata_service(self) -> None:
         """
-        Check for Azure Metadata Service.
+        Check for Azure Metadata Service (with caching).
 
         
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
         """
+        if self._metadata_cache and time.time() < self._metadata_expiry:
+            return
         try:
-            jsn = requests.get(
-                AZURE_METADATA_SERVICE_INSTANCE_URL,
-                params={"api-version": "2021-02-01"},
-                headers={"Metadata": "true"},
-                timeout=2,
-            ).json()
-            if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
-                raise AirflowException(
-                    f"Was able to fetch some metadata, but it doesn't look 
like Azure Metadata: {jsn}"
-                )
+            for attempt in self._get_retry_object():
+                with attempt:
+                    response = requests.get(
+                        AZURE_METADATA_SERVICE_INSTANCE_URL,
+                        params={"api-version": "2021-02-01"},
+                        headers={"Metadata": "true"},
+                        timeout=2,
+                    )
+                    response.raise_for_status()
+                    jsn = response.json()
+
+                    self._metadata_cache = jsn
+                    self._metadata_expiry = time.time() + self._metadata_ttl
+                    self._validate_azure_metadata_service(jsn)
+                    break
+        except RetryError:
+            raise AirflowException(
+                f"Failed to reach Azure Metadata Service after 
{self.retry_limit} retries."
+            )

Review Comment:
   It seems the [[LAZY CONSENSUS] Avoid adding new direct `AirflowException` 
usage](https://lists.apache.org/thread/t8bnhyqy77kq4fk7fj3fmjd5wo9kv6w0) _will_ 
be reached soon.
   
   Perhaps we could add new exception for this case in advance.



##########
providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -515,44 +518,69 @@ def _is_oauth_token_valid(token: dict, 
time_key="expires_on") -> bool:
 
         return int(token[time_key]) > (int(time.time()) + 
TOKEN_REFRESH_LEAD_TIME)
 
-    @staticmethod
-    def _check_azure_metadata_service() -> None:
+    def _check_azure_metadata_service(self) -> None:
         """
-        Check for Azure Metadata Service.
+        Check for Azure Metadata Service (with caching).
 
         
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
         """
+        if self._metadata_cache and time.time() < self._metadata_expiry:
+            return
         try:
-            jsn = requests.get(
-                AZURE_METADATA_SERVICE_INSTANCE_URL,
-                params={"api-version": "2021-02-01"},
-                headers={"Metadata": "true"},
-                timeout=2,
-            ).json()
-            if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
-                raise AirflowException(
-                    f"Was able to fetch some metadata, but it doesn't look 
like Azure Metadata: {jsn}"
-                )
+            for attempt in self._get_retry_object():
+                with attempt:
+                    response = requests.get(
+                        AZURE_METADATA_SERVICE_INSTANCE_URL,
+                        params={"api-version": "2021-02-01"},
+                        headers={"Metadata": "true"},
+                        timeout=2,
+                    )
+                    response.raise_for_status()
+                    jsn = response.json()
+
+                    self._metadata_cache = jsn
+                    self._metadata_expiry = time.time() + self._metadata_ttl
+                    self._validate_azure_metadata_service(jsn)

Review Comment:
   Would it be better to validated the response before storing them?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to