This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2980eb137d fix(providers/microsoft): setting use_async=True for
get_async_default_azure_credential (#35432)
2980eb137d is described below
commit 2980eb137d518d071aaec4f849a6dbbe5e1724cb
Author: Wei Lee <[email protected]>
AuthorDate: Sun Nov 5 23:07:13 2023 +0800
fix(providers/microsoft): setting use_async=True for
get_async_default_azure_credential (#35432)
---
airflow/providers/microsoft/azure/utils.py | 6 +++---
tests/providers/microsoft/azure/test_utils.py | 16 ++++++++++++++++
2 files changed, 19 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/microsoft/azure/utils.py
b/airflow/providers/microsoft/azure/utils.py
index 1b738ed957..a7a7e38966 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -59,8 +59,8 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict,
field_name: str):
def _get_default_azure_credential(
*,
- managed_identity_client_id: str | None,
- workload_identity_tenant_id: str | None,
+ managed_identity_client_id: str | None = None,
+ workload_identity_tenant_id: str | None = None,
use_async: bool = False,
) -> DefaultAzureCredential | AsyncDefaultAzureCredential:
"""Get DefaultAzureCredential based on provided arguments.
@@ -88,7 +88,7 @@ get_sync_default_azure_credential:
partial[DefaultAzureCredential] = partial(
get_async_default_azure_credential: partial[AsyncDefaultAzureCredential] =
partial(
_get_default_azure_credential, # type: ignore[arg-type]
- use_async=False,
+ use_async=True,
)
diff --git a/tests/providers/microsoft/azure/test_utils.py
b/tests/providers/microsoft/azure/test_utils.py
index 5a081441ca..f04acaab13 100644
--- a/tests/providers/microsoft/azure/test_utils.py
+++ b/tests/providers/microsoft/azure/test_utils.py
@@ -25,7 +25,10 @@ import pytest
from airflow.providers.microsoft.azure.utils import (
AzureIdentityCredentialAdapter,
add_managed_identity_connection_widgets,
+ get_async_default_azure_credential,
get_field,
+ # _get_default_azure_credential
+ get_sync_default_azure_credential,
)
MODULE = "airflow.providers.microsoft.azure.utils"
@@ -77,6 +80,19 @@ def test_add_managed_identity_connection_widgets():
assert "workload_identity_tenant_id" in widgets
[email protected](f"{MODULE}.DefaultAzureCredential")
+def test_get_sync_default_azure_credential(mock_default_azure_credential):
+ get_sync_default_azure_credential()
+
+ assert mock_default_azure_credential.called
+
+
[email protected](f"{MODULE}.AsyncDefaultAzureCredential")
+def test_get_async_default_azure_credential(mock_default_azure_credential):
+ get_async_default_azure_credential()
+ assert mock_default_azure_credential.called
+
+
class TestAzureIdentityCredentialAdapter:
@mock.patch(f"{MODULE}.PipelineRequest")
@mock.patch(f"{MODULE}.BearerTokenCredentialPolicy")