This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d58a0f22153 feat: add create_collection function to MongoHook (#50518)
d58a0f22153 is described below

commit d58a0f22153f77549488e541e3967445a240faff
Author: Aaron Chen <[email protected]>
AuthorDate: Thu May 22 11:29:44 2025 -0700

    feat: add create_collection function to MongoHook (#50518)
    
    * feat: add create_collection function with unit tests
    
    * rm comments
    
    * fix static checks
    
    * make the type of `create_kwargs` clearer
    
    Co-authored-by: Wei Lee <[email protected]>
    
    * modfiy variable name: create_if_exists -> return_if_exists
    
    * move import CollectionInvalid to top
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
---
 .../src/airflow/providers/mongo/hooks/mongo.py     | 32 ++++++++++++++++++++++
 .../mongo/tests/unit/mongo/hooks/test_mongo.py     | 31 +++++++++++++++++++++
 2 files changed, 63 insertions(+)

diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py 
b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py
index c71ce92ea35..20dbf9f0dda 100644
--- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py
+++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py
@@ -25,6 +25,7 @@ from urllib.parse import quote_plus, urlunsplit
 
 import pymongo
 from pymongo import MongoClient, ReplaceOne
+from pymongo.errors import CollectionInvalid
 
 from airflow.exceptions import AirflowConfigException
 from airflow.hooks.base import BaseHook
@@ -225,6 +226,37 @@ class MongoHook(BaseHook):
 
         return 
mongo_conn.get_database(mongo_db).get_collection(mongo_collection)
 
+    def create_collection(
+        self,
+        mongo_collection: str,
+        mongo_db: str | None = None,
+        return_if_exists: bool = True,
+        **create_kwargs: Any,
+    ) -> MongoCollection:
+        """
+        Create the collection (optionally a time‑series collection) and return 
it.
+
+        
https://pymongo.readthedocs.io/en/stable/api/pymongo/database.html#pymongo.database.Database.create_collection
+
+        :param mongo_collection: Name of the collection.
+        :param mongo_db: Target database; defaults to the schema in the 
connection string.
+        :param return_if_exists: If True and the collection already exists, 
return it instead of raising.
+        :param create_kwargs: Additional keyword arguments forwarded to 
``db.create_collection()``,
+                                  e.g. ``timeseries={...}``, ``capped=True``.
+        """
+        mongo_db = mongo_db or self.connection.schema
+        mongo_conn: MongoClient = self.get_conn()
+        db = mongo_conn.get_database(mongo_db)
+
+        try:
+            db.create_collection(mongo_collection, **create_kwargs)
+        except CollectionInvalid:
+            if not return_if_exists:
+                raise
+            # Collection already exists – fall through and fetch it.
+
+        return db.get_collection(mongo_collection)
+
     def aggregate(
         self, mongo_collection: str, aggregate_query: list, mongo_db: str | 
None = None, **kwargs
     ) -> CommandCursor:
diff --git a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py 
b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py
index 0d646e6e6cb..126617af88f 100644
--- a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py
+++ b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING
 
 import pymongo
 import pytest
+from pymongo.errors import CollectionInvalid
 
 from airflow.exceptions import AirflowConfigException
 from airflow.models import Connection
@@ -387,6 +388,36 @@ class TestMongoHook:
         results = self.hook.distinct(collection, "test_id", {"test_status": 
"failure"})
         assert len(results) == 1
 
+    def test_create_standard_collection(self):
+        mock_client = mongomock.MongoClient()
+        self.hook.get_conn = lambda: mock_client
+        self.hook.connection.schema = "test_db"
+
+        collection = 
self.hook.create_collection(mongo_collection="plain_collection")
+        assert collection.name == "plain_collection"
+        assert "plain_collection" in 
mock_client["test_db"].list_collection_names()
+
+    def test_return_if_exists_true_returns_existing(self):
+        mock_client = mongomock.MongoClient()
+        self.hook.get_conn = lambda: mock_client
+        self.hook.connection.schema = "test_db"
+
+        first = self.hook.create_collection(mongo_collection="foo")
+        second = self.hook.create_collection(mongo_collection="foo", 
return_if_exists=True)
+
+        assert first.full_name == second.full_name
+        assert "foo" in mock_client["test_db"].list_collection_names()
+
+    def test_return_if_exists_false_raises(self):
+        mock_client = mongomock.MongoClient()
+        self.hook.get_conn = lambda: mock_client
+        self.hook.connection.schema = "test_db"
+
+        self.hook.create_collection(mongo_collection="bar")
+
+        with pytest.raises(CollectionInvalid):
+            self.hook.create_collection(mongo_collection="bar", 
return_if_exists=False)
+
 
 def test_context_manager():
     with MongoHook(mongo_conn_id="mongo_default") as ctx_hook:

Reply via email to