This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 93ad181684c Add context to Azure Service Bus Message callback (#43370)
93ad181684c is described below
commit 93ad181684c8314e199f2521f487cd4292e06b5c
Author: perry2of5 <[email protected]>
AuthorDate: Sat Oct 26 14:32:54 2024 -0700
Add context to Azure Service Bus Message callback (#43370)
* Add context to Azure Service Bus Message callback
The original callback only took the message as a paramter. However,
users may want to push status or location information into XComs and
so callbacks need access to the context (or the XComs, but context is
more general). This commit changes the code to pass the context
NOTE: This is a BREAKING CHANGE.
Fixes 43361
* Add breaking change note to CHANGELOG
---
.../providers/microsoft/azure/CHANGELOG.rst | 6 ++++++
.../airflow/providers/microsoft/azure/hooks/asb.py | 16 ++++++++++------
.../providers/microsoft/azure/operators/asb.py | 4 +++-
providers/tests/microsoft/azure/hooks/test_asb.py | 22 ++++++++++++++++------
.../tests/microsoft/azure/operators/test_asb.py | 14 +++++++++-----
5 files changed, 44 insertions(+), 18 deletions(-)
diff --git a/providers/src/airflow/providers/microsoft/azure/CHANGELOG.rst
b/providers/src/airflow/providers/microsoft/azure/CHANGELOG.rst
index 148fce2e115..457abcd9764 100644
--- a/providers/src/airflow/providers/microsoft/azure/CHANGELOG.rst
+++ b/providers/src/airflow/providers/microsoft/azure/CHANGELOG.rst
@@ -27,6 +27,12 @@
Changelog
---------
+Breaking changes
+~~~~~~~~~~~~~~~~
+.. warning::
+ * We changed the message callback for Azure Service Bus messages to take
two parameters, the message and the context, rather than just the message. This
allows pushing message information into XComs. To upgrade from the previous
version, which only took the message, please update your callback to take the
context as a second parameter.
+
+
10.5.1
......
diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/asb.py
b/providers/src/airflow/providers/microsoft/azure/hooks/asb.py
index 317447d1117..1dafe3c7f3c 100644
--- a/providers/src/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/providers/src/airflow/providers/microsoft/azure/hooks/asb.py
@@ -34,12 +34,13 @@ from airflow.providers.microsoft.azure.utils import (
get_sync_default_azure_credential,
)
-MessageCallback = Callable[[ServiceBusMessage], None]
-
-
if TYPE_CHECKING:
from azure.identity import DefaultAzureCredential
+ from airflow.utils.context import Context
+
+ MessageCallback = Callable[[ServiceBusMessage, Context], None]
+
class BaseAzureServiceBusHook(BaseHook):
"""
@@ -283,6 +284,7 @@ class MessageHook(BaseAzureServiceBusHook):
def receive_message(
self,
queue_name: str,
+ context: Context,
max_message_count: int | None = 1,
max_wait_time: float | None = None,
message_callback: MessageCallback | None = None,
@@ -309,12 +311,13 @@ class MessageHook(BaseAzureServiceBusHook):
max_message_count=max_message_count,
max_wait_time=max_wait_time
)
for msg in received_msgs:
- self._process_message(msg, message_callback, receiver)
+ self._process_message(msg, context, message_callback, receiver)
def receive_subscription_message(
self,
topic_name: str,
subscription_name: str,
+ context: Context,
max_message_count: int | None,
max_wait_time: float | None,
message_callback: MessageCallback | None = None,
@@ -350,11 +353,12 @@ class MessageHook(BaseAzureServiceBusHook):
max_message_count=max_message_count,
max_wait_time=max_wait_time
)
for msg in received_msgs:
- self._process_message(msg, message_callback,
subscription_receiver)
+ self._process_message(msg, context, message_callback,
subscription_receiver)
def _process_message(
self,
msg: ServiceBusReceivedMessage,
+ context: Context,
message_callback: MessageCallback | None,
receiver: ServiceBusReceiver,
):
@@ -372,7 +376,7 @@ class MessageHook(BaseAzureServiceBusHook):
receiver.complete_message(msg)
else:
try:
- message_callback(msg)
+ message_callback(msg, context)
except Exception as e:
self.log.error("Error processing message: %s", e)
receiver.abandon_message(msg)
diff --git a/providers/src/airflow/providers/microsoft/azure/operators/asb.py
b/providers/src/airflow/providers/microsoft/azure/operators/asb.py
index 85619526cfb..7d6bab0d625 100644
--- a/providers/src/airflow/providers/microsoft/azure/operators/asb.py
+++ b/providers/src/airflow/providers/microsoft/azure/operators/asb.py
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
from airflow.utils.context import Context
- MessageCallback = Callable[[ServiceBusMessage], None]
+ MessageCallback = Callable[[ServiceBusMessage, Context], None]
class AzureServiceBusCreateQueueOperator(BaseOperator):
@@ -176,6 +176,7 @@ class AzureServiceBusReceiveMessageOperator(BaseOperator):
# Receive message
hook.receive_message(
self.queue_name,
+ context,
max_message_count=self.max_message_count,
max_wait_time=self.max_wait_time,
message_callback=self.message_callback,
@@ -562,6 +563,7 @@ class ASBReceiveSubscriptionMessageOperator(BaseOperator):
hook.receive_subscription_message(
self.topic_name,
self.subscription_name,
+ context,
self.max_message_count,
self.max_wait_time,
message_callback=self.message_callback,
diff --git a/providers/tests/microsoft/azure/hooks/test_asb.py
b/providers/tests/microsoft/azure/hooks/test_asb.py
index 83e04833bf0..6d090e6653d 100644
--- a/providers/tests/microsoft/azure/hooks/test_asb.py
+++ b/providers/tests/microsoft/azure/hooks/test_asb.py
@@ -33,6 +33,7 @@ except ImportError:
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook,
MessageHook
+from airflow.utils.context import Context
MESSAGE = "Test Message"
MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(10)]
@@ -256,7 +257,7 @@ class TestMessageHook:
mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value
= [
mock_service_bus_message
]
- hook.receive_message(self.queue_name)
+ hook.receive_message(self.queue_name, Context())
expected_calls = [
mock.call()
.__enter__()
@@ -285,12 +286,13 @@ class TestMessageHook:
received_messages = []
- def message_callback(msg: Any) -> None:
+ def message_callback(msg: Any, context: Context) -> None:
nonlocal received_messages
print("received message:", msg)
+ assert context is not None
received_messages.append(msg)
- hook.receive_message(self.queue_name,
message_callback=message_callback)
+ hook.receive_message(self.queue_name, Context(),
message_callback=message_callback)
assert len(received_messages) == 1
assert received_messages[0] == mock_service_bus_message
@@ -316,7 +318,9 @@ class TestMessageHook:
max_message_count = 10
max_wait_time = 5
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
- hook.receive_subscription_message(topic_name, subscription_name,
max_message_count, max_wait_time)
+ hook.receive_subscription_message(
+ topic_name, subscription_name, Context(), max_message_count,
max_wait_time
+ )
expected_calls = [
mock.call()
.__enter__()
@@ -350,13 +354,19 @@ class TestMessageHook:
received_messages = []
- def message_callback(msg: ServiceBusMessage) -> None:
+ def message_callback(msg: ServiceBusMessage, context: Context) -> None:
nonlocal received_messages
print("received message:", msg)
+ assert context is not None
received_messages.append(msg)
hook.receive_subscription_message(
- topic_name, subscription_name, max_message_count, max_wait_time,
message_callback=message_callback
+ topic_name,
+ subscription_name,
+ Context(),
+ max_message_count,
+ max_wait_time,
+ message_callback=message_callback,
)
assert len(received_messages) == 2
diff --git a/providers/tests/microsoft/azure/operators/test_asb.py
b/providers/tests/microsoft/azure/operators/test_asb.py
index 42b770095b4..7e0c953890c 100644
--- a/providers/tests/microsoft/azure/operators/test_asb.py
+++ b/providers/tests/microsoft/azure/operators/test_asb.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from unittest import mock
import pytest
+from azure.servicebus import ServiceBusMessage
try:
from azure.servicebus import ServiceBusMessage
@@ -37,6 +38,7 @@ from airflow.providers.microsoft.azure.operators.asb import (
AzureServiceBusTopicDeleteOperator,
AzureServiceBusUpdateSubscriptionOperator,
)
+from airflow.utils.context import Context
QUEUE_NAME = "test_queue"
MESSAGE = "Test Message"
@@ -216,21 +218,22 @@ class TestAzureServiceBusReceiveMessageOperator:
Test AzureServiceBusReceiveMessageOperator by mock connection, values
and the service bus receive message
"""
- mock_service_bus_message = ServiceBusMessage("Test message")
+ mock_service_bus_message = ServiceBusMessage("Test message with
context")
mock_get_conn.return_value.__enter__.return_value.get_queue_receiver.return_value.__enter__.return_value.receive_messages.return_value
= [
mock_service_bus_message
]
messages_received = []
- def message_callback(msg):
+ def message_callback(msg: ServiceBusMessage, context: Context):
messages_received.append(msg)
+ assert context is not None
print(msg)
asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator(
task_id="asb_receive_message_queue", queue_name=QUEUE_NAME,
message_callback=message_callback
)
- asb_receive_queue_operator.execute(None)
+ asb_receive_queue_operator.execute(Context())
assert len(messages_received) == 1
assert messages_received[0] == mock_service_bus_message
@@ -470,8 +473,9 @@ class TestASBSubscriptionReceiveMessageOperator:
messages_received = []
- def message_callback(msg):
+ def message_callback(msg: ServiceBusMessage, context: Context):
messages_received.append(msg)
+ assert context is not None
print(msg)
asb_subscription_receive_message =
ASBReceiveSubscriptionMessageOperator(
@@ -482,7 +486,7 @@ class TestASBSubscriptionReceiveMessageOperator:
message_callback=message_callback,
)
- asb_subscription_receive_message.execute(None)
+ asb_subscription_receive_message.execute(Context())
expected_calls = [
mock.call()
.__enter__()