This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 88c20aeb5bb add message_id, reply_to, and message_headers to send 
message operator (#47522)
88c20aeb5bb is described below

commit 88c20aeb5bbebb0c4363e73aec4085595d72ab25
Author: perry2of5 <[email protected]>
AuthorDate: Sun Mar 30 20:30:33 2025 -0700

    add message_id, reply_to, and message_headers to send message operator 
(#47522)
---
 .../airflow/providers/microsoft/azure/hooks/asb.py | 51 +++++++++++-----
 .../providers/microsoft/azure/operators/asb.py     | 16 ++++-
 .../tests/unit/microsoft/azure/hooks/test_asb.py   | 37 ++++++++++++
 .../unit/microsoft/azure/operators/test_asb.py     | 68 +++++++++++++++-------
 4 files changed, 135 insertions(+), 37 deletions(-)

diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
index d5cb666411e..c98c99f928e 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/asb.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 from typing import TYPE_CHECKING, Any, Callable
-from uuid import uuid4
+from uuid import UUID, uuid4
 
 from azure.core.exceptions import ResourceNotFoundError
 from azure.servicebus import (
@@ -468,7 +468,15 @@ class MessageHook(BaseAzureServiceBusHook):
         self.log.info("Create and returns ServiceBusClient")
         return client
 
-    def send_message(self, queue_name: str, messages: str | list[str], 
batch_message_flag: bool = False):
+    def send_message(
+        self,
+        queue_name: str,
+        messages: str | list[str],
+        batch_message_flag: bool = False,
+        message_id: str | None = None,
+        reply_to: str | None = None,
+        message_headers: dict[str | bytes, int | float | bytes | bool | str | 
UUID] | None = None,
+    ):
         """
         Use ServiceBusClient Send to send message(s) to a Service Bus Queue.
 
@@ -478,38 +486,49 @@ class MessageHook(BaseAzureServiceBusHook):
         :param messages: Message which needs to be sent to the queue. It can 
be string or list of string.
         :param batch_message_flag: bool flag, can be set to True if message 
needs to be
             sent as batch message.
+        :param message_id: Message ID to set on message being sent to the 
queue. Please note, message_id may only be
+            set when a single message is sent.
+        :param reply_to: Reply to which needs to be sent to the queue.
+        :param message_headers: Headers to add to the message's 
application_properties field for Azure Service Bus.
         """
         if queue_name is None:
             raise TypeError("Queue name cannot be None.")
         if not messages:
             raise ValueError("Messages list cannot be empty.")
+        if message_id and isinstance(messages, list) and len(messages) != 1:
+            raise TypeError("Message ID can only be set if a single message is 
sent.")
         with (
             self.get_conn() as service_bus_client,
             service_bus_client.get_queue_sender(queue_name=queue_name) as 
sender,
             sender,
         ):
-            if isinstance(messages, str):
-                if not batch_message_flag:
-                    msg = ServiceBusMessage(messages)
-                    sender.send_messages(msg)
-                else:
-                    self.send_batch_message(sender, [messages])
+            message_creator = lambda msg_body: ServiceBusMessage(
+                msg_body, message_id=message_id, reply_to=reply_to, 
application_properties=message_headers
+            )
+            message_list = [messages] if isinstance(messages, str) else 
messages
+            if not batch_message_flag:
+                self.send_list_messages(sender, message_list, message_creator)
             else:
-                if not batch_message_flag:
-                    self.send_list_messages(sender, messages)
-                else:
-                    self.send_batch_message(sender, messages)
+                self.send_batch_message(sender, message_list, message_creator)
 
     @staticmethod
-    def send_list_messages(sender: ServiceBusSender, messages: list[str]):
-        list_messages = [ServiceBusMessage(message) for message in messages]
+    def send_list_messages(
+        sender: ServiceBusSender,
+        messages: list[str],
+        message_creator: Callable[[str], ServiceBusMessage],
+    ):
+        list_messages = [message_creator(body) for body in messages]
         sender.send_messages(list_messages)  # type: ignore[arg-type]
 
     @staticmethod
-    def send_batch_message(sender: ServiceBusSender, messages: list[str]):
+    def send_batch_message(
+        sender: ServiceBusSender,
+        messages: list[str],
+        message_creator: Callable[[str], ServiceBusMessage],
+    ):
         batch_message = sender.create_message_batch()
         for message in messages:
-            batch_message.add_message(ServiceBusMessage(message))
+            batch_message.add_message(message_creator(message))
         sender.send_messages(batch_message)
 
     def receive_message(
diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/asb.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/asb.py
index aa8eecb3f24..dfd0147c9c8 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/asb.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/asb.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 from collections.abc import Sequence
 from typing import TYPE_CHECKING, Any, Callable
+from uuid import UUID
 
 from airflow.models import BaseOperator
 from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, 
MessageHook
@@ -100,6 +101,11 @@ class AzureServiceBusSendMessageOperator(BaseOperator):
         as batch message it can be set to True.
     :param azure_service_bus_conn_id: Reference to the
         :ref: `Azure Service Bus 
connection<howto/connection:azure_service_bus>`.
+    :param message_id: Message ID to set on message being sent to the queue. 
Please note, message_id may only be
+        set when a single message is sent.
+    :param reply_to: Name of queue or topic the receiver should reply to. 
Determination of if the reply will be sent to
+        a queue or a topic should be made out-of-band.
+    :param message_headers: Headers to add to the message's 
application_properties field for Azure Service Bus.
     """
 
     template_fields: Sequence[str] = ("queue_name",)
@@ -112,6 +118,9 @@ class AzureServiceBusSendMessageOperator(BaseOperator):
         message: str | list[str],
         batch: bool = False,
         azure_service_bus_conn_id: str = "azure_service_bus_default",
+        message_id: str | None = None,
+        reply_to: str | None = None,
+        message_headers: dict[str | bytes, int | float | bytes | bool | str | 
UUID] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -119,6 +128,9 @@ class AzureServiceBusSendMessageOperator(BaseOperator):
         self.batch = batch
         self.message = message
         self.azure_service_bus_conn_id = azure_service_bus_conn_id
+        self.message_id = message_id
+        self.reply_to = reply_to
+        self.message_headers = message_headers
 
     def execute(self, context: Context) -> None:
         """Send Message to the specific queue in Service Bus namespace."""
@@ -126,7 +138,9 @@ class AzureServiceBusSendMessageOperator(BaseOperator):
         hook = 
MessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)
 
         # send message
-        hook.send_message(self.queue_name, self.message, self.batch)
+        hook.send_message(
+            self.queue_name, self.message, self.batch, self.message_id, 
self.reply_to, self.message_headers
+        )
 
 
 class AzureServiceBusReceiveMessageOperator(BaseOperator):
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_asb.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_asb.py
index 7953facfdb8..a55b4b997fb 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_asb.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_asb.py
@@ -337,6 +337,43 @@ class TestMessageHook:
         ]
         mock_sb_client.assert_has_calls(expected_calls, any_order=False)
 
+    @mock.patch(f"{MODULE}.MessageHook.get_conn", autospec=True)
+    @mock.patch("azure.servicebus.ServiceBusSender", autospec=True)
+    def test_send_message_with_id_reply_to_and_headers(self, mock_q_sender, 
mock_sb_client):
+        """
+        Test `send_message` hook function with batch flag and message passed 
as mocked params,
+        which can be string or list of string, mock the azure service bus 
`send_messages` function
+        """
+        sent_messages = []
+
+        def mock_send_messages(messages):
+            nonlocal sent_messages
+            sent_messages.extend(messages)
+
+        
mock_sb_client.return_value.__enter__.return_value.get_queue_sender.return_value.__enter__.return_value
 = mock_q_sender
+        mock_q_sender.send_messages.side_effect = mock_send_messages
+
+        MSG_ID = "test_msg_id"
+        REPLY_TO = "test_reply_to"
+        HEADERS = {"test-key": "test-value"}
+        hook = 
MessageHook(azure_service_bus_conn_id="azure_service_bus_default")
+        hook.send_message(
+            queue_name="test_queue",
+            messages=MESSAGE,
+            batch_message_flag=False,
+            message_id=MSG_ID,
+            reply_to=REPLY_TO,
+            message_headers=HEADERS,
+        )
+
+        mock_q_sender.send_messages.assert_called_once()
+
+        assert len(sent_messages) == 1
+        assert str(sent_messages[0]) == MESSAGE
+        assert sent_messages[0].message_id == MSG_ID
+        assert sent_messages[0].reply_to == REPLY_TO
+        assert sent_messages[0].application_properties == HEADERS
+
     @mock.patch(f"{MODULE}.MessageHook.get_conn")
     def test_send_message_exception(self, mock_sb_client):
         """
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_asb.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_asb.py
index 667f4097198..8cbca431cc4 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_asb.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_asb.py
@@ -131,52 +131,80 @@ class TestAzureServiceBusDeleteQueueOperator:
 
 class TestAzureServiceBusSendMessageOperator:
     @pytest.mark.parametrize(
-        "mock_message, mock_batch_flag",
+        "mock_message, mock_batch_flag, mock_message_id, mock_reply_to, 
mock_headers",
         [
-            (MESSAGE, True),
-            (MESSAGE, False),
-            (MESSAGE_LIST, True),
-            (MESSAGE_LIST, False),
+            (MESSAGE, True, None, None, None),
+            (MESSAGE, False, "test_message_id", "test_reply_to", 
{"test_header": "test_value"}),
+            (MESSAGE_LIST, True, None, None, None),
+            (MESSAGE_LIST, False, None, None, None),
         ],
     )
-    def test_init(self, mock_message, mock_batch_flag):
+    def test_init(self, mock_message, mock_batch_flag, mock_message_id, 
mock_reply_to, mock_headers):
         """
         Test init by creating AzureServiceBusSendMessageOperator with task id, 
queue_name, message,
-        batch and asserting with values
+        batch, message_id, reply_to, and message headers and asserting with 
values
         """
         asb_send_message_queue_operator = AzureServiceBusSendMessageOperator(
             task_id="asb_send_message_queue_without_batch",
             queue_name=QUEUE_NAME,
             message=mock_message,
             batch=mock_batch_flag,
+            message_id=mock_message_id,
+            reply_to=mock_reply_to,
+            message_headers=mock_headers,
         )
         assert asb_send_message_queue_operator.task_id == 
"asb_send_message_queue_without_batch"
         assert asb_send_message_queue_operator.queue_name == QUEUE_NAME
         assert asb_send_message_queue_operator.message == mock_message
         assert asb_send_message_queue_operator.batch is mock_batch_flag
+        assert asb_send_message_queue_operator.message_id == mock_message_id
+        assert asb_send_message_queue_operator.reply_to == mock_reply_to
+        assert asb_send_message_queue_operator.message_headers == mock_headers
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
-    def test_send_message_queue(self, mock_get_conn):
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_message")
+    def test_send_message_queue(self, mock_send_message):
         """
         Test AzureServiceBusSendMessageOperator with queue name, batch boolean 
flag, mock
         the send_messages of azure service bus function
         """
+        TASK_ID = "task-id"
+        MSG_BODY = "test message body"
+        MSG_ID = None
+        REPLY_TO = None
+        HDRS = None
         asb_send_message_queue_operator = AzureServiceBusSendMessageOperator(
-            task_id="asb_send_message_queue",
+            task_id=TASK_ID,
             queue_name=QUEUE_NAME,
-            message="Test message",
+            message=MSG_BODY,
             batch=False,
         )
         asb_send_message_queue_operator.execute(None)
-        expected_calls = [
-            mock.call()
-            .__enter__()
-            .get_queue_sender(QUEUE_NAME)
-            .__enter__()
-            .send_messages(ServiceBusMessage("Test message"))
-            .__exit__()
-        ]
-        mock_get_conn.assert_has_calls(expected_calls, any_order=False)
+        expected_calls = [mock.call(QUEUE_NAME, MSG_BODY, False, MSG_ID, 
REPLY_TO, HDRS)]
+        mock_send_message.assert_has_calls(expected_calls, any_order=False)
+
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_message")
+    def test_send_message_queue_with_id_hdrs_and_reply_to(self, 
mock_send_message):
+        """
+        Test AzureServiceBusSendMessageOperator with queue name, batch boolean 
flag, mock
+        the send_messages of azure service bus function
+        """
+        TASK_ID = "task-id"
+        MSG_ID = "test_message_id"
+        MSG_BODY = "test message body"
+        REPLY_TO = "test_reply_to"
+        HDRS = {"test_header": "test_value"}
+        asb_send_message_queue_operator = AzureServiceBusSendMessageOperator(
+            task_id=TASK_ID,
+            queue_name=QUEUE_NAME,
+            message=MSG_BODY,
+            batch=False,
+            message_id=MSG_ID,
+            reply_to=REPLY_TO,
+            message_headers=HDRS,
+        )
+        asb_send_message_queue_operator.execute(None)
+        expected_calls = [mock.call(QUEUE_NAME, MSG_BODY, False, MSG_ID, 
REPLY_TO, HDRS)]
+        mock_send_message.assert_has_calls(expected_calls, any_order=False)
 
 
 class TestAzureServiceBusReceiveMessageOperator:

Reply via email to