This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 93ad181684c Add context to Azure Service Bus Message callback (#43370)
93ad181684c is described below

commit 93ad181684c8314e199f2521f487cd4292e06b5c
Author: perry2of5 <[email protected]>
AuthorDate: Sat Oct 26 14:32:54 2024 -0700

    Add context to Azure Service Bus Message callback (#43370)
    
    * Add context to Azure Service Bus Message callback
    
    The original callback only took the message as a paramter. However,
    users may want to push status or location information into XComs and
    so callbacks need access to the context (or the XComs, but context is
    more general). This commit changes the code to pass the context
    
    NOTE: This is a BREAKING CHANGE.
    
    Fixes 43361
    
    * Add breaking change note to CHANGELOG
---
 .../providers/microsoft/azure/CHANGELOG.rst        |  6 ++++++
 .../airflow/providers/microsoft/azure/hooks/asb.py | 16 ++++++++++------
 .../providers/microsoft/azure/operators/asb.py     |  4 +++-
 providers/tests/microsoft/azure/hooks/test_asb.py  | 22 ++++++++++++++++------
 .../tests/microsoft/azure/operators/test_asb.py    | 14 +++++++++-----
 5 files changed, 44 insertions(+), 18 deletions(-)

diff --git a/providers/src/airflow/providers/microsoft/azure/CHANGELOG.rst 
b/providers/src/airflow/providers/microsoft/azure/CHANGELOG.rst
index 148fce2e115..457abcd9764 100644
--- a/providers/src/airflow/providers/microsoft/azure/CHANGELOG.rst
+++ b/providers/src/airflow/providers/microsoft/azure/CHANGELOG.rst
@@ -27,6 +27,12 @@
 Changelog
 ---------
 
+Breaking changes
+~~~~~~~~~~~~~~~~
+.. warning::
+   * We changed the message callback for Azure Service Bus messages to take 
two parameters, the message and the context, rather than just the message. This 
allows pushing message information into XComs. To upgrade from the previous 
version, which only took the message, please update your callback to take the 
context as a second parameter.
+
+
 10.5.1
 ......
 
diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/asb.py 
b/providers/src/airflow/providers/microsoft/azure/hooks/asb.py
index 317447d1117..1dafe3c7f3c 100644
--- a/providers/src/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/providers/src/airflow/providers/microsoft/azure/hooks/asb.py
@@ -34,12 +34,13 @@ from airflow.providers.microsoft.azure.utils import (
     get_sync_default_azure_credential,
 )
 
-MessageCallback = Callable[[ServiceBusMessage], None]
-
-
 if TYPE_CHECKING:
     from azure.identity import DefaultAzureCredential
 
+    from airflow.utils.context import Context
+
+    MessageCallback = Callable[[ServiceBusMessage, Context], None]
+
 
 class BaseAzureServiceBusHook(BaseHook):
     """
@@ -283,6 +284,7 @@ class MessageHook(BaseAzureServiceBusHook):
     def receive_message(
         self,
         queue_name: str,
+        context: Context,
         max_message_count: int | None = 1,
         max_wait_time: float | None = None,
         message_callback: MessageCallback | None = None,
@@ -309,12 +311,13 @@ class MessageHook(BaseAzureServiceBusHook):
                 max_message_count=max_message_count, 
max_wait_time=max_wait_time
             )
             for msg in received_msgs:
-                self._process_message(msg, message_callback, receiver)
+                self._process_message(msg, context, message_callback, receiver)
 
     def receive_subscription_message(
         self,
         topic_name: str,
         subscription_name: str,
+        context: Context,
         max_message_count: int | None,
         max_wait_time: float | None,
         message_callback: MessageCallback | None = None,
@@ -350,11 +353,12 @@ class MessageHook(BaseAzureServiceBusHook):
                 max_message_count=max_message_count, 
max_wait_time=max_wait_time
             )
             for msg in received_msgs:
-                self._process_message(msg, message_callback, 
subscription_receiver)
+                self._process_message(msg, context, message_callback, 
subscription_receiver)
 
     def _process_message(
         self,
         msg: ServiceBusReceivedMessage,
+        context: Context,
         message_callback: MessageCallback | None,
         receiver: ServiceBusReceiver,
     ):
@@ -372,7 +376,7 @@ class MessageHook(BaseAzureServiceBusHook):
             receiver.complete_message(msg)
         else:
             try:
-                message_callback(msg)
+                message_callback(msg, context)
             except Exception as e:
                 self.log.error("Error processing message: %s", e)
                 receiver.abandon_message(msg)
diff --git a/providers/src/airflow/providers/microsoft/azure/operators/asb.py 
b/providers/src/airflow/providers/microsoft/azure/operators/asb.py
index 85619526cfb..7d6bab0d625 100644
--- a/providers/src/airflow/providers/microsoft/azure/operators/asb.py
+++ b/providers/src/airflow/providers/microsoft/azure/operators/asb.py
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
 
     from airflow.utils.context import Context
 
-    MessageCallback = Callable[[ServiceBusMessage], None]
+    MessageCallback = Callable[[ServiceBusMessage, Context], None]
 
 
 class AzureServiceBusCreateQueueOperator(BaseOperator):
@@ -176,6 +176,7 @@ class AzureServiceBusReceiveMessageOperator(BaseOperator):
         # Receive message
         hook.receive_message(
             self.queue_name,
+            context,
             max_message_count=self.max_message_count,
             max_wait_time=self.max_wait_time,
             message_callback=self.message_callback,
@@ -562,6 +563,7 @@ class ASBReceiveSubscriptionMessageOperator(BaseOperator):
         hook.receive_subscription_message(
             self.topic_name,
             self.subscription_name,
+            context,
             self.max_message_count,
             self.max_wait_time,
             message_callback=self.message_callback,
diff --git a/providers/tests/microsoft/azure/hooks/test_asb.py 
b/providers/tests/microsoft/azure/hooks/test_asb.py
index 83e04833bf0..6d090e6653d 100644
--- a/providers/tests/microsoft/azure/hooks/test_asb.py
+++ b/providers/tests/microsoft/azure/hooks/test_asb.py
@@ -33,6 +33,7 @@ except ImportError:
 
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, 
MessageHook
+from airflow.utils.context import Context
 
 MESSAGE = "Test Message"
 MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(10)]
@@ -256,7 +257,7 @@ class TestMessageHook:
         
mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value
 = [
             mock_service_bus_message
         ]
-        hook.receive_message(self.queue_name)
+        hook.receive_message(self.queue_name, Context())
         expected_calls = [
             mock.call()
             .__enter__()
@@ -285,12 +286,13 @@ class TestMessageHook:
 
         received_messages = []
 
-        def message_callback(msg: Any) -> None:
+        def message_callback(msg: Any, context: Context) -> None:
             nonlocal received_messages
             print("received message:", msg)
+            assert context is not None
             received_messages.append(msg)
 
-        hook.receive_message(self.queue_name, 
message_callback=message_callback)
+        hook.receive_message(self.queue_name, Context(), 
message_callback=message_callback)
 
         assert len(received_messages) == 1
         assert received_messages[0] == mock_service_bus_message
@@ -316,7 +318,9 @@ class TestMessageHook:
         max_message_count = 10
         max_wait_time = 5
         hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
-        hook.receive_subscription_message(topic_name, subscription_name, 
max_message_count, max_wait_time)
+        hook.receive_subscription_message(
+            topic_name, subscription_name, Context(), max_message_count, 
max_wait_time
+        )
         expected_calls = [
             mock.call()
             .__enter__()
@@ -350,13 +354,19 @@ class TestMessageHook:
 
         received_messages = []
 
-        def message_callback(msg: ServiceBusMessage) -> None:
+        def message_callback(msg: ServiceBusMessage, context: Context) -> None:
             nonlocal received_messages
             print("received message:", msg)
+            assert context is not None
             received_messages.append(msg)
 
         hook.receive_subscription_message(
-            topic_name, subscription_name, max_message_count, max_wait_time, 
message_callback=message_callback
+            topic_name,
+            subscription_name,
+            Context(),
+            max_message_count,
+            max_wait_time,
+            message_callback=message_callback,
         )
 
         assert len(received_messages) == 2
diff --git a/providers/tests/microsoft/azure/operators/test_asb.py 
b/providers/tests/microsoft/azure/operators/test_asb.py
index 42b770095b4..7e0c953890c 100644
--- a/providers/tests/microsoft/azure/operators/test_asb.py
+++ b/providers/tests/microsoft/azure/operators/test_asb.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 from unittest import mock
 
 import pytest
+from azure.servicebus import ServiceBusMessage
 
 try:
     from azure.servicebus import ServiceBusMessage
@@ -37,6 +38,7 @@ from airflow.providers.microsoft.azure.operators.asb import (
     AzureServiceBusTopicDeleteOperator,
     AzureServiceBusUpdateSubscriptionOperator,
 )
+from airflow.utils.context import Context
 
 QUEUE_NAME = "test_queue"
 MESSAGE = "Test Message"
@@ -216,21 +218,22 @@ class TestAzureServiceBusReceiveMessageOperator:
         Test AzureServiceBusReceiveMessageOperator by mock connection, values
         and the service bus receive message
         """
-        mock_service_bus_message = ServiceBusMessage("Test message")
+        mock_service_bus_message = ServiceBusMessage("Test message with 
context")
         
mock_get_conn.return_value.__enter__.return_value.get_queue_receiver.return_value.__enter__.return_value.receive_messages.return_value
 = [
             mock_service_bus_message
         ]
 
         messages_received = []
 
-        def message_callback(msg):
+        def message_callback(msg: ServiceBusMessage, context: Context):
             messages_received.append(msg)
+            assert context is not None
             print(msg)
 
         asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator(
             task_id="asb_receive_message_queue", queue_name=QUEUE_NAME, 
message_callback=message_callback
         )
-        asb_receive_queue_operator.execute(None)
+        asb_receive_queue_operator.execute(Context())
         assert len(messages_received) == 1
         assert messages_received[0] == mock_service_bus_message
 
@@ -470,8 +473,9 @@ class TestASBSubscriptionReceiveMessageOperator:
 
         messages_received = []
 
-        def message_callback(msg):
+        def message_callback(msg: ServiceBusMessage, context: Context):
             messages_received.append(msg)
+            assert context is not None
             print(msg)
 
         asb_subscription_receive_message = 
ASBReceiveSubscriptionMessageOperator(
@@ -482,7 +486,7 @@ class TestASBSubscriptionReceiveMessageOperator:
             message_callback=message_callback,
         )
 
-        asb_subscription_receive_message.execute(None)
+        asb_subscription_receive_message.execute(Context())
         expected_calls = [
             mock.call()
             .__enter__()

Reply via email to