This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new dd1095202b6 Add Azure Service Bus Queue and Subscription triggers for
async message processing (#53356)
dd1095202b6 is described below
commit dd1095202b6f8c715832e3c6425d63545e88fedc
Author: Ranuga <[email protected]>
AuthorDate: Mon Nov 24 16:12:21 2025 +0530
Add Azure Service Bus Queue and Subscription triggers for async message
processing (#53356)
* Add Azure Service Bus integration with triggers and documentation
* Implement Azure Service Bus Queue and Subscription triggers with async
message processing and unit tests
* feat(Azure): Add methods to read messages from Service Bus queue and
subscription
* Update
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
Co-authored-by: LIU ZHE YOU <[email protected]>
* Update
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
Co-authored-by: LIU ZHE YOU <[email protected]>
* Update
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
Co-authored-by: LIU ZHE YOU <[email protected]>
* Remove Queue URI format documentation for Azure Service Bus Queue and
Subscription Providers
* Enhance Azure Service Bus integration by adding new message types and
updating trigger initialization tests
* Fix Azure Service Bus message trigger test failures
- Simplified mocking to avoid connection and import issues
- Fixed mock assertions for MessageHook initialization
- Addressed compatibility issues with older Airflow versions
- Used context managers for proper mock lifecycle management
* Fix Azure Service Bus trigger test mocking issues
- Replace Mock object string representations with proper Mock(body=value)
pattern
- Fix str(message.body) calls returning mock object strings instead of
expected message content
- Apply standard Azure provider mocking conventions consistent with other
tests
- Remove unnecessary comments and custom MockMessage classes
* Update the Trigger Event Message to Include the Actual Content
* resolve review comments
* chore: change BaseTrigger to BaseEventTrigger
* Update
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
Co-authored-by: LIU ZHE YOU <[email protected]>
* chore: fix ci/cd
---------
Co-authored-by: LIU ZHE YOU <[email protected]>
---
docs/spelling_wordlist.txt | 2 +
.../azure/docs/connections/message_bus.rst | 99 +++++++
providers/microsoft/azure/provider.yaml | 3 +
.../providers/microsoft/azure/get_provider_info.py | 4 +
.../airflow/providers/microsoft/azure/hooks/asb.py | 54 ++++
.../microsoft/azure/triggers/message_bus.py | 223 ++++++++++++++++
.../microsoft/azure/triggers/test_message_bus.py | 286 +++++++++++++++++++++
7 files changed, 671 insertions(+)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 0c7db72e31d..cf0babd72e6 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1623,6 +1623,8 @@ serializer
serializers
serverless
ServiceAccount
+servicebus
+ServiceBusReceivedMessage
ServicePrincipalCredentials
ServiceResource
SES
diff --git a/providers/microsoft/azure/docs/connections/message_bus.rst
b/providers/microsoft/azure/docs/connections/message_bus.rst
new file mode 100644
index 00000000000..4611c921a3f
--- /dev/null
+++ b/providers/microsoft/azure/docs/connections/message_bus.rst
@@ -0,0 +1,99 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/trigger:azure_service_bus_queue:
+
+Microsoft Azure Service Bus Queue Trigger
+=========================================
+
+The Microsoft Azure Service Bus Queue trigger enables you to monitor Azure
Service Bus queues for new messages and trigger DAG runs when messages arrive.
+
+Authenticating to Azure
+-----------------------
+
+The trigger uses the connection you have configured in Airflow to connect to
Azure.
+Please refer to the :ref:`howto/connection:azure_service_bus` documentation
for details on how to configure your connection.
+
+Using the Trigger
+-----------------
+
+This example shows how to use the ``AzureServiceBusQueueTrigger``.
+
+.. code-block:: python
+
+ from airflow.providers.microsoft.azure.triggers.message_bus import
AzureServiceBusQueueTrigger
+
+ trigger = AzureServiceBusQueueTrigger(
+ queues=["my_queue"],
+ azure_service_bus_conn_id="azure_service_bus_default",
+ max_message_count=1,
+ poll_interval=5.0,
+ )
+
+ # The trigger will fire when a message is available in the queue.
+
+.. _howto/trigger:azure_service_bus_subscription:
+
+Microsoft Azure Service Bus Subscription Trigger
+================================================
+
+The Microsoft Azure Service Bus Subscription trigger enables you to monitor
Azure Service Bus topic subscriptions for new messages and trigger DAG runs
when messages arrive.
+
+Authenticating to Azure
+-----------------------
+
+The trigger uses the connection you have configured in Airflow to connect to
Azure.
+Please refer to the :ref:`howto/connection:azure_service_bus` documentation
for details on how to configure your connection.
+
+Using the Trigger
+-----------------
+
+This example shows how to use the ``AzureServiceBusSubscriptionTrigger``.
+
+.. code-block:: python
+
+ from airflow.providers.microsoft.azure.triggers.message_bus import
AzureServiceBusSubscriptionTrigger
+
+ trigger = AzureServiceBusSubscriptionTrigger(
+ topics=["my_topic"],
+ subscription_name="my_subscription",
+ azure_service_bus_conn_id="azure_service_bus_default",
+ max_message_count=1,
+ poll_interval=5.0,
+ )
+
+ # The trigger will fire when a message is available in the topic
subscription.
+
+Azure Service Bus Message Queue
+===============================
+
+Azure Service Bus Queue Provider
+--------------------------------
+
+Implemented by
:class:`~airflow.providers.microsoft.azure.triggers.message_bus.AzureServiceBusQueueTrigger`
+
+The Azure Service Bus Queue Provider is a message queue provider that uses
Azure Service Bus queues.
+It allows you to monitor Azure Service Bus queues for new messages and trigger
DAG runs when messages arrive.
+
+
+Azure Service Bus Subscription Provider
+---------------------------------------
+
+Implemented by
:class:`~airflow.providers.microsoft.azure.triggers.message_bus.AzureServiceBusSubscriptionTrigger`
+
+The Azure Service Bus Subscription Provider is a message queue provider that
uses Azure Service Bus topic subscriptions.
+It allows you to monitor Azure Service Bus topic subscriptions for new
messages and trigger DAG runs when messages arrive.
diff --git a/providers/microsoft/azure/provider.yaml
b/providers/microsoft/azure/provider.yaml
index 0a294fa6600..75a01269dbf 100644
--- a/providers/microsoft/azure/provider.yaml
+++ b/providers/microsoft/azure/provider.yaml
@@ -294,6 +294,9 @@ triggers:
- integration-name: Microsoft Power BI
python-modules:
- airflow.providers.microsoft.azure.triggers.powerbi
+ - integration-name: Microsoft Azure Service Bus
+ python-modules:
+ - airflow.providers.microsoft.azure.triggers.message_bus
transfers:
- source-integration-name: Local
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
index ff4c09f0c1b..5c9ef3ab15a 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py
@@ -281,6 +281,10 @@ def get_provider_info():
"integration-name": "Microsoft Power BI",
"python-modules":
["airflow.providers.microsoft.azure.triggers.powerbi"],
},
+ {
+ "integration-name": "Microsoft Azure Service Bus",
+ "python-modules":
["airflow.providers.microsoft.azure.triggers.message_bus"],
+ },
],
"transfers": [
{
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
index 5b91adcce62..8dd2d93578e 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
@@ -606,6 +606,60 @@ class MessageHook(BaseAzureServiceBusHook):
for msg in received_msgs:
self._process_message(msg, context, message_callback,
subscription_receiver)
+ def read_message(
+ self,
+ queue_name: str,
+ max_wait_time: float | None = None,
+ ) -> ServiceBusReceivedMessage | None:
+ """
+ Read a single message from a Service Bus queue without callback
processing.
+
+ :param queue_name: The name of the queue to read from.
+ :param max_wait_time: Maximum time to wait for messages (seconds).
+ :return: The received message or None if no message is available.
+ """
+ with (
+ self.get_conn() as service_bus_client,
+ service_bus_client.get_queue_receiver(queue_name=queue_name) as
receiver,
+ receiver,
+ ):
+ received_msgs = receiver.receive_messages(max_message_count=1,
max_wait_time=max_wait_time)
+ if received_msgs:
+ msg = received_msgs[0]
+ receiver.complete_message(msg)
+ return msg
+ return None
+
+ def read_subscription_message(
+ self,
+ topic_name: str,
+ subscription_name: str,
+ max_wait_time: float | None = None,
+ ) -> ServiceBusReceivedMessage | None:
+ """
+ Read a single message from a Service Bus topic subscription without
callback processing.
+
+ :param topic_name: The name of the topic.
+ :param subscription_name: The name of the subscription.
+ :param max_wait_time: Maximum time to wait for messages (seconds).
+ :return: The received message or None if no message is available.
+ """
+ with (
+ self.get_conn() as service_bus_client,
+ service_bus_client.get_subscription_receiver(
+ topic_name, subscription_name
+ ) as subscription_receiver,
+ subscription_receiver,
+ ):
+ received_msgs = subscription_receiver.receive_messages(
+ max_message_count=1, max_wait_time=max_wait_time
+ )
+ if received_msgs:
+ msg = received_msgs[0]
+ subscription_receiver.complete_message(msg)
+ return msg
+ return None
+
def _process_message(
self,
msg: ServiceBusReceivedMessage,
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
new file mode 100644
index 00000000000..bafd1e677cf
--- /dev/null
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py
@@ -0,0 +1,223 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from abc import abstractmethod
+from collections.abc import AsyncIterator
+from typing import TYPE_CHECKING, Any
+
+from asgiref.sync import sync_to_async
+
+from airflow.providers.microsoft.azure.hooks.asb import MessageHook
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.triggers.base import BaseEventTrigger, TriggerEvent
+else:
+ from airflow.triggers.base import ( # type: ignore
+ BaseTrigger as BaseEventTrigger,
+ TriggerEvent,
+ )
+
+if TYPE_CHECKING:
+ from azure.servicebus import ServiceBusReceivedMessage
+
+
+class BaseAzureServiceBusTrigger(BaseEventTrigger):
+ """
+ Base trigger for Azure Service Bus message processing.
+
+ This trigger provides common functionality for listening to Azure Service
Bus
+ queues and topics/subscriptions. It handles connection management and
+ async message processing.
+
+ :param poll_interval: Time interval between polling operations (seconds)
+ :param azure_service_bus_conn_id: Connection ID for Azure Service Bus
+ :param max_wait_time: Maximum time to wait for messages (seconds)
+ """
+
+ default_conn_name = "azure_service_bus_default"
+ default_max_wait_time = None
+ default_poll_interval = 60
+
+ def __init__(
+ self,
+ poll_interval: float | None = None,
+ azure_service_bus_conn_id: str | None = None,
+ max_wait_time: float | None = None,
+ ) -> None:
+ self.connection_id = (
+ azure_service_bus_conn_id
+ if azure_service_bus_conn_id
+ else BaseAzureServiceBusTrigger.default_conn_name
+ )
+ self.max_wait_time = (
+ max_wait_time if max_wait_time else
BaseAzureServiceBusTrigger.default_max_wait_time
+ )
+ self.poll_interval = (
+ poll_interval if poll_interval else
BaseAzureServiceBusTrigger.default_poll_interval
+ )
+ self.message_hook =
MessageHook(azure_service_bus_conn_id=self.connection_id)
+
+ @abstractmethod
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize the trigger instance."""
+
+ @abstractmethod
+ def run(self) -> AsyncIterator[TriggerEvent]:
+ """Run the trigger logic."""
+
+ @classmethod
+ def _get_message_body(cls, message: ServiceBusReceivedMessage) -> str:
+ message_body = message.body
+ if isinstance(message_body, bytes):
+ return message_body.decode("utf-8")
+ try:
+ return "".join(chunk.decode("utf-8") for chunk in message_body)
+ except Exception:
+ raise TypeError(f"Expected bytes or an iterator of bytes, but got
{type(message_body).__name__}")
+
+
+class AzureServiceBusQueueTrigger(BaseAzureServiceBusTrigger):
+ """
+ Trigger for Azure Service Bus Queue message processing.
+
+ This trigger monitors one or more Azure Service Bus queues for incoming
messages.
+ When messages arrive, they are processed and yielded as trigger events
that can
+ be consumed by downstream tasks.
+
+ Example:
+ >>> trigger = AzureServiceBusQueueTrigger(
+ ... queues=["queue1", "queue2"],
+ ... azure_service_bus_conn_id="my_asb_conn",
+ ... poll_interval=30,
+ ... )
+
+ :param queues: List of queue names to monitor
+ :param poll_interval: Time interval between polling operations (seconds)
+ :param azure_service_bus_conn_id: Connection ID for Azure Service Bus
+ :param max_wait_time: Maximum time to wait for messages (seconds)
+ """
+
+ def __init__(
+ self,
+ queues: list[str],
+ poll_interval: float | None = None,
+ azure_service_bus_conn_id: str | None = None,
+ max_wait_time: float | None = None,
+ ) -> None:
+ super().__init__(poll_interval, azure_service_bus_conn_id,
max_wait_time)
+ self.queues = queues
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "azure_service_bus_conn_id": self.connection_id,
+ "queues": self.queues,
+ "poll_interval": self.poll_interval,
+ "max_wait_time": self.max_wait_time,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ read_queue_message_async =
sync_to_async(self.message_hook.read_message)
+
+ while True:
+ for queue_name in self.queues:
+ message = await read_queue_message_async(
+ queue_name=queue_name, max_wait_time=self.max_wait_time
+ )
+ if message:
+ yield TriggerEvent(
+ {
+ "message":
BaseAzureServiceBusTrigger._get_message_body(message),
+ "queue": queue_name,
+ }
+ )
+ break
+ await asyncio.sleep(self.poll_interval)
+
+
+class AzureServiceBusSubscriptionTrigger(BaseAzureServiceBusTrigger):
+ """
+ Trigger for Azure Service Bus Topic Subscription message processing.
+
+ This trigger monitors topic subscriptions for incoming messages. It can
handle
+ multiple topics with a single subscription name, processing messages as
they
+ arrive and yielding them as trigger events.
+
+ Example:
+ >>> trigger = AzureServiceBusSubscriptionTrigger(
+ ... topics=["topic1", "topic2"],
+ ... subscription_name="my-subscription",
+ ... azure_service_bus_conn_id="my_asb_conn",
+ ... )
+
+ :param topics: List of topic names to monitor
+ :param subscription_name: Name of the subscription to use
+ :param poll_interval: Time interval between polling operations (seconds)
+ :param azure_service_bus_conn_id: Connection ID for Azure Service Bus
+ :param max_wait_time: Maximum time to wait for messages (seconds)
+ """
+
+ def __init__(
+ self,
+ topics: list[str],
+ subscription_name: str,
+ poll_interval: float | None = None,
+ azure_service_bus_conn_id: str | None = None,
+ max_wait_time: float | None = None,
+ ) -> None:
+ super().__init__(poll_interval, azure_service_bus_conn_id,
max_wait_time)
+ self.topics = topics
+ self.subscription_name = subscription_name
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "azure_service_bus_conn_id": self.connection_id,
+ "topics": self.topics,
+ "subscription_name": self.subscription_name,
+ "poll_interval": self.poll_interval,
+ "max_wait_time": self.max_wait_time,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ read_subscription_message_async =
sync_to_async(self.message_hook.read_subscription_message)
+
+ while True:
+ for topic_name in self.topics:
+ message = await read_subscription_message_async(
+ topic_name=topic_name,
+ subscription_name=self.subscription_name,
+ max_wait_time=self.max_wait_time,
+ )
+ if message:
+ yield TriggerEvent(
+ {
+ "message":
BaseAzureServiceBusTrigger._get_message_body(message),
+ "topic": topic_name,
+ "subscription": self.subscription_name,
+ }
+ )
+ break
+ await asyncio.sleep(self.poll_interval)
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_message_bus.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_message_bus.py
new file mode 100644
index 00000000000..cdb99b3a423
--- /dev/null
+++
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_message_bus.py
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+from airflow.providers.microsoft.azure.triggers.message_bus import (
+ AzureServiceBusQueueTrigger,
+ AzureServiceBusSubscriptionTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+
+class TestBaseAzureServiceBusTrigger:
+ """Test the base trigger functionality."""
+
+ def test_init_with_defaults(self):
+ """Test initialization with default values using queue trigger."""
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusQueueTrigger(queues=["test_queue"])
+
+ assert trigger.max_wait_time is None
+ assert trigger.poll_interval == 60
+ assert hasattr(trigger, "message_hook")
+
+ def test_init_with_custom_values(self):
+ """Test initialization with custom values using queue trigger."""
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusQueueTrigger(
+ queues=["test_queue"],
+ poll_interval=30,
+ azure_service_bus_conn_id="custom_conn",
+ max_wait_time=120,
+ )
+
+ assert trigger.poll_interval == 30
+ assert trigger.max_wait_time == 120
+ assert trigger.connection_id == "custom_conn"
+
+
+class TestAzureServiceBusQueueTrigger:
+ """Test the queue trigger functionality."""
+
+ def test_init(self):
+ """Test queue trigger initialization."""
+ queues = ["queue1", "queue2"]
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusQueueTrigger(
+ queues=queues,
+ azure_service_bus_conn_id="test_conn",
+ )
+
+ assert trigger.queues == queues
+
+ def test_serialize(self):
+ """Test serialization of queue trigger."""
+ queues = ["queue1", "queue2"]
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusQueueTrigger(
+ queues=queues,
+ azure_service_bus_conn_id="test_conn",
+ )
+
+ class_path, config = trigger.serialize()
+
+ assert "AzureServiceBusQueueTrigger" in class_path
+ assert config["queues"] == queues
+ assert "azure_service_bus_conn_id" in config
+
+ @pytest.mark.asyncio
+ async def test_run_with_message(self):
+ """Test the main run method with a mock message as bytes."""
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusQueueTrigger(
+ queues=["test_queue"],
+ poll_interval=0.01, # Very short for testing
+ )
+
+ mock_message = Mock(body=b"test message")
+ trigger.message_hook.read_message = Mock(return_value=mock_message)
+
+ # Get one event from the generator
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+ if len(events) >= 1:
+ break
+
+ assert len(events) == 1
+ assert isinstance(events[0], TriggerEvent)
+ assert events[0].payload["message"] == "test message"
+ assert events[0].payload["queue"] == "test_queue"
+
+ @pytest.mark.asyncio
+ async def test_run_with_iterator_message(self):
+ """Test the main run method with a mock message as an iterator."""
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusQueueTrigger(
+ queues=["test_queue"],
+ poll_interval=0.01, # Very short for testing
+ )
+
+ mock_message = Mock(body=iter([b"test", b" ", b"iterator", b" ",
b"message"]))
+ trigger.message_hook.read_message = Mock(return_value=mock_message)
+
+ # Get one event from the generator
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+ if len(events) >= 1:
+ break
+
+ assert len(events) == 1
+ assert isinstance(events[0], TriggerEvent)
+ assert events[0].payload["message"] == "test iterator message"
+ assert events[0].payload["queue"] == "test_queue"
+
+
+class TestAzureServiceBusSubscriptionTrigger:
+ """Test the subscription trigger functionality."""
+
+ def test_init(self):
+ """Test subscription trigger initialization."""
+ topics = ["topic1", "topic2"]
+ subscription = "test-subscription"
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusSubscriptionTrigger(
+ topics=topics,
+ subscription_name=subscription,
+ azure_service_bus_conn_id="test_conn",
+ )
+
+ assert trigger.topics == topics
+ assert trigger.subscription_name == subscription
+
+ def test_serialize(self):
+ """Test serialization of subscription trigger."""
+ topics = ["topic1", "topic2"]
+ subscription = "test-subscription"
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusSubscriptionTrigger(
+ topics=topics,
+ subscription_name=subscription,
+ azure_service_bus_conn_id="test_conn",
+ )
+
+ class_path, config = trigger.serialize()
+
+ assert "AzureServiceBusSubscriptionTrigger" in class_path
+ assert config["topics"] == topics
+ assert config["subscription_name"] == subscription
+
+ @pytest.mark.asyncio
+ async def test_run_subscription_with_message(self):
+ """Test the main run method with a mock message as bytes."""
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusSubscriptionTrigger(
+ topics=["test_topic"],
+ subscription_name="test-sub",
+ poll_interval=0.01, # Very short for testing
+ azure_service_bus_conn_id="test_conn",
+ )
+
+ mock_message = Mock(body=b"subscription test message")
+ trigger.message_hook.read_subscription_message =
Mock(return_value=mock_message)
+
+ # Get one event from the generator
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+ if len(events) >= 1:
+ break
+
+ assert len(events) == 1
+ assert isinstance(events[0], TriggerEvent)
+ assert events[0].payload["message"] == "subscription test message"
+ assert events[0].payload["topic"] == "test_topic"
+ assert events[0].payload["subscription"] == "test-sub"
+
+ @pytest.mark.asyncio
+ async def test_run_subscription_with_iterator_message(self):
+ """Test the main run method with a mock message as an iterator."""
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusSubscriptionTrigger(
+ topics=["test_topic"],
+ subscription_name="test-sub",
+ poll_interval=0.01, # Very short for testing
+ azure_service_bus_conn_id="test_conn",
+ )
+
+ mock_message = Mock(body=iter([b"iterator", b" ",
b"subscription"]))
+ trigger.message_hook.read_subscription_message =
Mock(return_value=mock_message)
+
+ # Get one event from the generator
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+ if len(events) >= 1:
+ break
+
+ assert len(events) == 1
+ assert isinstance(events[0], TriggerEvent)
+ assert events[0].payload["message"] == "iterator subscription"
+ assert events[0].payload["topic"] == "test_topic"
+ assert events[0].payload["subscription"] == "test-sub"
+
+
+class TestIntegrationScenarios:
+ """Test integration scenarios and edge cases."""
+
+ @pytest.mark.asyncio
+ async def test_multiple_messages_processing(self):
+ """Test processing multiple messages in sequence."""
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusQueueTrigger(
+ queues=["test_queue"],
+ poll_interval=0.01, # Very short for testing
+ )
+
+ messages_as_str = ["msg1", "msg2", "msg3"]
+ mock_messages = [Mock(body=msg.encode("utf-8")) for msg in
messages_as_str]
+ trigger.message_hook.read_message = Mock(side_effect=mock_messages
+ [None])
+
+ # Collect events
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+ if len(events) >= 3:
+ break
+
+ assert len(events) == 3
+ received_messages = [event.payload["message"] for event in events]
+ assert received_messages == messages_as_str
+
+ def test_queue_trigger_with_empty_queues_list(self):
+ """Test queue trigger with empty queues list."""
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusQueueTrigger(queues=[])
+ assert trigger.queues == []
+
+ def test_subscription_trigger_with_empty_topics_list(self):
+ """Test subscription trigger with empty topics list."""
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook"):
+ trigger = AzureServiceBusSubscriptionTrigger(
+ topics=[], subscription_name="test-sub",
azure_service_bus_conn_id="test_conn"
+ )
+ assert trigger.topics == []
+
+ def test_message_hook_initialization(self):
+ """Test that MessageHook is properly initialized."""
+ conn_id = "test_connection"
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook") as
mock_hook_class:
+ trigger = AzureServiceBusQueueTrigger(queues=["test"],
azure_service_bus_conn_id=conn_id)
+
+ # Verify the hook was initialized with the correct connection ID
+
mock_hook_class.assert_called_once_with(azure_service_bus_conn_id=conn_id)
+ # Also verify the trigger has the message_hook attribute
+ assert hasattr(trigger, "message_hook")
+
+ def test_message_hook_properly_configured(self):
+ """Test that MessageHook is properly configured with connection."""
+ conn_id = "test_connection"
+ with
patch("airflow.providers.microsoft.azure.triggers.message_bus.MessageHook") as
mock_hook_class:
+ trigger = AzureServiceBusQueueTrigger(queues=["test"],
azure_service_bus_conn_id=conn_id)
+
+ # Verify the hook was called with the correct parameters
+
mock_hook_class.assert_called_once_with(azure_service_bus_conn_id=conn_id)
+ assert hasattr(trigger, "message_hook")
+ # Verify the connection_id is set correctly
+ assert trigger.connection_id == conn_id