This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b47586eb554 Add async support for Amazon SNS Notifier (#56133)
b47586eb554 is described below
commit b47586eb55466d0ed05a4ebb67de19cfefd4cfc8
Author: D. Ferruzzi <[email protected]>
AuthorDate: Fri Sep 26 16:34:42 2025 -0700
Add async support for Amazon SNS Notifier (#56133)
* Add async support for Amazon SNS Notifier
---
.../src/airflow/providers/amazon/aws/hooks/sns.py | 92 ++++++++---
.../providers/amazon/aws/notifications/sns.py | 17 ++-
.../amazon/tests/unit/amazon/aws/hooks/test_sns.py | 168 +++++++++++++--------
.../unit/amazon/aws/notifications/test_sns.py | 46 ++++--
4 files changed, 227 insertions(+), 96 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py
index 00315fbddc3..de309df4815 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py
@@ -22,6 +22,7 @@ from __future__ import annotations
import json
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.utils.helpers import prune_dict
def _get_message_attribute(o):
@@ -38,6 +39,33 @@ def _get_message_attribute(o):
)
+def _build_publish_kwargs(
+ target_arn: str,
+ message: str,
+ subject: str | None = None,
+ message_attributes: dict | None = None,
+ message_deduplication_id: str | None = None,
+ message_group_id: str | None = None,
+) -> dict[str, str | dict]:
+ publish_kwargs: dict[str, str | dict] = prune_dict(
+ {
+ "TargetArn": target_arn,
+ "MessageStructure": "json",
+ "Message": json.dumps({"default": message}),
+ "Subject": subject,
+ "MessageDeduplicationId": message_deduplication_id,
+ "MessageGroupId": message_group_id,
+ }
+ )
+
+ if message_attributes:
+ publish_kwargs["MessageAttributes"] = {
+ key: _get_message_attribute(val) for key, val in
message_attributes.items()
+ }
+
+ return publish_kwargs
+
+
class SnsHook(AwsBaseHook):
"""
Interact with Amazon Simple Notification Service.
@@ -84,22 +112,50 @@ class SnsHook(AwsBaseHook):
:param message_group_id: Tag that specifies that a message belongs to
a specific message group.
This parameter applies only to FIFO (first-in-first-out) topics.
"""
- publish_kwargs: dict[str, str | dict] = {
- "TargetArn": target_arn,
- "MessageStructure": "json",
- "Message": json.dumps({"default": message}),
- }
+ return self.get_conn().publish(
+ **_build_publish_kwargs(
+ target_arn, message, subject, message_attributes,
message_deduplication_id, message_group_id
+ )
+ )
- # Construct args this way because boto3 distinguishes from missing
args and those set to None
- if subject:
- publish_kwargs["Subject"] = subject
- if message_deduplication_id:
- publish_kwargs["MessageDeduplicationId"] = message_deduplication_id
- if message_group_id:
- publish_kwargs["MessageGroupId"] = message_group_id
- if message_attributes:
- publish_kwargs["MessageAttributes"] = {
- key: _get_message_attribute(val) for key, val in
message_attributes.items()
- }
-
- return self.get_conn().publish(**publish_kwargs)
+ async def apublish_to_target(
+ self,
+ target_arn: str,
+ message: str,
+ subject: str | None = None,
+ message_attributes: dict | None = None,
+ message_deduplication_id: str | None = None,
+ message_group_id: str | None = None,
+ ):
+ """
+ Publish a message to a SNS topic or an endpoint.
+
+ .. seealso::
+ - :external+boto3:py:meth:`SNS.Client.publish`
+
+ :param target_arn: either a TopicArn or an EndpointArn
+ :param message: the default message you want to send
+ :param subject: subject of message
+ :param message_attributes: additional attributes to publish for
message filtering. This should be
+ a flat dict; the DataType to be sent depends on the type of the
value:
+
+ - bytes = Binary
+ - str = String
+ - int, float = Number
+ - iterable = String.Array
+ :param message_deduplication_id: Every message must have a unique
message_deduplication_id.
+ This parameter applies only to FIFO (first-in-first-out) topics.
+ :param message_group_id: Tag that specifies that a message belongs to
a specific message group.
+ This parameter applies only to FIFO (first-in-first-out) topics.
+ """
+ async with await self.get_async_conn() as async_client:
+ return await async_client.publish(
+ **_build_publish_kwargs(
+ target_arn,
+ message,
+ subject,
+ message_attributes,
+ message_deduplication_id,
+ message_group_id,
+ )
+ )
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py
b/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py
index c73d52e85cc..62a3c037a7f 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py
@@ -21,6 +21,7 @@ from collections.abc import Sequence
from functools import cached_property
from airflow.providers.amazon.aws.hooks.sns import SnsHook
+from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_PLUS
from airflow.providers.common.compat.notifier import BaseNotifier
@@ -60,8 +61,13 @@ class SnsNotifier(BaseNotifier):
subject: str | None = None,
message_attributes: dict | None = None,
region_name: str | None = None,
+ **kwargs,
):
- super().__init__()
+ if AIRFLOW_V_3_1_PLUS:
+ # Support for passing context was added in 3.1.0
+ super().__init__(**kwargs)
+ else:
+ super().__init__()
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.target_arn = target_arn
@@ -83,5 +89,14 @@ class SnsNotifier(BaseNotifier):
message_attributes=self.message_attributes,
)
+ async def async_notify(self, context):
+ """Publish the notification message to Amazon SNS (async)."""
+ await self.hook.apublish_to_target(
+ target_arn=self.target_arn,
+ message=self.message,
+ subject=self.subject,
+ message_attributes=self.message_attributes,
+ )
+
send_sns_notification = SnsNotifier
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py
index 19043cd5f45..e39157ca1f1 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py
@@ -17,95 +17,75 @@
# under the License.
from __future__ import annotations
+from unittest import mock
+
import pytest
from moto import mock_aws
from airflow.providers.amazon.aws.hooks.sns import SnsHook
+DEDUPE_ID = "test-dedupe-id"
+GROUP_ID = "test-group-id"
MESSAGE = "Hello world"
-TOPIC_NAME = "test-topic"
SUBJECT = "test-subject"
+INVALID_ATTRIBUTES_MSG = r"Values in MessageAttributes must be one of bytes,
str, int, float, or iterable"
+TOPIC_NAME = "test-topic"
+TOPIC_ARN = f"arn:aws:sns:us-east-1:123456789012:{TOPIC_NAME}"
-@mock_aws
-class TestSnsHook:
- def test_get_conn_returns_a_boto3_connection(self):
- hook = SnsHook(aws_conn_id="aws_default")
- assert hook.get_conn() is not None
-
- def test_publish_to_target_with_subject(self):
- hook = SnsHook(aws_conn_id="aws_default")
-
- message = MESSAGE
- topic_name = TOPIC_NAME
- subject = SUBJECT
- target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
-
- response = hook.publish_to_target(target, message, subject)
+INVALID_ATTRIBUTES = {"test-non-iterable": object()}
+VALID_ATTRIBUTES = {
+ "test-string": "string-value",
+ "test-number": 123456,
+ "test-array": ["first", "second", "third"],
+ "test-binary": b"binary-value",
+}
- assert "MessageId" in response
+MESSAGE_ID_KEY = "MessageId"
+TOPIC_ARN_KEY = "TopicArn"
- def test_publish_to_target_with_attributes(self):
- hook = SnsHook(aws_conn_id="aws_default")
- message = MESSAGE
- topic_name = TOPIC_NAME
- target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
+class TestSnsHook:
+ @pytest.fixture(autouse=True)
+ def setup_moto(self):
+ with mock_aws():
+ yield
- response = hook.publish_to_target(
- target,
- message,
- message_attributes={
- "test-string": "string-value",
- "test-number": 123456,
- "test-array": ["first", "second", "third"],
- "test-binary": b"binary-value",
- },
- )
+ @pytest.fixture
+ def hook(self):
+ return SnsHook(aws_conn_id="aws_default")
- assert "MessageId" in response
+ @pytest.fixture
+ def target(self, hook):
+ return hook.get_conn().create_topic(Name=TOPIC_NAME).get(TOPIC_ARN_KEY)
- def test_publish_to_target_plain(self):
- hook = SnsHook(aws_conn_id="aws_default")
+ def test_get_conn_returns_a_boto3_connection(self, hook):
+ assert hook.get_conn() is not None
- message = MESSAGE
- topic_name = "test-topic"
- target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
+ def test_publish_to_target_with_subject(self, hook, target):
+ response = hook.publish_to_target(target, MESSAGE, SUBJECT)
- response = hook.publish_to_target(target, message)
+ assert MESSAGE_ID_KEY in response
- assert "MessageId" in response
+ def test_publish_to_target_with_attributes(self, hook, target):
+ response = hook.publish_to_target(target, MESSAGE,
message_attributes=VALID_ATTRIBUTES)
- def test_publish_to_target_error(self):
- hook = SnsHook(aws_conn_id="aws_default")
+ assert MESSAGE_ID_KEY in response
- message = "Hello world"
- topic_name = "test-topic"
- target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
+ def test_publish_to_target_plain(self, hook, target):
+ response = hook.publish_to_target(target, MESSAGE)
- error_message = (
- r"Values in MessageAttributes must be one of bytes, str, int,
float, or iterable; got .*"
- )
- with pytest.raises(TypeError, match=error_message):
- hook.publish_to_target(
- target,
- message,
- message_attributes={
- "test-non-iterable": object(),
- },
- )
+ assert MESSAGE_ID_KEY in response
- def test_publish_to_target_with_deduplication(self):
- hook = SnsHook(aws_conn_id="aws_default")
+ def test_publish_to_target_error(self, hook, target):
+ with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG):
+ hook.publish_to_target(target, MESSAGE,
message_attributes=INVALID_ATTRIBUTES)
- message = MESSAGE
- topic_name = TOPIC_NAME + ".fifo"
- deduplication_id = "abc"
- group_id = "a"
- target = (
+ def test_publish_to_target_with_deduplication(self, hook):
+ fifo_target = (
hook.get_conn()
.create_topic(
- Name=topic_name,
+ Name=f"{TOPIC_NAME}.fifo",
Attributes={
"FifoTopic": "true",
"ContentBasedDeduplication": "false",
@@ -115,7 +95,63 @@ class TestSnsHook:
)
response = hook.publish_to_target(
- target, message, message_deduplication_id=deduplication_id,
message_group_id=group_id
+ fifo_target, MESSAGE, message_deduplication_id=DEDUPE_ID,
message_group_id=GROUP_ID
+ )
+ assert MESSAGE_ID_KEY in response
+
+
[email protected]
+class TestAsyncSnsHook:
+ """The mock_aws decorator uses `moto` which does not currently support
async SNS so we mock it manually."""
+
+ @pytest.fixture
+ def hook(self):
+ return SnsHook(aws_conn_id="aws_default")
+
+ @pytest.fixture
+ def mock_async_client(self):
+ mock_client = mock.AsyncMock()
+ mock_client.publish.return_value = {MESSAGE_ID_KEY: "test-message-id"}
+ return mock_client
+
+ @pytest.fixture
+ def mock_get_async_conn(self, mock_async_client):
+ with mock.patch.object(SnsHook, "get_async_conn") as mocked_conn:
+ mocked_conn.return_value = mock_async_client
+ mocked_conn.return_value.__aenter__.return_value =
mock_async_client
+ yield mocked_conn
+
+ async def test_get_async_conn(self, hook, mock_get_async_conn,
mock_async_client):
+ # Test context manager access
+ async with await hook.get_async_conn() as async_conn:
+ assert async_conn is mock_async_client
+
+ # Test direct access
+ async_conn = await hook.get_async_conn()
+ assert async_conn is mock_async_client
+
+ async def test_apublish_to_target_with_subject(self, hook,
mock_get_async_conn, mock_async_client):
+ response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE, SUBJECT)
+
+ assert MESSAGE_ID_KEY in response
+
+ async def test_apublish_to_target_with_attributes(self, hook,
mock_get_async_conn, mock_async_client):
+ response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE,
message_attributes=VALID_ATTRIBUTES)
+
+ assert MESSAGE_ID_KEY in response
+
+ async def test_publish_to_target_plain(self, hook, mock_get_async_conn,
mock_async_client):
+ response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE)
+
+ assert MESSAGE_ID_KEY in response
+
+ async def test_publish_to_target_error(self, hook, mock_get_async_conn,
mock_async_client):
+ with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG):
+ await hook.apublish_to_target(TOPIC_ARN, MESSAGE,
message_attributes=INVALID_ATTRIBUTES)
+
+ async def test_apublish_to_target_with_deduplication(self, hook,
mock_get_async_conn, mock_async_client):
+ response = await hook.apublish_to_target(
+ TOPIC_ARN, MESSAGE, message_deduplication_id=DEDUPE_ID,
message_group_id=GROUP_ID
)
- assert "MessageId" in response
+ assert MESSAGE_ID_KEY in response
diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
index 87f1bfb94cc..b09098d0d49 100644
--- a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
+++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
@@ -23,30 +23,44 @@ import pytest
from airflow.providers.amazon.aws.notifications.sns import SnsNotifier,
send_sns_notification
from airflow.utils.types import NOTSET
-PARAM_DEFAULT_VALUE = pytest.param(NOTSET, id="default-value")
+PUBLISH_KWARGS = {
+ "target_arn": "arn:aws:sns:us-west-2:123456789098:TopicName",
+ "message": "foo-bar",
+ "subject": "spam-egg",
+ "message_attributes": {},
+}
class TestSnsNotifier:
def test_class_and_notifier_are_same(self):
assert send_sns_notification is SnsNotifier
- @pytest.mark.parametrize("aws_conn_id", ["aws_test_conn_id", None,
PARAM_DEFAULT_VALUE])
- @pytest.mark.parametrize("region_name", ["eu-west-2", None,
PARAM_DEFAULT_VALUE])
+ @pytest.mark.parametrize(
+ "aws_conn_id",
+ [
+ pytest.param("aws_test_conn_id", id="custom-conn"),
+ pytest.param(None, id="none-conn"),
+ pytest.param(NOTSET, id="default-value"),
+ ],
+ )
+ @pytest.mark.parametrize(
+ "region_name",
+ [
+ pytest.param("eu-west-2", id="custom-region"),
+ pytest.param(None, id="no-region"),
+ pytest.param(NOTSET, id="default-value"),
+ ],
+ )
def test_parameters_propagate_to_hook(self, aws_conn_id, region_name):
"""Test notifier attributes propagate to SnsHook."""
- publish_kwargs = {
- "target_arn": "arn:aws:sns:us-west-2:123456789098:TopicName",
- "message": "foo-bar",
- "subject": "spam-egg",
- "message_attributes": {},
- }
+
notifier_kwargs = {}
if aws_conn_id is not NOTSET:
notifier_kwargs["aws_conn_id"] = aws_conn_id
if region_name is not NOTSET:
notifier_kwargs["region_name"] = region_name
- notifier = SnsNotifier(**notifier_kwargs, **publish_kwargs)
+ notifier = SnsNotifier(**notifier_kwargs, **PUBLISH_KWARGS)
with
mock.patch("airflow.providers.amazon.aws.notifications.sns.SnsHook") as
mock_hook:
hook = notifier.hook
assert hook is notifier.hook, "Hook property not cached"
@@ -57,7 +71,17 @@ class TestSnsNotifier:
# Basic check for notifier
notifier.notify({})
-
mock_hook.return_value.publish_to_target.assert_called_once_with(**publish_kwargs)
+
mock_hook.return_value.publish_to_target.assert_called_once_with(**PUBLISH_KWARGS)
+
+ @pytest.mark.asyncio
+ async def test_async_notify(self):
+ notifier = SnsNotifier(**PUBLISH_KWARGS)
+ with
mock.patch("airflow.providers.amazon.aws.notifications.sns.SnsHook") as
mock_hook:
+ mock_hook.return_value.apublish_to_target = mock.AsyncMock()
+
+ await notifier.async_notify({})
+
+
mock_hook.return_value.apublish_to_target.assert_called_once_with(**PUBLISH_KWARGS)
def test_sns_notifier_templated(self, create_dag_without_db):
notifier = SnsNotifier(