This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9b4a053bc6 Add batch option to `SqsSensor` (#24554)
9b4a053bc6 is described below
commit 9b4a053bc6496e5e35caabb3f68ef64c1381e48b
Author: TungHoang <[email protected]>
AuthorDate: Tue Jun 28 14:47:46 2022 +0200
Add batch option to `SqsSensor` (#24554)
Add batch option to `SqsSensor`
Co-authored-by: TungHoang <[email protected]>
Co-authored-by: D. Ferruzzi <[email protected]>
---
airflow/providers/amazon/aws/sensors/sqs.py | 77 ++++++++++++++++--------
tests/providers/amazon/aws/sensors/test_sqs.py | 28 +++++++++
tests/system/providers/amazon/aws/example_sqs.py | 23 ++++++-
3 files changed, 99 insertions(+), 29 deletions(-)
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py
b/airflow/providers/amazon/aws/sensors/sqs.py
index cc026ec5e8..c1b6c4c747 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -18,12 +18,13 @@
"""Reads and then deletes the message from SQS queue"""
import json
import warnings
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Collection, List, Optional, Sequence
from jsonpath_ng import parse
from typing_extensions import Literal
from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.sensors.base import BaseSensorOperator
@@ -34,9 +35,13 @@ if TYPE_CHECKING:
class SqsSensor(BaseSensorOperator):
"""
Get messages from an Amazon SQS queue and then delete the messages from
the queue.
- If deletion of messages fails an AirflowException is thrown. Otherwise,
the messages
+ If deletion of messages fails, an AirflowException is thrown. Otherwise,
the messages
are pushed through XCom with the key ``messages``.
+ By default,the sensor performs one and only one SQS call per poke, which
limits the result to
+ a maximum of 10 messages. However, the total number of SQS API calls per
poke can be controlled
+ by num_batches param.
+
.. seealso::
For more information on how to use this sensor, take a look at the
guide:
:ref:`howto/sensor:SqsSensor`
@@ -44,6 +49,7 @@ class SqsSensor(BaseSensorOperator):
:param aws_conn_id: AWS connection id
:param sqs_queue: The SQS queue url (templated)
:param max_messages: The maximum number of messages to retrieve for each
poke (templated)
+ :param num_batches: The number of times the sensor will call the SQS API
to receive messages (default: 1)
:param wait_time_seconds: The time in seconds to wait for receiving
messages (default: 1 second)
:param visibility_timeout: Visibility timeout, a period of time during
which
Amazon SQS prevents other consumers from receiving and processing the
message.
@@ -72,6 +78,7 @@ class SqsSensor(BaseSensorOperator):
sqs_queue,
aws_conn_id: str = 'aws_default',
max_messages: int = 5,
+ num_batches: int = 1,
wait_time_seconds: int = 1,
visibility_timeout: Optional[int] = None,
message_filtering: Optional[Literal["literal", "jsonpath"]] = None,
@@ -84,6 +91,7 @@ class SqsSensor(BaseSensorOperator):
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
self.max_messages = max_messages
+ self.num_batches = num_batches
self.wait_time_seconds = wait_time_seconds
self.visibility_timeout = visibility_timeout
@@ -104,15 +112,13 @@ class SqsSensor(BaseSensorOperator):
self.hook: Optional[SqsHook] = None
- def poke(self, context: 'Context'):
+ def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
"""
- Check for message on subscribed queue and write to xcom the message
with key ``messages``
+ Poll SQS queue to retrieve messages.
- :param context: the context object
- :return: ``True`` if message is available or ``False``
+ :param sqs_conn: SQS connection
+ :return: A list of messages retrieved from SQS
"""
- sqs_conn = self.get_hook().get_conn()
-
self.log.info('SqsSensor checking for message on queue: %s',
self.sqs_queue)
receive_message_kwargs = {
@@ -126,7 +132,7 @@ class SqsSensor(BaseSensorOperator):
response = sqs_conn.receive_message(**receive_message_kwargs)
if "Messages" not in response:
- return False
+ return []
messages = response['Messages']
num_messages = len(messages)
@@ -136,28 +142,47 @@ class SqsSensor(BaseSensorOperator):
messages = self.filter_messages(messages)
num_messages = len(messages)
self.log.info("There are %d messages left after filtering",
num_messages)
+ return messages
- if not num_messages:
- return False
+ def poke(self, context: 'Context'):
+ """
+ Check subscribed queue for messages and write them to xcom with the
``messages`` key.
- if not self.delete_message_on_reception:
- context['ti'].xcom_push(key='messages', value=messages)
- return True
+ :param context: the context object
+ :return: ``True`` if message is available or ``False``
+ """
+ sqs_conn = self.get_hook().get_conn()
- self.log.info("Deleting %d messages", num_messages)
+ message_batch: List[Any] = []
- entries = [
- {'Id': message['MessageId'], 'ReceiptHandle':
message['ReceiptHandle']} for message in messages
- ]
- response = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue,
Entries=entries)
+ # perform multiple SQS call to retrieve messages in series
+ for _ in range(self.num_batches):
+ messages = self.poll_sqs(sqs_conn=sqs_conn)
- if 'Successful' in response:
- context['ti'].xcom_push(key='messages', value=messages)
- return True
- else:
- raise AirflowException(
- 'Delete SQS Messages failed ' + str(response) + ' for messages
' + str(messages)
- )
+ if not len(messages):
+ continue
+
+ message_batch.extend(messages)
+
+ if self.delete_message_on_reception:
+
+ self.log.info("Deleting %d messages", len(messages))
+
+ entries = [
+ {'Id': message['MessageId'], 'ReceiptHandle':
message['ReceiptHandle']}
+ for message in messages
+ ]
+ response =
sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+
+ if 'Successful' not in response:
+ raise AirflowException(
+ 'Delete SQS Messages failed ' + str(response) + ' for
messages ' + str(messages)
+ )
+ if not len(message_batch):
+ return False
+
+ context['ti'].xcom_push(key='messages', value=message_batch)
+ return True
def get_hook(self) -> SqsHook:
"""Create and return an SqsHook"""
diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py
b/tests/providers/amazon/aws/sensors/test_sqs.py
index c2a6b3aa2f..2254779535 100644
--- a/tests/providers/amazon/aws/sensors/test_sqs.py
+++ b/tests/providers/amazon/aws/sensors/test_sqs.py
@@ -305,3 +305,31 @@ class TestSqsSensor(unittest.TestCase):
)
self.sensor.poke(self.mock_context)
assert mock_conn.delete_message_batch.called is False
+
+ @mock_sqs
+ def test_poke_batch_messages(self):
+ messages = ["hello", "brave", "world"]
+
+ self.sqs_hook.create_queue(QUEUE_NAME)
+ # Do publish 3 messages
+ for message in messages:
+ self.sqs_hook.send_message(queue_url=QUEUE_URL,
message_body=message)
+
+ # Init batch sensor to get 1 message for each SQS poll
+ # and perform 3 polls
+ self.sensor = SqsSensor(
+ task_id='test_task3',
+ dag=self.dag,
+ sqs_queue=QUEUE_URL,
+ aws_conn_id='aws_default',
+ max_messages=1,
+ num_batches=3,
+ )
+ result = self.sensor.poke(self.mock_context)
+ assert result
+
+ # expect all messages are retrieved
+ for message in messages:
+ assert f"'Body': '{message}'" in str(
+ self.mock_context['ti'].method_calls
+ ), "context call should contain message '{message}'"
diff --git a/tests/system/providers/amazon/aws/example_sqs.py
b/tests/system/providers/amazon/aws/example_sqs.py
index 09f697d472..0aabe9e118 100644
--- a/tests/system/providers/amazon/aws/example_sqs.py
+++ b/tests/system/providers/amazon/aws/example_sqs.py
@@ -52,8 +52,13 @@ with DAG(
sqs_queue = create_queue()
# [START howto_operator_sqs]
- publish_to_queue = SqsPublishOperator(
- task_id='publish_to_queue',
+ publish_to_queue_1 = SqsPublishOperator(
+ task_id='publish_to_queue_1',
+ sqs_queue=sqs_queue,
+ message_content='{{ task_instance }}-{{ logical_date }}',
+ )
+ publish_to_queue_2 = SqsPublishOperator(
+ task_id='publish_to_queue_2',
sqs_queue=sqs_queue,
message_content='{{ task_instance }}-{{ logical_date }}',
)
@@ -64,14 +69,26 @@ with DAG(
task_id='read_from_queue',
sqs_queue=sqs_queue,
)
+ # Retrieve multiple batches of messages from SQS.
+ # The SQS API only returns a maximum of 10 messages per poll.
+ read_from_queue_in_batch = SqsSensor(
+ task_id='read_from_queue_in_batch',
+ sqs_queue=create_queue,
+ # Get maximum 10 messages each poll
+ max_messages=10,
+ # Combine 3 polls before returning results
+ num_batches=3,
+ )
# [END howto_sensor_sqs]
chain(
# TEST SETUP
sqs_queue,
# TEST BODY
- publish_to_queue,
+ publish_to_queue_1,
read_from_queue,
+ publish_to_queue_2,
+ read_from_queue_in_batch,
# TEST TEARDOWN
delete_queue(sqs_queue),
)