This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 17d031df66 Add Amazon SQS Notifier (#33962)
17d031df66 is described below
commit 17d031df66ce99943aa7e7272e24c8e6d3b3ebd6
Author: Andrey Anshin <[email protected]>
AuthorDate: Thu Aug 31 22:26:16 2023 +0400
Add Amazon SQS Notifier (#33962)
---------
Co-authored-by: Elad Kalif <[email protected]>
Co-authored-by: Vincent <[email protected]>
---
airflow/providers/amazon/aws/notifications/sqs.py | 100 +++++++++++++++++++++
airflow/providers/amazon/provider.yaml | 1 +
.../notifications/sqs.rst | 63 +++++++++++++
.../providers/amazon/aws/notifications/test_sqs.py | 61 +++++++++++++
4 files changed, 225 insertions(+)
diff --git a/airflow/providers/amazon/aws/notifications/sqs.py
b/airflow/providers/amazon/aws/notifications/sqs.py
new file mode 100644
index 0000000000..ea79bc1de6
--- /dev/null
+++ b/airflow/providers/amazon/aws/notifications/sqs.py
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import Sequence
+
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+from airflow.providers.amazon.aws.hooks.sqs import SqsHook
+
+try:
+ from airflow.notifications.basenotifier import BaseNotifier
+except ImportError:
+ raise AirflowOptionalProviderFeatureException(
+ "Failed to import BaseNotifier. This feature is only available in
Airflow versions >= 2.6.0"
+ )
+
+
+class SqsNotifier(BaseNotifier):
+ """
+ Amazon SQS (Simple Queue Service) Notifier.
+
+ .. seealso::
+ For more information on how to use this notifier, take a look at the
guide:
+ :ref:`howto/notifier:SqsNotifier`
+
+ :param aws_conn_id: The :ref:`Amazon Web Services Connection id
<howto/connection:aws>`
+ used for AWS credentials. If this is None or empty then the default
boto3 behaviour is used.
+ :param queue_url: The URL of the Amazon SQS queue to which a message is
sent.
+ :param message_body: The message to send.
+ :param message_attributes: additional attributes for the message.
+ For details of the attributes parameter see
:py:meth:`botocore.client.SQS.send_message`.
+ :param message_group_id: This parameter applies only to FIFO
(first-in-first-out) queues.
+ For details of the attributes parameter see
:py:meth:`botocore.client.SQS.send_message`.
+ :param delay_seconds: The length of time, in seconds, for which to delay a
message.
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ """
+
+ template_fields: Sequence[str] = (
+ "queue_url",
+ "message_body",
+ "message_attributes",
+ "message_group_id",
+ "delay_seconds",
+ "aws_conn_id",
+ "region_name",
+ )
+
+ def __init__(
+ self,
+ *,
+ aws_conn_id: str | None = SqsHook.default_conn_name,
+ queue_url: str,
+ message_body: str,
+ message_attributes: dict | None = None,
+ message_group_id: str | None = None,
+ delay_seconds: int = 0,
+ region_name: str | None = None,
+ ):
+ super().__init__()
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.queue_url = queue_url
+ self.message_body = message_body
+ self.message_attributes = message_attributes or {}
+ self.message_group_id = message_group_id
+ self.delay_seconds = delay_seconds
+
+ @cached_property
+ def hook(self) -> SqsHook:
+ """Amazon SQS Hook (cached)."""
+ return SqsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+ def notify(self, context):
+ """Publish the notification message to Amazon SQS queue."""
+ self.hook.send_message(
+ queue_url=self.queue_url,
+ message_body=self.message_body,
+ delay_seconds=self.delay_seconds,
+ message_attributes=self.message_attributes,
+ message_group_id=self.message_group_id,
+ )
+
+
+send_sqs_notification = SqsNotifier
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index d31a4bc1bb..0487fd1e6d 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -688,6 +688,7 @@ connection-types:
notifications:
- airflow.providers.amazon.aws.notifications.chime.ChimeNotifier
- airflow.providers.amazon.aws.notifications.sns.SnsNotifier
+ - airflow.providers.amazon.aws.notifications.sqs.SqsNotifier
secrets-backends:
- airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend
diff --git a/docs/apache-airflow-providers-amazon/notifications/sqs.rst
b/docs/apache-airflow-providers-amazon/notifications/sqs.rst
new file mode 100644
index 0000000000..4a2232b006
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/notifications/sqs.rst
@@ -0,0 +1,63 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/notifier:SqsNotifier:
+
+How-to Guide for Amazon Simple Queue Service (Amazon SQS) notifications
+=======================================================================
+
+Introduction
+------------
+`Amazon SQS <https://aws.amazon.com/sqs/>`__ notifier
:class:`~airflow.providers.amazon.aws.notifications.sqs.SqsNotifier`
+allows users to push messages to an Amazon SQS Queue using the various
``on_*_callbacks`` at both the DAG level and Task level.
+
+You can also use a notifier with ``sla_miss_callback``.
+
+.. note::
+ When notifiers are used with ``sla_miss_callback`` the context will
contain only values passed to the callback,
+ refer :ref:`sla_miss_callback<concepts:sla_miss_callback>`.
+
+Example Code:
+-------------
+
+.. code-block:: python
+
+ from datetime import datetime, timezone
+ from airflow import DAG
+ from airflow.operators.bash import BashOperator
+ from airflow.providers.amazon.aws.notifications.sqs import
send_sqs_notification
+
+ dag_failure_sqs_notification = send_sqs_notification(
+ aws_conn_id="aws_default",
+ queue_url="https://sqs.eu-west-1.amazonaws.com/123456789098/MyQueue",
+ message_body="The DAG {{ dag.dag_id }} failed",
+ )
+ task_failure_sqs_notification = send_sqs_notification(
+ aws_conn_id="aws_default",
+ region_name="eu-west-1",
+ queue_url="https://sqs.eu-west-1.amazonaws.com/123456789098/MyQueue",
+ message_body="The task {{ ti.task_id }} failed",
+ )
+
+ with DAG(
+ dag_id="mydag",
+ schedule="@once",
+ start_date=datetime(2023, 1, 1, tzinfo=timezone.utc),
+ on_failure_callback=[dag_failure_sqs_notification],
+ catchup=False,
+ ):
+ BashOperator(task_id="mytask",
on_failure_callback=[task_failure_sqs_notification], bash_command="fail")
diff --git a/tests/providers/amazon/aws/notifications/test_sqs.py
b/tests/providers/amazon/aws/notifications/test_sqs.py
new file mode 100644
index 0000000000..356e46428a
--- /dev/null
+++ b/tests/providers/amazon/aws/notifications/test_sqs.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.amazon.aws.notifications.sqs import SqsNotifier,
send_sqs_notification
+from airflow.utils.types import NOTSET
+
+PARAM_DEFAULT_VALUE = pytest.param(NOTSET, id="default-value")
+
+
+class TestSqsNotifier:
+ def test_class_and_notifier_are_same(self):
+ assert send_sqs_notification is SqsNotifier
+
+ @pytest.mark.parametrize("aws_conn_id", ["aws_test_conn_id", None,
PARAM_DEFAULT_VALUE])
+ @pytest.mark.parametrize("region_name", ["eu-west-2", None,
PARAM_DEFAULT_VALUE])
+ def test_parameters_propagate_to_hook(self, aws_conn_id, region_name):
+ """Test notifier attributes propagate to SqsHook."""
+ send_message_kwargs = {
+ "queue_url":
"https://sqs.eu-west-1.amazonaws.com/123456789098/MyQueue",
+ "message_body": "foo-bar",
+ "delay_seconds": 42,
+ "message_attributes": {},
+ "message_group_id": "foo-bar",
+ }
+ notifier_kwargs = {}
+ if aws_conn_id is not NOTSET:
+ notifier_kwargs["aws_conn_id"] = aws_conn_id
+ if region_name is not NOTSET:
+ notifier_kwargs["region_name"] = region_name
+
+ notifier = SqsNotifier(**notifier_kwargs, **send_message_kwargs)
+ with
mock.patch("airflow.providers.amazon.aws.notifications.sqs.SqsHook") as
mock_hook:
+ hook = notifier.hook
+ assert hook is notifier.hook, "Hook property not cached"
+ mock_hook.assert_called_once_with(
+ aws_conn_id=(aws_conn_id if aws_conn_id is not NOTSET else
"aws_default"),
+ region_name=(region_name if region_name is not NOTSET else
None),
+ )
+
+ # Basic check for notifier
+ notifier.notify({})
+
mock_hook.return_value.send_message.assert_called_once_with(**send_message_kwargs)