This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ebcb16201a  make DefaultAzureCredential configurable in 
AzureKeyVaultBackend (#35052)
ebcb16201a is described below

commit ebcb16201af08f9815124f27e2fba841c2b9cd9f
Author: Wei Lee <[email protected]>
AuthorDate: Tue Oct 31 02:01:43 2023 +0900

     make DefaultAzureCredential configurable in AzureKeyVaultBackend (#35052)
    
    * feat(azure): make DefaultAzureCredential configurable in 
AzureKeyVaultBackend
    
    * test(providers/azure): extract common module string as a variable
    
    * test(providers/azure): add test case 
test_client_authenticate_with_default_azure_credential_and_customized_configuration
    
    * docs(providers/microsoft): update document for azure secret backend kwargs
---
 .../providers/microsoft/azure/secrets/key_vault.py | 28 ++++++++--
 .../secrets-backends/azure-key-vault.rst           | 10 ++++
 .../azure/secrets/test_azure_key_vault.py          | 61 ++++++++++++++++------
 3 files changed, 79 insertions(+), 20 deletions(-)

diff --git a/airflow/providers/microsoft/azure/secrets/key_vault.py 
b/airflow/providers/microsoft/azure/secrets/key_vault.py
index 3c6b909570..32ec56a837 100644
--- a/airflow/providers/microsoft/azure/secrets/key_vault.py
+++ b/airflow/providers/microsoft/azure/secrets/key_vault.py
@@ -14,6 +14,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""
+This module contains Azure Key Vault Backend.
+
+.. spelling:word-list::
+
+    Entra
+"""
+
 from __future__ import annotations
 
 import logging
@@ -76,8 +84,11 @@ class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin):
         If not given, it falls back to ``DefaultAzureCredential``
     :param client_id: The client id of an Azure Key Vault to use.
         If not given, it falls back to ``DefaultAzureCredential``
-    :param client_secret: The client secret of an Azure Key Vault to use.
-        If not given, it falls back to ``DefaultAzureCredential``
+    :param managed_identity_client_id: The client ID of a user-assigned 
managed identity.
+        If provided with `workload_identity_tenant_id`, they'll pass to 
``DefaultAzureCredential``.
+    :param workload_identity_tenant_id: ID of the application's Microsoft 
Entra tenant.
+        Also called its "directory" ID.
+        If provided with `managed_identity_client_id`, they'll pass to 
``DefaultAzureCredential``.
     """
 
     def __init__(
@@ -91,6 +102,8 @@ class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin):
         tenant_id: str = "",
         client_id: str = "",
         client_secret: str = "",
+        managed_identity_client_id: str = "",
+        workload_identity_tenant_id: str = "",
         **kwargs,
     ) -> None:
         super().__init__()
@@ -118,6 +131,8 @@ class AzureKeyVaultBackend(BaseSecretsBackend, 
LoggingMixin):
         self.tenant_id = tenant_id
         self.client_id = client_id
         self.client_secret = client_secret
+        self.managed_identity_client_id = managed_identity_client_id
+        self.workload_identity_tenant_id = workload_identity_tenant_id
         self.kwargs = kwargs
 
     @cached_property
@@ -127,7 +142,14 @@ class AzureKeyVaultBackend(BaseSecretsBackend, 
LoggingMixin):
         if all([self.tenant_id, self.client_id, self.client_secret]):
             credential = ClientSecretCredential(self.tenant_id, 
self.client_id, self.client_secret)
         else:
-            credential = DefaultAzureCredential()
+            if self.managed_identity_client_id and 
self.workload_identity_tenant_id:
+                credential = DefaultAzureCredential(
+                    managed_identity_client_id=self.managed_identity_client_id,
+                    
workload_identity_tenant_id=self.workload_identity_tenant_id,
+                    
additionally_allowed_tenants=[self.workload_identity_tenant_id],
+                )
+            else:
+                credential = DefaultAzureCredential()
         client = SecretClient(vault_url=self.vault_url, credential=credential, 
**self.kwargs)
         return client
 
diff --git 
a/docs/apache-airflow-providers-microsoft-azure/secrets-backends/azure-key-vault.rst
 
b/docs/apache-airflow-providers-microsoft-azure/secrets-backends/azure-key-vault.rst
index b7430d3269..942d6d806f 100644
--- 
a/docs/apache-airflow-providers-microsoft-azure/secrets-backends/azure-key-vault.rst
+++ 
b/docs/apache-airflow-providers-microsoft-azure/secrets-backends/azure-key-vault.rst
@@ -67,6 +67,16 @@ Storing and Retrieving Variables
 If you have set ``variables_prefix`` as ``airflow-variables``, then for an 
Variable key of ``hello``,
 you would want to store your Variable at ``airflow-variables-hello``.
 
+
+Authentication
+""""""""""""""
+There are 3 ways to authenticate Azure Key Vault  backend.
+
+1. Set ``tenant_id``, ``client_id``, ``client_secret`` (using 
`ClientSecretCredential 
<https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python>`_)
+2. Set ``managed_identity_client_id``, ``workload_identity_tenant_id`` (using 
`DefaultAzureCredential 
<https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python>`_
 with these arguments)
+3. Not providing extra connection configuration for falling back to 
`DefaultAzureCredential 
<https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python>`_
+
+
 Reference
 """""""""
 
diff --git a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py 
b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
index 6b5c8d5dd2..7292b4fb7c 100644
--- a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
+++ b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
@@ -25,16 +25,18 @@ from azure.core.exceptions import ResourceNotFoundError
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.microsoft.azure.secrets.key_vault import 
AzureKeyVaultBackend
 
+KEY_VAULT_MODULE = "airflow.providers.microsoft.azure.secrets.key_vault"
+
 
 class TestAzureKeyVaultBackend:
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.get_conn_value")
+    @mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.get_conn_value")
     def test_get_connection(self, mock_get_value):
         mock_get_value.return_value = "scheme://user:pass@host:100"
         conn = AzureKeyVaultBackend().get_connection("fake_conn")
         assert conn.host == "host"
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.DefaultAzureCredential")
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.SecretClient")
+    @mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
         mock_cred = mock.Mock()
         mock_sec_client = mock.Mock()
@@ -53,7 +55,7 @@ class TestAzureKeyVaultBackend:
         )
         assert returned_uri == "postgresql://airflow:airflow@host:5432/airflow"
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
+    @mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
     def test_get_conn_uri_non_existent_key(self, mock_client):
         """
         Test that if the key with connection ID is not present,
@@ -67,7 +69,7 @@ class TestAzureKeyVaultBackend:
             assert backend.get_conn_uri(conn_id=conn_id) is None
         assert backend.get_connection(conn_id=conn_id) is None
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
+    @mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
     def test_get_variable(self, mock_client):
         mock_client.get_secret.return_value = mock.Mock(value="world")
         backend = AzureKeyVaultBackend()
@@ -75,7 +77,7 @@ class TestAzureKeyVaultBackend:
         
mock_client.get_secret.assert_called_with(name="airflow-variables-hello")
         assert "world" == returned_uri
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
+    @mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
     def test_get_variable_non_existent_key(self, mock_client):
         """
         Test that if Variable key is not present,
@@ -85,7 +87,7 @@ class TestAzureKeyVaultBackend:
         backend = AzureKeyVaultBackend()
         assert backend.get_variable("test_mysql") is None
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
+    @mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
     def test_get_secret_value_not_found(self, mock_client):
         """
         Test that if a non-existent secret returns None
@@ -96,7 +98,7 @@ class TestAzureKeyVaultBackend:
             backend._get_secret(path_prefix=backend.connections_prefix, 
secret_id="test_non_existent") is None
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
+    @mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
     def test_get_secret_value(self, mock_client):
         """
         Test that get_secret returns the secret value
@@ -107,7 +109,7 @@ class TestAzureKeyVaultBackend:
         
mock_client.get_secret.assert_called_with(name="af-secrets-test-mysql-password")
         assert secret_val == "super-secret"
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret")
+    @mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend._get_secret")
     def test_connection_prefix_none_value(self, mock_get_secret):
         """
         Test that if Connections prefix is None,
@@ -125,7 +127,7 @@ class TestAzureKeyVaultBackend:
             assert backend.get_conn_uri("test_mysql") is None
             mock_get_secret.assert_not_called()
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret")
+    @mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend._get_secret")
     def test_variable_prefix_none_value(self, mock_get_secret):
         """
         Test that if Variables prefix is None,
@@ -138,7 +140,7 @@ class TestAzureKeyVaultBackend:
         assert backend.get_variable("hello") is None
         mock_get_secret.assert_not_called()
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret")
+    @mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend._get_secret")
     def test_config_prefix_none_value(self, mock_get_secret):
         """
         Test that if Config prefix is None,
@@ -151,20 +153,45 @@ class TestAzureKeyVaultBackend:
         assert backend.get_config("test_mysql") is None
         mock_get_secret.assert_not_called()
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.DefaultAzureCredential")
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.ClientSecretCredential")
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.SecretClient")
+    @mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def test_client_authenticate_with_default_azure_credential(
         self, mock_client, mock_client_secret_credential, 
mock_defaul_azure_credential
     ):
+        """
+        Test that if AzureKeyValueBackend is authenticated with 
DefaultAzureCredential
+        tenant_id, client_id and client_secret are not provided
+
+        """
         backend = 
AzureKeyVaultBackend(vault_url="https://example-akv-resource-name.vault.azure.net/";)
         backend.client
         assert not mock_client_secret_credential.called
         mock_defaul_azure_credential.assert_called_once()
 
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.DefaultAzureCredential")
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.ClientSecretCredential")
-    
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.SecretClient")
+    @mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
+    def 
test_client_authenticate_with_default_azure_credential_and_customized_configuration(
+        self, mock_client, mock_client_secret_credential, 
mock_defaul_azure_credential
+    ):
+        backend = AzureKeyVaultBackend(
+            vault_url="https://example-akv-resource-name.vault.azure.net/";,
+            managed_identity_client_id="managed_identity_client_id",
+            workload_identity_tenant_id="workload_identity_tenant_id",
+            additionally_allowed_tenants=["workload_identity_tenant_id"],
+        )
+        backend.client
+        assert not mock_client_secret_credential.called
+        mock_defaul_azure_credential.assert_called_once_with(
+            managed_identity_client_id="managed_identity_client_id",
+            workload_identity_tenant_id="workload_identity_tenant_id",
+            additionally_allowed_tenants=["workload_identity_tenant_id"],
+        )
+
+    @mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def test_client_authenticate_with_client_secret_credential(
         self, mock_client, mock_client_secret_credential, 
mock_defaul_azure_credential
     ):

Reply via email to