This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c72befcfd Fix `LambdaInvokeFunctionOperator` payload parameter type 
(#32259)
5c72befcfd is described below

commit 5c72befcfde63ade2870491cfeb708675399d9d6
Author: Elad Galili <[email protected]>
AuthorDate: Mon Jul 3 09:45:24 2023 +0300

    Fix `LambdaInvokeFunctionOperator` payload parameter type (#32259)
    
    * Fixing issue - Fix payload parameter of amazon 
LambdaCreateFunctionOperator
    
    ---------
    
    Co-authored-by: Elad Galili <[email protected]>
---
 .../providers/amazon/aws/hooks/lambda_function.py  |  5 ++++-
 .../amazon/aws/operators/lambda_function.py        |  2 +-
 .../amazon/aws/hooks/test_lambda_function.py       | 11 ++++++++---
 .../amazon/aws/operators/test_lambda_function.py   | 23 +++++++++++++++-------
 4 files changed, 29 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py 
b/airflow/providers/amazon/aws/hooks/lambda_function.py
index 2d61f0751f..58ecac8bcc 100644
--- a/airflow/providers/amazon/aws/hooks/lambda_function.py
+++ b/airflow/providers/amazon/aws/hooks/lambda_function.py
@@ -48,7 +48,7 @@ class LambdaHook(AwsBaseHook):
         invocation_type: str | None = None,
         log_type: str | None = None,
         client_context: str | None = None,
-        payload: str | None = None,
+        payload: bytes | str | None = None,
         qualifier: str | None = None,
     ):
         """
@@ -65,6 +65,9 @@ class LambdaHook(AwsBaseHook):
         :param payload: The JSON that you want to provide to your Lambda 
function as input.
         :param qualifier: AWS Lambda Function Version or Alias Name
         """
+        if isinstance(payload, str):
+            payload = payload.encode()
+
         invoke_args = {
             "FunctionName": function_name,
             "InvocationType": invocation_type,
diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py 
b/airflow/providers/amazon/aws/operators/lambda_function.py
index 93907634c1..28b6313204 100644
--- a/airflow/providers/amazon/aws/operators/lambda_function.py
+++ b/airflow/providers/amazon/aws/operators/lambda_function.py
@@ -150,7 +150,7 @@ class LambdaInvokeFunctionOperator(BaseOperator):
         qualifier: str | None = None,
         invocation_type: str | None = None,
         client_context: str | None = None,
-        payload: str | None = None,
+        payload: bytes | str | None = None,
         aws_conn_id: str = "aws_default",
         **kwargs,
     ):
diff --git a/tests/providers/amazon/aws/hooks/test_lambda_function.py 
b/tests/providers/amazon/aws/hooks/test_lambda_function.py
index f21c000ea6..caaf164be4 100644
--- a/tests/providers/amazon/aws/hooks/test_lambda_function.py
+++ b/tests/providers/amazon/aws/hooks/test_lambda_function.py
@@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.hooks.lambda_function 
import LambdaHook
 
 FUNCTION_NAME = "test_function"
 PAYLOAD = '{"hello": "airflow"}'
+BYTES_PAYLOAD = b'{"hello": "airflow"}'
 RUNTIME = "python3.9"
 ROLE = "role"
 HANDLER = "handler"
@@ -48,13 +49,17 @@ class TestLambdaHook:
     @mock.patch(
         "airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook.conn", 
new_callable=mock.PropertyMock
     )
-    def test_invoke_lambda(self, mock_conn):
+    @pytest.mark.parametrize(
+        "payload, invoke_payload",
+        [(PAYLOAD, BYTES_PAYLOAD), (BYTES_PAYLOAD, BYTES_PAYLOAD)],
+    )
+    def test_invoke_lambda(self, mock_conn, payload, invoke_payload):
         hook = LambdaHook()
-        hook.invoke_lambda(function_name=FUNCTION_NAME, payload=PAYLOAD)
+        hook.invoke_lambda(function_name=FUNCTION_NAME, payload=payload)
 
         mock_conn().invoke.assert_called_once_with(
             FunctionName=FUNCTION_NAME,
-            Payload=PAYLOAD,
+            Payload=invoke_payload,
         )
 
     @pytest.mark.parametrize(
diff --git a/tests/providers/amazon/aws/operators/test_lambda_function.py 
b/tests/providers/amazon/aws/operators/test_lambda_function.py
index 6f1d98e8ac..f0b4b834eb 100644
--- a/tests/providers/amazon/aws/operators/test_lambda_function.py
+++ b/tests/providers/amazon/aws/operators/test_lambda_function.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import json
 from unittest import mock
 from unittest.mock import Mock, patch
 
@@ -30,6 +29,8 @@ from airflow.providers.amazon.aws.operators.lambda_function 
import (
 )
 
 FUNCTION_NAME = "function_name"
+PAYLOAD = '{"hello": "airflow"}'
+BYTES_PAYLOAD = b'{"hello": "airflow"}'
 ROLE_ARN = "role_arn"
 IMAGE_URI = "image_uri"
 
@@ -70,29 +71,37 @@ class TestLambdaCreateFunctionOperator:
 
 
 class TestLambdaInvokeFunctionOperator:
-    def test_init(self):
+    @pytest.mark.parametrize(
+        "payload",
+        [PAYLOAD, BYTES_PAYLOAD],
+    )
+    def test_init(self, payload):
         lambda_operator = LambdaInvokeFunctionOperator(
             task_id="test",
             function_name="test",
-            payload=json.dumps({"TestInput": "Testdata"}),
+            payload=payload,
             log_type="None",
             aws_conn_id="aws_conn_test",
         )
         assert lambda_operator.task_id == "test"
         assert lambda_operator.function_name == "test"
-        assert lambda_operator.payload == json.dumps({"TestInput": "Testdata"})
+        assert lambda_operator.payload == payload
         assert lambda_operator.log_type == "None"
         assert lambda_operator.aws_conn_id == "aws_conn_test"
 
     @patch.object(LambdaInvokeFunctionOperator, "hook", 
new_callable=mock.PropertyMock)
-    def test_invoke_lambda(self, hook_mock):
+    @pytest.mark.parametrize(
+        "payload",
+        [PAYLOAD, BYTES_PAYLOAD],
+    )
+    def test_invoke_lambda(self, hook_mock, payload):
         operator = LambdaInvokeFunctionOperator(
             task_id="task_test",
             function_name="a",
             invocation_type="b",
             log_type="c",
             client_context="d",
-            payload="e",
+            payload=payload,
             qualifier="f",
         )
         returned_payload = Mock()
@@ -111,7 +120,7 @@ class TestLambdaInvokeFunctionOperator:
             invocation_type="b",
             log_type="c",
             client_context="d",
-            payload="e",
+            payload=payload,
             qualifier="f",
         )
 

Reply via email to