jason810496 commented on code in PR #53356:
URL: https://github.com/apache/airflow/pull/53356#discussion_r2337292918
##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py:
##########
@@ -606,6 +606,74 @@ def receive_subscription_message(
for msg in received_msgs:
self._process_message(msg, context, message_callback,
subscription_receiver)
+ def read_message(
+ self,
+ queue_name: str,
+ max_message_count: int | None = 1,
+ max_wait_time: float | None = None,
+ ) -> ServiceBusReceivedMessage | None:
+ """
+ Read a single message from a Service Bus queue without callback
processing.
+
+ :param queue_name: The name of the queue to read from.
+ :param max_message_count: Maximum number of messages to retrieve
(defaults to 1).
+ :param max_wait_time: Maximum time to wait for messages (seconds).
+ :return: The received message or None if no message is available.
+ """
+ if queue_name is None:
+ raise TypeError("Queue name cannot be None.")
+
+ with (
+ self.get_conn() as service_bus_client,
+ service_bus_client.get_queue_receiver(queue_name=queue_name) as
receiver,
+ receiver,
+ ):
+ received_msgs = receiver.receive_messages(
+ max_message_count=max_message_count or 1,
max_wait_time=max_wait_time
+ )
+ if received_msgs:
+ msg = received_msgs[0]
+ receiver.complete_message(msg)
+ return msg
Review Comment:
Not sure do I understand correctly. Does this means that even we set
`max_message_count` more than `1`, we will just complete `1` message
intentionally?
##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py:
##########
@@ -606,6 +606,74 @@ def receive_subscription_message(
for msg in received_msgs:
self._process_message(msg, context, message_callback,
subscription_receiver)
+ def read_message(
+ self,
+ queue_name: str,
+ max_message_count: int | None = 1,
+ max_wait_time: float | None = None,
+ ) -> ServiceBusReceivedMessage | None:
+ """
+ Read a single message from a Service Bus queue without callback
processing.
+
+ :param queue_name: The name of the queue to read from.
+ :param max_message_count: Maximum number of messages to retrieve
(defaults to 1).
+ :param max_wait_time: Maximum time to wait for messages (seconds).
+ :return: The received message or None if no message is available.
+ """
+ if queue_name is None:
+ raise TypeError("Queue name cannot be None.")
+
+ with (
+ self.get_conn() as service_bus_client,
+ service_bus_client.get_queue_receiver(queue_name=queue_name) as
receiver,
+ receiver,
+ ):
+ received_msgs = receiver.receive_messages(
+ max_message_count=max_message_count or 1,
max_wait_time=max_wait_time
+ )
+ if received_msgs:
+ msg = received_msgs[0]
+ receiver.complete_message(msg)
+ return msg
+ return None
+
+ def read_subscription_message(
+ self,
+ topic_name: str,
+ subscription_name: str,
+ max_message_count: int | None = 1,
+ max_wait_time: float | None = None,
+ ) -> ServiceBusReceivedMessage | None:
+ """
+ Read a single message from a Service Bus topic subscription without
callback processing.
+
+ :param topic_name: The name of the topic.
+ :param subscription_name: The name of the subscription.
+ :param max_message_count: Maximum number of messages to retrieve
(defaults to 1).
+ :param max_wait_time: Maximum time to wait for messages (seconds).
+ :return: The received message or None if no message is available.
+ """
+ if subscription_name is None:
+ raise TypeError("Subscription name cannot be None.")
+ if topic_name is None:
+ raise TypeError("Topic name cannot be None.")
+
+ with (
+ self.get_conn() as service_bus_client,
+ service_bus_client.get_subscription_receiver(
+ topic_name, subscription_name
+ ) as subscription_receiver,
+ subscription_receiver,
+ ):
+ received_msgs = subscription_receiver.receive_messages(
+ max_message_count=max_message_count or 1,
max_wait_time=max_wait_time
+ )
+ if received_msgs:
+ msg = received_msgs[0]
+ subscription_receiver.complete_message(msg)
+ return msg
+ return None
Review Comment:
Same question here.
##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/message_bus.py:
##########
@@ -0,0 +1,196 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from abc import abstractmethod
+from collections.abc import AsyncIterator
+from typing import Any
+
+from asgiref.sync import sync_to_async
+
+from airflow.providers.microsoft.azure.hooks.asb import MessageHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class BaseAzureServiceBusTrigger(BaseTrigger):
+ """
+ Base trigger for Azure Service Bus message processing.
+
+ This trigger provides common functionality for listening to Azure Service
Bus
+ queues and topics/subscriptions. It handles connection management and
+ async message processing.
+
+ :param poll_interval: Time interval between polling operations (seconds)
+ :param azure_service_bus_conn_id: Connection ID for Azure Service Bus
+ :param max_wait_time: Maximum time to wait for messages (seconds)
+ """
+
+ default_conn_name = "azure_service_bus_default"
+ default_max_wait_time = None
+ default_poll_interval = 60
+
+ def __init__(
+ self,
+ poll_interval: float | None = None,
+ azure_service_bus_conn_id: str | None = None,
+ max_wait_time: float | None = None,
+ ) -> None:
+ self.connection_id = (
+ azure_service_bus_conn_id
+ if azure_service_bus_conn_id
+ else BaseAzureServiceBusTrigger.default_conn_name
+ )
+ self.max_wait_time = (
+ max_wait_time if max_wait_time else
BaseAzureServiceBusTrigger.default_max_wait_time
+ )
+ self.poll_interval = (
+ poll_interval if poll_interval else
BaseAzureServiceBusTrigger.default_poll_interval
+ )
+ self.message_hook =
MessageHook(azure_service_bus_conn_id=self.connection_id)
+
+ @abstractmethod
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize the trigger instance."""
+
+ @abstractmethod
+ def run(self) -> AsyncIterator[TriggerEvent]:
+ """Run the trigger logic."""
+
+
+class AzureServiceBusQueueTrigger(BaseAzureServiceBusTrigger):
+ """
+ Trigger for Azure Service Bus Queue message processing.
+
+ This trigger monitors one or more Azure Service Bus queues for incoming
messages.
+ When messages arrive, they are processed and yielded as trigger events
that can
+ be consumed by downstream tasks.
+
+ Example:
+ >>> trigger = AzureServiceBusQueueTrigger(
+ ... queues=["queue1", "queue2"],
+ ... azure_service_bus_conn_id="my_asb_conn",
+ ... poll_interval=30,
+ ... )
+
+ :param queues: List of queue names to monitor
+ :param poll_interval: Time interval between polling operations (seconds)
+ :param azure_service_bus_conn_id: Connection ID for Azure Service Bus
+ :param max_wait_time: Maximum time to wait for messages (seconds)
+ """
+
+ def __init__(
+ self,
+ queues: list[str],
+ poll_interval: float | None = None,
+ azure_service_bus_conn_id: str | None = None,
+ max_wait_time: float | None = None,
+ ) -> None:
+ super().__init__(poll_interval, azure_service_bus_conn_id,
max_wait_time)
+ self.queues = queues
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "azure_service_bus_conn_id": self.connection_id,
+ "queues": self.queues,
+ "poll_interval": self.poll_interval,
+ "max_wait_time": self.max_wait_time,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ read_queue_message_async =
sync_to_async(self.message_hook.read_message)
+
+ while True:
+ for queue_name in self.queues:
+ message = await read_queue_message_async(
+ queue_name=queue_name, max_wait_time=self.max_wait_time
+ )
+ if message:
+ yield TriggerEvent({"message": str(message.body), "queue":
queue_name})
Review Comment:
Are we able to deserialize back the `str(message.body)` ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]