This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b4a053bc6 Add batch option to `SqsSensor` (#24554)
9b4a053bc6 is described below

commit 9b4a053bc6496e5e35caabb3f68ef64c1381e48b
Author: TungHoang <[email protected]>
AuthorDate: Tue Jun 28 14:47:46 2022 +0200

    Add batch option to `SqsSensor` (#24554)
    
    Add batch option to `SqsSensor`
    
    Co-authored-by: TungHoang <[email protected]>
    Co-authored-by: D. Ferruzzi <[email protected]>
---
 airflow/providers/amazon/aws/sensors/sqs.py      | 77 ++++++++++++++++--------
 tests/providers/amazon/aws/sensors/test_sqs.py   | 28 +++++++++
 tests/system/providers/amazon/aws/example_sqs.py | 23 ++++++-
 3 files changed, 99 insertions(+), 29 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/sqs.py 
b/airflow/providers/amazon/aws/sensors/sqs.py
index cc026ec5e8..c1b6c4c747 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -18,12 +18,13 @@
 """Reads and then deletes the message from SQS queue"""
 import json
 import warnings
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Collection, List, Optional, Sequence
 
 from jsonpath_ng import parse
 from typing_extensions import Literal
 
 from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
 from airflow.providers.amazon.aws.hooks.sqs import SqsHook
 from airflow.sensors.base import BaseSensorOperator
 
@@ -34,9 +35,13 @@ if TYPE_CHECKING:
 class SqsSensor(BaseSensorOperator):
     """
     Get messages from an Amazon SQS queue and then delete the messages from 
the queue.
-    If deletion of messages fails an AirflowException is thrown. Otherwise, 
the messages
+    If deletion of messages fails, an AirflowException is thrown. Otherwise, 
the messages
     are pushed through XCom with the key ``messages``.
 
+    By default,the sensor performs one and only one SQS call per poke, which 
limits the result to
+    a maximum of 10 messages. However, the total number of SQS API calls per 
poke can be controlled
+    by num_batches param.
+
     .. seealso::
         For more information on how to use this sensor, take a look at the 
guide:
         :ref:`howto/sensor:SqsSensor`
@@ -44,6 +49,7 @@ class SqsSensor(BaseSensorOperator):
     :param aws_conn_id: AWS connection id
     :param sqs_queue: The SQS queue url (templated)
     :param max_messages: The maximum number of messages to retrieve for each 
poke (templated)
+    :param num_batches: The number of times the sensor will call the SQS API 
to receive messages (default: 1)
     :param wait_time_seconds: The time in seconds to wait for receiving 
messages (default: 1 second)
     :param visibility_timeout: Visibility timeout, a period of time during 
which
         Amazon SQS prevents other consumers from receiving and processing the 
message.
@@ -72,6 +78,7 @@ class SqsSensor(BaseSensorOperator):
         sqs_queue,
         aws_conn_id: str = 'aws_default',
         max_messages: int = 5,
+        num_batches: int = 1,
         wait_time_seconds: int = 1,
         visibility_timeout: Optional[int] = None,
         message_filtering: Optional[Literal["literal", "jsonpath"]] = None,
@@ -84,6 +91,7 @@ class SqsSensor(BaseSensorOperator):
         self.sqs_queue = sqs_queue
         self.aws_conn_id = aws_conn_id
         self.max_messages = max_messages
+        self.num_batches = num_batches
         self.wait_time_seconds = wait_time_seconds
         self.visibility_timeout = visibility_timeout
 
@@ -104,15 +112,13 @@ class SqsSensor(BaseSensorOperator):
 
         self.hook: Optional[SqsHook] = None
 
-    def poke(self, context: 'Context'):
+    def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
         """
-        Check for message on subscribed queue and write to xcom the message 
with key ``messages``
+        Poll SQS queue to retrieve messages.
 
-        :param context: the context object
-        :return: ``True`` if message is available or ``False``
+        :param sqs_conn: SQS connection
+        :return: A list of messages retrieved from SQS
         """
-        sqs_conn = self.get_hook().get_conn()
-
         self.log.info('SqsSensor checking for message on queue: %s', 
self.sqs_queue)
 
         receive_message_kwargs = {
@@ -126,7 +132,7 @@ class SqsSensor(BaseSensorOperator):
         response = sqs_conn.receive_message(**receive_message_kwargs)
 
         if "Messages" not in response:
-            return False
+            return []
 
         messages = response['Messages']
         num_messages = len(messages)
@@ -136,28 +142,47 @@ class SqsSensor(BaseSensorOperator):
             messages = self.filter_messages(messages)
             num_messages = len(messages)
             self.log.info("There are %d messages left after filtering", 
num_messages)
+        return messages
 
-        if not num_messages:
-            return False
+    def poke(self, context: 'Context'):
+        """
+        Check subscribed queue for messages and write them to xcom with the 
``messages`` key.
 
-        if not self.delete_message_on_reception:
-            context['ti'].xcom_push(key='messages', value=messages)
-            return True
+        :param context: the context object
+        :return: ``True`` if message is available or ``False``
+        """
+        sqs_conn = self.get_hook().get_conn()
 
-        self.log.info("Deleting %d messages", num_messages)
+        message_batch: List[Any] = []
 
-        entries = [
-            {'Id': message['MessageId'], 'ReceiptHandle': 
message['ReceiptHandle']} for message in messages
-        ]
-        response = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, 
Entries=entries)
+        # perform multiple SQS call to retrieve messages in series
+        for _ in range(self.num_batches):
+            messages = self.poll_sqs(sqs_conn=sqs_conn)
 
-        if 'Successful' in response:
-            context['ti'].xcom_push(key='messages', value=messages)
-            return True
-        else:
-            raise AirflowException(
-                'Delete SQS Messages failed ' + str(response) + ' for messages 
' + str(messages)
-            )
+            if not len(messages):
+                continue
+
+            message_batch.extend(messages)
+
+            if self.delete_message_on_reception:
+
+                self.log.info("Deleting %d messages", len(messages))
+
+                entries = [
+                    {'Id': message['MessageId'], 'ReceiptHandle': 
message['ReceiptHandle']}
+                    for message in messages
+                ]
+                response = 
sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+
+                if 'Successful' not in response:
+                    raise AirflowException(
+                        'Delete SQS Messages failed ' + str(response) + ' for 
messages ' + str(messages)
+                    )
+        if not len(message_batch):
+            return False
+
+        context['ti'].xcom_push(key='messages', value=message_batch)
+        return True
 
     def get_hook(self) -> SqsHook:
         """Create and return an SqsHook"""
diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py 
b/tests/providers/amazon/aws/sensors/test_sqs.py
index c2a6b3aa2f..2254779535 100644
--- a/tests/providers/amazon/aws/sensors/test_sqs.py
+++ b/tests/providers/amazon/aws/sensors/test_sqs.py
@@ -305,3 +305,31 @@ class TestSqsSensor(unittest.TestCase):
         )
         self.sensor.poke(self.mock_context)
         assert mock_conn.delete_message_batch.called is False
+
+    @mock_sqs
+    def test_poke_batch_messages(self):
+        messages = ["hello", "brave", "world"]
+
+        self.sqs_hook.create_queue(QUEUE_NAME)
+        # Do publish 3 messages
+        for message in messages:
+            self.sqs_hook.send_message(queue_url=QUEUE_URL, 
message_body=message)
+
+        # Init batch sensor to get 1 message for each SQS poll
+        # and perform 3 polls
+        self.sensor = SqsSensor(
+            task_id='test_task3',
+            dag=self.dag,
+            sqs_queue=QUEUE_URL,
+            aws_conn_id='aws_default',
+            max_messages=1,
+            num_batches=3,
+        )
+        result = self.sensor.poke(self.mock_context)
+        assert result
+
+        # expect all messages are retrieved
+        for message in messages:
+            assert f"'Body': '{message}'" in str(
+                self.mock_context['ti'].method_calls
+            ), "context call should contain message '{message}'"
diff --git a/tests/system/providers/amazon/aws/example_sqs.py 
b/tests/system/providers/amazon/aws/example_sqs.py
index 09f697d472..0aabe9e118 100644
--- a/tests/system/providers/amazon/aws/example_sqs.py
+++ b/tests/system/providers/amazon/aws/example_sqs.py
@@ -52,8 +52,13 @@ with DAG(
     sqs_queue = create_queue()
 
     # [START howto_operator_sqs]
-    publish_to_queue = SqsPublishOperator(
-        task_id='publish_to_queue',
+    publish_to_queue_1 = SqsPublishOperator(
+        task_id='publish_to_queue_1',
+        sqs_queue=sqs_queue,
+        message_content='{{ task_instance }}-{{ logical_date }}',
+    )
+    publish_to_queue_2 = SqsPublishOperator(
+        task_id='publish_to_queue_2',
         sqs_queue=sqs_queue,
         message_content='{{ task_instance }}-{{ logical_date }}',
     )
@@ -64,14 +69,26 @@ with DAG(
         task_id='read_from_queue',
         sqs_queue=sqs_queue,
     )
+    # Retrieve multiple batches of messages from SQS.
+    # The SQS API only returns a maximum of 10 messages per poll.
+    read_from_queue_in_batch = SqsSensor(
+        task_id='read_from_queue_in_batch',
+        sqs_queue=create_queue,
+        # Get maximum 10 messages each poll
+        max_messages=10,
+        # Combine 3 polls before returning results
+        num_batches=3,
+    )
     # [END howto_sensor_sqs]
 
     chain(
         # TEST SETUP
         sqs_queue,
         # TEST BODY
-        publish_to_queue,
+        publish_to_queue_1,
         read_from_queue,
+        publish_to_queue_2,
+        read_from_queue_in_batch,
         # TEST TEARDOWN
         delete_queue(sqs_queue),
     )

Reply via email to