This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 768e1169b1 fix cosmos hook static checks by making providing
partition_key mandatory (#38199)
768e1169b1 is described below
commit 768e1169b1946fe536c02ee968a95594d43ebba2
Author: Hussein Awala <[email protected]>
AuthorDate: Wed Apr 10 10:47:59 2024 +0200
fix cosmos hook static checks by making providing partition_key mandatory
(#38199)
---
airflow/providers/microsoft/azure/CHANGELOG.rst | 5 +++
airflow/providers/microsoft/azure/hooks/cosmos.py | 37 ++++++++++++++++------
airflow/providers/microsoft/azure/provider.yaml | 3 +-
.../connections/azure_cosmos.rst | 1 +
generated/provider_dependencies.json | 2 +-
.../providers/microsoft/azure/hooks/test_cosmos.py | 11 +++++--
.../microsoft/azure/operators/test_cosmos.py | 2 ++
7 files changed, 46 insertions(+), 15 deletions(-)
diff --git a/airflow/providers/microsoft/azure/CHANGELOG.rst
b/airflow/providers/microsoft/azure/CHANGELOG.rst
index 4030c3c7ed..9fe19f899e 100644
--- a/airflow/providers/microsoft/azure/CHANGELOG.rst
+++ b/airflow/providers/microsoft/azure/CHANGELOG.rst
@@ -36,6 +36,11 @@ Breaking changes
* ``azure_synapse_pipeline`` connection type has been changed to
``azure_synapse``.
* The usage of ``default_conn_name=azure_synapse_connection`` is deprecated
and will be removed in future. The new default connection name for
``AzureSynapsePipelineHook`` is: ``default_conn_name=azure_synapse_default``.
+Significant changes
+~~~~~~~~~~~~~~~~~~~
+.. warning::
+ * We bumped the minimum version of azure-cosmos to 4.6.0, and providing a
partition key is now required to create, get or delete a container and to get a
document.
+
9.0.1
.....
diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py
b/airflow/providers/microsoft/azure/hooks/cosmos.py
index fe4aa88b5f..9d8b360a75 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -27,9 +27,10 @@ the default database and collection to use (see connection
`azure_cosmos_default
from __future__ import annotations
import uuid
-from typing import Any
+from typing import TYPE_CHECKING, Any, List, Union
from urllib.parse import urlparse
+from azure.cosmos import PartitionKey
from azure.cosmos.cosmos_client import CosmosClient
from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.mgmt.cosmosdb import CosmosDBManagementClient
@@ -42,6 +43,9 @@ from airflow.providers.microsoft.azure.utils import (
get_sync_default_azure_credential,
)
+if TYPE_CHECKING:
+ PartitionKeyType = Union[str, List[str]]
+
class AzureCosmosDBHook(BaseHook):
"""
@@ -111,6 +115,7 @@ class AzureCosmosDBHook(BaseHook):
self.default_database_name = None
self.default_collection_name = None
+ self.default_partition_key = None
def _get_field(self, extras, name):
return get_field(
@@ -153,6 +158,7 @@ class AzureCosmosDBHook(BaseHook):
self.default_database_name = self._get_field(extras,
"database_name")
self.default_collection_name = self._get_field(extras,
"collection_name")
+ self.default_partition_key = self._get_field(extras,
"partition_key")
# Initialize the Python Azure Cosmos DB client
self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key})
@@ -180,6 +186,18 @@ class AzureCosmosDBHook(BaseHook):
return coll_name
+ def __get_partition_key(self, partition_key: PartitionKeyType | None =
None) -> PartitionKeyType:
+ self.get_conn()
+ if partition_key is None:
+ part_key = self.default_partition_key
+ else:
+ part_key = partition_key
+
+ if part_key is None:
+ raise AirflowBadRequest("Partition key must be specified")
+
+ return part_key
+
def does_collection_exist(self, collection_name: str, database_name: str)
-> bool:
"""Check if a collection exists in CosmosDB."""
if collection_name is None:
@@ -204,7 +222,7 @@ class AzureCosmosDBHook(BaseHook):
self,
collection_name: str,
database_name: str | None = None,
- partition_key: str | None = None,
+ partition_key: PartitionKeyType | None = None,
) -> None:
"""Create a new collection in the CosmosDB database."""
if collection_name is None:
@@ -226,7 +244,8 @@ class AzureCosmosDBHook(BaseHook):
# Only create if we did not find it already existing
if not existing_container:
self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container(
- collection_name, partition_key=partition_key
+ collection_name,
+
partition_key=PartitionKey(path=self.__get_partition_key(partition_key)),
)
def does_database_exist(self, database_name: str) -> bool:
@@ -328,7 +347,7 @@ class AzureCosmosDBHook(BaseHook):
document_id: str,
database_name: str | None = None,
collection_name: str | None = None,
- partition_key: str | None = None,
+ partition_key: PartitionKeyType | None = None,
) -> None:
"""Delete an existing document out of a collection in the CosmosDB
database."""
if document_id is None:
@@ -337,7 +356,7 @@ class AzureCosmosDBHook(BaseHook):
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
- .delete_item(document_id, partition_key=partition_key)
+ .delete_item(document_id,
partition_key=self.__get_partition_key(partition_key))
)
def get_document(
@@ -345,7 +364,7 @@ class AzureCosmosDBHook(BaseHook):
document_id: str,
database_name: str | None = None,
collection_name: str | None = None,
- partition_key: str | None = None,
+ partition_key: PartitionKeyType | None = None,
):
"""Get a document from an existing collection in the CosmosDB
database."""
if document_id is None:
@@ -356,7 +375,7 @@ class AzureCosmosDBHook(BaseHook):
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
- .read_item(document_id, partition_key=partition_key)
+ .read_item(document_id,
partition_key=self.__get_partition_key(partition_key))
)
except CosmosHttpResponseError:
return None
@@ -366,7 +385,7 @@ class AzureCosmosDBHook(BaseHook):
sql_string: str,
database_name: str | None = None,
collection_name: str | None = None,
- partition_key: str | None = None,
+ partition_key: PartitionKeyType | None = None,
) -> list | None:
"""Get a list of documents from an existing collection in the CosmosDB
database via SQL query."""
if sql_string is None:
@@ -377,7 +396,7 @@ class AzureCosmosDBHook(BaseHook):
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
- .query_items(sql_string, partition_key=partition_key)
+ .query_items(sql_string,
partition_key=self.__get_partition_key(partition_key))
)
return list(result_iterable)
except CosmosHttpResponseError:
diff --git a/airflow/providers/microsoft/azure/provider.yaml
b/airflow/providers/microsoft/azure/provider.yaml
index fe3e45a92f..2ddc479b63 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -79,8 +79,7 @@ dependencies:
- apache-airflow>=2.6.0
- adlfs>=2023.10.0
- azure-batch>=8.0.0
- # azure-cosmos 4.6.0 fail on mypy, limit version till we fix the issue
- - azure-cosmos>=4.0.0,<4.6.0
+ - azure-cosmos>=4.6.0
- azure-mgmt-cosmosdb
- azure-datalake-store>=0.0.45
- azure-identity>=1.3.1
diff --git
a/docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst
b/docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst
index e7933fc897..2d7754f2cc 100644
--- a/docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst
@@ -56,6 +56,7 @@ Extra (optional)
* ``database_name``: Specify the azure cosmos database to use.
* ``collection_name``: Specify the azure cosmos collection to use.
+ * ``partition_key``: Specify the partition key for the collection.
* ``subscription_id``: Specify the ID of the subscription used for the
initial connection. Required for falling back to DefaultAzureCredential_
* ``resource_group_name``: Specify the Azure Resource Group Name under
which the desired azure cosmos resides. Required for falling back to
DefaultAzureCredential_
* ``managed_identity_client_id``: The client ID of a user-assigned
managed identity. If provided with `workload_identity_tenant_id`, they'll pass
to ``DefaultAzureCredential``.
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 284ff97cc7..9315766f81 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -682,7 +682,7 @@
"adlfs>=2023.10.0",
"apache-airflow>=2.6.0",
"azure-batch>=8.0.0",
- "azure-cosmos>=4.0.0,<4.6.0",
+ "azure-cosmos>=4.6.0",
"azure-datalake-store>=0.0.45",
"azure-identity>=1.3.1",
"azure-keyvault-secrets>=4.1.0",
diff --git a/tests/providers/microsoft/azure/hooks/test_cosmos.py
b/tests/providers/microsoft/azure/hooks/test_cosmos.py
index b63b46410e..bc6b4f4277 100644
--- a/tests/providers/microsoft/azure/hooks/test_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_cosmos.py
@@ -23,6 +23,7 @@ from unittest import mock
from unittest.mock import PropertyMock
import pytest
+from azure.cosmos import PartitionKey
from azure.cosmos.cosmos_client import CosmosClient
from airflow.exceptions import AirflowException
@@ -43,6 +44,7 @@ class TestAzureCosmosDbHook:
self.test_collection_name = "test_collection_name"
self.test_database_default = "test_database_default"
self.test_collection_default = "test_collection_default"
+ self.test_partition_key = "/test_partition_key"
create_mock_connection(
Connection(
conn_id="azure_cosmos_test_key_id",
@@ -52,6 +54,7 @@ class TestAzureCosmosDbHook:
extra={
"database_name": self.test_database_default,
"collection_name": self.test_collection_default,
+ "partition_key": self.test_partition_key,
},
)
)
@@ -115,11 +118,11 @@ class TestAzureCosmosDbHook:
@mock.patch(f"{MODULE}.CosmosClient")
def test_create_container(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
- hook.create_collection(self.test_collection_name,
self.test_database_name)
+ hook.create_collection(self.test_collection_name,
self.test_database_name, partition_key="/id")
expected_calls = [
mock.call()
.get_database_client("test_database_name")
- .create_container("test_collection_name", partition_key=None)
+ .create_container("test_collection_name",
partition_key=PartitionKey(path="/id"))
]
mock_cosmos.assert_any_call(self.test_end_point, {"masterKey":
self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
@@ -131,7 +134,9 @@ class TestAzureCosmosDbHook:
expected_calls = [
mock.call()
.get_database_client("test_database_name")
- .create_container("test_collection_name", partition_key=None)
+ .create_container(
+ "test_collection_name",
partition_key=PartitionKey(path=self.test_partition_key)
+ )
]
mock_cosmos.assert_any_call(self.test_end_point, {"masterKey":
self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
diff --git a/tests/providers/microsoft/azure/operators/test_cosmos.py
b/tests/providers/microsoft/azure/operators/test_cosmos.py
index 231a56c2b1..d3627482bb 100644
--- a/tests/providers/microsoft/azure/operators/test_cosmos.py
+++ b/tests/providers/microsoft/azure/operators/test_cosmos.py
@@ -35,6 +35,7 @@ class TestAzureCosmosDbHook:
self.test_master_key = "magic_test_key"
self.test_database_name = "test_database_name"
self.test_collection_name = "test_collection_name"
+ self.test_partition_key = "test_partition_key"
create_mock_connection(
Connection(
conn_id="azure_cosmos_test_key_id",
@@ -44,6 +45,7 @@ class TestAzureCosmosDbHook:
extra={
"database_name": self.test_database_name,
"collection_name": self.test_collection_name,
+ "partition_key": self.test_partition_key,
},
)
)