This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 768e1169b1 fix cosmos hook static checks by making providing 
partition_key mandatory (#38199)
768e1169b1 is described below

commit 768e1169b1946fe536c02ee968a95594d43ebba2
Author: Hussein Awala <[email protected]>
AuthorDate: Wed Apr 10 10:47:59 2024 +0200

    fix cosmos hook static checks by making providing partition_key mandatory 
(#38199)
---
 airflow/providers/microsoft/azure/CHANGELOG.rst    |  5 +++
 airflow/providers/microsoft/azure/hooks/cosmos.py  | 37 ++++++++++++++++------
 airflow/providers/microsoft/azure/provider.yaml    |  3 +-
 .../connections/azure_cosmos.rst                   |  1 +
 generated/provider_dependencies.json               |  2 +-
 .../providers/microsoft/azure/hooks/test_cosmos.py | 11 +++++--
 .../microsoft/azure/operators/test_cosmos.py       |  2 ++
 7 files changed, 46 insertions(+), 15 deletions(-)

diff --git a/airflow/providers/microsoft/azure/CHANGELOG.rst 
b/airflow/providers/microsoft/azure/CHANGELOG.rst
index 4030c3c7ed..9fe19f899e 100644
--- a/airflow/providers/microsoft/azure/CHANGELOG.rst
+++ b/airflow/providers/microsoft/azure/CHANGELOG.rst
@@ -36,6 +36,11 @@ Breaking changes
    * ``azure_synapse_pipeline`` connection type has been changed to 
``azure_synapse``.
    * The usage of ``default_conn_name=azure_synapse_connection`` is deprecated 
and will be removed in future. The new default connection name for 
``AzureSynapsePipelineHook`` is: ``default_conn_name=azure_synapse_default``.
 
+Significant changes
+~~~~~~~~~~~~~~~~~~~
+.. warning::
+   * We bumped the minimum version of azure-cosmos to 4.6.0, and providing a 
partition key is now required to create, get or delete a container and to get a 
document.
+
 9.0.1
 .....
 
diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py 
b/airflow/providers/microsoft/azure/hooks/cosmos.py
index fe4aa88b5f..9d8b360a75 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -27,9 +27,10 @@ the default database and collection to use (see connection 
`azure_cosmos_default
 from __future__ import annotations
 
 import uuid
-from typing import Any
+from typing import TYPE_CHECKING, Any, List, Union
 from urllib.parse import urlparse
 
+from azure.cosmos import PartitionKey
 from azure.cosmos.cosmos_client import CosmosClient
 from azure.cosmos.exceptions import CosmosHttpResponseError
 from azure.mgmt.cosmosdb import CosmosDBManagementClient
@@ -42,6 +43,9 @@ from airflow.providers.microsoft.azure.utils import (
     get_sync_default_azure_credential,
 )
 
+if TYPE_CHECKING:
+    PartitionKeyType = Union[str, List[str]]
+
 
 class AzureCosmosDBHook(BaseHook):
     """
@@ -111,6 +115,7 @@ class AzureCosmosDBHook(BaseHook):
 
         self.default_database_name = None
         self.default_collection_name = None
+        self.default_partition_key = None
 
     def _get_field(self, extras, name):
         return get_field(
@@ -153,6 +158,7 @@ class AzureCosmosDBHook(BaseHook):
 
             self.default_database_name = self._get_field(extras, 
"database_name")
             self.default_collection_name = self._get_field(extras, 
"collection_name")
+            self.default_partition_key = self._get_field(extras, 
"partition_key")
 
             # Initialize the Python Azure Cosmos DB client
             self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key})
@@ -180,6 +186,18 @@ class AzureCosmosDBHook(BaseHook):
 
         return coll_name
 
+    def __get_partition_key(self, partition_key: PartitionKeyType | None = 
None) -> PartitionKeyType:
+        self.get_conn()
+        if partition_key is None:
+            part_key = self.default_partition_key
+        else:
+            part_key = partition_key
+
+        if part_key is None:
+            raise AirflowBadRequest("Partition key must be specified")
+
+        return part_key
+
     def does_collection_exist(self, collection_name: str, database_name: str) 
-> bool:
         """Check if a collection exists in CosmosDB."""
         if collection_name is None:
@@ -204,7 +222,7 @@ class AzureCosmosDBHook(BaseHook):
         self,
         collection_name: str,
         database_name: str | None = None,
-        partition_key: str | None = None,
+        partition_key: PartitionKeyType | None = None,
     ) -> None:
         """Create a new collection in the CosmosDB database."""
         if collection_name is None:
@@ -226,7 +244,8 @@ class AzureCosmosDBHook(BaseHook):
         # Only create if we did not find it already existing
         if not existing_container:
             
self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container(
-                collection_name, partition_key=partition_key
+                collection_name,
+                
partition_key=PartitionKey(path=self.__get_partition_key(partition_key)),
             )
 
     def does_database_exist(self, database_name: str) -> bool:
@@ -328,7 +347,7 @@ class AzureCosmosDBHook(BaseHook):
         document_id: str,
         database_name: str | None = None,
         collection_name: str | None = None,
-        partition_key: str | None = None,
+        partition_key: PartitionKeyType | None = None,
     ) -> None:
         """Delete an existing document out of a collection in the CosmosDB 
database."""
         if document_id is None:
@@ -337,7 +356,7 @@ class AzureCosmosDBHook(BaseHook):
             self.get_conn()
             .get_database_client(self.__get_database_name(database_name))
             .get_container_client(self.__get_collection_name(collection_name))
-            .delete_item(document_id, partition_key=partition_key)
+            .delete_item(document_id, 
partition_key=self.__get_partition_key(partition_key))
         )
 
     def get_document(
@@ -345,7 +364,7 @@ class AzureCosmosDBHook(BaseHook):
         document_id: str,
         database_name: str | None = None,
         collection_name: str | None = None,
-        partition_key: str | None = None,
+        partition_key: PartitionKeyType | None = None,
     ):
         """Get a document from an existing collection in the CosmosDB 
database."""
         if document_id is None:
@@ -356,7 +375,7 @@ class AzureCosmosDBHook(BaseHook):
                 self.get_conn()
                 .get_database_client(self.__get_database_name(database_name))
                 
.get_container_client(self.__get_collection_name(collection_name))
-                .read_item(document_id, partition_key=partition_key)
+                .read_item(document_id, 
partition_key=self.__get_partition_key(partition_key))
             )
         except CosmosHttpResponseError:
             return None
@@ -366,7 +385,7 @@ class AzureCosmosDBHook(BaseHook):
         sql_string: str,
         database_name: str | None = None,
         collection_name: str | None = None,
-        partition_key: str | None = None,
+        partition_key: PartitionKeyType | None = None,
     ) -> list | None:
         """Get a list of documents from an existing collection in the CosmosDB 
database via SQL query."""
         if sql_string is None:
@@ -377,7 +396,7 @@ class AzureCosmosDBHook(BaseHook):
                 self.get_conn()
                 .get_database_client(self.__get_database_name(database_name))
                 
.get_container_client(self.__get_collection_name(collection_name))
-                .query_items(sql_string, partition_key=partition_key)
+                .query_items(sql_string, 
partition_key=self.__get_partition_key(partition_key))
             )
             return list(result_iterable)
         except CosmosHttpResponseError:
diff --git a/airflow/providers/microsoft/azure/provider.yaml 
b/airflow/providers/microsoft/azure/provider.yaml
index fe3e45a92f..2ddc479b63 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -79,8 +79,7 @@ dependencies:
   - apache-airflow>=2.6.0
   - adlfs>=2023.10.0
   - azure-batch>=8.0.0
-  # azure-cosmos 4.6.0 fail on mypy, limit version till we fix the issue
-  - azure-cosmos>=4.0.0,<4.6.0
+  - azure-cosmos>=4.6.0
   - azure-mgmt-cosmosdb
   - azure-datalake-store>=0.0.45
   - azure-identity>=1.3.1
diff --git 
a/docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst 
b/docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst
index e7933fc897..2d7754f2cc 100644
--- a/docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst
@@ -56,6 +56,7 @@ Extra (optional)
 
     * ``database_name``: Specify the azure cosmos database to use.
     * ``collection_name``: Specify the azure cosmos collection to use.
+    * ``partition_key``: Specify the partition key for the collection.
     * ``subscription_id``: Specify the ID of the subscription used for the 
initial connection. Required for falling back to DefaultAzureCredential_
     * ``resource_group_name``: Specify the  Azure Resource Group Name under 
which the desired azure cosmos resides. Required for falling back to 
DefaultAzureCredential_
     * ``managed_identity_client_id``:  The client ID of a user-assigned 
managed identity. If provided with `workload_identity_tenant_id`, they'll pass 
to ``DefaultAzureCredential``.
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 284ff97cc7..9315766f81 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -682,7 +682,7 @@
       "adlfs>=2023.10.0",
       "apache-airflow>=2.6.0",
       "azure-batch>=8.0.0",
-      "azure-cosmos>=4.0.0,<4.6.0",
+      "azure-cosmos>=4.6.0",
       "azure-datalake-store>=0.0.45",
       "azure-identity>=1.3.1",
       "azure-keyvault-secrets>=4.1.0",
diff --git a/tests/providers/microsoft/azure/hooks/test_cosmos.py 
b/tests/providers/microsoft/azure/hooks/test_cosmos.py
index b63b46410e..bc6b4f4277 100644
--- a/tests/providers/microsoft/azure/hooks/test_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_cosmos.py
@@ -23,6 +23,7 @@ from unittest import mock
 from unittest.mock import PropertyMock
 
 import pytest
+from azure.cosmos import PartitionKey
 from azure.cosmos.cosmos_client import CosmosClient
 
 from airflow.exceptions import AirflowException
@@ -43,6 +44,7 @@ class TestAzureCosmosDbHook:
         self.test_collection_name = "test_collection_name"
         self.test_database_default = "test_database_default"
         self.test_collection_default = "test_collection_default"
+        self.test_partition_key = "/test_partition_key"
         create_mock_connection(
             Connection(
                 conn_id="azure_cosmos_test_key_id",
@@ -52,6 +54,7 @@ class TestAzureCosmosDbHook:
                 extra={
                     "database_name": self.test_database_default,
                     "collection_name": self.test_collection_default,
+                    "partition_key": self.test_partition_key,
                 },
             )
         )
@@ -115,11 +118,11 @@ class TestAzureCosmosDbHook:
     @mock.patch(f"{MODULE}.CosmosClient")
     def test_create_container(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
-        hook.create_collection(self.test_collection_name, 
self.test_database_name)
+        hook.create_collection(self.test_collection_name, 
self.test_database_name, partition_key="/id")
         expected_calls = [
             mock.call()
             .get_database_client("test_database_name")
-            .create_container("test_collection_name", partition_key=None)
+            .create_container("test_collection_name", 
partition_key=PartitionKey(path="/id"))
         ]
         mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": 
self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)
@@ -131,7 +134,9 @@ class TestAzureCosmosDbHook:
         expected_calls = [
             mock.call()
             .get_database_client("test_database_name")
-            .create_container("test_collection_name", partition_key=None)
+            .create_container(
+                "test_collection_name", 
partition_key=PartitionKey(path=self.test_partition_key)
+            )
         ]
         mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": 
self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)
diff --git a/tests/providers/microsoft/azure/operators/test_cosmos.py 
b/tests/providers/microsoft/azure/operators/test_cosmos.py
index 231a56c2b1..d3627482bb 100644
--- a/tests/providers/microsoft/azure/operators/test_cosmos.py
+++ b/tests/providers/microsoft/azure/operators/test_cosmos.py
@@ -35,6 +35,7 @@ class TestAzureCosmosDbHook:
         self.test_master_key = "magic_test_key"
         self.test_database_name = "test_database_name"
         self.test_collection_name = "test_collection_name"
+        self.test_partition_key = "test_partition_key"
         create_mock_connection(
             Connection(
                 conn_id="azure_cosmos_test_key_id",
@@ -44,6 +45,7 @@ class TestAzureCosmosDbHook:
                 extra={
                     "database_name": self.test_database_name,
                     "collection_name": self.test_collection_name,
+                    "partition_key": self.test_partition_key,
                 },
             )
         )

Reply via email to