KoviAnusha commented on code in PR #57184:
URL: https://github.com/apache/airflow/pull/57184#discussion_r2459065850


##########
providers/postgres/src/airflow/providers/postgres/hooks/postgres.py:
##########
@@ -522,19 +522,17 @@ def get_azure_iam_token(self, conn: Connection) -> 
tuple[str, str, int]:
             azure_conn = Connection.get(azure_conn_id)
         except AttributeError:
             azure_conn = Connection.get_connection_from_secrets(azure_conn_id) 
 # type: ignore[attr-defined]
-        azure_base_hook: AzureBaseHook = azure_conn.get_hook()
-        scope = conf.get("postgres", "azure_oauth_scope", 
fallback=self.default_azure_oauth_scope)
         try:
-            token = azure_base_hook.get_token(scope).token
-        except AttributeError as e:
-            if e.name == "get_token" and e.obj == azure_base_hook:
-                raise AttributeError(
-                    "'AzureBaseHook' object has no attribute 'get_token'. "
-                    "Please upgrade 
apache-airflow-providers-microsoft-azure>=12.8.0",
-                    name=e.name,
-                    obj=e.obj,
+            azure_base_hook: AzureBaseHook = azure_conn.get_hook()
+        except TypeError as e:
+            if "required positional argument: 'sdk_client'" in str(e):
+                raise TypeError(
+                    "Getting azure token is not supported by current version 
of 'AzureBaseHook'. "
+                    "Please upgrade 
apache-airflow-providers-microsoft-azure>=12.8.0"
                 ) from e
             raise
+        scope = conf.get("postgres", "azure_oauth_scope", 
fallback=self.default_azure_oauth_scope)
+        token = azure_base_hook.get_token(scope).token

Review Comment:
   Small nit: this same config-key logic is also in the Snowflake hook. It 
might be worth moving to a small shared helper later to keep both hooks 
consistent.
   



##########
providers/postgres/src/airflow/providers/postgres/hooks/postgres.py:
##########
@@ -522,19 +522,17 @@ def get_azure_iam_token(self, conn: Connection) -> 
tuple[str, str, int]:
             azure_conn = Connection.get(azure_conn_id)
         except AttributeError:
             azure_conn = Connection.get_connection_from_secrets(azure_conn_id) 
 # type: ignore[attr-defined]
-        azure_base_hook: AzureBaseHook = azure_conn.get_hook()
-        scope = conf.get("postgres", "azure_oauth_scope", 
fallback=self.default_azure_oauth_scope)
         try:
-            token = azure_base_hook.get_token(scope).token
-        except AttributeError as e:
-            if e.name == "get_token" and e.obj == azure_base_hook:
-                raise AttributeError(
-                    "'AzureBaseHook' object has no attribute 'get_token'. "
-                    "Please upgrade 
apache-airflow-providers-microsoft-azure>=12.8.0",
-                    name=e.name,
-                    obj=e.obj,
+            azure_base_hook: AzureBaseHook = azure_conn.get_hook()
+        except TypeError as e:
+            if "required positional argument: 'sdk_client'" in str(e):
+                raise TypeError(
+                    "Getting azure token is not supported by current version 
of 'AzureBaseHook'. "
+                    "Please upgrade 
apache-airflow-providers-microsoft-azure>=12.8.0"
                 ) from e
             raise

Review Comment:
   Good that you are handling old AzureBaseHook versions. You could simplify 
this block by checking hasattr(azure_base_hook, "get_token") before calling it. 
That might read cleaner than using try/except AttributeError.  Also, maybe 
raising AirflowException instead of built-in errors, so users see a clearer 
message in task logs?
   



##########
providers/postgres/tests/unit/postgres/hooks/test_postgres.py:
##########
@@ -474,19 +474,19 @@ def test_get_conn_azure_iam(self, mocker, mock_connect):
     def test_get_azure_iam_token_expect_failure_on_get_token(self, mocker):
         """Test get_azure_iam_token method gets token from provided connection 
id"""
 
-        class MockAzureBaseHookWithoutGetToken:
-            def __init__(self):
+        class MockAzureBaseHookOldVersion:
+            def __init__(self, sdk_client, conn_id="azure_default"):
                 pass
 
         azure_conn_id = "azure_test_conn"
         mock_connection_class = 
mocker.patch("airflow.providers.postgres.hooks.postgres.Connection")
-        mock_connection_class.get.return_value.get_hook.return_value = 
MockAzureBaseHookWithoutGetToken()
+        mock_connection_class.get.return_value.get_hook = 
MockAzureBaseHookOldVersion

Review Comment:
   Nice clear naming for the mock classes. You could add a short docstring 
above each to explain which legacy behavior they simulate (e.g., missing 
get_token vs old SDK client) that helps future contributors understand the 
intent at a glance.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to