This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 17d031df66 Add Amazon SQS Notifier (#33962)
17d031df66 is described below

commit 17d031df66ce99943aa7e7272e24c8e6d3b3ebd6
Author: Andrey Anshin <[email protected]>
AuthorDate: Thu Aug 31 22:26:16 2023 +0400

    Add Amazon SQS Notifier (#33962)
    
    ---------
    
    Co-authored-by: Elad Kalif <[email protected]>
    Co-authored-by: Vincent <[email protected]>
---
 airflow/providers/amazon/aws/notifications/sqs.py  | 100 +++++++++++++++++++++
 airflow/providers/amazon/provider.yaml             |   1 +
 .../notifications/sqs.rst                          |  63 +++++++++++++
 .../providers/amazon/aws/notifications/test_sqs.py |  61 +++++++++++++
 4 files changed, 225 insertions(+)

diff --git a/airflow/providers/amazon/aws/notifications/sqs.py 
b/airflow/providers/amazon/aws/notifications/sqs.py
new file mode 100644
index 0000000000..ea79bc1de6
--- /dev/null
+++ b/airflow/providers/amazon/aws/notifications/sqs.py
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import Sequence
+
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+from airflow.providers.amazon.aws.hooks.sqs import SqsHook
+
+try:
+    from airflow.notifications.basenotifier import BaseNotifier
+except ImportError:
+    raise AirflowOptionalProviderFeatureException(
+        "Failed to import BaseNotifier. This feature is only available in 
Airflow versions >= 2.6.0"
+    )
+
+
+class SqsNotifier(BaseNotifier):
+    """
+    Amazon SQS (Simple Queue Service) Notifier.
+
+    .. seealso::
+        For more information on how to use this notifier, take a look at the 
guide:
+        :ref:`howto/notifier:SqsNotifier`
+
+    :param aws_conn_id: The :ref:`Amazon Web Services Connection id 
<howto/connection:aws>`
+        used for AWS credentials. If this is None or empty then the default 
boto3 behaviour is used.
+    :param queue_url: The URL of the Amazon SQS queue to which a message is 
sent.
+    :param message_body: The message to send.
+    :param message_attributes: additional attributes for the message.
+        For details of the attributes parameter see 
:py:meth:`botocore.client.SQS.send_message`.
+    :param message_group_id: This parameter applies only to FIFO 
(first-in-first-out) queues.
+        For details of the attributes parameter see 
:py:meth:`botocore.client.SQS.send_message`.
+    :param delay_seconds: The length of time, in seconds, for which to delay a 
message.
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    """
+
+    template_fields: Sequence[str] = (
+        "queue_url",
+        "message_body",
+        "message_attributes",
+        "message_group_id",
+        "delay_seconds",
+        "aws_conn_id",
+        "region_name",
+    )
+
+    def __init__(
+        self,
+        *,
+        aws_conn_id: str | None = SqsHook.default_conn_name,
+        queue_url: str,
+        message_body: str,
+        message_attributes: dict | None = None,
+        message_group_id: str | None = None,
+        delay_seconds: int = 0,
+        region_name: str | None = None,
+    ):
+        super().__init__()
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.queue_url = queue_url
+        self.message_body = message_body
+        self.message_attributes = message_attributes or {}
+        self.message_group_id = message_group_id
+        self.delay_seconds = delay_seconds
+
+    @cached_property
+    def hook(self) -> SqsHook:
+        """Amazon SQS Hook (cached)."""
+        return SqsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+
+    def notify(self, context):
+        """Publish the notification message to Amazon SQS queue."""
+        self.hook.send_message(
+            queue_url=self.queue_url,
+            message_body=self.message_body,
+            delay_seconds=self.delay_seconds,
+            message_attributes=self.message_attributes,
+            message_group_id=self.message_group_id,
+        )
+
+
+send_sqs_notification = SqsNotifier
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index d31a4bc1bb..0487fd1e6d 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -688,6 +688,7 @@ connection-types:
 notifications:
   - airflow.providers.amazon.aws.notifications.chime.ChimeNotifier
   - airflow.providers.amazon.aws.notifications.sns.SnsNotifier
+  - airflow.providers.amazon.aws.notifications.sqs.SqsNotifier
 
 secrets-backends:
   - airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend
diff --git a/docs/apache-airflow-providers-amazon/notifications/sqs.rst 
b/docs/apache-airflow-providers-amazon/notifications/sqs.rst
new file mode 100644
index 0000000000..4a2232b006
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/notifications/sqs.rst
@@ -0,0 +1,63 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/notifier:SqsNotifier:
+
+How-to Guide for Amazon Simple Queue Service (Amazon SQS) notifications
+=======================================================================
+
+Introduction
+------------
+`Amazon SQS <https://aws.amazon.com/sqs/>`__ notifier 
:class:`~airflow.providers.amazon.aws.notifications.sqs.SqsNotifier`
+allows users to push messages to an Amazon SQS Queue using the various 
``on_*_callbacks`` at both the DAG level and Task level.
+
+You can also use a notifier with ``sla_miss_callback``.
+
+.. note::
+    When notifiers are used with ``sla_miss_callback`` the context will 
contain only values passed to the callback,
+    refer :ref:`sla_miss_callback<concepts:sla_miss_callback>`.
+
+Example Code:
+-------------
+
+.. code-block:: python
+
+    from datetime import datetime, timezone
+    from airflow import DAG
+    from airflow.operators.bash import BashOperator
+    from airflow.providers.amazon.aws.notifications.sqs import 
send_sqs_notification
+
+    dag_failure_sqs_notification = send_sqs_notification(
+        aws_conn_id="aws_default",
+        queue_url="https://sqs.eu-west-1.amazonaws.com/123456789098/MyQueue";,
+        message_body="The DAG {{ dag.dag_id }} failed",
+    )
+    task_failure_sqs_notification = send_sqs_notification(
+        aws_conn_id="aws_default",
+        region_name="eu-west-1",
+        queue_url="https://sqs.eu-west-1.amazonaws.com/123456789098/MyQueue";,
+        message_body="The task {{ ti.task_id }} failed",
+    )
+
+    with DAG(
+        dag_id="mydag",
+        schedule="@once",
+        start_date=datetime(2023, 1, 1, tzinfo=timezone.utc),
+        on_failure_callback=[dag_failure_sqs_notification],
+        catchup=False,
+    ):
+        BashOperator(task_id="mytask", 
on_failure_callback=[task_failure_sqs_notification], bash_command="fail")
diff --git a/tests/providers/amazon/aws/notifications/test_sqs.py 
b/tests/providers/amazon/aws/notifications/test_sqs.py
new file mode 100644
index 0000000000..356e46428a
--- /dev/null
+++ b/tests/providers/amazon/aws/notifications/test_sqs.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.amazon.aws.notifications.sqs import SqsNotifier, 
send_sqs_notification
+from airflow.utils.types import NOTSET
+
+PARAM_DEFAULT_VALUE = pytest.param(NOTSET, id="default-value")
+
+
+class TestSqsNotifier:
+    def test_class_and_notifier_are_same(self):
+        assert send_sqs_notification is SqsNotifier
+
+    @pytest.mark.parametrize("aws_conn_id", ["aws_test_conn_id", None, 
PARAM_DEFAULT_VALUE])
+    @pytest.mark.parametrize("region_name", ["eu-west-2", None, 
PARAM_DEFAULT_VALUE])
+    def test_parameters_propagate_to_hook(self, aws_conn_id, region_name):
+        """Test notifier attributes propagate to SqsHook."""
+        send_message_kwargs = {
+            "queue_url": 
"https://sqs.eu-west-1.amazonaws.com/123456789098/MyQueue";,
+            "message_body": "foo-bar",
+            "delay_seconds": 42,
+            "message_attributes": {},
+            "message_group_id": "foo-bar",
+        }
+        notifier_kwargs = {}
+        if aws_conn_id is not NOTSET:
+            notifier_kwargs["aws_conn_id"] = aws_conn_id
+        if region_name is not NOTSET:
+            notifier_kwargs["region_name"] = region_name
+
+        notifier = SqsNotifier(**notifier_kwargs, **send_message_kwargs)
+        with 
mock.patch("airflow.providers.amazon.aws.notifications.sqs.SqsHook") as 
mock_hook:
+            hook = notifier.hook
+            assert hook is notifier.hook, "Hook property not cached"
+            mock_hook.assert_called_once_with(
+                aws_conn_id=(aws_conn_id if aws_conn_id is not NOTSET else 
"aws_default"),
+                region_name=(region_name if region_name is not NOTSET else 
None),
+            )
+
+            # Basic check for notifier
+            notifier.notify({})
+            
mock_hook.return_value.send_message.assert_called_once_with(**send_message_kwargs)

Reply via email to