This is an automated email from the ASF dual-hosted git repository.
pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new faa50cbe2f feat(providers/microsoft): add DefaultAzureCredential
support to AzureContainerInstanceHook (#33467)
faa50cbe2f is described below
commit faa50cbe2f6dbf816e599bbbb933ac4976a55778
Author: Wei Lee <[email protected]>
AuthorDate: Thu Aug 24 20:25:59 2023 +0800
feat(providers/microsoft): add DefaultAzureCredential support to
AzureContainerInstanceHook (#33467)
* feat(provider/microsoft): add DefaultAzureCredential compatibility to
azure-python-sdk through AzureIdentityCredentialAdapter wrapper
https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig
* feat(providers/microsoft): add DefaultAzureCredential support to
AzureContainerInstanceHook
* fix(providers/microsfot): replace AzureIdentityCredentialAdapter with
DefaultAzureCredential due to backward compatibility
---
.../providers/microsoft/azure/hooks/base_azure.py | 16 +++++--
airflow/providers/microsoft/azure/utils.py | 50 ++++++++++++++++++++++
2 files changed, 62 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/base_azure.py
b/airflow/providers/microsoft/azure/hooks/base_azure.py
index 214cbd5f20..54130b3b2e 100644
--- a/airflow/providers/microsoft/azure/hooks/base_azure.py
+++ b/airflow/providers/microsoft/azure/hooks/base_azure.py
@@ -24,6 +24,7 @@ from azure.common.credentials import
ServicePrincipalCredentials
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
+from airflow.providers.microsoft.azure.utils import
AzureIdentityCredentialAdapter
class AzureBaseHook(BaseHook):
@@ -124,10 +125,17 @@ class AzureBaseHook(BaseHook):
self.log.info("Getting connection using a JSON config.")
return get_client_from_json_dict(client_class=self.sdk_client,
config_dict=key_json)
- self.log.info("Getting connection using specific credentials and
subscription_id.")
- return self.sdk_client(
- credentials=ServicePrincipalCredentials(
+ credentials: ServicePrincipalCredentials |
AzureIdentityCredentialAdapter
+ if all([conn.login, conn.password, tenant]):
+ self.log.info("Getting connection using specific credentials and
subscription_id.")
+ credentials = ServicePrincipalCredentials(
client_id=conn.login, secret=conn.password, tenant=tenant
- ),
+ )
+ else:
+ self.log.info("Using DefaultAzureCredential as credential")
+ credentials = AzureIdentityCredentialAdapter()
+
+ return self.sdk_client(
+ credentials=credentials,
subscription_id=subscription_id,
)
diff --git a/airflow/providers/microsoft/azure/utils.py
b/airflow/providers/microsoft/azure/utils.py
index 0a8edcf7c7..5afc2a48ca 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -19,6 +19,12 @@ from __future__ import annotations
import warnings
+from azure.core.pipeline import PipelineContext, PipelineRequest
+from azure.core.pipeline.policies import BearerTokenCredentialPolicy
+from azure.core.pipeline.transport import HttpRequest
+from azure.identity import DefaultAzureCredential
+from msrest.authentication import BasicTokenAuthentication
+
def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
"""Get field from extra, first checking short name, then for backcompat we
check for prefixed name."""
@@ -43,3 +49,47 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict,
field_name: str):
if ret == "":
return None
return ret
+
+
+class AzureIdentityCredentialAdapter(BasicTokenAuthentication):
+ """Adapt azure-identity credentials for backward compatibility.
+
+ Adapt credentials from azure-identity to be compatible with SD
+ that needs msrestazure or azure.common.credentials
+
+ Check
https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig
+ """
+
+ def __init__(self, credential=None,
resource_id="https://management.azure.com/.default", **kwargs):
+ """Adapt azure-identity credentials for backward compatibility.
+
+ :param credential: Any azure-identity credential
(DefaultAzureCredential by default)
+ :param str resource_id: The scope to use to get the token (default ARM)
+ """
+ super().__init__(None)
+ if credential is None:
+ credential = DefaultAzureCredential()
+ self._policy = BearerTokenCredentialPolicy(credential, resource_id,
**kwargs)
+
+ def _make_request(self):
+ return PipelineRequest(
+ HttpRequest("AzureIdentityCredentialAdapter", "https://fakeurl"),
PipelineContext(None)
+ )
+
+ def set_token(self):
+ """Ask the azure-core BearerTokenCredentialPolicy policy to get a
token.
+
+ Using the policy gives us for free the caching system of azure-core.
+ We could make this code simpler by using private method, but by
definition
+ I can't assure they will be there forever, so mocking a fake call to
the policy
+ to extract the token, using 100% public API.
+ """
+ request = self._make_request()
+ self._policy.on_request(request)
+ # Read Authorization, and get the second part after Bearer
+ token = request.http_request.headers["Authorization"].split(" ", 1)[1]
+ self.token = {"access_token": token}
+
+ def signed_session(self, azure_session=None):
+ self.set_token()
+ return super().signed_session(azure_session)