This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d99705f69 Add `deferrable` option to `LambdaCreateFunctionOperator` 
(#33327)
4d99705f69 is described below

commit 4d99705f69114d37bb3e85d7723602e71bd023c1
Author: Vincent <[email protected]>
AuthorDate: Mon Aug 14 17:31:13 2023 -0400

    Add `deferrable` option to `LambdaCreateFunctionOperator` (#33327)
---
 .../amazon/aws/operators/lambda_function.py        | 36 ++++++++++++++-
 airflow/providers/amazon/aws/triggers/athena.py    |  2 +-
 airflow/providers/amazon/aws/triggers/base.py      |  1 -
 .../aws/triggers/{athena.py => lambda_function.py} | 46 ++++++++++--------
 airflow/providers/amazon/provider.yaml             |  3 ++
 .../operators/lambda.rst                           |  2 +
 .../amazon/aws/operators/test_lambda_function.py   | 15 ++++++
 .../amazon/aws/triggers/test_lambda_function.py    | 54 ++++++++++++++++++++++
 8 files changed, 136 insertions(+), 23 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py 
b/airflow/providers/amazon/aws/operators/lambda_function.py
index e472deaaac..5d7e980bb5 100644
--- a/airflow/providers/amazon/aws/operators/lambda_function.py
+++ b/airflow/providers/amazon/aws/operators/lambda_function.py
@@ -18,11 +18,15 @@
 from __future__ import annotations
 
 import json
+from datetime import timedelta
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
+from airflow import AirflowException
+from airflow.configuration import conf
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
+from airflow.providers.amazon.aws.triggers.lambda_function import 
LambdaCreateFunctionCompleteTrigger
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -50,6 +54,11 @@ class LambdaCreateFunctionOperator(BaseOperator):
     :param timeout: The amount of time (in seconds) that Lambda allows a 
function to run before stopping it.
     :param config: Optional dictionary for arbitrary parameters to the boto 
API create_lambda call.
     :param wait_for_completion: If True, the operator will wait until the 
function is active.
+    :param waiter_max_attempts: Maximum number of attempts to poll the 
creation.
+    :param waiter_delay: Number of seconds between polling the state of the 
creation.
+    :param deferrable: If True, the operator will wait asynchronously for the 
creation to complete.
+        This implies waiting for creation complete. This mode requires 
aiobotocore module to be installed.
+        (default: False, but can be overridden in config file by setting 
default_deferrable to True)
     :param aws_conn_id: The AWS connection ID to use
     """
 
@@ -75,6 +84,9 @@ class LambdaCreateFunctionOperator(BaseOperator):
         timeout: int | None = None,
         config: dict = {},
         wait_for_completion: bool = False,
+        waiter_max_attempts: int = 60,
+        waiter_delay: int = 15,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         aws_conn_id: str = "aws_default",
         **kwargs,
     ):
@@ -88,6 +100,9 @@ class LambdaCreateFunctionOperator(BaseOperator):
         self.timeout = timeout
         self.config = config
         self.wait_for_completion = wait_for_completion
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
         self.aws_conn_id = aws_conn_id
 
     @cached_property
@@ -108,6 +123,18 @@ class LambdaCreateFunctionOperator(BaseOperator):
         )
         self.log.info("Lambda response: %r", response)
 
+        if self.deferrable:
+            self.defer(
+                trigger=LambdaCreateFunctionCompleteTrigger(
+                    function_name=self.function_name,
+                    function_arn=response["FunctionArn"],
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+                timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay),
+            )
         if self.wait_for_completion:
             self.log.info("Wait for Lambda function to be active")
             waiter = self.hook.conn.get_waiter("function_active_v2")
@@ -117,6 +144,13 @@ class LambdaCreateFunctionOperator(BaseOperator):
 
         return response.get("FunctionArn")
 
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        if not event or event["status"] != "success":
+            raise AirflowException(f"Trigger error: event is {event}")
+
+        self.log.info("Lambda function created successfully")
+        return event["function_arn"]
+
 
 class LambdaInvokeFunctionOperator(BaseOperator):
     """
diff --git a/airflow/providers/amazon/aws/triggers/athena.py 
b/airflow/providers/amazon/aws/triggers/athena.py
index 636c135059..fcd09dae19 100644
--- a/airflow/providers/amazon/aws/triggers/athena.py
+++ b/airflow/providers/amazon/aws/triggers/athena.py
@@ -23,7 +23,7 @@ from airflow.providers.amazon.aws.triggers.base import 
AwsBaseWaiterTrigger
 
 class AthenaTrigger(AwsBaseWaiterTrigger):
     """
-    Trigger for RedshiftCreateClusterOperator.
+    Trigger for AthenaOperator.
 
     The trigger will asynchronously poll the boto3 API and wait for the
     Redshift cluster to be in the `available` state.
diff --git a/airflow/providers/amazon/aws/triggers/base.py 
b/airflow/providers/amazon/aws/triggers/base.py
index 41f7d2dc33..d2d664d97f 100644
--- a/airflow/providers/amazon/aws/triggers/base.py
+++ b/airflow/providers/amazon/aws/triggers/base.py
@@ -112,7 +112,6 @@ class AwsBaseWaiterTrigger(BaseTrigger):
     @abstractmethod
     def hook(self) -> AwsGenericHook:
         """Override in subclasses to return the right hook."""
-        ...
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         hook = self.hook()
diff --git a/airflow/providers/amazon/aws/triggers/athena.py 
b/airflow/providers/amazon/aws/triggers/lambda_function.py
similarity index 55%
copy from airflow/providers/amazon/aws/triggers/athena.py
copy to airflow/providers/amazon/aws/triggers/lambda_function.py
index 636c135059..f0f6a40551 100644
--- a/airflow/providers/amazon/aws/triggers/athena.py
+++ b/airflow/providers/amazon/aws/triggers/lambda_function.py
@@ -16,19 +16,17 @@
 # under the License.
 from __future__ import annotations
 
-from airflow.providers.amazon.aws.hooks.athena import AthenaHook
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
 from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
 
 
-class AthenaTrigger(AwsBaseWaiterTrigger):
+class LambdaCreateFunctionCompleteTrigger(AwsBaseWaiterTrigger):
     """
-    Trigger for RedshiftCreateClusterOperator.
+    Trigger to poll for the completion of a Lambda function creation.
 
-    The trigger will asynchronously poll the boto3 API and wait for the
-    Redshift cluster to be in the `available` state.
-
-    :param query_execution_id:  ID of the Athena query execution to watch
+    :param function_name: The function name
+    :param function_arn: The function ARN
     :param waiter_delay: The amount of time in seconds to wait between 
attempts.
     :param waiter_max_attempts: The maximum number of attempts to be made.
     :param aws_conn_id: The Airflow connection used for AWS credentials.
@@ -36,23 +34,31 @@ class AthenaTrigger(AwsBaseWaiterTrigger):
 
     def __init__(
         self,
-        query_execution_id: str,
-        waiter_delay: int,
-        waiter_max_attempts: int,
-        aws_conn_id: str,
-    ):
+        *,
+        function_name: str,
+        function_arn: str,
+        waiter_delay: int = 60,
+        waiter_max_attempts: int = 30,
+        aws_conn_id: str | None = None,
+    ) -> None:
+
         super().__init__(
-            serialized_fields={"query_execution_id": query_execution_id},
-            waiter_name="query_complete",
-            waiter_args={"QueryExecutionId": query_execution_id},
-            failure_message=f"Error while waiting for query 
{query_execution_id} to complete",
-            status_message=f"Query execution id: {query_execution_id}",
-            status_queries=["QueryExecution.Status"],
-            return_value=query_execution_id,
+            serialized_fields={"function_name": function_name, "function_arn": 
function_arn},
+            waiter_name="function_active_v2",
+            waiter_args={"FunctionName": function_name},
+            failure_message="Lambda function creation failed",
+            status_message="Status of Lambda function creation is",
+            status_queries=[
+                "Configuration.LastUpdateStatus",
+                "Configuration.LastUpdateStatusReason",
+                "Configuration.LastUpdateStatusReasonCode",
+            ],
+            return_key="function_arn",
+            return_value=function_arn,
             waiter_delay=waiter_delay,
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
         )
 
     def hook(self) -> AwsGenericHook:
-        return AthenaHook(self.aws_conn_id)
+        return LambdaHook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index 284cc429fa..6aa3282510 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -544,6 +544,9 @@ triggers:
   - integration-name: Amazon EC2
     python-modules:
       - airflow.providers.amazon.aws.triggers.ec2
+  - integration-name: AWS Lambda
+    python-modules:
+      - airflow.providers.amazon.aws.triggers.lambda_function
   - integration-name: Amazon Redshift
     python-modules:
       - airflow.providers.amazon.aws.triggers.redshift_cluster
diff --git a/docs/apache-airflow-providers-amazon/operators/lambda.rst 
b/docs/apache-airflow-providers-amazon/operators/lambda.rst
index 79649f106f..0149b5b620 100644
--- a/docs/apache-airflow-providers-amazon/operators/lambda.rst
+++ b/docs/apache-airflow-providers-amazon/operators/lambda.rst
@@ -40,6 +40,8 @@ Create an AWS Lambda function
 
 To create an AWS lambda function you can use
 
:class:`~airflow.providers.amazon.aws.operators.lambda_function.LambdaCreateFunctionOperator`.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as 
a parameter. This requires
+the aiobotocore module to be installed.
 
 .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_lambda.py
     :language: python
diff --git a/tests/providers/amazon/aws/operators/test_lambda_function.py 
b/tests/providers/amazon/aws/operators/test_lambda_function.py
index f0b4b834eb..6fc3a3b64f 100644
--- a/tests/providers/amazon/aws/operators/test_lambda_function.py
+++ b/tests/providers/amazon/aws/operators/test_lambda_function.py
@@ -22,6 +22,7 @@ from unittest.mock import Mock, patch
 
 import pytest
 
+from airflow.exceptions import TaskDeferred
 from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
 from airflow.providers.amazon.aws.operators.lambda_function import (
     LambdaCreateFunctionOperator,
@@ -69,6 +70,20 @@ class TestLambdaCreateFunctionOperator:
         mock_hook_create_lambda.assert_called_once()
         mock_hook_conn.get_waiter.assert_called_once_with("function_active_v2")
 
+    @mock.patch.object(LambdaHook, "create_lambda")
+    def test_create_lambda_deferrable(self, _):
+        operator = LambdaCreateFunctionOperator(
+            task_id="task_test",
+            function_name=FUNCTION_NAME,
+            role=ROLE_ARN,
+            code={
+                "ImageUri": IMAGE_URI,
+            },
+            deferrable=True,
+        )
+        with pytest.raises(TaskDeferred):
+            operator.execute(None)
+
 
 class TestLambdaInvokeFunctionOperator:
     @pytest.mark.parametrize(
diff --git a/tests/providers/amazon/aws/triggers/test_lambda_function.py 
b/tests/providers/amazon/aws/triggers/test_lambda_function.py
new file mode 100644
index 0000000000..c06a99d42e
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_lambda_function.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.lambda_function import 
LambdaCreateFunctionCompleteTrigger
+
+TEST_FUNCTION_NAME = "test-function-name"
+TEST_FUNCTION_ARN = "test-function-arn"
+TEST_WAITER_DELAY = 10
+TEST_WAITER_MAX_ATTEMPTS = 10
+TEST_AWS_CONN_ID = "test-conn-id"
+TEST_REGION_NAME = "test-region-name"
+
+
+class TestLambdaFunctionTriggers:
+    @pytest.mark.parametrize(
+        "trigger",
+        [
+            LambdaCreateFunctionCompleteTrigger(
+                function_name=TEST_FUNCTION_NAME,
+                function_arn=TEST_FUNCTION_ARN,
+                aws_conn_id=TEST_AWS_CONN_ID,
+                waiter_delay=TEST_WAITER_DELAY,
+                waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+            )
+        ],
+    )
+    def test_serialize_recreate(self, trigger):
+        class_path, args = trigger.serialize()
+
+        class_name = class_path.split(".")[-1]
+        clazz = globals()[class_name]
+        instance = clazz(**args)
+
+        class_path2, args2 = instance.serialize()
+
+        assert class_path == class_path2
+        assert args == args2

Reply via email to