This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e22f961071 Reuse get_default_azure_credential method from Azure utils 
method (#35318)
e22f961071 is described below

commit e22f96107198f9afbbfce2b7e1913ae598f1813a
Author: Wei Lee <[email protected]>
AuthorDate: Wed Nov 1 16:33:42 2023 +0800

    Reuse get_default_azure_credential method from Azure utils method (#35318)
---
 airflow/providers/microsoft/azure/secrets/key_vault.py      | 13 +++++--------
 .../microsoft/azure/secrets/test_azure_key_vault.py         | 10 ++++------
 2 files changed, 9 insertions(+), 14 deletions(-)

diff --git a/airflow/providers/microsoft/azure/secrets/key_vault.py 
b/airflow/providers/microsoft/azure/secrets/key_vault.py
index 32ec56a837..f4c51f1f64 100644
--- a/airflow/providers/microsoft/azure/secrets/key_vault.py
+++ b/airflow/providers/microsoft/azure/secrets/key_vault.py
@@ -35,6 +35,7 @@ from azure.identity import ClientSecretCredential, 
DefaultAzureCredential
 from azure.keyvault.secrets import SecretClient
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential
 from airflow.secrets import BaseSecretsBackend
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.version import version as airflow_version
@@ -142,14 +143,10 @@ class AzureKeyVaultBackend(BaseSecretsBackend, 
LoggingMixin):
         if all([self.tenant_id, self.client_id, self.client_secret]):
             credential = ClientSecretCredential(self.tenant_id, 
self.client_id, self.client_secret)
         else:
-            if self.managed_identity_client_id and 
self.workload_identity_tenant_id:
-                credential = DefaultAzureCredential(
-                    managed_identity_client_id=self.managed_identity_client_id,
-                    
workload_identity_tenant_id=self.workload_identity_tenant_id,
-                    
additionally_allowed_tenants=[self.workload_identity_tenant_id],
-                )
-            else:
-                credential = DefaultAzureCredential()
+            credential = get_default_azure_credential(
+                managed_identity_client_id=self.managed_identity_client_id,
+                workload_identity_tenant_id=self.workload_identity_tenant_id,
+            )
         client = SecretClient(vault_url=self.vault_url, credential=credential, 
**self.kwargs)
         return client
 
diff --git a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py 
b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
index 7292b4fb7c..51b20084ed 100644
--- a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
+++ b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
@@ -35,7 +35,7 @@ class TestAzureKeyVaultBackend:
         conn = AzureKeyVaultBackend().get_connection("fake_conn")
         assert conn.host == "host"
 
-    @mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
     @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
         mock_cred = mock.Mock()
@@ -153,7 +153,7 @@ class TestAzureKeyVaultBackend:
         assert backend.get_config("test_mysql") is None
         mock_get_secret.assert_not_called()
 
-    @mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
     @mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
     @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def test_client_authenticate_with_default_azure_credential(
@@ -169,7 +169,7 @@ class TestAzureKeyVaultBackend:
         assert not mock_client_secret_credential.called
         mock_defaul_azure_credential.assert_called_once()
 
-    @mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
     @mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
     @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def 
test_client_authenticate_with_default_azure_credential_and_customized_configuration(
@@ -179,17 +179,15 @@ class TestAzureKeyVaultBackend:
             vault_url="https://example-akv-resource-name.vault.azure.net/";,
             managed_identity_client_id="managed_identity_client_id",
             workload_identity_tenant_id="workload_identity_tenant_id",
-            additionally_allowed_tenants=["workload_identity_tenant_id"],
         )
         backend.client
         assert not mock_client_secret_credential.called
         mock_defaul_azure_credential.assert_called_once_with(
             managed_identity_client_id="managed_identity_client_id",
             workload_identity_tenant_id="workload_identity_tenant_id",
-            additionally_allowed_tenants=["workload_identity_tenant_id"],
         )
 
-    @mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
     @mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
     @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def test_client_authenticate_with_client_secret_credential(

Reply via email to