This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 54a59d7cae Add `WasbBlobAsyncSensor` (#30197)
54a59d7cae is described below

commit 54a59d7cae5d49993e018ad408954c44f15dc509
Author: Phani Kumar <[email protected]>
AuthorDate: Fri Mar 31 22:31:53 2023 +0530

    Add `WasbBlobAsyncSensor` (#30197)
    
    * Add WasbBlobAsyncSensor
    
    * Add WasbBlobAsyncSensor
    
    * Add WasbBlobAsyncSensor
    
    * add tests and example DAG
    
    * Fix failure during deferral
    
    * Apply review suggestions
    
    * Add wasb async conn
    
    * Specify timeout in the tests
    
    * Fix tests
---
 airflow/providers/microsoft/azure/hooks/wasb.py    | 167 ++++++++++++++++++++-
 airflow/providers/microsoft/azure/sensors/wasb.py  |  64 +++++++-
 airflow/providers/microsoft/azure/triggers/wasb.py |  88 +++++++++++
 .../providers/microsoft/azure/sensors/test_wasb.py | 103 ++++++++++++-
 .../microsoft/azure/example_azure_blob_to_gcs.py   |   6 +-
 5 files changed, 422 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py 
b/airflow/providers/microsoft/azure/hooks/wasb.py
index ce6870eee0..1c5b7d8b3e 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -28,15 +28,28 @@ from __future__ import annotations
 import logging
 import os
 from functools import wraps
-from typing import Any
+from typing import Any, Union
 
+from asgiref.sync import sync_to_async
 from azure.core.exceptions import HttpResponseError, ResourceExistsError, 
ResourceNotFoundError
 from azure.identity import ClientSecretCredential, DefaultAzureCredential
+from azure.identity.aio import (
+    ClientSecretCredential as AsyncClientSecretCredential,
+    DefaultAzureCredential as AsyncDefaultAzureCredential,
+)
 from azure.storage.blob import BlobClient, BlobServiceClient, ContainerClient, 
StorageStreamDownloader
+from azure.storage.blob._models import BlobProperties
+from azure.storage.blob.aio import (
+    BlobClient as AsyncBlobClient,
+    BlobServiceClient as AsyncBlobServiceClient,
+    ContainerClient as AsyncContainerClient,
+)
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+AsyncCredentials = Union[AsyncClientSecretCredential, 
AsyncDefaultAzureCredential]
+
 
 def _ensure_prefixes(conn_type):
     """
@@ -502,3 +515,155 @@ class WasbHook(BaseHook):
             return success
         except Exception as e:
             return False, str(e)
+
+
+class WasbAsyncHook(WasbHook):
+    """
+    An async hook that connects to Azure WASB to perform operations.
+
+    :param wasb_conn_id: reference to the :ref:`wasb connection 
<howto/connection:wasb>`
+    :param public_read: whether an anonymous public read access should be 
used. default is False
+    """
+
+    def __init__(
+        self,
+        wasb_conn_id: str = "wasb_default",
+        public_read: bool = False,
+    ) -> None:
+        """Initialize the hook instance."""
+        self.conn_id = wasb_conn_id
+        self.public_read = public_read
+        self.blob_service_client: AsyncBlobServiceClient = None  # type: ignore
+
+    async def get_async_conn(self) -> AsyncBlobServiceClient:
+        """Return the Async BlobServiceClient object."""
+        if self.blob_service_client is not None:
+            return self.blob_service_client
+
+        conn = await sync_to_async(self.get_connection)(self.conn_id)
+        extra = conn.extra_dejson or {}
+
+        if self.public_read:
+            # Here we use anonymous public read
+            # more info
+            # 
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
+            self.blob_service_client = 
AsyncBlobServiceClient(account_url=conn.host, **extra)
+            return self.blob_service_client
+
+        connection_string = self._get_field(extra, "connection_string")
+        if connection_string:
+            # connection_string auth takes priority
+            self.blob_service_client = 
AsyncBlobServiceClient.from_connection_string(
+                connection_string, **extra
+            )
+            return self.blob_service_client
+
+        shared_access_key = self._get_field(extra, "shared_access_key")
+        if shared_access_key:
+            # using shared access key
+            self.blob_service_client = AsyncBlobServiceClient(
+                account_url=conn.host, credential=shared_access_key, **extra
+            )
+            return self.blob_service_client
+
+        tenant = self._get_field(extra, "tenant_id")
+        if tenant:
+            # use Active Directory auth
+            app_id = conn.login
+            app_secret = conn.password
+            token_credential = AsyncClientSecretCredential(tenant, app_id, 
app_secret)
+            self.blob_service_client = AsyncBlobServiceClient(
+                account_url=conn.host, credential=token_credential, **extra  # 
type:ignore[arg-type]
+            )
+            return self.blob_service_client
+
+        sas_token = self._get_field(extra, "sas_token")
+        if sas_token:
+            if sas_token.startswith("https"):
+                self.blob_service_client = 
AsyncBlobServiceClient(account_url=sas_token, **extra)
+            else:
+                self.blob_service_client = AsyncBlobServiceClient(
+                    
account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}";, **extra
+                )
+            return self.blob_service_client
+
+        # Fall back to old auth (password) or use managed identity if not 
provided.
+        credential = conn.password
+        if not credential:
+            credential = AsyncDefaultAzureCredential()
+            self.log.info("Using DefaultAzureCredential as credential")
+        self.blob_service_client = AsyncBlobServiceClient(
+            account_url=f"https://{conn.login}.blob.core.windows.net/";,
+            credential=credential,
+            **extra,
+        )
+
+        return self.blob_service_client
+
+    def _get_blob_client(self, container_name: str, blob_name: str) -> 
AsyncBlobClient:
+        """
+        Instantiate a blob client.
+
+        :param container_name: the name of the blob container
+        :param blob_name: the name of the blob. This needs not be existing
+        """
+        return 
self.blob_service_client.get_blob_client(container=container_name, 
blob=blob_name)
+
+    async def check_for_blob_async(self, container_name: str, blob_name: str, 
**kwargs: Any) -> bool:
+        """
+        Check if a blob exists on Azure Blob Storage.
+
+        :param container_name: name of the container
+        :param blob_name: name of the blob
+        :param kwargs: optional keyword arguments for 
``BlobClient.get_blob_properties``
+        """
+        try:
+            await self._get_blob_client(container_name, 
blob_name).get_blob_properties(**kwargs)
+        except ResourceNotFoundError:
+            return False
+        return True
+
+    def _get_container_client(self, container_name: str) -> 
AsyncContainerClient:
+        """
+        Instantiate a container client.
+
+        :param container_name: the name of the container
+        """
+        return self.blob_service_client.get_container_client(container_name)
+
+    async def get_blobs_list_async(
+        self,
+        container_name: str,
+        prefix: str | None = None,
+        include: list[str] | None = None,
+        delimiter: str = "/",
+        **kwargs: Any,
+    ) -> list[BlobProperties]:
+        """
+        List blobs in a given container.
+
+        :param container_name: the name of the container
+        :param prefix: filters the results to return only blobs whose names
+            begin with the specified prefix.
+        :param include: specifies one or more additional datasets to include 
in the
+            response. Options include: ``snapshots``, ``metadata``, 
``uncommittedblobs``,
+            ``copy`, ``deleted``.
+        :param delimiter: filters objects based on the delimiter (for e.g 
'.csv')
+        """
+        container = self._get_container_client(container_name)
+        blob_list = []
+        blobs = container.walk_blobs(name_starts_with=prefix, include=include, 
delimiter=delimiter, **kwargs)
+        async for blob in blobs:
+            blob_list.append(blob.name)
+        return blob_list
+
+    async def check_for_prefix_async(self, container_name: str, prefix: str, 
**kwargs: Any) -> bool:
+        """
+        Check if a prefix exists on Azure Blob storage.
+
+        :param container_name: Name of the container.
+        :param prefix: Prefix of the blob.
+        :param kwargs: Optional keyword arguments for 
``ContainerClient.walk_blobs``
+        """
+        blobs = await self.get_blobs_list_async(container_name=container_name, 
prefix=prefix, **kwargs)
+        return len(blobs) > 0
diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py 
b/airflow/providers/microsoft/azure/sensors/wasb.py
index 388a571f7d..017d73720d 100644
--- a/airflow/providers/microsoft/azure/sensors/wasb.py
+++ b/airflow/providers/microsoft/azure/sensors/wasb.py
@@ -17,9 +17,12 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Sequence
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Sequence
 
+from airflow.exceptions import AirflowException
 from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
+from airflow.providers.microsoft.azure.triggers.wasb import 
WasbBlobSensorTrigger
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -62,6 +65,65 @@ class WasbBlobSensor(BaseSensorOperator):
         return hook.check_for_blob(self.container_name, self.blob_name, 
**self.check_options)
 
 
+class WasbBlobAsyncSensor(WasbBlobSensor):
+    """
+    Polls asynchronously for the existence of a blob in a WASB container.
+
+    :param container_name: name of the container in which the blob should be 
searched for
+    :param blob_name: name of the blob to check existence for
+    :param wasb_conn_id: the connection identifier for connecting to Azure WASB
+    :param poke_interval:  polling period in seconds to check for the status
+    :param public_read: whether an anonymous public read access should be 
used. Default is False
+    :param timeout: Time, in seconds before the task times out and fails.
+    """
+
+    def __init__(
+        self,
+        *,
+        container_name: str,
+        blob_name: str,
+        wasb_conn_id: str = "wasb_default",
+        public_read: bool = False,
+        poke_interval: float = 5.0,
+        **kwargs: Any,
+    ):
+        self.container_name = container_name
+        self.blob_name = blob_name
+        self.poke_interval = poke_interval
+        super().__init__(container_name=container_name, blob_name=blob_name, 
**kwargs)
+        self.wasb_conn_id = wasb_conn_id
+        self.public_read = public_read
+
+    def execute(self, context: Context) -> None:
+        """Defers trigger class to poll for state of the job run until it 
reaches
+        a failure state or success state
+        """
+        self.defer(
+            timeout=timedelta(seconds=self.timeout),
+            trigger=WasbBlobSensorTrigger(
+                container_name=self.container_name,
+                blob_name=self.blob_name,
+                wasb_conn_id=self.wasb_conn_id,
+                public_read=self.public_read,
+                poke_interval=self.poke_interval,
+            ),
+            method_name="execute_complete",
+        )
+
+    def execute_complete(self, context: Context, event: dict[str, str]) -> 
None:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if event:
+            if event["status"] == "error":
+                raise AirflowException(event["message"])
+            self.log.info(event["message"])
+        else:
+            raise AirflowException("Did not receive valid event from the 
triggerer")
+
+
 class WasbPrefixSensor(BaseSensorOperator):
     """
     Waits for blobs matching a prefix to arrive on Azure Blob Storage.
diff --git a/airflow/providers/microsoft/azure/triggers/wasb.py 
b/airflow/providers/microsoft/azure/triggers/wasb.py
new file mode 100644
index 0000000000..cea5b76fbf
--- /dev/null
+++ b/airflow/providers/microsoft/azure/triggers/wasb.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from typing import Any, AsyncIterator
+
+from airflow.providers.microsoft.azure.hooks.wasb import WasbAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class WasbBlobSensorTrigger(BaseTrigger):
+    """
+    WasbBlobSensorTrigger is fired as deferred class with params to run the 
task in
+    trigger worker to check for existence of the given blob in the provided 
container.
+
+    :param container_name: name of the container in which the blob should be 
searched for
+    :param blob_name: name of the blob to check existence for
+    :param wasb_conn_id: the connection identifier for connecting to Azure WASB
+    :param poke_interval:  polling period in seconds to check for the status
+    :param public_read: whether an anonymous public read access should be 
used. Default is False
+    """
+
+    def __init__(
+        self,
+        container_name: str,
+        blob_name: str,
+        wasb_conn_id: str = "wasb_default",
+        public_read: bool = False,
+        poke_interval: float = 5.0,
+    ):
+        super().__init__()
+        self.container_name = container_name
+        self.blob_name = blob_name
+        self.wasb_conn_id = wasb_conn_id
+        self.poke_interval = poke_interval
+        self.public_read = public_read
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes WasbBlobSensorTrigger arguments and classpath."""
+        return (
+            
"airflow.providers.microsoft.azure.triggers.wasb.WasbBlobSensorTrigger",
+            {
+                "container_name": self.container_name,
+                "blob_name": self.blob_name,
+                "wasb_conn_id": self.wasb_conn_id,
+                "poke_interval": self.poke_interval,
+                "public_read": self.public_read,
+            },
+        )
+
+    async def run(self) -> AsyncIterator["TriggerEvent"]:
+        """Makes async connection to Azure WASB and polls for existence of the 
given blob name."""
+        blob_exists = False
+        hook = WasbAsyncHook(wasb_conn_id=self.wasb_conn_id, 
public_read=self.public_read)
+        try:
+            async with await hook.get_async_conn():
+                while not blob_exists:
+                    blob_exists = await hook.check_for_blob_async(
+                        container_name=self.container_name,
+                        blob_name=self.blob_name,
+                    )
+                    if blob_exists:
+                        message = f"Blob {self.blob_name} found in container 
{self.container_name}."
+                        yield TriggerEvent({"status": "success", "message": 
message})
+                    else:
+                        message = (
+                            f"Blob {self.blob_name} not available yet in 
container {self.container_name}."
+                            f" Sleeping for {self.poke_interval} seconds"
+                        )
+                        self.log.info(message)
+                        await asyncio.sleep(self.poke_interval)
+        except Exception as e:
+            yield TriggerEvent({"status": "error", "message": str(e)})
diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py 
b/tests/providers/microsoft/azure/sensors/test_wasb.py
index 5edc9bcd84..67830c44e3 100644
--- a/tests/providers/microsoft/azure/sensors/test_wasb.py
+++ b/tests/providers/microsoft/azure/sensors/test_wasb.py
@@ -20,8 +20,26 @@ from __future__ import annotations
 import datetime
 from unittest import mock
 
-from airflow.models.dag import DAG
-from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor, 
WasbPrefixSensor
+import pendulum
+import pytest
+
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.models import DAG, Connection
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
+from airflow.providers.microsoft.azure.sensors.wasb import (
+    WasbBlobAsyncSensor,
+    WasbBlobSensor,
+    WasbPrefixSensor,
+)
+from airflow.providers.microsoft.azure.triggers.wasb import 
WasbBlobSensorTrigger
+from airflow.utils import timezone
+from airflow.utils.types import DagRunType
+
+TEST_DATA_STORAGE_BLOB_NAME = "test_blob_providers.txt"
+TEST_DATA_STORAGE_CONTAINER_NAME = "test-container-providers"
+TEST_DATA_STORAGE_BLOB_PREFIX = TEST_DATA_STORAGE_BLOB_NAME[:10]
 
 
 class TestWasbBlobSensor:
@@ -59,6 +77,87 @@ class TestWasbBlobSensor:
         mock_instance.check_for_blob.assert_called_once_with("container", 
"blob", timeout=2)
 
 
+class TestWasbBlobAsyncSensor:
+    def get_dag_run(self, dag_id: str = "test_dag_id", run_id: str = 
"test_dag_id") -> DagRun:
+        dag_run = DagRun(
+            dag_id=dag_id, run_type="manual", 
execution_date=timezone.datetime(2022, 1, 1), run_id=run_id
+        )
+        return dag_run
+
+    def get_task_instance(self, task: BaseOperator) -> TaskInstance:
+        return TaskInstance(task, timezone.datetime(2022, 1, 1))
+
+    def get_conn(self) -> Connection:
+        return Connection(
+            conn_id="test_conn",
+            extra={},
+        )
+
+    def create_context(self, task, dag=None):
+        if dag is None:
+            dag = DAG(dag_id="dag")
+        tzinfo = pendulum.timezone("UTC")
+        execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+        dag_run = DagRun(
+            dag_id=dag.dag_id,
+            execution_date=execution_date,
+            run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
+        )
+
+        task_instance = TaskInstance(task=task)
+        task_instance.dag_run = dag_run
+        task_instance.xcom_push = mock.Mock()
+        return {
+            "dag": dag,
+            "ts": execution_date.isoformat(),
+            "task": task,
+            "ti": task_instance,
+            "task_instance": task_instance,
+            "run_id": dag_run.run_id,
+            "dag_run": dag_run,
+            "execution_date": execution_date,
+            "data_interval_end": execution_date,
+            "logical_date": execution_date,
+        }
+
+    SENSOR = WasbBlobAsyncSensor(
+        task_id="wasb_blob_async_sensor",
+        container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
+        blob_name=TEST_DATA_STORAGE_BLOB_NAME,
+        timeout=5,
+    )
+
+    def test_wasb_blob_sensor_async(self):
+        """Assert execute method defer for wasb blob sensor"""
+
+        with pytest.raises(TaskDeferred) as exc:
+            self.SENSOR.execute(self.create_context(self.SENSOR))
+        assert isinstance(exc.value.trigger, WasbBlobSensorTrigger), "Trigger 
is not a WasbBlobSensorTrigger"
+        assert exc.value.timeout == datetime.timedelta(seconds=5)
+
+    @pytest.mark.parametrize(
+        "event",
+        [None, {"status": "success", "message": "Job completed"}],
+    )
+    def test_wasb_blob_sensor_execute_complete_success(self, event):
+        """Assert execute_complete log success message when trigger fire with 
target status."""
+
+        if not event:
+            with pytest.raises(AirflowException) as exception_info:
+                self.SENSOR.execute_complete(context=None, event=None)
+            assert exception_info.value.args[0] == "Did not receive valid 
event from the triggerer"
+        else:
+            with mock.patch.object(self.SENSOR.log, "info") as mock_log_info:
+                self.SENSOR.execute_complete(context={}, event=event)
+            mock_log_info.assert_called_with(event["message"])
+
+    def test_wasb_blob_sensor_execute_complete_failure(self):
+        """Assert execute_complete method raises an exception when the 
triggerer fires an error event."""
+
+        with pytest.raises(AirflowException):
+            self.SENSOR.execute_complete(context={}, event={"status": "error", 
"message": ""})
+
+
 class TestWasbPrefixSensor:
     _config = {
         "container_name": "container",
diff --git 
a/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py 
b/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py
index b2243e8bc8..83da48d985 100644
--- a/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py
+++ b/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py
@@ -21,7 +21,7 @@ import os
 from datetime import datetime
 
 from airflow import DAG
-from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor
+from airflow.providers.microsoft.azure.sensors.wasb import 
WasbBlobAsyncSensor, WasbBlobSensor
 from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import 
AzureBlobStorageToGCSOperator
 
 # Ignore missing args provided by default_args
@@ -46,6 +46,8 @@ with DAG(
 
     wait_for_blob = WasbBlobSensor(task_id="wait_for_blob")
 
+    wait_for_blob_async = WasbBlobAsyncSensor(task_id="wait_for_blob_async")
+
     transfer_files_to_gcs = AzureBlobStorageToGCSOperator(
         task_id="transfer_files_to_gcs",
         # AZURE arg
@@ -60,7 +62,7 @@ with DAG(
     )
     # [END how_to_azure_blob_to_gcs]
 
-    wait_for_blob >> transfer_files_to_gcs
+    wait_for_blob >> wait_for_blob_async >> transfer_files_to_gcs
 
     from tests.system.utils.watcher import watcher
 

Reply via email to