This is an automated email from the ASF dual-hosted git repository.
joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 54a59d7cae Add `WasbBlobAsyncSensor` (#30197)
54a59d7cae is described below
commit 54a59d7cae5d49993e018ad408954c44f15dc509
Author: Phani Kumar <[email protected]>
AuthorDate: Fri Mar 31 22:31:53 2023 +0530
Add `WasbBlobAsyncSensor` (#30197)
* Add WasbBlobAsyncSensor
* Add WasbBlobAsyncSensor
* Add WasbBlobAsyncSensor
* add tests and example DAG
* Fix failure during deferral
* Apply review suggestions
* Add wasb async conn
* Specify timeout in the tests
* Fix tests
---
airflow/providers/microsoft/azure/hooks/wasb.py | 167 ++++++++++++++++++++-
airflow/providers/microsoft/azure/sensors/wasb.py | 64 +++++++-
airflow/providers/microsoft/azure/triggers/wasb.py | 88 +++++++++++
.../providers/microsoft/azure/sensors/test_wasb.py | 103 ++++++++++++-
.../microsoft/azure/example_azure_blob_to_gcs.py | 6 +-
5 files changed, 422 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py
b/airflow/providers/microsoft/azure/hooks/wasb.py
index ce6870eee0..1c5b7d8b3e 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -28,15 +28,28 @@ from __future__ import annotations
import logging
import os
from functools import wraps
-from typing import Any
+from typing import Any, Union
+from asgiref.sync import sync_to_async
from azure.core.exceptions import HttpResponseError, ResourceExistsError,
ResourceNotFoundError
from azure.identity import ClientSecretCredential, DefaultAzureCredential
+from azure.identity.aio import (
+ ClientSecretCredential as AsyncClientSecretCredential,
+ DefaultAzureCredential as AsyncDefaultAzureCredential,
+)
from azure.storage.blob import BlobClient, BlobServiceClient, ContainerClient,
StorageStreamDownloader
+from azure.storage.blob._models import BlobProperties
+from azure.storage.blob.aio import (
+ BlobClient as AsyncBlobClient,
+ BlobServiceClient as AsyncBlobServiceClient,
+ ContainerClient as AsyncContainerClient,
+)
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
+AsyncCredentials = Union[AsyncClientSecretCredential,
AsyncDefaultAzureCredential]
+
def _ensure_prefixes(conn_type):
"""
@@ -502,3 +515,155 @@ class WasbHook(BaseHook):
return success
except Exception as e:
return False, str(e)
+
+
+class WasbAsyncHook(WasbHook):
+ """
+ An async hook that connects to Azure WASB to perform operations.
+
+ :param wasb_conn_id: reference to the :ref:`wasb connection
<howto/connection:wasb>`
+ :param public_read: whether an anonymous public read access should be
used. default is False
+ """
+
+ def __init__(
+ self,
+ wasb_conn_id: str = "wasb_default",
+ public_read: bool = False,
+ ) -> None:
+ """Initialize the hook instance."""
+ self.conn_id = wasb_conn_id
+ self.public_read = public_read
+ self.blob_service_client: AsyncBlobServiceClient = None # type: ignore
+
+ async def get_async_conn(self) -> AsyncBlobServiceClient:
+ """Return the Async BlobServiceClient object."""
+ if self.blob_service_client is not None:
+ return self.blob_service_client
+
+ conn = await sync_to_async(self.get_connection)(self.conn_id)
+ extra = conn.extra_dejson or {}
+
+ if self.public_read:
+ # Here we use anonymous public read
+ # more info
+ #
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
+ self.blob_service_client =
AsyncBlobServiceClient(account_url=conn.host, **extra)
+ return self.blob_service_client
+
+ connection_string = self._get_field(extra, "connection_string")
+ if connection_string:
+ # connection_string auth takes priority
+ self.blob_service_client =
AsyncBlobServiceClient.from_connection_string(
+ connection_string, **extra
+ )
+ return self.blob_service_client
+
+ shared_access_key = self._get_field(extra, "shared_access_key")
+ if shared_access_key:
+ # using shared access key
+ self.blob_service_client = AsyncBlobServiceClient(
+ account_url=conn.host, credential=shared_access_key, **extra
+ )
+ return self.blob_service_client
+
+ tenant = self._get_field(extra, "tenant_id")
+ if tenant:
+ # use Active Directory auth
+ app_id = conn.login
+ app_secret = conn.password
+ token_credential = AsyncClientSecretCredential(tenant, app_id,
app_secret)
+ self.blob_service_client = AsyncBlobServiceClient(
+ account_url=conn.host, credential=token_credential, **extra #
type:ignore[arg-type]
+ )
+ return self.blob_service_client
+
+ sas_token = self._get_field(extra, "sas_token")
+ if sas_token:
+ if sas_token.startswith("https"):
+ self.blob_service_client =
AsyncBlobServiceClient(account_url=sas_token, **extra)
+ else:
+ self.blob_service_client = AsyncBlobServiceClient(
+
account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}", **extra
+ )
+ return self.blob_service_client
+
+ # Fall back to old auth (password) or use managed identity if not
provided.
+ credential = conn.password
+ if not credential:
+ credential = AsyncDefaultAzureCredential()
+ self.log.info("Using DefaultAzureCredential as credential")
+ self.blob_service_client = AsyncBlobServiceClient(
+ account_url=f"https://{conn.login}.blob.core.windows.net/",
+ credential=credential,
+ **extra,
+ )
+
+ return self.blob_service_client
+
+ def _get_blob_client(self, container_name: str, blob_name: str) ->
AsyncBlobClient:
+ """
+ Instantiate a blob client.
+
+ :param container_name: the name of the blob container
+ :param blob_name: the name of the blob. This needs not be existing
+ """
+ return
self.blob_service_client.get_blob_client(container=container_name,
blob=blob_name)
+
+ async def check_for_blob_async(self, container_name: str, blob_name: str,
**kwargs: Any) -> bool:
+ """
+ Check if a blob exists on Azure Blob Storage.
+
+ :param container_name: name of the container
+ :param blob_name: name of the blob
+ :param kwargs: optional keyword arguments for
``BlobClient.get_blob_properties``
+ """
+ try:
+ await self._get_blob_client(container_name,
blob_name).get_blob_properties(**kwargs)
+ except ResourceNotFoundError:
+ return False
+ return True
+
+ def _get_container_client(self, container_name: str) ->
AsyncContainerClient:
+ """
+ Instantiate a container client.
+
+ :param container_name: the name of the container
+ """
+ return self.blob_service_client.get_container_client(container_name)
+
+ async def get_blobs_list_async(
+ self,
+ container_name: str,
+ prefix: str | None = None,
+ include: list[str] | None = None,
+ delimiter: str = "/",
+ **kwargs: Any,
+ ) -> list[BlobProperties]:
+ """
+ List blobs in a given container.
+
+ :param container_name: the name of the container
+ :param prefix: filters the results to return only blobs whose names
+ begin with the specified prefix.
+ :param include: specifies one or more additional datasets to include
in the
+ response. Options include: ``snapshots``, ``metadata``,
``uncommittedblobs``,
+ ``copy`, ``deleted``.
+ :param delimiter: filters objects based on the delimiter (for e.g
'.csv')
+ """
+ container = self._get_container_client(container_name)
+ blob_list = []
+ blobs = container.walk_blobs(name_starts_with=prefix, include=include,
delimiter=delimiter, **kwargs)
+ async for blob in blobs:
+ blob_list.append(blob.name)
+ return blob_list
+
+ async def check_for_prefix_async(self, container_name: str, prefix: str,
**kwargs: Any) -> bool:
+ """
+ Check if a prefix exists on Azure Blob storage.
+
+ :param container_name: Name of the container.
+ :param prefix: Prefix of the blob.
+ :param kwargs: Optional keyword arguments for
``ContainerClient.walk_blobs``
+ """
+ blobs = await self.get_blobs_list_async(container_name=container_name,
prefix=prefix, **kwargs)
+ return len(blobs) > 0
diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py
b/airflow/providers/microsoft/azure/sensors/wasb.py
index 388a571f7d..017d73720d 100644
--- a/airflow/providers/microsoft/azure/sensors/wasb.py
+++ b/airflow/providers/microsoft/azure/sensors/wasb.py
@@ -17,9 +17,12 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Sequence
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Sequence
+from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
+from airflow.providers.microsoft.azure.triggers.wasb import
WasbBlobSensorTrigger
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
@@ -62,6 +65,65 @@ class WasbBlobSensor(BaseSensorOperator):
return hook.check_for_blob(self.container_name, self.blob_name,
**self.check_options)
+class WasbBlobAsyncSensor(WasbBlobSensor):
+ """
+ Polls asynchronously for the existence of a blob in a WASB container.
+
+ :param container_name: name of the container in which the blob should be
searched for
+ :param blob_name: name of the blob to check existence for
+ :param wasb_conn_id: the connection identifier for connecting to Azure WASB
+ :param poke_interval: polling period in seconds to check for the status
+ :param public_read: whether an anonymous public read access should be
used. Default is False
+ :param timeout: Time, in seconds before the task times out and fails.
+ """
+
+ def __init__(
+ self,
+ *,
+ container_name: str,
+ blob_name: str,
+ wasb_conn_id: str = "wasb_default",
+ public_read: bool = False,
+ poke_interval: float = 5.0,
+ **kwargs: Any,
+ ):
+ self.container_name = container_name
+ self.blob_name = blob_name
+ self.poke_interval = poke_interval
+ super().__init__(container_name=container_name, blob_name=blob_name,
**kwargs)
+ self.wasb_conn_id = wasb_conn_id
+ self.public_read = public_read
+
+ def execute(self, context: Context) -> None:
+ """Defers trigger class to poll for state of the job run until it
reaches
+ a failure state or success state
+ """
+ self.defer(
+ timeout=timedelta(seconds=self.timeout),
+ trigger=WasbBlobSensorTrigger(
+ container_name=self.container_name,
+ blob_name=self.blob_name,
+ wasb_conn_id=self.wasb_conn_id,
+ public_read=self.public_read,
+ poke_interval=self.poke_interval,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict[str, str]) ->
None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if event:
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info(event["message"])
+ else:
+ raise AirflowException("Did not receive valid event from the
triggerer")
+
+
class WasbPrefixSensor(BaseSensorOperator):
"""
Waits for blobs matching a prefix to arrive on Azure Blob Storage.
diff --git a/airflow/providers/microsoft/azure/triggers/wasb.py
b/airflow/providers/microsoft/azure/triggers/wasb.py
new file mode 100644
index 0000000000..cea5b76fbf
--- /dev/null
+++ b/airflow/providers/microsoft/azure/triggers/wasb.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from typing import Any, AsyncIterator
+
+from airflow.providers.microsoft.azure.hooks.wasb import WasbAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class WasbBlobSensorTrigger(BaseTrigger):
+ """
+ WasbBlobSensorTrigger is fired as deferred class with params to run the
task in
+ trigger worker to check for existence of the given blob in the provided
container.
+
+ :param container_name: name of the container in which the blob should be
searched for
+ :param blob_name: name of the blob to check existence for
+ :param wasb_conn_id: the connection identifier for connecting to Azure WASB
+ :param poke_interval: polling period in seconds to check for the status
+ :param public_read: whether an anonymous public read access should be
used. Default is False
+ """
+
+ def __init__(
+ self,
+ container_name: str,
+ blob_name: str,
+ wasb_conn_id: str = "wasb_default",
+ public_read: bool = False,
+ poke_interval: float = 5.0,
+ ):
+ super().__init__()
+ self.container_name = container_name
+ self.blob_name = blob_name
+ self.wasb_conn_id = wasb_conn_id
+ self.poke_interval = poke_interval
+ self.public_read = public_read
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes WasbBlobSensorTrigger arguments and classpath."""
+ return (
+
"airflow.providers.microsoft.azure.triggers.wasb.WasbBlobSensorTrigger",
+ {
+ "container_name": self.container_name,
+ "blob_name": self.blob_name,
+ "wasb_conn_id": self.wasb_conn_id,
+ "poke_interval": self.poke_interval,
+ "public_read": self.public_read,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]:
+ """Makes async connection to Azure WASB and polls for existence of the
given blob name."""
+ blob_exists = False
+ hook = WasbAsyncHook(wasb_conn_id=self.wasb_conn_id,
public_read=self.public_read)
+ try:
+ async with await hook.get_async_conn():
+ while not blob_exists:
+ blob_exists = await hook.check_for_blob_async(
+ container_name=self.container_name,
+ blob_name=self.blob_name,
+ )
+ if blob_exists:
+ message = f"Blob {self.blob_name} found in container
{self.container_name}."
+ yield TriggerEvent({"status": "success", "message":
message})
+ else:
+ message = (
+ f"Blob {self.blob_name} not available yet in
container {self.container_name}."
+ f" Sleeping for {self.poke_interval} seconds"
+ )
+ self.log.info(message)
+ await asyncio.sleep(self.poke_interval)
+ except Exception as e:
+ yield TriggerEvent({"status": "error", "message": str(e)})
diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py
b/tests/providers/microsoft/azure/sensors/test_wasb.py
index 5edc9bcd84..67830c44e3 100644
--- a/tests/providers/microsoft/azure/sensors/test_wasb.py
+++ b/tests/providers/microsoft/azure/sensors/test_wasb.py
@@ -20,8 +20,26 @@ from __future__ import annotations
import datetime
from unittest import mock
-from airflow.models.dag import DAG
-from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor,
WasbPrefixSensor
+import pendulum
+import pytest
+
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.models import DAG, Connection
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
+from airflow.providers.microsoft.azure.sensors.wasb import (
+ WasbBlobAsyncSensor,
+ WasbBlobSensor,
+ WasbPrefixSensor,
+)
+from airflow.providers.microsoft.azure.triggers.wasb import
WasbBlobSensorTrigger
+from airflow.utils import timezone
+from airflow.utils.types import DagRunType
+
+TEST_DATA_STORAGE_BLOB_NAME = "test_blob_providers.txt"
+TEST_DATA_STORAGE_CONTAINER_NAME = "test-container-providers"
+TEST_DATA_STORAGE_BLOB_PREFIX = TEST_DATA_STORAGE_BLOB_NAME[:10]
class TestWasbBlobSensor:
@@ -59,6 +77,87 @@ class TestWasbBlobSensor:
mock_instance.check_for_blob.assert_called_once_with("container",
"blob", timeout=2)
+class TestWasbBlobAsyncSensor:
+ def get_dag_run(self, dag_id: str = "test_dag_id", run_id: str =
"test_dag_id") -> DagRun:
+ dag_run = DagRun(
+ dag_id=dag_id, run_type="manual",
execution_date=timezone.datetime(2022, 1, 1), run_id=run_id
+ )
+ return dag_run
+
+ def get_task_instance(self, task: BaseOperator) -> TaskInstance:
+ return TaskInstance(task, timezone.datetime(2022, 1, 1))
+
+ def get_conn(self) -> Connection:
+ return Connection(
+ conn_id="test_conn",
+ extra={},
+ )
+
+ def create_context(self, task, dag=None):
+ if dag is None:
+ dag = DAG(dag_id="dag")
+ tzinfo = pendulum.timezone("UTC")
+ execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+ dag_run = DagRun(
+ dag_id=dag.dag_id,
+ execution_date=execution_date,
+ run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
+ )
+
+ task_instance = TaskInstance(task=task)
+ task_instance.dag_run = dag_run
+ task_instance.xcom_push = mock.Mock()
+ return {
+ "dag": dag,
+ "ts": execution_date.isoformat(),
+ "task": task,
+ "ti": task_instance,
+ "task_instance": task_instance,
+ "run_id": dag_run.run_id,
+ "dag_run": dag_run,
+ "execution_date": execution_date,
+ "data_interval_end": execution_date,
+ "logical_date": execution_date,
+ }
+
+ SENSOR = WasbBlobAsyncSensor(
+ task_id="wasb_blob_async_sensor",
+ container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
+ blob_name=TEST_DATA_STORAGE_BLOB_NAME,
+ timeout=5,
+ )
+
+ def test_wasb_blob_sensor_async(self):
+ """Assert execute method defer for wasb blob sensor"""
+
+ with pytest.raises(TaskDeferred) as exc:
+ self.SENSOR.execute(self.create_context(self.SENSOR))
+ assert isinstance(exc.value.trigger, WasbBlobSensorTrigger), "Trigger
is not a WasbBlobSensorTrigger"
+ assert exc.value.timeout == datetime.timedelta(seconds=5)
+
+ @pytest.mark.parametrize(
+ "event",
+ [None, {"status": "success", "message": "Job completed"}],
+ )
+ def test_wasb_blob_sensor_execute_complete_success(self, event):
+ """Assert execute_complete log success message when trigger fire with
target status."""
+
+ if not event:
+ with pytest.raises(AirflowException) as exception_info:
+ self.SENSOR.execute_complete(context=None, event=None)
+ assert exception_info.value.args[0] == "Did not receive valid
event from the triggerer"
+ else:
+ with mock.patch.object(self.SENSOR.log, "info") as mock_log_info:
+ self.SENSOR.execute_complete(context={}, event=event)
+ mock_log_info.assert_called_with(event["message"])
+
+ def test_wasb_blob_sensor_execute_complete_failure(self):
+ """Assert execute_complete method raises an exception when the
triggerer fires an error event."""
+
+ with pytest.raises(AirflowException):
+ self.SENSOR.execute_complete(context={}, event={"status": "error",
"message": ""})
+
+
class TestWasbPrefixSensor:
_config = {
"container_name": "container",
diff --git
a/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py
b/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py
index b2243e8bc8..83da48d985 100644
--- a/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py
+++ b/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py
@@ -21,7 +21,7 @@ import os
from datetime import datetime
from airflow import DAG
-from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor
+from airflow.providers.microsoft.azure.sensors.wasb import
WasbBlobAsyncSensor, WasbBlobSensor
from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import
AzureBlobStorageToGCSOperator
# Ignore missing args provided by default_args
@@ -46,6 +46,8 @@ with DAG(
wait_for_blob = WasbBlobSensor(task_id="wait_for_blob")
+ wait_for_blob_async = WasbBlobAsyncSensor(task_id="wait_for_blob_async")
+
transfer_files_to_gcs = AzureBlobStorageToGCSOperator(
task_id="transfer_files_to_gcs",
# AZURE arg
@@ -60,7 +62,7 @@ with DAG(
)
# [END how_to_azure_blob_to_gcs]
- wait_for_blob >> transfer_files_to_gcs
+ wait_for_blob >> wait_for_blob_async >> transfer_files_to_gcs
from tests.system.utils.watcher import watcher