This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 03529d524f Add DefaultAzureCredential support to cosmos (#33436)
03529d524f is described below
commit 03529d524fbebb4ff2c886a085966230314022f3
Author: Wei Lee <[email protected]>
AuthorDate: Sat Aug 26 15:00:10 2023 +0800
Add DefaultAzureCredential support to cosmos (#33436)
* feat(providers/microsoft): add DefaultAzureCredential support to cosmos
* feat: use CosmosDBManagementClient to authenticate cosmos client through
DefaultAzureCredential
---
airflow/providers/microsoft/azure/hooks/cosmos.py | 35 ++++++++++++++++++++--
airflow/providers/microsoft/azure/provider.yaml | 1 +
.../index.rst | 1 +
generated/provider_dependencies.json | 1 +
.../microsoft/azure/hooks/test_azure_cosmos.py | 3 +-
5 files changed, 37 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py
b/airflow/providers/microsoft/azure/hooks/cosmos.py
index 4b42217e68..a23bd44035 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -27,11 +27,14 @@ from __future__ import annotations
import uuid
from typing import Any
+from urllib.parse import urlparse
from azure.cosmos.cosmos_client import CosmosClient
from azure.cosmos.exceptions import CosmosHttpResponseError
+from azure.identity import DefaultAzureCredential
+from azure.mgmt.cosmosdb import CosmosDBManagementClient
-from airflow.exceptions import AirflowBadRequest
+from airflow.exceptions import AirflowBadRequest, AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import get_field
@@ -67,6 +70,14 @@ class AzureCosmosDBHook(BaseHook):
"collection_name": StringField(
lazy_gettext("Cosmos Collection Name (optional)"),
widget=BS3TextFieldWidget()
),
+ "subscription_id": StringField(
+ lazy_gettext("Subscription ID (optional)"),
+ widget=BS3TextFieldWidget(),
+ ),
+ "resource_group_name": StringField(
+ lazy_gettext("Resource Group Name (optional)"),
+ widget=BS3TextFieldWidget(),
+ ),
}
@staticmethod
@@ -80,9 +91,11 @@ class AzureCosmosDBHook(BaseHook):
},
"placeholders": {
"login": "endpoint uri",
- "password": "master key",
+ "password": "master key (not needed for Azure AD
authentication)",
"database_name": "database name",
"collection_name": "collection name",
+ "subscription_id": "Subscription ID (required for Azure AD
authentication)",
+ "resource_group_name": "Resource Group Name (required for
Azure AD authentication)",
},
}
@@ -108,7 +121,23 @@ class AzureCosmosDBHook(BaseHook):
conn = self.get_connection(self.conn_id)
extras = conn.extra_dejson
endpoint_uri = conn.login
- master_key = conn.password
+ resource_group_name = self._get_field(extras,
"resource_group_name")
+
+ if conn.password:
+ master_key = conn.password
+ elif resource_group_name:
+ management_client = CosmosDBManagementClient(
+ credential=DefaultAzureCredential(),
+ subscription_id=self._get_field(extras, "subscription_id"),
+ )
+
+ database_account = urlparse(conn.login).netloc.split(".")[0]
+ database_account_keys =
management_client.database_accounts.list_keys(
+ resource_group_name, database_account
+ )
+ master_key = database_account_keys.primary_master_key
+ else:
+ raise AirflowException("Either password or resource_group_name
is required")
self.default_database_name = self._get_field(extras,
"database_name")
self.default_collection_name = self._get_field(extras,
"collection_name")
diff --git a/airflow/providers/microsoft/azure/provider.yaml
b/airflow/providers/microsoft/azure/provider.yaml
index 1edb4974f7..e3ef3db6f8 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -67,6 +67,7 @@ dependencies:
- apache-airflow>=2.4.0
- azure-batch>=8.0.0
- azure-cosmos>=4.0.0
+ - azure-mgmt-cosmosdb
- azure-datalake-store>=0.0.45
- azure-identity>=1.3.1
- azure-keyvault-secrets>=4.1.0
diff --git a/docs/apache-airflow-providers-microsoft-azure/index.rst
b/docs/apache-airflow-providers-microsoft-azure/index.rst
index 5112909f4b..5f7b40ad8c 100644
--- a/docs/apache-airflow-providers-microsoft-azure/index.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/index.rst
@@ -107,6 +107,7 @@ PIP package Version required
``apache-airflow`` ``>=2.4.0``
``azure-batch`` ``>=8.0.0``
``azure-cosmos`` ``>=4.0.0``
+``azure-mgmt-cosmosdb``
``azure-datalake-store`` ``>=0.0.45``
``azure-identity`` ``>=1.3.1``
``azure-keyvault-secrets`` ``>=4.1.0``
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 5191600a98..b0dcab34c1 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -560,6 +560,7 @@
"azure-keyvault-secrets>=4.1.0",
"azure-kusto-data>=4.1.0",
"azure-mgmt-containerinstance>=1.5.0,<2.0",
+ "azure-mgmt-cosmosdb",
"azure-mgmt-datafactory>=1.0.0,<2.0",
"azure-mgmt-datalake-store>=0.5.0",
"azure-mgmt-resource>=2.2.0",
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index f63b8e8dbd..aab2fd3ffa 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -32,7 +32,6 @@ from tests.test_utils.providers import
get_provider_min_airflow_version
class TestAzureCosmosDbHook:
-
# Set up an environment to test with
@pytest.fixture(autouse=True)
def setup_test_cases(self, create_mock_connection):
@@ -262,6 +261,8 @@ class TestAzureCosmosDbHook:
"password",
"database_name",
"collection_name",
+ "subscription_id",
+ "resource_group_name",
]
if
get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >=
(2, 5):
raise Exception(