This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6473fac3542 Support HA schedulers for the Lambda Executor (#53396)
6473fac3542 is described below

commit 6473fac354298f6906a7e5aface1cab04bd5a4f9
Author: Niko Oliveira <oniko...@amazon.com>
AuthorDate: Wed Jul 16 10:59:49 2025 -0700

    Support HA schedulers for the Lambda Executor (#53396)
    
    The lambda executor uses a single shared SQS queue as a results backend
    (of sorts) to receive results from the Lambda invocations. Previously if
    Lambda Executor A read an event which contained a task started by Lambda
    Executor B it would delete the message as unrecognized.
    Now properly formatted messages are returned to the queue for the correct
    executor to pull (ill formatted messages are still deleted). UAT testing
    has show that this simple solution actually stabilizes quite quickly
    especially since executors with no running tasks do not query the queue.
    If we see any further scaling issues in the future we can revisit this
    with a more complex solution.
---
 .../aws/executors/aws_lambda/lambda_executor.py    |  43 +++++++--
 .../executors/aws_lambda/test_lambda_executor.py   | 105 ++++++++++++++++++++-
 2 files changed, 137 insertions(+), 11 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
index 2e338789362..e4676c915e8 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
@@ -362,26 +362,57 @@ class AwsLambdaExecutor(BaseExecutor):
             MaxNumberOfMessages=10,
         )
 
+        # Pagination? Maybe we don't need it. But we don't always delete 
messages after viewing them so we
+        # could possibly accumulate a lot of messages in the queue and get 
stuck if we don't read bigger
+        # chunks and paginate.
         messages = response.get("Messages", [])
-        # Pagination? Maybe we don't need it. Since we always delete messages 
after looking at them.
-        # But then that may delete messages that could have been adopted. 
Let's leave it for now and see how it goes.
+        # The keys that we validate in the messages below will be different 
depending on whether or not
+        # the message is from the dead letter queue or the main results queue.
+        message_keys = ("return_code", "task_key")
         if messages and queue_url == self.dlq_url:
             self.log.warning("%d messages received from the dead letter 
queue", len(messages))
+            message_keys = ("command", "task_key")
 
         for message in messages:
+            delete_message = False
             receipt_handle = message["ReceiptHandle"]
-            body = json.loads(message["Body"])
+            try:
+                body = json.loads(message["Body"])
+            except json.JSONDecodeError:
+                self.log.warning(
+                    "Received a message from the queue that could not be 
parsed as JSON: %s",
+                    message["Body"],
+                )
+                delete_message = True
+            # If the message is not already marked for deletion, check if it 
has the required keys.
+            if not delete_message and not all(key in body for key in 
message_keys):
+                self.log.warning(
+                    "Message is not formatted correctly, %s and/or %s are 
missing: %s", *message_keys, body
+                )
+                delete_message = True
+            if delete_message:
+                self.log.warning("Deleting the message to avoid processing it 
again.")
+                self.sqs_client.delete_message(QueueUrl=queue_url, 
ReceiptHandle=receipt_handle)
+                continue
             return_code = body.get("return_code")
             ser_task_key = body.get("task_key")
             # Fetch the real task key from the running_tasks dict, using the 
serialized task key.
             try:
                 task_key = self.running_tasks[ser_task_key]
             except KeyError:
-                self.log.warning(
-                    "Received task %s from the queue which is not found in 
running tasks. Removing message.",
+                self.log.debug(
+                    "Received task %s from the queue which is not found in 
running tasks, it is likely "
+                    "from another Lambda Executor sharing this queue or might 
be a stale message that needs "
+                    "deleting manually. Marking the message as visible again.",
                     ser_task_key,
                 )
-                task_key = None
+                # Mark task as visible again in SQS so that another executor 
can pick it up.
+                self.sqs_client.change_message_visibility(
+                    QueueUrl=queue_url,
+                    ReceiptHandle=receipt_handle,
+                    VisibilityTimeout=0,
+                )
+                continue
 
             if task_key:
                 if return_code == 0:
diff --git 
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
 
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
index f1123ca5e07..a2def0b86a8 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
@@ -395,7 +395,7 @@ class TestAwsLambdaExecutor:
         mock_executor.running_tasks.clear()
         mock_executor.running_tasks[ser_airflow_key] = airflow_key
         mock_executor.sqs_client.receive_message.side_effect = [
-            {},  # First request from the results queue will be empt
+            {},  # First request from the results queue will be empty
             {
                 # Second request from the DLQ will have a message
                 "Messages": [
@@ -510,6 +510,87 @@ class TestAwsLambdaExecutor:
         fail_mock.assert_called_once()
         assert mock_executor.sqs_client.delete_message.call_count == 1
 
+    def test_sync_running_fail_bad_json(self, mock_executor, mock_airflow_key):
+        airflow_key = mock_airflow_key()
+        ser_airflow_key = json.dumps(airflow_key._asdict())
+
+        mock_executor.running_tasks.clear()
+        mock_executor.running_tasks[ser_airflow_key] = airflow_key
+        mock_executor.sqs_client.receive_message.side_effect = [
+            {
+                "Messages": [
+                    {
+                        "ReceiptHandle": "receipt_handle",
+                        "Body": "Banana",  # Body not json format
+                    }
+                ]
+            },
+            {},  # Second request from the DLQ will be empty
+        ]
+
+        mock_executor.sync_running_tasks()
+        # Assert that the message is deleted if the message is not formatted 
as json
+        assert mock_executor.sqs_client.receive_message.call_count == 2
+        assert mock_executor.sqs_client.delete_message.call_count == 1
+
+    def test_sync_running_fail_bad_format(self, mock_executor, 
mock_airflow_key):
+        airflow_key = mock_airflow_key()
+        ser_airflow_key = json.dumps(airflow_key._asdict())
+
+        mock_executor.running_tasks.clear()
+        mock_executor.running_tasks[ser_airflow_key] = airflow_key
+        mock_executor.sqs_client.receive_message.side_effect = [
+            {
+                "Messages": [
+                    {
+                        "ReceiptHandle": "receipt_handle",
+                        "Body": json.dumps(
+                            {
+                                "foo": "bar",  # Missing expected keys like 
"task_key"
+                                "return_code": 1,  # Non-zero return code, 
task failed
+                            }
+                        ),
+                    }
+                ]
+            },
+            {},  # Second request from the DLQ will be empty
+        ]
+
+        mock_executor.sync_running_tasks()
+        # Assert that the message is deleted if the message does not contain 
the expected keys
+        assert mock_executor.sqs_client.receive_message.call_count == 2
+        assert mock_executor.sqs_client.delete_message.call_count == 1
+
+    def test_sync_running_fail_bad_format_dlq(self, mock_executor, 
mock_airflow_key):
+        airflow_key = mock_airflow_key()
+        ser_airflow_key = json.dumps(airflow_key._asdict())
+
+        mock_executor.running_tasks.clear()
+        mock_executor.running_tasks[ser_airflow_key] = airflow_key
+        # Failure message
+        mock_executor.sqs_client.receive_message.side_effect = [
+            {},  # First request from the results queue will be empty
+            {
+                # Second request from the DLQ will have a message
+                "Messages": [
+                    {
+                        "ReceiptHandle": "receipt_handle",
+                        "Body": json.dumps(
+                            {
+                                "foo": "bar",  # Missing expected keys like 
"task_key"
+                                "return_code": 1,
+                            }
+                        ),
+                    }
+                ]
+            },
+        ]
+
+        mock_executor.sync_running_tasks()
+        # Assert that the message is deleted if the message does not contain 
the expected keys
+        assert mock_executor.sqs_client.receive_message.call_count == 2
+        assert mock_executor.sqs_client.delete_message.call_count == 1
+
     @mock.patch.object(BaseExecutor, "fail")
     @mock.patch.object(BaseExecutor, "success")
     def test_sync_running_short_circuit(self, success_mock, fail_mock, 
mock_executor, mock_airflow_key):
@@ -605,10 +686,12 @@ class TestAwsLambdaExecutor:
         mock_executor.running_tasks[ser_airflow_key] = airflow_key
 
         # Receive the known task and unknown task
+        known_task_receipt = "receipt_handle_known"
+        unknown_task_receipt = "receipt_handle_unknown"
         mock_executor.sqs_client.receive_message.return_value = {
             "Messages": [
                 {
-                    "ReceiptHandle": "receipt_handle",
+                    "ReceiptHandle": known_task_receipt,
                     "Body": json.dumps(
                         {
                             "task_key": ser_airflow_key,
@@ -617,7 +700,7 @@ class TestAwsLambdaExecutor:
                     ),
                 },
                 {
-                    "ReceiptHandle": "receipt_handle",
+                    "ReceiptHandle": unknown_task_receipt,
                     "Body": json.dumps(
                         {
                             "task_key": ser_airflow_key_2,
@@ -635,8 +718,20 @@ class TestAwsLambdaExecutor:
         assert len(mock_executor.running_tasks) == 0
         success_mock.assert_called_once()
         fail_mock.assert_not_called()
-        # Both messages from the queue should be deleted, both known and 
unknown
-        assert mock_executor.sqs_client.delete_message.call_count == 2
+        # Only the known message from the queue should be deleted, the other 
should be marked as visible again
+        assert mock_executor.sqs_client.delete_message.call_count == 1
+        assert mock_executor.sqs_client.change_message_visibility.call_count 
== 1
+        # The argument to delete_message should be the known task
+        assert 
mock_executor.sqs_client.delete_message.call_args_list[0].kwargs == {
+            "QueueUrl": DEFAULT_QUEUE_URL,
+            "ReceiptHandle": known_task_receipt,
+        }
+        # The change_message_visibility should be called with the unknown task
+        assert 
mock_executor.sqs_client.change_message_visibility.call_args_list[0].kwargs == {
+            "QueueUrl": DEFAULT_QUEUE_URL,
+            "ReceiptHandle": unknown_task_receipt,
+            "VisibilityTimeout": 0,
+        }
 
     def test_start_no_check_health(self, mock_executor):
         mock_executor.check_health = mock.Mock()

Reply via email to