This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b47586eb554 Add async support for Amazon SNS Notifier (#56133)
b47586eb554 is described below

commit b47586eb55466d0ed05a4ebb67de19cfefd4cfc8
Author: D. Ferruzzi <[email protected]>
AuthorDate: Fri Sep 26 16:34:42 2025 -0700

    Add async support for Amazon SNS Notifier (#56133)
    
    * Add async support for Amazon SNS Notifier
---
 .../src/airflow/providers/amazon/aws/hooks/sns.py  |  92 ++++++++---
 .../providers/amazon/aws/notifications/sns.py      |  17 ++-
 .../amazon/tests/unit/amazon/aws/hooks/test_sns.py | 168 +++++++++++++--------
 .../unit/amazon/aws/notifications/test_sns.py      |  46 ++++--
 4 files changed, 227 insertions(+), 96 deletions(-)

diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py
index 00315fbddc3..de309df4815 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py
@@ -22,6 +22,7 @@ from __future__ import annotations
 import json
 
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.utils.helpers import prune_dict
 
 
 def _get_message_attribute(o):
@@ -38,6 +39,33 @@ def _get_message_attribute(o):
     )
 
 
+def _build_publish_kwargs(
+    target_arn: str,
+    message: str,
+    subject: str | None = None,
+    message_attributes: dict | None = None,
+    message_deduplication_id: str | None = None,
+    message_group_id: str | None = None,
+) -> dict[str, str | dict]:
+    publish_kwargs: dict[str, str | dict] = prune_dict(
+        {
+            "TargetArn": target_arn,
+            "MessageStructure": "json",
+            "Message": json.dumps({"default": message}),
+            "Subject": subject,
+            "MessageDeduplicationId": message_deduplication_id,
+            "MessageGroupId": message_group_id,
+        }
+    )
+
+    if message_attributes:
+        publish_kwargs["MessageAttributes"] = {
+            key: _get_message_attribute(val) for key, val in 
message_attributes.items()
+        }
+
+    return publish_kwargs
+
+
 class SnsHook(AwsBaseHook):
     """
     Interact with Amazon Simple Notification Service.
@@ -84,22 +112,50 @@ class SnsHook(AwsBaseHook):
         :param message_group_id: Tag that specifies that a message belongs to 
a specific message group.
             This parameter applies only to FIFO (first-in-first-out) topics.
         """
-        publish_kwargs: dict[str, str | dict] = {
-            "TargetArn": target_arn,
-            "MessageStructure": "json",
-            "Message": json.dumps({"default": message}),
-        }
+        return self.get_conn().publish(
+            **_build_publish_kwargs(
+                target_arn, message, subject, message_attributes, 
message_deduplication_id, message_group_id
+            )
+        )
 
-        # Construct args this way because boto3 distinguishes from missing 
args and those set to None
-        if subject:
-            publish_kwargs["Subject"] = subject
-        if message_deduplication_id:
-            publish_kwargs["MessageDeduplicationId"] = message_deduplication_id
-        if message_group_id:
-            publish_kwargs["MessageGroupId"] = message_group_id
-        if message_attributes:
-            publish_kwargs["MessageAttributes"] = {
-                key: _get_message_attribute(val) for key, val in 
message_attributes.items()
-            }
-
-        return self.get_conn().publish(**publish_kwargs)
+    async def apublish_to_target(
+        self,
+        target_arn: str,
+        message: str,
+        subject: str | None = None,
+        message_attributes: dict | None = None,
+        message_deduplication_id: str | None = None,
+        message_group_id: str | None = None,
+    ):
+        """
+        Publish a message to a SNS topic or an endpoint.
+
+        .. seealso::
+            - :external+boto3:py:meth:`SNS.Client.publish`
+
+        :param target_arn: either a TopicArn or an EndpointArn
+        :param message: the default message you want to send
+        :param subject: subject of message
+        :param message_attributes: additional attributes to publish for 
message filtering. This should be
+            a flat dict; the DataType to be sent depends on the type of the 
value:
+
+            - bytes = Binary
+            - str = String
+            - int, float = Number
+            - iterable = String.Array
+        :param message_deduplication_id: Every message must have a unique 
message_deduplication_id.
+            This parameter applies only to FIFO (first-in-first-out) topics.
+        :param message_group_id: Tag that specifies that a message belongs to 
a specific message group.
+            This parameter applies only to FIFO (first-in-first-out) topics.
+        """
+        async with await self.get_async_conn() as async_client:
+            return await async_client.publish(
+                **_build_publish_kwargs(
+                    target_arn,
+                    message,
+                    subject,
+                    message_attributes,
+                    message_deduplication_id,
+                    message_group_id,
+                )
+            )
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py 
b/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py
index c73d52e85cc..62a3c037a7f 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py
@@ -21,6 +21,7 @@ from collections.abc import Sequence
 from functools import cached_property
 
 from airflow.providers.amazon.aws.hooks.sns import SnsHook
+from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_PLUS
 from airflow.providers.common.compat.notifier import BaseNotifier
 
 
@@ -60,8 +61,13 @@ class SnsNotifier(BaseNotifier):
         subject: str | None = None,
         message_attributes: dict | None = None,
         region_name: str | None = None,
+        **kwargs,
     ):
-        super().__init__()
+        if AIRFLOW_V_3_1_PLUS:
+            #  Support for passing context was added in 3.1.0
+            super().__init__(**kwargs)
+        else:
+            super().__init__()
         self.aws_conn_id = aws_conn_id
         self.region_name = region_name
         self.target_arn = target_arn
@@ -83,5 +89,14 @@ class SnsNotifier(BaseNotifier):
             message_attributes=self.message_attributes,
         )
 
+    async def async_notify(self, context):
+        """Publish the notification message to Amazon SNS (async)."""
+        await self.hook.apublish_to_target(
+            target_arn=self.target_arn,
+            message=self.message,
+            subject=self.subject,
+            message_attributes=self.message_attributes,
+        )
+
 
 send_sns_notification = SnsNotifier
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py
index 19043cd5f45..e39157ca1f1 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py
@@ -17,95 +17,75 @@
 # under the License.
 from __future__ import annotations
 
+from unittest import mock
+
 import pytest
 from moto import mock_aws
 
 from airflow.providers.amazon.aws.hooks.sns import SnsHook
 
+DEDUPE_ID = "test-dedupe-id"
+GROUP_ID = "test-group-id"
 MESSAGE = "Hello world"
-TOPIC_NAME = "test-topic"
 SUBJECT = "test-subject"
+INVALID_ATTRIBUTES_MSG = r"Values in MessageAttributes must be one of bytes, 
str, int, float, or iterable"
 
+TOPIC_NAME = "test-topic"
+TOPIC_ARN = f"arn:aws:sns:us-east-1:123456789012:{TOPIC_NAME}"
 
-@mock_aws
-class TestSnsHook:
-    def test_get_conn_returns_a_boto3_connection(self):
-        hook = SnsHook(aws_conn_id="aws_default")
-        assert hook.get_conn() is not None
-
-    def test_publish_to_target_with_subject(self):
-        hook = SnsHook(aws_conn_id="aws_default")
-
-        message = MESSAGE
-        topic_name = TOPIC_NAME
-        subject = SUBJECT
-        target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
-
-        response = hook.publish_to_target(target, message, subject)
+INVALID_ATTRIBUTES = {"test-non-iterable": object()}
+VALID_ATTRIBUTES = {
+    "test-string": "string-value",
+    "test-number": 123456,
+    "test-array": ["first", "second", "third"],
+    "test-binary": b"binary-value",
+}
 
-        assert "MessageId" in response
+MESSAGE_ID_KEY = "MessageId"
+TOPIC_ARN_KEY = "TopicArn"
 
-    def test_publish_to_target_with_attributes(self):
-        hook = SnsHook(aws_conn_id="aws_default")
 
-        message = MESSAGE
-        topic_name = TOPIC_NAME
-        target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
+class TestSnsHook:
+    @pytest.fixture(autouse=True)
+    def setup_moto(self):
+        with mock_aws():
+            yield
 
-        response = hook.publish_to_target(
-            target,
-            message,
-            message_attributes={
-                "test-string": "string-value",
-                "test-number": 123456,
-                "test-array": ["first", "second", "third"],
-                "test-binary": b"binary-value",
-            },
-        )
+    @pytest.fixture
+    def hook(self):
+        return SnsHook(aws_conn_id="aws_default")
 
-        assert "MessageId" in response
+    @pytest.fixture
+    def target(self, hook):
+        return hook.get_conn().create_topic(Name=TOPIC_NAME).get(TOPIC_ARN_KEY)
 
-    def test_publish_to_target_plain(self):
-        hook = SnsHook(aws_conn_id="aws_default")
+    def test_get_conn_returns_a_boto3_connection(self, hook):
+        assert hook.get_conn() is not None
 
-        message = MESSAGE
-        topic_name = "test-topic"
-        target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
+    def test_publish_to_target_with_subject(self, hook, target):
+        response = hook.publish_to_target(target, MESSAGE, SUBJECT)
 
-        response = hook.publish_to_target(target, message)
+        assert MESSAGE_ID_KEY in response
 
-        assert "MessageId" in response
+    def test_publish_to_target_with_attributes(self, hook, target):
+        response = hook.publish_to_target(target, MESSAGE, 
message_attributes=VALID_ATTRIBUTES)
 
-    def test_publish_to_target_error(self):
-        hook = SnsHook(aws_conn_id="aws_default")
+        assert MESSAGE_ID_KEY in response
 
-        message = "Hello world"
-        topic_name = "test-topic"
-        target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
+    def test_publish_to_target_plain(self, hook, target):
+        response = hook.publish_to_target(target, MESSAGE)
 
-        error_message = (
-            r"Values in MessageAttributes must be one of bytes, str, int, 
float, or iterable; got .*"
-        )
-        with pytest.raises(TypeError, match=error_message):
-            hook.publish_to_target(
-                target,
-                message,
-                message_attributes={
-                    "test-non-iterable": object(),
-                },
-            )
+        assert MESSAGE_ID_KEY in response
 
-    def test_publish_to_target_with_deduplication(self):
-        hook = SnsHook(aws_conn_id="aws_default")
+    def test_publish_to_target_error(self, hook, target):
+        with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG):
+            hook.publish_to_target(target, MESSAGE, 
message_attributes=INVALID_ATTRIBUTES)
 
-        message = MESSAGE
-        topic_name = TOPIC_NAME + ".fifo"
-        deduplication_id = "abc"
-        group_id = "a"
-        target = (
+    def test_publish_to_target_with_deduplication(self, hook):
+        fifo_target = (
             hook.get_conn()
             .create_topic(
-                Name=topic_name,
+                Name=f"{TOPIC_NAME}.fifo",
                 Attributes={
                     "FifoTopic": "true",
                     "ContentBasedDeduplication": "false",
@@ -115,7 +95,63 @@ class TestSnsHook:
         )
 
         response = hook.publish_to_target(
-            target, message, message_deduplication_id=deduplication_id, 
message_group_id=group_id
+            fifo_target, MESSAGE, message_deduplication_id=DEDUPE_ID, 
message_group_id=GROUP_ID
+        )
+        assert MESSAGE_ID_KEY in response
+
+
[email protected]
+class TestAsyncSnsHook:
+    """The mock_aws decorator uses `moto` which does not currently support 
async SNS so we mock it manually."""
+
+    @pytest.fixture
+    def hook(self):
+        return SnsHook(aws_conn_id="aws_default")
+
+    @pytest.fixture
+    def mock_async_client(self):
+        mock_client = mock.AsyncMock()
+        mock_client.publish.return_value = {MESSAGE_ID_KEY: "test-message-id"}
+        return mock_client
+
+    @pytest.fixture
+    def mock_get_async_conn(self, mock_async_client):
+        with mock.patch.object(SnsHook, "get_async_conn") as mocked_conn:
+            mocked_conn.return_value = mock_async_client
+            mocked_conn.return_value.__aenter__.return_value = 
mock_async_client
+            yield mocked_conn
+
+    async def test_get_async_conn(self, hook, mock_get_async_conn, 
mock_async_client):
+        # Test context manager access
+        async with await hook.get_async_conn() as async_conn:
+            assert async_conn is mock_async_client
+
+        # Test direct access
+        async_conn = await hook.get_async_conn()
+        assert async_conn is mock_async_client
+
+    async def test_apublish_to_target_with_subject(self, hook, 
mock_get_async_conn, mock_async_client):
+        response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE, SUBJECT)
+
+        assert MESSAGE_ID_KEY in response
+
+    async def test_apublish_to_target_with_attributes(self, hook, 
mock_get_async_conn, mock_async_client):
+        response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE, 
message_attributes=VALID_ATTRIBUTES)
+
+        assert MESSAGE_ID_KEY in response
+
+    async def test_publish_to_target_plain(self, hook, mock_get_async_conn, 
mock_async_client):
+        response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE)
+
+        assert MESSAGE_ID_KEY in response
+
+    async def test_publish_to_target_error(self, hook, mock_get_async_conn, 
mock_async_client):
+        with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG):
+            await hook.apublish_to_target(TOPIC_ARN, MESSAGE, 
message_attributes=INVALID_ATTRIBUTES)
+
+    async def test_apublish_to_target_with_deduplication(self, hook, 
mock_get_async_conn, mock_async_client):
+        response = await hook.apublish_to_target(
+            TOPIC_ARN, MESSAGE, message_deduplication_id=DEDUPE_ID, 
message_group_id=GROUP_ID
         )
 
-        assert "MessageId" in response
+        assert MESSAGE_ID_KEY in response
diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py 
b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
index 87f1bfb94cc..b09098d0d49 100644
--- a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
+++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
@@ -23,30 +23,44 @@ import pytest
 from airflow.providers.amazon.aws.notifications.sns import SnsNotifier, 
send_sns_notification
 from airflow.utils.types import NOTSET
 
-PARAM_DEFAULT_VALUE = pytest.param(NOTSET, id="default-value")
+PUBLISH_KWARGS = {
+    "target_arn": "arn:aws:sns:us-west-2:123456789098:TopicName",
+    "message": "foo-bar",
+    "subject": "spam-egg",
+    "message_attributes": {},
+}
 
 
 class TestSnsNotifier:
     def test_class_and_notifier_are_same(self):
         assert send_sns_notification is SnsNotifier
 
-    @pytest.mark.parametrize("aws_conn_id", ["aws_test_conn_id", None, 
PARAM_DEFAULT_VALUE])
-    @pytest.mark.parametrize("region_name", ["eu-west-2", None, 
PARAM_DEFAULT_VALUE])
+    @pytest.mark.parametrize(
+        "aws_conn_id",
+        [
+            pytest.param("aws_test_conn_id", id="custom-conn"),
+            pytest.param(None, id="none-conn"),
+            pytest.param(NOTSET, id="default-value"),
+        ],
+    )
+    @pytest.mark.parametrize(
+        "region_name",
+        [
+            pytest.param("eu-west-2", id="custom-region"),
+            pytest.param(None, id="no-region"),
+            pytest.param(NOTSET, id="default-value"),
+        ],
+    )
     def test_parameters_propagate_to_hook(self, aws_conn_id, region_name):
         """Test notifier attributes propagate to SnsHook."""
-        publish_kwargs = {
-            "target_arn": "arn:aws:sns:us-west-2:123456789098:TopicName",
-            "message": "foo-bar",
-            "subject": "spam-egg",
-            "message_attributes": {},
-        }
+
         notifier_kwargs = {}
         if aws_conn_id is not NOTSET:
             notifier_kwargs["aws_conn_id"] = aws_conn_id
         if region_name is not NOTSET:
             notifier_kwargs["region_name"] = region_name
 
-        notifier = SnsNotifier(**notifier_kwargs, **publish_kwargs)
+        notifier = SnsNotifier(**notifier_kwargs, **PUBLISH_KWARGS)
         with 
mock.patch("airflow.providers.amazon.aws.notifications.sns.SnsHook") as 
mock_hook:
             hook = notifier.hook
             assert hook is notifier.hook, "Hook property not cached"
@@ -57,7 +71,17 @@ class TestSnsNotifier:
 
             # Basic check for notifier
             notifier.notify({})
-            
mock_hook.return_value.publish_to_target.assert_called_once_with(**publish_kwargs)
+            
mock_hook.return_value.publish_to_target.assert_called_once_with(**PUBLISH_KWARGS)
+
+    @pytest.mark.asyncio
+    async def test_async_notify(self):
+        notifier = SnsNotifier(**PUBLISH_KWARGS)
+        with 
mock.patch("airflow.providers.amazon.aws.notifications.sns.SnsHook") as 
mock_hook:
+            mock_hook.return_value.apublish_to_target = mock.AsyncMock()
+
+            await notifier.async_notify({})
+
+            
mock_hook.return_value.apublish_to_target.assert_called_once_with(**PUBLISH_KWARGS)
 
     def test_sns_notifier_templated(self, create_dag_without_db):
         notifier = SnsNotifier(

Reply via email to