This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new dd1095202b6 Add Azure Service Bus Queue and Subscription triggers for 
async message processing (#53356)
dd1095202b6 is described below

commit dd1095202b6f8c715832e3c6425d63545e88fedc
Author: Ranuga <[email protected]>
AuthorDate: Mon Nov 24 16:12:21 2025 +0530

    Add Azure Service Bus Queue and Subscription triggers for async message 
processing (#53356)
    
    * Add Azure Service Bus integration with triggers and documentation
    
    * Implement Azure Service Bus Queue and Subscription triggers with async 
message processing and unit tests
    
    * feat(Azure): Add methods to read messages from Service Bus queue and 
subscription
    
    * Update 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
    
    Co-authored-by: LIU ZHE YOU <[email protected]>
    
    * Update 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
    
    Co-authored-by: LIU ZHE YOU <[email protected]>
    
    * Update 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
    
    Co-authored-by: LIU ZHE YOU <[email protected]>
    
    * Remove Queue URI format documentation for Azure Service Bus Queue and 
Subscription Providers
    
    * Enhance Azure Service Bus integration by adding new message types and 
updating trigger initialization tests
    
    * Fix Azure Service Bus message trigger test failures
    
    - Simplified mocking to avoid connection and import issues
    - Fixed mock assertions for MessageHook initialization
    - Addressed compatibility issues with older Airflow versions
    - Used context managers for proper mock lifecycle management
    
    * Fix Azure Service Bus trigger test mocking issues
    
      - Replace Mock object string representations with proper Mock(body=value) 
pattern
      - Fix str(message.body) calls returning mock object strings instead of 
expected message content
      - Apply standard Azure provider mocking conventions consistent with other 
tests
      - Remove unnecessary comments and custom MockMessage classes
    
    * Update the Trigger Event Message to Include the Actual Content
    
    * resolve review comments
    
    * chore: change BaseTrigger to BaseEventTrigger
    
    * Update 
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
    
    Co-authored-by: LIU ZHE YOU <[email protected]>
    
    * chore: fix ci/cd
    
    ---------
    
    Co-authored-by: LIU ZHE YOU <[email protected]>
---
 docs/spelling_wordlist.txt                         |   2 +
 .../azure/docs/connections/message_bus.rst         |  99 +++++++
 providers/microsoft/azure/provider.yaml            |   3 +
 .../providers/microsoft/azure/get_provider_info.py |   4 +
 .../airflow/providers/microsoft/azure/hooks/asb.py |  54 ++++
 .../microsoft/azure/triggers/message_bus.py        | 223 ++++++++++++++++
 .../microsoft/azure/triggers/test_message_bus.py   | 286 +++++++++++++++++++++
 7 files changed, 671 insertions(+)

diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 0c7db72e31d..cf0babd72e6 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1623,6 +1623,8 @@ serializer
 serializers
 serverless
 ServiceAccount
+servicebus
+ServiceBusReceivedMessage
 ServicePrincipalCredentials
 ServiceResource
 SES
diff --git a/providers/microsoft/azure/docs/connections/message_bus.rst 
b/providers/microsoft/azure/docs/connections/message_bus.rst
new file mode 100644
index 00000000000..4611c921a3f
--- /dev/null
+++ b/providers/microsoft/azure/docs/connections/message_bus.rst
@@ -0,0 +1,99 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/trigger:azure_service_bus_queue:
+
+Microsoft Azure Service Bus Queue Trigger
+=========================================
+
+The Microsoft Azure Service Bus Queue trigger enables you to monitor Azure 
Service Bus queues for new messages and trigger DAG runs when messages arrive.
+
+Authenticating to Azure
+-----------------------
+
+The trigger uses the connection you have configured in Airflow to connect to 
Azure.
+Please refer to the :ref:`howto/connection:azure_service_bus` documentation 
for details on how to configure your connection.
+
+Using the Trigger
+-----------------
+
+This example shows how to use the ``AzureServiceBusQueueTrigger``.
+
+.. code-block:: python
+
+      from airflow.providers.microsoft.azure.triggers.message_bus import 
AzureServiceBusQueueTrigger
+
+      trigger = AzureServiceBusQueueTrigger(
+          queues=["my_queue"],
+          azure_service_bus_conn_id="azure_service_bus_default",
+          max_message_count=1,
+          poll_interval=5.0,
+      )
+
+      # The trigger will fire when a message is available in the queue.
+
+.. _howto/trigger:azure_service_bus_subscription:
+
+Microsoft Azure Service Bus Subscription Trigger
+================================================
+
+The Microsoft Azure Service Bus Subscription trigger enables you to monitor 
Azure Service Bus topic subscriptions for new messages and trigger DAG runs 
when messages arrive.
+
+Authenticating to Azure
+-----------------------
+
+The trigger uses the connection you have configured in Airflow to connect to 
Azure.
+Please refer to the :ref:`howto/connection:azure_service_bus` documentation 
for details on how to configure your connection.
+
+Using the Trigger
+-----------------
+
+This example shows how to use the ``AzureServiceBusSubscriptionTrigger``.
+
+.. code-block:: python
+
+      from airflow.providers.microsoft.azure.triggers.message_bus import 
AzureServiceBusSubscriptionTrigger
+
+      trigger = AzureServiceBusSubscriptionTrigger(
+          topics=["my_topic"],
+          subscription_name="my_subscription",
+          azure_service_bus_conn_id="azure_service_bus_default",
+          max_message_count=1,
+          poll_interval=5.0,
+      )
+
+      # The trigger will fire when a message is available in the topic 
subscription.
+
+Azure Service Bus Message Queue
+===============================
+
+Azure Service Bus Queue Provider
+--------------------------------
+
+Implemented by 
:class:`~airflow.providers.microsoft.azure.triggers.message_bus.AzureServiceBusQueueTrigger`
+
+The Azure Service Bus Queue Provider is a message queue provider that uses 
Azure Service Bus queues.
+It allows you to monitor Azure Service Bus queues for new messages and trigger 
DAG runs when messages arrive.
+
+
+Azure Service Bus Subscription Provider
+---------------------------------------
+
+Implemented by 
:class:`~airflow.providers.microsoft.azure.triggers.message_bus.AzureServiceBusSubscriptionTrigger`
+
+The Azure Service Bus Subscription Provider is a message queue provider that 
uses Azure Service Bus topic subscriptions.
+It allows you to monitor Azure Service Bus topic subscriptions for new 
messages and trigger DAG runs when messages arrive.
diff --git a/providers/microsoft/azure/provider.yaml 
b/providers/microsoft/azure/provider.yaml
index 0a294fa6600..75a01269dbf 100644
--- a/providers/microsoft/azure/provider.yaml
+++ b/providers/microsoft/azure/provider.yaml
@@ -294,6 +294,9 @@ triggers:
   - integration-name: Microsoft Power BI
     python-modules:
       - airflow.providers.microsoft.azure.triggers.powerbi
+  - integration-name: Microsoft Azure Service Bus
+    python-modules:
+      - airflow.providers.microsoft.azure.triggers.message_bus
 
 transfers:
   - source-integration-name: Local
diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
index ff4c09f0c1b..5c9ef3ab15a 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
@@ -281,6 +281,10 @@ def get_provider_info():
                 "integration-name": "Microsoft Power BI",
                 "python-modules": 
["airflow.providers.microsoft.azure.triggers.powerbi"],
             },
+            {
+                "integration-name": "Microsoft Azure Service Bus",
+                "python-modules": 
["airflow.providers.microsoft.azure.triggers.message_bus"],
+            },
         ],
         "transfers": [
             {
diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
index 5b91adcce62..8dd2d93578e 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
@@ -606,6 +606,60 @@ class MessageHook(BaseAzureServiceBusHook):
             for msg in received_msgs:
                 self._process_message(msg, context, message_callback, 
subscription_receiver)
 
+    def read_message(
+        self,
+        queue_name: str,
+        max_wait_time: float | None = None,
+    ) -> ServiceBusReceivedMessage | None:
+        """
+        Read a single message from a Service Bus queue without callback 
processing.
+
+        :param queue_name: The name of the queue to read from.
+        :param max_wait_time: Maximum time to wait for messages (seconds).
+        :return: The received message or None if no message is available.
+        """
+        with (
+            self.get_conn() as service_bus_client,
+            service_bus_client.get_queue_receiver(queue_name=queue_name) as 
receiver,
+            receiver,
+        ):
+            received_msgs = receiver.receive_messages(max_message_count=1, 
max_wait_time=max_wait_time)
+            if received_msgs:
+                msg = received_msgs[0]
+                receiver.complete_message(msg)
+                return msg
+            return None
+
+    def read_subscription_message(
+        self,
+        topic_name: str,
+        subscription_name: str,
+        max_wait_time: float | None = None,
+    ) -> ServiceBusReceivedMessage | None:
+        """
+        Read a single message from a Service Bus topic subscription without 
callback processing.
+
+        :param topic_name: The name of the topic.
+        :param subscription_name: The name of the subscription.
+        :param max_wait_time: Maximum time to wait for messages (seconds).
+        :return: The received message or None if no message is available.
+        """
+        with (
+            self.get_conn() as service_bus_client,
+            service_bus_client.get_subscription_receiver(
+                topic_name, subscription_name
+            ) as subscription_receiver,
+            subscription_receiver,
+        ):
+            received_msgs = subscription_receiver.receive_messages(
+                max_message_count=1, max_wait_time=max_wait_time
+            )
+            if received_msgs:
+                msg = received_msgs[0]
+                subscription_receiver.complete_message(msg)
+                return msg
+            return None
+
     def _process_message(
         self,
         msg: ServiceBusReceivedMessage,
diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
new file mode 100644
index 00000000000..bafd1e677cf
--- /dev/null
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
@@ -0,0 +1,223 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from abc import abstractmethod
+from collections.abc import AsyncIterator
+from typing import TYPE_CHECKING, Any
+
+from asgiref.sync import sync_to_async
+
+from airflow.providers.microsoft.azure.hooks.asb import MessageHook
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.triggers.base import BaseEventTrigger, TriggerEvent
+else:
+    from airflow.triggers.base import (  # type: ignore
+        BaseTrigger as BaseEventTrigger,
+        TriggerEvent,
+    )
+
+if TYPE_CHECKING:
+    from azure.servicebus import ServiceBusReceivedMessage
+
+
+class BaseAzureServiceBusTrigger(BaseEventTrigger):
+    """
+    Base trigger for Azure Service Bus message processing.
+
+    This trigger provides common functionality for listening to Azure Service 
Bus
+    queues and topics/subscriptions. It handles connection management and
+    async message processing.
+
+    :param poll_interval: Time interval between polling operations (seconds)
+    :param azure_service_bus_conn_id: Connection ID for Azure Service Bus
+    :param max_wait_time: Maximum time to wait for messages (seconds)
+    """
+
+    default_conn_name = "azure_service_bus_default"
+    default_max_wait_time = None
+    default_poll_interval = 60
+
+    def __init__(
+        self,
+        poll_interval: float | None = None,
+        azure_service_bus_conn_id: str | None = None,
+        max_wait_time: float | None = None,
+    ) -> None:
+        self.connection_id = (
+            azure_service_bus_conn_id
+            if azure_service_bus_conn_id
+            else BaseAzureServiceBusTrigger.default_conn_name
+        )
+        self.max_wait_time = (
+            max_wait_time if max_wait_time else 
BaseAzureServiceBusTrigger.default_max_wait_time
+        )
+        self.poll_interval = (
+            poll_interval if poll_interval else 
BaseAzureServiceBusTrigger.default_poll_interval
+        )
+        self.message_hook = 
MessageHook(azure_service_bus_conn_id=self.connection_id)
+
+    @abstractmethod
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize the trigger instance."""
+
+    @abstractmethod
+    def run(self) -> AsyncIterator[TriggerEvent]:
+        """Run the trigger logic."""
+
+    @classmethod
+    def _get_message_body(cls, message: ServiceBusReceivedMessage) -> str:
+        message_body = message.body
+        if isinstance(message_body, bytes):
+            return message_body.decode("utf-8")
+        try:
+            return "".join(chunk.decode("utf-8") for chunk in message_body)
+        except Exception:
+            raise TypeError(f"Expected bytes or an iterator of bytes, but got 
{type(message_body).__name__}")
+
+
+class AzureServiceBusQueueTrigger(BaseAzureServiceBusTrigger):
+    """
+    Trigger for Azure Service Bus Queue message processing.
+
+    This trigger monitors one or more Azure Service Bus queues for incoming 
messages.
+    When messages arrive, they are processed and yielded as trigger events 
that can
+    be consumed by downstream tasks.
+
+    Example:
+        >>> trigger = AzureServiceBusQueueTrigger(
+        ...     queues=["queue1", "queue2"],
+        ...     azure_service_bus_conn_id="my_asb_conn",
+        ...     poll_interval=30,
+        ... )
+
+    :param queues: List of queue names to monitor
+    :param poll_interval: Time interval between polling operations (seconds)
+    :param azure_service_bus_conn_id: Connection ID for Azure Service Bus
+    :param max_wait_time: Maximum time to wait for messages (seconds)
+    """
+
+    def __init__(
+        self,
+        queues: list[str],
+        poll_interval: float | None = None,
+        azure_service_bus_conn_id: str | None = None,
+        max_wait_time: float | None = None,
+    ) -> None:
+        super().__init__(poll_interval, azure_service_bus_conn_id, 
max_wait_time)
+        self.queues = queues
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "azure_service_bus_conn_id": self.connection_id,
+                "queues": self.queues,
+                "poll_interval": self.poll_interval,
+                "max_wait_time": self.max_wait_time,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        read_queue_message_async = 
sync_to_async(self.message_hook.read_message)
+
+        while True:
+            for queue_name in self.queues:
+                message = await read_queue_message_async(
+                    queue_name=queue_name, max_wait_time=self.max_wait_time
+                )
+                if message:
+                    yield TriggerEvent(
+                        {
+                            "message": 
BaseAzureServiceBusTrigger._get_message_body(message),
+                            "queue": queue_name,
+                        }
+                    )
+                    break
+            await asyncio.sleep(self.poll_interval)
+
+
+class AzureServiceBusSubscriptionTrigger(BaseAzureServiceBusTrigger):
+    """
+    Trigger for Azure Service Bus Topic Subscription message processing.
+
+    This trigger monitors topic subscriptions for incoming messages. It can 
handle
+    multiple topics with a single subscription name, processing messages as 
they
+    arrive and yielding them as trigger events.
+
+    Example:
+        >>> trigger = AzureServiceBusSubscriptionTrigger(
+        ...     topics=["topic1", "topic2"],
+        ...     subscription_name="my-subscription",
+        ...     azure_service_bus_conn_id="my_asb_conn",
+        ... )
+
+    :param topics: List of topic names to monitor
+    :param subscription_name: Name of the subscription to use
+    :param poll_interval: Time interval between polling operations (seconds)
+    :param azure_service_bus_conn_id: Connection ID for Azure Service Bus
+    :param max_wait_time: Maximum time to wait for messages (seconds)
+    """
+
+    def __init__(
+        self,
+        topics: list[str],
+        subscription_name: str,
+        poll_interval: float | None = None,
+        azure_service_bus_conn_id: str | None = None,
+        max_wait_time: float | None = None,
+    ) -> None:
+        super().__init__(poll_interval, azure_service_bus_conn_id, 
max_wait_time)
+        self.topics = topics
+        self.subscription_name = subscription_name
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "azure_service_bus_conn_id": self.connection_id,
+                "topics": self.topics,
+                "subscription_name": self.subscription_name,
+                "poll_interval": self.poll_interval,
+                "max_wait_time": self.max_wait_time,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        read_subscription_message_async = 
sync_to_async(self.message_hook.read_subscription_message)
+
+        while True:
+            for topic_name in self.topics:
+                message = await read_subscription_message_async(
+                    topic_name=topic_name,
+                    subscription_name=self.subscription_name,
+                    max_wait_time=self.max_wait_time,
+                )
+                if message:
+                    yield TriggerEvent(
+                        {
+                            "message": 
BaseAzureServiceBusTrigger._get_message_body(message),
+                            "topic": topic_name,
+                            "subscription": self.subscription_name,
+                        }
+                    )
+                    break
+            await asyncio.sleep(self.poll_interval)
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_message_bus.py
 
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_message_bus.py
new file mode 100644
index 00000000000..cdb99b3a423
--- /dev/null
+++ 
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_message_bus.py
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+from airflow.providers.microsoft.azure.triggers.message_bus import (
+    AzureServiceBusQueueTrigger,
+    AzureServiceBusSubscriptionTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+
+class TestBaseAzureServiceBusTrigger:
+    """Test the base trigger functionality."""
+
+    def test_init_with_defaults(self):
+        """Test initialization with default values using queue trigger."""
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusQueueTrigger(queues=["test_queue"])
+
+            assert trigger.max_wait_time is None
+            assert trigger.poll_interval == 60
+            assert hasattr(trigger, "message_hook")
+
+    def test_init_with_custom_values(self):
+        """Test initialization with custom values using queue trigger."""
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusQueueTrigger(
+                queues=["test_queue"],
+                poll_interval=30,
+                azure_service_bus_conn_id="custom_conn",
+                max_wait_time=120,
+            )
+
+            assert trigger.poll_interval == 30
+            assert trigger.max_wait_time == 120
+            assert trigger.connection_id == "custom_conn"
+
+
+class TestAzureServiceBusQueueTrigger:
+    """Test the queue trigger functionality."""
+
+    def test_init(self):
+        """Test queue trigger initialization."""
+        queues = ["queue1", "queue2"]
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusQueueTrigger(
+                queues=queues,
+                azure_service_bus_conn_id="test_conn",
+            )
+
+            assert trigger.queues == queues
+
+    def test_serialize(self):
+        """Test serialization of queue trigger."""
+        queues = ["queue1", "queue2"]
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusQueueTrigger(
+                queues=queues,
+                azure_service_bus_conn_id="test_conn",
+            )
+
+            class_path, config = trigger.serialize()
+
+            assert "AzureServiceBusQueueTrigger" in class_path
+            assert config["queues"] == queues
+            assert "azure_service_bus_conn_id" in config
+
+    @pytest.mark.asyncio
+    async def test_run_with_message(self):
+        """Test the main run method with a mock message as bytes."""
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusQueueTrigger(
+                queues=["test_queue"],
+                poll_interval=0.01,  # Very short for testing
+            )
+
+            mock_message = Mock(body=b"test message")
+            trigger.message_hook.read_message = Mock(return_value=mock_message)
+
+            # Get one event from the generator
+            events = []
+            async for event in trigger.run():
+                events.append(event)
+                if len(events) >= 1:
+                    break
+
+            assert len(events) == 1
+            assert isinstance(events[0], TriggerEvent)
+            assert events[0].payload["message"] == "test message"
+            assert events[0].payload["queue"] == "test_queue"
+
+    @pytest.mark.asyncio
+    async def test_run_with_iterator_message(self):
+        """Test the main run method with a mock message as an iterator."""
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusQueueTrigger(
+                queues=["test_queue"],
+                poll_interval=0.01,  # Very short for testing
+            )
+
+            mock_message = Mock(body=iter([b"test", b" ", b"iterator", b" ", 
b"message"]))
+            trigger.message_hook.read_message = Mock(return_value=mock_message)
+
+            # Get one event from the generator
+            events = []
+            async for event in trigger.run():
+                events.append(event)
+                if len(events) >= 1:
+                    break
+
+            assert len(events) == 1
+            assert isinstance(events[0], TriggerEvent)
+            assert events[0].payload["message"] == "test iterator message"
+            assert events[0].payload["queue"] == "test_queue"
+
+
+class TestAzureServiceBusSubscriptionTrigger:
+    """Test the subscription trigger functionality."""
+
+    def test_init(self):
+        """Test subscription trigger initialization."""
+        topics = ["topic1", "topic2"]
+        subscription = "test-subscription"
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusSubscriptionTrigger(
+                topics=topics,
+                subscription_name=subscription,
+                azure_service_bus_conn_id="test_conn",
+            )
+
+            assert trigger.topics == topics
+            assert trigger.subscription_name == subscription
+
+    def test_serialize(self):
+        """Test serialization of subscription trigger."""
+        topics = ["topic1", "topic2"]
+        subscription = "test-subscription"
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusSubscriptionTrigger(
+                topics=topics,
+                subscription_name=subscription,
+                azure_service_bus_conn_id="test_conn",
+            )
+
+            class_path, config = trigger.serialize()
+
+            assert "AzureServiceBusSubscriptionTrigger" in class_path
+            assert config["topics"] == topics
+            assert config["subscription_name"] == subscription
+
+    @pytest.mark.asyncio
+    async def test_run_subscription_with_message(self):
+        """Test the main run method with a mock message as bytes."""
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusSubscriptionTrigger(
+                topics=["test_topic"],
+                subscription_name="test-sub",
+                poll_interval=0.01,  # Very short for testing
+                azure_service_bus_conn_id="test_conn",
+            )
+
+            mock_message = Mock(body=b"subscription test message")
+            trigger.message_hook.read_subscription_message = 
Mock(return_value=mock_message)
+
+            # Get one event from the generator
+            events = []
+            async for event in trigger.run():
+                events.append(event)
+                if len(events) >= 1:
+                    break
+
+            assert len(events) == 1
+            assert isinstance(events[0], TriggerEvent)
+            assert events[0].payload["message"] == "subscription test message"
+            assert events[0].payload["topic"] == "test_topic"
+            assert events[0].payload["subscription"] == "test-sub"
+
+    @pytest.mark.asyncio
+    async def test_run_subscription_with_iterator_message(self):
+        """Test the main run method with a mock message as an iterator."""
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusSubscriptionTrigger(
+                topics=["test_topic"],
+                subscription_name="test-sub",
+                poll_interval=0.01,  # Very short for testing
+                azure_service_bus_conn_id="test_conn",
+            )
+
+            mock_message = Mock(body=iter([b"iterator", b" ", 
b"subscription"]))
+            trigger.message_hook.read_subscription_message = 
Mock(return_value=mock_message)
+
+            # Get one event from the generator
+            events = []
+            async for event in trigger.run():
+                events.append(event)
+                if len(events) >= 1:
+                    break
+
+            assert len(events) == 1
+            assert isinstance(events[0], TriggerEvent)
+            assert events[0].payload["message"] == "iterator subscription"
+            assert events[0].payload["topic"] == "test_topic"
+            assert events[0].payload["subscription"] == "test-sub"
+
+
+class TestIntegrationScenarios:
+    """Test integration scenarios and edge cases."""
+
+    @pytest.mark.asyncio
+    async def test_multiple_messages_processing(self):
+        """Test processing multiple messages in sequence."""
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusQueueTrigger(
+                queues=["test_queue"],
+                poll_interval=0.01,  # Very short for testing
+            )
+
+            messages_as_str = ["msg1", "msg2", "msg3"]
+            mock_messages = [Mock(body=msg.encode("utf-8")) for msg in 
messages_as_str]
+            trigger.message_hook.read_message = Mock(side_effect=mock_messages 
+ [None])
+
+            # Collect events
+            events = []
+            async for event in trigger.run():
+                events.append(event)
+                if len(events) >= 3:
+                    break
+
+            assert len(events) == 3
+            received_messages = [event.payload["message"] for event in events]
+            assert received_messages == messages_as_str
+
+    def test_queue_trigger_with_empty_queues_list(self):
+        """Test queue trigger with empty queues list."""
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusQueueTrigger(queues=[])
+            assert trigger.queues == []
+
+    def test_subscription_trigger_with_empty_topics_list(self):
+        """Test subscription trigger with empty topics list."""
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+            trigger = AzureServiceBusSubscriptionTrigger(
+                topics=[], subscription_name="test-sub", 
azure_service_bus_conn_id="test_conn"
+            )
+            assert trigger.topics == []
+
+    def test_message_hook_initialization(self):
+        """Test that MessageHook is properly initialized."""
+        conn_id = "test_connection"
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook") as 
mock_hook_class:
+            trigger = AzureServiceBusQueueTrigger(queues=["test"], 
azure_service_bus_conn_id=conn_id)
+
+            # Verify the hook was initialized with the correct connection ID
+            
mock_hook_class.assert_called_once_with(azure_service_bus_conn_id=conn_id)
+            # Also verify the trigger has the message_hook attribute
+            assert hasattr(trigger, "message_hook")
+
+    def test_message_hook_properly_configured(self):
+        """Test that MessageHook is properly configured with connection."""
+        conn_id = "test_connection"
+        with 
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook") as 
mock_hook_class:
+            trigger = AzureServiceBusQueueTrigger(queues=["test"], 
azure_service_bus_conn_id=conn_id)
+
+            # Verify the hook was called with the correct parameters
+            
mock_hook_class.assert_called_once_with(azure_service_bus_conn_id=conn_id)
+            assert hasattr(trigger, "message_hook")
+            # Verify the connection_id is set correctly
+            assert trigger.connection_id == conn_id

Reply via email to