This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5305f4b696 Fix WasbPrefixSensor arg inconsistency between sync and
async mode (#36806)
5305f4b696 is described below
commit 5305f4b696cf5a786f30e5ebbeab25949b5bbdd4
Author: Wei Lee <[email protected]>
AuthorDate: Wed Jan 17 06:24:39 2024 +0800
Fix WasbPrefixSensor arg inconsistency between sync and async mode (#36806)
---
airflow/providers/microsoft/azure/sensors/wasb.py | 10 +-
airflow/providers/microsoft/azure/triggers/wasb.py | 20 +-
tests/always/test_project_structure.py | 1 -
.../microsoft/azure/triggers/test_wasb.py | 220 +++++++++++++++++++++
4 files changed, 237 insertions(+), 14 deletions(-)
diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py
b/airflow/providers/microsoft/azure/sensors/wasb.py
index 3cd3457527..a0a6bfb81c 100644
--- a/airflow/providers/microsoft/azure/sensors/wasb.py
+++ b/airflow/providers/microsoft/azure/sensors/wasb.py
@@ -149,6 +149,8 @@ class WasbPrefixSensor(BaseSensorOperator):
:param wasb_conn_id: Reference to the wasb connection.
:param check_options: Optional keyword arguments that
`WasbHook.check_for_prefix()` takes.
+ :param public_read: whether an anonymous public read access should be
used. Default is False
+ :param deferrable: Run operator in the deferrable mode.
"""
template_fields: Sequence[str] = ("container_name", "prefix")
@@ -160,21 +162,23 @@ class WasbPrefixSensor(BaseSensorOperator):
prefix: str,
wasb_conn_id: str = "wasb_default",
check_options: dict | None = None,
+ public_read: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
if check_options is None:
check_options = {}
- self.wasb_conn_id = wasb_conn_id
self.container_name = container_name
self.prefix = prefix
+ self.wasb_conn_id = wasb_conn_id
self.check_options = check_options
+ self.public_read = public_read
self.deferrable = deferrable
def poke(self, context: Context) -> bool:
self.log.info("Poking for prefix: %s in wasb://%s", self.prefix,
self.container_name)
- hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
+ hook = WasbHook(wasb_conn_id=self.wasb_conn_id,
public_read=self.public_read)
return hook.check_for_prefix(self.container_name, self.prefix,
**self.check_options)
def execute(self, context: Context) -> None:
@@ -193,6 +197,8 @@ class WasbPrefixSensor(BaseSensorOperator):
container_name=self.container_name,
prefix=self.prefix,
wasb_conn_id=self.wasb_conn_id,
+ check_options=self.check_options,
+ public_read=self.public_read,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
diff --git a/airflow/providers/microsoft/azure/triggers/wasb.py
b/airflow/providers/microsoft/azure/triggers/wasb.py
index 944c7ddae1..6d74a3023b 100644
--- a/airflow/providers/microsoft/azure/triggers/wasb.py
+++ b/airflow/providers/microsoft/azure/triggers/wasb.py
@@ -102,26 +102,28 @@ class WasbPrefixSensorTrigger(BaseTrigger):
``copy``, ``deleted``
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:param wasb_conn_id: the connection identifier for connecting to Azure WASB
- :param poke_interval: polling period in seconds to check for the status
+ :param check_options: Optional keyword arguments that
+ `WasbAsyncHook.check_for_prefix_async()` takes.
:param public_read: whether an anonymous public read access should be
used. Default is False
+ :param poke_interval: polling period in seconds to check for the status
"""
def __init__(
self,
container_name: str,
prefix: str,
- include: list[str] | None = None,
- delimiter: str = "/",
wasb_conn_id: str = "wasb_default",
+ check_options: dict | None = None,
public_read: bool = False,
poke_interval: float = 5.0,
):
+ if not check_options:
+ check_options = {}
super().__init__()
self.container_name = container_name
self.prefix = prefix
- self.include = include
- self.delimiter = delimiter
self.wasb_conn_id = wasb_conn_id
+ self.check_options = check_options
self.poke_interval = poke_interval
self.public_read = public_read
@@ -132,10 +134,9 @@ class WasbPrefixSensorTrigger(BaseTrigger):
{
"container_name": self.container_name,
"prefix": self.prefix,
- "include": self.include,
- "delimiter": self.delimiter,
"wasb_conn_id": self.wasb_conn_id,
"poke_interval": self.poke_interval,
+ "check_options": self.check_options,
"public_read": self.public_read,
},
)
@@ -148,10 +149,7 @@ class WasbPrefixSensorTrigger(BaseTrigger):
async with await hook.get_async_conn():
while not prefix_exists:
prefix_exists = await hook.check_for_prefix_async(
- container_name=self.container_name,
- prefix=self.prefix,
- include=self.include,
- delimiter=self.delimiter,
+ container_name=self.container_name,
prefix=self.prefix, **self.check_options
)
if prefix_exists:
message = f"Prefix {self.prefix} found in container
{self.container_name}."
diff --git a/tests/always/test_project_structure.py
b/tests/always/test_project_structure.py
index ca59b70c6f..db026aa6bf 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -158,7 +158,6 @@ class TestProjectStructure:
"tests/providers/microsoft/azure/fs/test_adls.py",
"tests/providers/microsoft/azure/operators/test_adls.py",
"tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py",
- "tests/providers/microsoft/azure/triggers/test_wasb.py",
"tests/providers/mongo/sensors/test_mongo.py",
"tests/providers/openlineage/extractors/test_manager.py",
"tests/providers/openlineage/plugins/test_adapter.py",
diff --git a/tests/providers/microsoft/azure/triggers/test_wasb.py
b/tests/providers/microsoft/azure/triggers/test_wasb.py
new file mode 100644
index 0000000000..1d6a185d48
--- /dev/null
+++ b/tests/providers/microsoft/azure/triggers/test_wasb.py
@@ -0,0 +1,220 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from unittest import mock
+
+import pytest
+
+from airflow.providers.microsoft.azure.triggers.wasb import (
+ WasbBlobSensorTrigger,
+ WasbPrefixSensorTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+TEST_DATA_STORAGE_BLOB_NAME = "test_blob_providers_team.txt"
+TEST_DATA_STORAGE_CONTAINER_NAME = "test-container-providers-team"
+TEST_DATA_STORAGE_BLOB_PREFIX = TEST_DATA_STORAGE_BLOB_NAME[:10]
+TEST_WASB_CONN_ID = "wasb_default"
+POKE_INTERVAL = 5.0
+
+
+class TestWasbBlobSensorTrigger:
+ TRIGGER = WasbBlobSensorTrigger(
+ container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
+ blob_name=TEST_DATA_STORAGE_BLOB_NAME,
+ wasb_conn_id=TEST_WASB_CONN_ID,
+ poke_interval=POKE_INTERVAL,
+ )
+
+ def test_serialization(self):
+ """
+ Asserts that the WasbBlobSensorTrigger correctly serializes its
arguments
+ and classpath.
+ """
+
+ classpath, kwargs = self.TRIGGER.serialize()
+ assert classpath ==
"airflow.providers.microsoft.azure.triggers.wasb.WasbBlobSensorTrigger"
+ assert kwargs == {
+ "container_name": TEST_DATA_STORAGE_CONTAINER_NAME,
+ "blob_name": TEST_DATA_STORAGE_BLOB_NAME,
+ "wasb_conn_id": TEST_WASB_CONN_ID,
+ "poke_interval": POKE_INTERVAL,
+ "public_read": False,
+ }
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "blob_exists",
+ [True, False],
+ )
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_blob_async")
+ async def test_running(self, mock_check_for_blob, blob_exists):
+ """
+ Test if the task is run in trigger successfully.
+ """
+ mock_check_for_blob.return_value = blob_exists
+
+ task = asyncio.create_task(self.TRIGGER.run().__anext__())
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.db_test
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_blob_async")
+ async def test_success(self, mock_check_for_blob):
+ """Tests the success state for that the WasbBlobSensorTrigger."""
+ mock_check_for_blob.return_value = True
+
+ task = asyncio.create_task(self.TRIGGER.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was returned
+ assert task.done() is True
+ asyncio.get_event_loop().stop()
+
+ message = f"Blob {TEST_DATA_STORAGE_BLOB_NAME} found in container
{TEST_DATA_STORAGE_CONTAINER_NAME}."
+ assert task.result() == TriggerEvent({"status": "success", "message":
message})
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_blob_async")
+ async def test_waiting_for_blob(self, mock_check_for_blob, caplog):
+ """Tests the WasbBlobSensorTrigger sleeps waiting for the blob to
arrive."""
+ mock_check_for_blob.side_effect = [False, True]
+ caplog.set_level(logging.INFO)
+
+ with mock.patch.object(self.TRIGGER.log, "info"):
+ task = asyncio.create_task(self.TRIGGER.run().__anext__())
+
+ await asyncio.sleep(POKE_INTERVAL + 0.5)
+
+ if not task.done():
+ message = (
+ f"Blob {TEST_DATA_STORAGE_BLOB_NAME} not available yet in
container {TEST_DATA_STORAGE_CONTAINER_NAME}."
+ f" Sleeping for {POKE_INTERVAL} seconds"
+ )
+ assert message in caplog.text
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.db_test
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_blob_async")
+ async def test_trigger_exception(self, mock_check_for_blob):
+ """Tests the WasbBlobSensorTrigger yields an error event if there is
an exception."""
+ mock_check_for_blob.side_effect = Exception("Test exception")
+
+ task = [i async for i in self.TRIGGER.run()]
+ assert len(task) == 1
+ assert TriggerEvent({"status": "error", "message": "Test exception"})
in task
+
+
+class TestWasbPrefixSensorTrigger:
+ TRIGGER = WasbPrefixSensorTrigger(
+ container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
+ prefix=TEST_DATA_STORAGE_BLOB_PREFIX,
+ wasb_conn_id=TEST_WASB_CONN_ID,
+ poke_interval=POKE_INTERVAL,
+ check_options={"delimiter": "/", "include": None},
+ )
+
+ def test_serialization(self):
+ """
+ Asserts that the WasbPrefixSensorTrigger correctly serializes its
arguments and classpath."""
+
+ classpath, kwargs = self.TRIGGER.serialize()
+ assert classpath ==
"airflow.providers.microsoft.azure.triggers.wasb.WasbPrefixSensorTrigger"
+ assert kwargs == {
+ "container_name": TEST_DATA_STORAGE_CONTAINER_NAME,
+ "prefix": TEST_DATA_STORAGE_BLOB_PREFIX,
+ "wasb_conn_id": TEST_WASB_CONN_ID,
+ "public_read": False,
+ "check_options": {
+ "delimiter": "/",
+ "include": None,
+ },
+ "poke_interval": POKE_INTERVAL,
+ }
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "prefix_exists",
+ [True, False],
+ )
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_prefix_async")
+ async def test_running(self, mock_check_for_prefix, prefix_exists):
+ """Test if the task is run in trigger successfully."""
+ mock_check_for_prefix.return_value = prefix_exists
+
+ task = asyncio.create_task(self.TRIGGER.run().__anext__())
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.db_test
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_prefix_async")
+ async def test_success(self, mock_check_for_prefix):
+ """Tests the success state for that the WasbPrefixSensorTrigger."""
+ mock_check_for_prefix.return_value = True
+
+ task = asyncio.create_task(self.TRIGGER.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was returned
+ assert task.done() is True
+ asyncio.get_event_loop().stop()
+
+ message = (
+ f"Prefix {TEST_DATA_STORAGE_BLOB_PREFIX} found in container
{TEST_DATA_STORAGE_CONTAINER_NAME}."
+ )
+ assert task.result() == TriggerEvent({"status": "success", "message":
message})
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_prefix_async")
+ async def test_waiting_for_blob(self, mock_check_for_prefix):
+ """Tests the WasbPrefixSensorTrigger sleeps waiting for the blob to
arrive."""
+ mock_check_for_prefix.side_effect = [False, True]
+
+ with mock.patch.object(self.TRIGGER.log, "info") as mock_log_info:
+ task = asyncio.create_task(self.TRIGGER.run().__anext__())
+
+ await asyncio.sleep(POKE_INTERVAL + 0.5)
+
+ if not task.done():
+ message = (
+ f"Prefix {TEST_DATA_STORAGE_BLOB_PREFIX} not available yet in
container "
+ f"{TEST_DATA_STORAGE_CONTAINER_NAME}. Sleeping for
{POKE_INTERVAL} seconds"
+ )
+ mock_log_info.assert_called_once_with(message)
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.db_test
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_prefix_async")
+ async def test_trigger_exception(self, mock_check_for_prefix):
+ """Tests the WasbPrefixSensorTrigger yields an error event if there is
an exception."""
+ mock_check_for_prefix.side_effect = Exception("Test exception")
+
+ task = [i async for i in self.TRIGGER.run()]
+ assert len(task) == 1
+ assert TriggerEvent({"status": "error", "message": "Test exception"})
in task