This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 03529d524f Add DefaultAzureCredential support to cosmos (#33436)
03529d524f is described below

commit 03529d524fbebb4ff2c886a085966230314022f3
Author: Wei Lee <[email protected]>
AuthorDate: Sat Aug 26 15:00:10 2023 +0800

    Add DefaultAzureCredential support to cosmos (#33436)
    
    * feat(providers/microsoft): add DefaultAzureCredential support to cosmos
    
    * feat: use CosmosDBManagementClient to authenticate cosmos client through 
DefaultAzureCredential
---
 airflow/providers/microsoft/azure/hooks/cosmos.py  | 35 ++++++++++++++++++++--
 airflow/providers/microsoft/azure/provider.yaml    |  1 +
 .../index.rst                                      |  1 +
 generated/provider_dependencies.json               |  1 +
 .../microsoft/azure/hooks/test_azure_cosmos.py     |  3 +-
 5 files changed, 37 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py 
b/airflow/providers/microsoft/azure/hooks/cosmos.py
index 4b42217e68..a23bd44035 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -27,11 +27,14 @@ from __future__ import annotations
 
 import uuid
 from typing import Any
+from urllib.parse import urlparse
 
 from azure.cosmos.cosmos_client import CosmosClient
 from azure.cosmos.exceptions import CosmosHttpResponseError
+from azure.identity import DefaultAzureCredential
+from azure.mgmt.cosmosdb import CosmosDBManagementClient
 
-from airflow.exceptions import AirflowBadRequest
+from airflow.exceptions import AirflowBadRequest, AirflowException
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import get_field
 
@@ -67,6 +70,14 @@ class AzureCosmosDBHook(BaseHook):
             "collection_name": StringField(
                 lazy_gettext("Cosmos Collection Name (optional)"), 
widget=BS3TextFieldWidget()
             ),
+            "subscription_id": StringField(
+                lazy_gettext("Subscription ID (optional)"),
+                widget=BS3TextFieldWidget(),
+            ),
+            "resource_group_name": StringField(
+                lazy_gettext("Resource Group Name (optional)"),
+                widget=BS3TextFieldWidget(),
+            ),
         }
 
     @staticmethod
@@ -80,9 +91,11 @@ class AzureCosmosDBHook(BaseHook):
             },
             "placeholders": {
                 "login": "endpoint uri",
-                "password": "master key",
+                "password": "master key (not needed for Azure AD 
authentication)",
                 "database_name": "database name",
                 "collection_name": "collection name",
+                "subscription_id": "Subscription ID (required for Azure AD 
authentication)",
+                "resource_group_name": "Resource Group Name (required for 
Azure AD authentication)",
             },
         }
 
@@ -108,7 +121,23 @@ class AzureCosmosDBHook(BaseHook):
             conn = self.get_connection(self.conn_id)
             extras = conn.extra_dejson
             endpoint_uri = conn.login
-            master_key = conn.password
+            resource_group_name = self._get_field(extras, 
"resource_group_name")
+
+            if conn.password:
+                master_key = conn.password
+            elif resource_group_name:
+                management_client = CosmosDBManagementClient(
+                    credential=DefaultAzureCredential(),
+                    subscription_id=self._get_field(extras, "subscription_id"),
+                )
+
+                database_account = urlparse(conn.login).netloc.split(".")[0]
+                database_account_keys = 
management_client.database_accounts.list_keys(
+                    resource_group_name, database_account
+                )
+                master_key = database_account_keys.primary_master_key
+            else:
+                raise AirflowException("Either password or resource_group_name 
is required")
 
             self.default_database_name = self._get_field(extras, 
"database_name")
             self.default_collection_name = self._get_field(extras, 
"collection_name")
diff --git a/airflow/providers/microsoft/azure/provider.yaml 
b/airflow/providers/microsoft/azure/provider.yaml
index 1edb4974f7..e3ef3db6f8 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -67,6 +67,7 @@ dependencies:
   - apache-airflow>=2.4.0
   - azure-batch>=8.0.0
   - azure-cosmos>=4.0.0
+  - azure-mgmt-cosmosdb
   - azure-datalake-store>=0.0.45
   - azure-identity>=1.3.1
   - azure-keyvault-secrets>=4.1.0
diff --git a/docs/apache-airflow-providers-microsoft-azure/index.rst 
b/docs/apache-airflow-providers-microsoft-azure/index.rst
index 5112909f4b..5f7b40ad8c 100644
--- a/docs/apache-airflow-providers-microsoft-azure/index.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/index.rst
@@ -107,6 +107,7 @@ PIP package                       Version required
 ``apache-airflow``                ``>=2.4.0``
 ``azure-batch``                   ``>=8.0.0``
 ``azure-cosmos``                  ``>=4.0.0``
+``azure-mgmt-cosmosdb``
 ``azure-datalake-store``          ``>=0.0.45``
 ``azure-identity``                ``>=1.3.1``
 ``azure-keyvault-secrets``        ``>=4.1.0``
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 5191600a98..b0dcab34c1 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -560,6 +560,7 @@
       "azure-keyvault-secrets>=4.1.0",
       "azure-kusto-data>=4.1.0",
       "azure-mgmt-containerinstance>=1.5.0,<2.0",
+      "azure-mgmt-cosmosdb",
       "azure-mgmt-datafactory>=1.0.0,<2.0",
       "azure-mgmt-datalake-store>=0.5.0",
       "azure-mgmt-resource>=2.2.0",
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py 
b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index f63b8e8dbd..aab2fd3ffa 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -32,7 +32,6 @@ from tests.test_utils.providers import 
get_provider_min_airflow_version
 
 
 class TestAzureCosmosDbHook:
-
     # Set up an environment to test with
     @pytest.fixture(autouse=True)
     def setup_test_cases(self, create_mock_connection):
@@ -262,6 +261,8 @@ class TestAzureCosmosDbHook:
             "password",
             "database_name",
             "collection_name",
+            "subscription_id",
+            "resource_group_name",
         ]
         if 
get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= 
(2, 5):
             raise Exception(

Reply via email to