This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5305f4b696 Fix WasbPrefixSensor arg inconsistency between sync and 
async mode (#36806)
5305f4b696 is described below

commit 5305f4b696cf5a786f30e5ebbeab25949b5bbdd4
Author: Wei Lee <[email protected]>
AuthorDate: Wed Jan 17 06:24:39 2024 +0800

    Fix WasbPrefixSensor arg inconsistency between sync and async mode (#36806)
---
 airflow/providers/microsoft/azure/sensors/wasb.py  |  10 +-
 airflow/providers/microsoft/azure/triggers/wasb.py |  20 +-
 tests/always/test_project_structure.py             |   1 -
 .../microsoft/azure/triggers/test_wasb.py          | 220 +++++++++++++++++++++
 4 files changed, 237 insertions(+), 14 deletions(-)

diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py 
b/airflow/providers/microsoft/azure/sensors/wasb.py
index 3cd3457527..a0a6bfb81c 100644
--- a/airflow/providers/microsoft/azure/sensors/wasb.py
+++ b/airflow/providers/microsoft/azure/sensors/wasb.py
@@ -149,6 +149,8 @@ class WasbPrefixSensor(BaseSensorOperator):
     :param wasb_conn_id: Reference to the wasb connection.
     :param check_options: Optional keyword arguments that
         `WasbHook.check_for_prefix()` takes.
+    :param public_read: whether an anonymous public read access should be 
used. Default is False
+    :param deferrable: Run operator in the deferrable mode.
     """
 
     template_fields: Sequence[str] = ("container_name", "prefix")
@@ -160,21 +162,23 @@ class WasbPrefixSensor(BaseSensorOperator):
         prefix: str,
         wasb_conn_id: str = "wasb_default",
         check_options: dict | None = None,
+        public_read: bool = False,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
         if check_options is None:
             check_options = {}
-        self.wasb_conn_id = wasb_conn_id
         self.container_name = container_name
         self.prefix = prefix
+        self.wasb_conn_id = wasb_conn_id
         self.check_options = check_options
+        self.public_read = public_read
         self.deferrable = deferrable
 
     def poke(self, context: Context) -> bool:
         self.log.info("Poking for prefix: %s in wasb://%s", self.prefix, 
self.container_name)
-        hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
+        hook = WasbHook(wasb_conn_id=self.wasb_conn_id, 
public_read=self.public_read)
         return hook.check_for_prefix(self.container_name, self.prefix, 
**self.check_options)
 
     def execute(self, context: Context) -> None:
@@ -193,6 +197,8 @@ class WasbPrefixSensor(BaseSensorOperator):
                         container_name=self.container_name,
                         prefix=self.prefix,
                         wasb_conn_id=self.wasb_conn_id,
+                        check_options=self.check_options,
+                        public_read=self.public_read,
                         poke_interval=self.poke_interval,
                     ),
                     method_name="execute_complete",
diff --git a/airflow/providers/microsoft/azure/triggers/wasb.py 
b/airflow/providers/microsoft/azure/triggers/wasb.py
index 944c7ddae1..6d74a3023b 100644
--- a/airflow/providers/microsoft/azure/triggers/wasb.py
+++ b/airflow/providers/microsoft/azure/triggers/wasb.py
@@ -102,26 +102,28 @@ class WasbPrefixSensorTrigger(BaseTrigger):
             ``copy``, ``deleted``
     :param delimiter: filters objects based on the delimiter (for e.g '.csv')
     :param wasb_conn_id: the connection identifier for connecting to Azure WASB
-    :param poke_interval:  polling period in seconds to check for the status
+    :param check_options: Optional keyword arguments that
+        `WasbAsyncHook.check_for_prefix_async()` takes.
     :param public_read: whether an anonymous public read access should be 
used. Default is False
+    :param poke_interval:  polling period in seconds to check for the status
     """
 
     def __init__(
         self,
         container_name: str,
         prefix: str,
-        include: list[str] | None = None,
-        delimiter: str = "/",
         wasb_conn_id: str = "wasb_default",
+        check_options: dict | None = None,
         public_read: bool = False,
         poke_interval: float = 5.0,
     ):
+        if not check_options:
+            check_options = {}
         super().__init__()
         self.container_name = container_name
         self.prefix = prefix
-        self.include = include
-        self.delimiter = delimiter
         self.wasb_conn_id = wasb_conn_id
+        self.check_options = check_options
         self.poke_interval = poke_interval
         self.public_read = public_read
 
@@ -132,10 +134,9 @@ class WasbPrefixSensorTrigger(BaseTrigger):
             {
                 "container_name": self.container_name,
                 "prefix": self.prefix,
-                "include": self.include,
-                "delimiter": self.delimiter,
                 "wasb_conn_id": self.wasb_conn_id,
                 "poke_interval": self.poke_interval,
+                "check_options": self.check_options,
                 "public_read": self.public_read,
             },
         )
@@ -148,10 +149,7 @@ class WasbPrefixSensorTrigger(BaseTrigger):
             async with await hook.get_async_conn():
                 while not prefix_exists:
                     prefix_exists = await hook.check_for_prefix_async(
-                        container_name=self.container_name,
-                        prefix=self.prefix,
-                        include=self.include,
-                        delimiter=self.delimiter,
+                        container_name=self.container_name, 
prefix=self.prefix, **self.check_options
                     )
                     if prefix_exists:
                         message = f"Prefix {self.prefix} found in container 
{self.container_name}."
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index ca59b70c6f..db026aa6bf 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -158,7 +158,6 @@ class TestProjectStructure:
             "tests/providers/microsoft/azure/fs/test_adls.py",
             "tests/providers/microsoft/azure/operators/test_adls.py",
             
"tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py",
-            "tests/providers/microsoft/azure/triggers/test_wasb.py",
             "tests/providers/mongo/sensors/test_mongo.py",
             "tests/providers/openlineage/extractors/test_manager.py",
             "tests/providers/openlineage/plugins/test_adapter.py",
diff --git a/tests/providers/microsoft/azure/triggers/test_wasb.py 
b/tests/providers/microsoft/azure/triggers/test_wasb.py
new file mode 100644
index 0000000000..1d6a185d48
--- /dev/null
+++ b/tests/providers/microsoft/azure/triggers/test_wasb.py
@@ -0,0 +1,220 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from unittest import mock
+
+import pytest
+
+from airflow.providers.microsoft.azure.triggers.wasb import (
+    WasbBlobSensorTrigger,
+    WasbPrefixSensorTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+TEST_DATA_STORAGE_BLOB_NAME = "test_blob_providers_team.txt"
+TEST_DATA_STORAGE_CONTAINER_NAME = "test-container-providers-team"
+TEST_DATA_STORAGE_BLOB_PREFIX = TEST_DATA_STORAGE_BLOB_NAME[:10]
+TEST_WASB_CONN_ID = "wasb_default"
+POKE_INTERVAL = 5.0
+
+
+class TestWasbBlobSensorTrigger:
+    TRIGGER = WasbBlobSensorTrigger(
+        container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
+        blob_name=TEST_DATA_STORAGE_BLOB_NAME,
+        wasb_conn_id=TEST_WASB_CONN_ID,
+        poke_interval=POKE_INTERVAL,
+    )
+
+    def test_serialization(self):
+        """
+        Asserts that the WasbBlobSensorTrigger correctly serializes its 
arguments
+        and classpath.
+        """
+
+        classpath, kwargs = self.TRIGGER.serialize()
+        assert classpath == 
"airflow.providers.microsoft.azure.triggers.wasb.WasbBlobSensorTrigger"
+        assert kwargs == {
+            "container_name": TEST_DATA_STORAGE_CONTAINER_NAME,
+            "blob_name": TEST_DATA_STORAGE_BLOB_NAME,
+            "wasb_conn_id": TEST_WASB_CONN_ID,
+            "poke_interval": POKE_INTERVAL,
+            "public_read": False,
+        }
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "blob_exists",
+        [True, False],
+    )
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_blob_async")
+    async def test_running(self, mock_check_for_blob, blob_exists):
+        """
+        Test if the task is run in trigger successfully.
+        """
+        mock_check_for_blob.return_value = blob_exists
+
+        task = asyncio.create_task(self.TRIGGER.run().__anext__())
+
+        # TriggerEvent was not returned
+        assert task.done() is False
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.db_test
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_blob_async")
+    async def test_success(self, mock_check_for_blob):
+        """Tests the success state for that the WasbBlobSensorTrigger."""
+        mock_check_for_blob.return_value = True
+
+        task = asyncio.create_task(self.TRIGGER.run().__anext__())
+        await asyncio.sleep(0.5)
+
+        # TriggerEvent was returned
+        assert task.done() is True
+        asyncio.get_event_loop().stop()
+
+        message = f"Blob {TEST_DATA_STORAGE_BLOB_NAME} found in container 
{TEST_DATA_STORAGE_CONTAINER_NAME}."
+        assert task.result() == TriggerEvent({"status": "success", "message": 
message})
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_blob_async")
+    async def test_waiting_for_blob(self, mock_check_for_blob, caplog):
+        """Tests the WasbBlobSensorTrigger sleeps waiting for the blob to 
arrive."""
+        mock_check_for_blob.side_effect = [False, True]
+        caplog.set_level(logging.INFO)
+
+        with mock.patch.object(self.TRIGGER.log, "info"):
+            task = asyncio.create_task(self.TRIGGER.run().__anext__())
+
+        await asyncio.sleep(POKE_INTERVAL + 0.5)
+
+        if not task.done():
+            message = (
+                f"Blob {TEST_DATA_STORAGE_BLOB_NAME} not available yet in 
container {TEST_DATA_STORAGE_CONTAINER_NAME}."
+                f" Sleeping for {POKE_INTERVAL} seconds"
+            )
+            assert message in caplog.text
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.db_test
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_blob_async")
+    async def test_trigger_exception(self, mock_check_for_blob):
+        """Tests the WasbBlobSensorTrigger yields an error event if there is 
an exception."""
+        mock_check_for_blob.side_effect = Exception("Test exception")
+
+        task = [i async for i in self.TRIGGER.run()]
+        assert len(task) == 1
+        assert TriggerEvent({"status": "error", "message": "Test exception"}) 
in task
+
+
+class TestWasbPrefixSensorTrigger:
+    TRIGGER = WasbPrefixSensorTrigger(
+        container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
+        prefix=TEST_DATA_STORAGE_BLOB_PREFIX,
+        wasb_conn_id=TEST_WASB_CONN_ID,
+        poke_interval=POKE_INTERVAL,
+        check_options={"delimiter": "/", "include": None},
+    )
+
+    def test_serialization(self):
+        """
+        Asserts that the WasbPrefixSensorTrigger correctly serializes its 
arguments and classpath."""
+
+        classpath, kwargs = self.TRIGGER.serialize()
+        assert classpath == 
"airflow.providers.microsoft.azure.triggers.wasb.WasbPrefixSensorTrigger"
+        assert kwargs == {
+            "container_name": TEST_DATA_STORAGE_CONTAINER_NAME,
+            "prefix": TEST_DATA_STORAGE_BLOB_PREFIX,
+            "wasb_conn_id": TEST_WASB_CONN_ID,
+            "public_read": False,
+            "check_options": {
+                "delimiter": "/",
+                "include": None,
+            },
+            "poke_interval": POKE_INTERVAL,
+        }
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "prefix_exists",
+        [True, False],
+    )
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_prefix_async")
+    async def test_running(self, mock_check_for_prefix, prefix_exists):
+        """Test if the task is run in trigger successfully."""
+        mock_check_for_prefix.return_value = prefix_exists
+
+        task = asyncio.create_task(self.TRIGGER.run().__anext__())
+
+        # TriggerEvent was not returned
+        assert task.done() is False
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.db_test
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_prefix_async")
+    async def test_success(self, mock_check_for_prefix):
+        """Tests the success state for that the WasbPrefixSensorTrigger."""
+        mock_check_for_prefix.return_value = True
+
+        task = asyncio.create_task(self.TRIGGER.run().__anext__())
+        await asyncio.sleep(0.5)
+
+        # TriggerEvent was returned
+        assert task.done() is True
+        asyncio.get_event_loop().stop()
+
+        message = (
+            f"Prefix {TEST_DATA_STORAGE_BLOB_PREFIX} found in container 
{TEST_DATA_STORAGE_CONTAINER_NAME}."
+        )
+        assert task.result() == TriggerEvent({"status": "success", "message": 
message})
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_prefix_async")
+    async def test_waiting_for_blob(self, mock_check_for_prefix):
+        """Tests the WasbPrefixSensorTrigger sleeps waiting for the blob to 
arrive."""
+        mock_check_for_prefix.side_effect = [False, True]
+
+        with mock.patch.object(self.TRIGGER.log, "info") as mock_log_info:
+            task = asyncio.create_task(self.TRIGGER.run().__anext__())
+
+        await asyncio.sleep(POKE_INTERVAL + 0.5)
+
+        if not task.done():
+            message = (
+                f"Prefix {TEST_DATA_STORAGE_BLOB_PREFIX} not available yet in 
container "
+                f"{TEST_DATA_STORAGE_CONTAINER_NAME}. Sleeping for 
{POKE_INTERVAL} seconds"
+            )
+            mock_log_info.assert_called_once_with(message)
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.db_test
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_prefix_async")
+    async def test_trigger_exception(self, mock_check_for_prefix):
+        """Tests the WasbPrefixSensorTrigger yields an error event if there is 
an exception."""
+        mock_check_for_prefix.side_effect = Exception("Test exception")
+
+        task = [i async for i in self.TRIGGER.run()]
+        assert len(task) == 1
+        assert TriggerEvent({"status": "error", "message": "Test exception"}) 
in task

Reply via email to