This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d58a0f22153 feat: add create_collection function to MongoHook (#50518)
d58a0f22153 is described below
commit d58a0f22153f77549488e541e3967445a240faff
Author: Aaron Chen <[email protected]>
AuthorDate: Thu May 22 11:29:44 2025 -0700
feat: add create_collection function to MongoHook (#50518)
* feat: add create_collection function with unit tests
* rm comments
* fix static checks
* make the type of `create_kwargs` clearer
Co-authored-by: Wei Lee <[email protected]>
* modfiy variable name: create_if_exists -> return_if_exists
* move import CollectionInvalid to top
---------
Co-authored-by: Wei Lee <[email protected]>
---
.../src/airflow/providers/mongo/hooks/mongo.py | 32 ++++++++++++++++++++++
.../mongo/tests/unit/mongo/hooks/test_mongo.py | 31 +++++++++++++++++++++
2 files changed, 63 insertions(+)
diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py
b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py
index c71ce92ea35..20dbf9f0dda 100644
--- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py
+++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py
@@ -25,6 +25,7 @@ from urllib.parse import quote_plus, urlunsplit
import pymongo
from pymongo import MongoClient, ReplaceOne
+from pymongo.errors import CollectionInvalid
from airflow.exceptions import AirflowConfigException
from airflow.hooks.base import BaseHook
@@ -225,6 +226,37 @@ class MongoHook(BaseHook):
return
mongo_conn.get_database(mongo_db).get_collection(mongo_collection)
+ def create_collection(
+ self,
+ mongo_collection: str,
+ mongo_db: str | None = None,
+ return_if_exists: bool = True,
+ **create_kwargs: Any,
+ ) -> MongoCollection:
+ """
+ Create the collection (optionally a time‑series collection) and return
it.
+
+
https://pymongo.readthedocs.io/en/stable/api/pymongo/database.html#pymongo.database.Database.create_collection
+
+ :param mongo_collection: Name of the collection.
+ :param mongo_db: Target database; defaults to the schema in the
connection string.
+ :param return_if_exists: If True and the collection already exists,
return it instead of raising.
+ :param create_kwargs: Additional keyword arguments forwarded to
``db.create_collection()``,
+ e.g. ``timeseries={...}``, ``capped=True``.
+ """
+ mongo_db = mongo_db or self.connection.schema
+ mongo_conn: MongoClient = self.get_conn()
+ db = mongo_conn.get_database(mongo_db)
+
+ try:
+ db.create_collection(mongo_collection, **create_kwargs)
+ except CollectionInvalid:
+ if not return_if_exists:
+ raise
+ # Collection already exists – fall through and fetch it.
+
+ return db.get_collection(mongo_collection)
+
def aggregate(
self, mongo_collection: str, aggregate_query: list, mongo_db: str |
None = None, **kwargs
) -> CommandCursor:
diff --git a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py
b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py
index 0d646e6e6cb..126617af88f 100644
--- a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py
+++ b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING
import pymongo
import pytest
+from pymongo.errors import CollectionInvalid
from airflow.exceptions import AirflowConfigException
from airflow.models import Connection
@@ -387,6 +388,36 @@ class TestMongoHook:
results = self.hook.distinct(collection, "test_id", {"test_status":
"failure"})
assert len(results) == 1
+ def test_create_standard_collection(self):
+ mock_client = mongomock.MongoClient()
+ self.hook.get_conn = lambda: mock_client
+ self.hook.connection.schema = "test_db"
+
+ collection =
self.hook.create_collection(mongo_collection="plain_collection")
+ assert collection.name == "plain_collection"
+ assert "plain_collection" in
mock_client["test_db"].list_collection_names()
+
+ def test_return_if_exists_true_returns_existing(self):
+ mock_client = mongomock.MongoClient()
+ self.hook.get_conn = lambda: mock_client
+ self.hook.connection.schema = "test_db"
+
+ first = self.hook.create_collection(mongo_collection="foo")
+ second = self.hook.create_collection(mongo_collection="foo",
return_if_exists=True)
+
+ assert first.full_name == second.full_name
+ assert "foo" in mock_client["test_db"].list_collection_names()
+
+ def test_return_if_exists_false_raises(self):
+ mock_client = mongomock.MongoClient()
+ self.hook.get_conn = lambda: mock_client
+ self.hook.connection.schema = "test_db"
+
+ self.hook.create_collection(mongo_collection="bar")
+
+ with pytest.raises(CollectionInvalid):
+ self.hook.create_collection(mongo_collection="bar",
return_if_exists=False)
+
def test_context_manager():
with MongoHook(mongo_conn_id="mongo_default") as ctx_hook: