This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4d99705f69 Add `deferrable` option to `LambdaCreateFunctionOperator`
(#33327)
4d99705f69 is described below
commit 4d99705f69114d37bb3e85d7723602e71bd023c1
Author: Vincent <[email protected]>
AuthorDate: Mon Aug 14 17:31:13 2023 -0400
Add `deferrable` option to `LambdaCreateFunctionOperator` (#33327)
---
.../amazon/aws/operators/lambda_function.py | 36 ++++++++++++++-
airflow/providers/amazon/aws/triggers/athena.py | 2 +-
airflow/providers/amazon/aws/triggers/base.py | 1 -
.../aws/triggers/{athena.py => lambda_function.py} | 46 ++++++++++--------
airflow/providers/amazon/provider.yaml | 3 ++
.../operators/lambda.rst | 2 +
.../amazon/aws/operators/test_lambda_function.py | 15 ++++++
.../amazon/aws/triggers/test_lambda_function.py | 54 ++++++++++++++++++++++
8 files changed, 136 insertions(+), 23 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py
b/airflow/providers/amazon/aws/operators/lambda_function.py
index e472deaaac..5d7e980bb5 100644
--- a/airflow/providers/amazon/aws/operators/lambda_function.py
+++ b/airflow/providers/amazon/aws/operators/lambda_function.py
@@ -18,11 +18,15 @@
from __future__ import annotations
import json
+from datetime import timedelta
from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
+from airflow import AirflowException
+from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
+from airflow.providers.amazon.aws.triggers.lambda_function import
LambdaCreateFunctionCompleteTrigger
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -50,6 +54,11 @@ class LambdaCreateFunctionOperator(BaseOperator):
:param timeout: The amount of time (in seconds) that Lambda allows a
function to run before stopping it.
:param config: Optional dictionary for arbitrary parameters to the boto
API create_lambda call.
:param wait_for_completion: If True, the operator will wait until the
function is active.
+ :param waiter_max_attempts: Maximum number of attempts to poll the
creation.
+ :param waiter_delay: Number of seconds between polling the state of the
creation.
+ :param deferrable: If True, the operator will wait asynchronously for the
creation to complete.
+ This implies waiting for creation complete. This mode requires
aiobotocore module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
:param aws_conn_id: The AWS connection ID to use
"""
@@ -75,6 +84,9 @@ class LambdaCreateFunctionOperator(BaseOperator):
timeout: int | None = None,
config: dict = {},
wait_for_completion: bool = False,
+ waiter_max_attempts: int = 60,
+ waiter_delay: int = 15,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
aws_conn_id: str = "aws_default",
**kwargs,
):
@@ -88,6 +100,9 @@ class LambdaCreateFunctionOperator(BaseOperator):
self.timeout = timeout
self.config = config
self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
self.aws_conn_id = aws_conn_id
@cached_property
@@ -108,6 +123,18 @@ class LambdaCreateFunctionOperator(BaseOperator):
)
self.log.info("Lambda response: %r", response)
+ if self.deferrable:
+ self.defer(
+ trigger=LambdaCreateFunctionCompleteTrigger(
+ function_name=self.function_name,
+ function_arn=response["FunctionArn"],
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ )
if self.wait_for_completion:
self.log.info("Wait for Lambda function to be active")
waiter = self.hook.conn.get_waiter("function_active_v2")
@@ -117,6 +144,13 @@ class LambdaCreateFunctionOperator(BaseOperator):
return response.get("FunctionArn")
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ if not event or event["status"] != "success":
+ raise AirflowException(f"Trigger error: event is {event}")
+
+ self.log.info("Lambda function created successfully")
+ return event["function_arn"]
+
class LambdaInvokeFunctionOperator(BaseOperator):
"""
diff --git a/airflow/providers/amazon/aws/triggers/athena.py
b/airflow/providers/amazon/aws/triggers/athena.py
index 636c135059..fcd09dae19 100644
--- a/airflow/providers/amazon/aws/triggers/athena.py
+++ b/airflow/providers/amazon/aws/triggers/athena.py
@@ -23,7 +23,7 @@ from airflow.providers.amazon.aws.triggers.base import
AwsBaseWaiterTrigger
class AthenaTrigger(AwsBaseWaiterTrigger):
"""
- Trigger for RedshiftCreateClusterOperator.
+ Trigger for AthenaOperator.
The trigger will asynchronously poll the boto3 API and wait for the
Redshift cluster to be in the `available` state.
diff --git a/airflow/providers/amazon/aws/triggers/base.py
b/airflow/providers/amazon/aws/triggers/base.py
index 41f7d2dc33..d2d664d97f 100644
--- a/airflow/providers/amazon/aws/triggers/base.py
+++ b/airflow/providers/amazon/aws/triggers/base.py
@@ -112,7 +112,6 @@ class AwsBaseWaiterTrigger(BaseTrigger):
@abstractmethod
def hook(self) -> AwsGenericHook:
"""Override in subclasses to return the right hook."""
- ...
async def run(self) -> AsyncIterator[TriggerEvent]:
hook = self.hook()
diff --git a/airflow/providers/amazon/aws/triggers/athena.py
b/airflow/providers/amazon/aws/triggers/lambda_function.py
similarity index 55%
copy from airflow/providers/amazon/aws/triggers/athena.py
copy to airflow/providers/amazon/aws/triggers/lambda_function.py
index 636c135059..f0f6a40551 100644
--- a/airflow/providers/amazon/aws/triggers/athena.py
+++ b/airflow/providers/amazon/aws/triggers/lambda_function.py
@@ -16,19 +16,17 @@
# under the License.
from __future__ import annotations
-from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
-class AthenaTrigger(AwsBaseWaiterTrigger):
+class LambdaCreateFunctionCompleteTrigger(AwsBaseWaiterTrigger):
"""
- Trigger for RedshiftCreateClusterOperator.
+ Trigger to poll for the completion of a Lambda function creation.
- The trigger will asynchronously poll the boto3 API and wait for the
- Redshift cluster to be in the `available` state.
-
- :param query_execution_id: ID of the Athena query execution to watch
+ :param function_name: The function name
+ :param function_arn: The function ARN
:param waiter_delay: The amount of time in seconds to wait between
attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
@@ -36,23 +34,31 @@ class AthenaTrigger(AwsBaseWaiterTrigger):
def __init__(
self,
- query_execution_id: str,
- waiter_delay: int,
- waiter_max_attempts: int,
- aws_conn_id: str,
- ):
+ *,
+ function_name: str,
+ function_arn: str,
+ waiter_delay: int = 60,
+ waiter_max_attempts: int = 30,
+ aws_conn_id: str | None = None,
+ ) -> None:
+
super().__init__(
- serialized_fields={"query_execution_id": query_execution_id},
- waiter_name="query_complete",
- waiter_args={"QueryExecutionId": query_execution_id},
- failure_message=f"Error while waiting for query
{query_execution_id} to complete",
- status_message=f"Query execution id: {query_execution_id}",
- status_queries=["QueryExecution.Status"],
- return_value=query_execution_id,
+ serialized_fields={"function_name": function_name, "function_arn":
function_arn},
+ waiter_name="function_active_v2",
+ waiter_args={"FunctionName": function_name},
+ failure_message="Lambda function creation failed",
+ status_message="Status of Lambda function creation is",
+ status_queries=[
+ "Configuration.LastUpdateStatus",
+ "Configuration.LastUpdateStatusReason",
+ "Configuration.LastUpdateStatusReasonCode",
+ ],
+ return_key="function_arn",
+ return_value=function_arn,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
def hook(self) -> AwsGenericHook:
- return AthenaHook(self.aws_conn_id)
+ return LambdaHook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 284cc429fa..6aa3282510 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -544,6 +544,9 @@ triggers:
- integration-name: Amazon EC2
python-modules:
- airflow.providers.amazon.aws.triggers.ec2
+ - integration-name: AWS Lambda
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.lambda_function
- integration-name: Amazon Redshift
python-modules:
- airflow.providers.amazon.aws.triggers.redshift_cluster
diff --git a/docs/apache-airflow-providers-amazon/operators/lambda.rst
b/docs/apache-airflow-providers-amazon/operators/lambda.rst
index 79649f106f..0149b5b620 100644
--- a/docs/apache-airflow-providers-amazon/operators/lambda.rst
+++ b/docs/apache-airflow-providers-amazon/operators/lambda.rst
@@ -40,6 +40,8 @@ Create an AWS Lambda function
To create an AWS lambda function you can use
:class:`~airflow.providers.amazon.aws.operators.lambda_function.LambdaCreateFunctionOperator`.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as
a parameter. This requires
+the aiobotocore module to be installed.
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_lambda.py
:language: python
diff --git a/tests/providers/amazon/aws/operators/test_lambda_function.py
b/tests/providers/amazon/aws/operators/test_lambda_function.py
index f0b4b834eb..6fc3a3b64f 100644
--- a/tests/providers/amazon/aws/operators/test_lambda_function.py
+++ b/tests/providers/amazon/aws/operators/test_lambda_function.py
@@ -22,6 +22,7 @@ from unittest.mock import Mock, patch
import pytest
+from airflow.exceptions import TaskDeferred
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.operators.lambda_function import (
LambdaCreateFunctionOperator,
@@ -69,6 +70,20 @@ class TestLambdaCreateFunctionOperator:
mock_hook_create_lambda.assert_called_once()
mock_hook_conn.get_waiter.assert_called_once_with("function_active_v2")
+ @mock.patch.object(LambdaHook, "create_lambda")
+ def test_create_lambda_deferrable(self, _):
+ operator = LambdaCreateFunctionOperator(
+ task_id="task_test",
+ function_name=FUNCTION_NAME,
+ role=ROLE_ARN,
+ code={
+ "ImageUri": IMAGE_URI,
+ },
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred):
+ operator.execute(None)
+
class TestLambdaInvokeFunctionOperator:
@pytest.mark.parametrize(
diff --git a/tests/providers/amazon/aws/triggers/test_lambda_function.py
b/tests/providers/amazon/aws/triggers/test_lambda_function.py
new file mode 100644
index 0000000000..c06a99d42e
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_lambda_function.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.lambda_function import
LambdaCreateFunctionCompleteTrigger
+
+TEST_FUNCTION_NAME = "test-function-name"
+TEST_FUNCTION_ARN = "test-function-arn"
+TEST_WAITER_DELAY = 10
+TEST_WAITER_MAX_ATTEMPTS = 10
+TEST_AWS_CONN_ID = "test-conn-id"
+TEST_REGION_NAME = "test-region-name"
+
+
+class TestLambdaFunctionTriggers:
+ @pytest.mark.parametrize(
+ "trigger",
+ [
+ LambdaCreateFunctionCompleteTrigger(
+ function_name=TEST_FUNCTION_NAME,
+ function_arn=TEST_FUNCTION_ARN,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ )
+ ],
+ )
+ def test_serialize_recreate(self, trigger):
+ class_path, args = trigger.serialize()
+
+ class_name = class_path.split(".")[-1]
+ clazz = globals()[class_name]
+ instance = clazz(**args)
+
+ class_path2, args2 = instance.serialize()
+
+ assert class_path == class_path2
+ assert args == args2