This is an automated email from the ASF dual-hosted git repository.

pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new faa50cbe2f feat(providers/microsoft): add DefaultAzureCredential 
support to AzureContainerInstanceHook (#33467)
faa50cbe2f is described below

commit faa50cbe2f6dbf816e599bbbb933ac4976a55778
Author: Wei Lee <[email protected]>
AuthorDate: Thu Aug 24 20:25:59 2023 +0800

    feat(providers/microsoft): add DefaultAzureCredential support to 
AzureContainerInstanceHook (#33467)
    
    * feat(provider/microsoft): add DefaultAzureCredential compatibility to 
azure-python-sdk through AzureIdentityCredentialAdapter wrapper
    
https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig
    
    * feat(providers/microsoft): add DefaultAzureCredential support to 
AzureContainerInstanceHook
    
    * fix(providers/microsfot): replace AzureIdentityCredentialAdapter with 
DefaultAzureCredential due to backward compatibility
---
 .../providers/microsoft/azure/hooks/base_azure.py  | 16 +++++--
 airflow/providers/microsoft/azure/utils.py         | 50 ++++++++++++++++++++++
 2 files changed, 62 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/base_azure.py 
b/airflow/providers/microsoft/azure/hooks/base_azure.py
index 214cbd5f20..54130b3b2e 100644
--- a/airflow/providers/microsoft/azure/hooks/base_azure.py
+++ b/airflow/providers/microsoft/azure/hooks/base_azure.py
@@ -24,6 +24,7 @@ from azure.common.credentials import 
ServicePrincipalCredentials
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
+from airflow.providers.microsoft.azure.utils import 
AzureIdentityCredentialAdapter
 
 
 class AzureBaseHook(BaseHook):
@@ -124,10 +125,17 @@ class AzureBaseHook(BaseHook):
             self.log.info("Getting connection using a JSON config.")
             return get_client_from_json_dict(client_class=self.sdk_client, 
config_dict=key_json)
 
-        self.log.info("Getting connection using specific credentials and 
subscription_id.")
-        return self.sdk_client(
-            credentials=ServicePrincipalCredentials(
+        credentials: ServicePrincipalCredentials | 
AzureIdentityCredentialAdapter
+        if all([conn.login, conn.password, tenant]):
+            self.log.info("Getting connection using specific credentials and 
subscription_id.")
+            credentials = ServicePrincipalCredentials(
                 client_id=conn.login, secret=conn.password, tenant=tenant
-            ),
+            )
+        else:
+            self.log.info("Using DefaultAzureCredential as credential")
+            credentials = AzureIdentityCredentialAdapter()
+
+        return self.sdk_client(
+            credentials=credentials,
             subscription_id=subscription_id,
         )
diff --git a/airflow/providers/microsoft/azure/utils.py 
b/airflow/providers/microsoft/azure/utils.py
index 0a8edcf7c7..5afc2a48ca 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -19,6 +19,12 @@ from __future__ import annotations
 
 import warnings
 
+from azure.core.pipeline import PipelineContext, PipelineRequest
+from azure.core.pipeline.policies import BearerTokenCredentialPolicy
+from azure.core.pipeline.transport import HttpRequest
+from azure.identity import DefaultAzureCredential
+from msrest.authentication import BasicTokenAuthentication
+
 
 def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
     """Get field from extra, first checking short name, then for backcompat we 
check for prefixed name."""
@@ -43,3 +49,47 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict, 
field_name: str):
     if ret == "":
         return None
     return ret
+
+
+class AzureIdentityCredentialAdapter(BasicTokenAuthentication):
+    """Adapt azure-identity credentials for backward compatibility.
+
+    Adapt credentials from azure-identity to be compatible with SD
+    that needs msrestazure or azure.common.credentials
+
+    Check 
https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig
+    """
+
+    def __init__(self, credential=None, 
resource_id="https://management.azure.com/.default";, **kwargs):
+        """Adapt azure-identity credentials for backward compatibility.
+
+        :param credential: Any azure-identity credential 
(DefaultAzureCredential by default)
+        :param str resource_id: The scope to use to get the token (default ARM)
+        """
+        super().__init__(None)
+        if credential is None:
+            credential = DefaultAzureCredential()
+        self._policy = BearerTokenCredentialPolicy(credential, resource_id, 
**kwargs)
+
+    def _make_request(self):
+        return PipelineRequest(
+            HttpRequest("AzureIdentityCredentialAdapter", "https://fakeurl";), 
PipelineContext(None)
+        )
+
+    def set_token(self):
+        """Ask the azure-core BearerTokenCredentialPolicy policy to get a 
token.
+
+        Using the policy gives us for free the caching system of azure-core.
+        We could make this code simpler by using private method, but by 
definition
+        I can't assure they will be there forever, so mocking a fake call to 
the policy
+        to extract the token, using 100% public API.
+        """
+        request = self._make_request()
+        self._policy.on_request(request)
+        # Read Authorization, and get the second part after Bearer
+        token = request.http_request.headers["Authorization"].split(" ", 1)[1]
+        self.token = {"access_token": token}
+
+    def signed_session(self, azure_session=None):
+        self.set_token()
+        return super().signed_session(azure_session)

Reply via email to