This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 815655101b Add Deferrable mode to StepFunctionStartExecutionOperator
(#32563)
815655101b is described below
commit 815655101b7457d60be08648e6cd02af30e0d695
Author: Syed Hussain <[email protected]>
AuthorDate: Fri Jul 21 10:53:17 2023 -0700
Add Deferrable mode to StepFunctionStartExecutionOperator (#32563)
* Add stepfunction trigger to provider.yaml
* Update docstring about default value of deferrable
* Change file name from test_stepfunction.py to test_step_function.py to
match other tests
---
.../amazon/aws/operators/step_function.py | 36 ++++++++++++-
.../providers/amazon/aws/triggers/step_function.py | 59 ++++++++++++++++++++++
.../amazon/aws/waiters/stepfunctions.json | 36 +++++++++++++
airflow/providers/amazon/provider.yaml | 3 ++
.../operators/step_functions.rst | 1 +
.../amazon/aws/operators/test_step_function.py | 17 +++++++
.../amazon/aws/triggers/test_step_function.py | 53 +++++++++++++++++++
7 files changed, 203 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/step_function.py
b/airflow/providers/amazon/aws/operators/step_function.py
index 2aa8bdd8e2..68324df731 100644
--- a/airflow/providers/amazon/aws/operators/step_function.py
+++ b/airflow/providers/amazon/aws/operators/step_function.py
@@ -17,11 +17,14 @@
from __future__ import annotations
import json
-from typing import TYPE_CHECKING, Sequence
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Sequence
+from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.providers.amazon.aws.triggers.step_function import
StepFunctionsExecutionCompleteTrigger
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -42,6 +45,11 @@ class StepFunctionStartExecutionOperator(BaseOperator):
:param state_machine_input: JSON data input to pass to the State Machine
:param aws_conn_id: aws connection to uses
:param do_xcom_push: if True, execution_arn is pushed to XCom with key
execution_arn.
+ :param waiter_max_attempts: Maximum number of attempts to poll the
execution.
+ :param waiter_delay: Number of seconds between polling the state of the
execution.
+ :param deferrable: If True, the operator will wait asynchronously for the
job to complete.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
"""
template_fields: Sequence[str] = ("state_machine_arn", "name", "input")
@@ -56,6 +64,9 @@ class StepFunctionStartExecutionOperator(BaseOperator):
state_machine_input: dict | str | None = None,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
+ waiter_max_attempts: int = 30,
+ waiter_delay: int = 60,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
):
super().__init__(**kwargs)
@@ -64,6 +75,9 @@ class StepFunctionStartExecutionOperator(BaseOperator):
self.input = state_machine_input
self.aws_conn_id = aws_conn_id
self.region_name = region_name
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
def execute(self, context: Context):
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
@@ -74,9 +88,27 @@ class StepFunctionStartExecutionOperator(BaseOperator):
raise AirflowException(f"Failed to start State Machine execution
for: {self.state_machine_arn}")
self.log.info("Started State Machine execution for %s: %s",
self.state_machine_arn, execution_arn)
-
+ if self.deferrable:
+ self.defer(
+ trigger=StepFunctionsExecutionCompleteTrigger(
+ execution_arn=execution_arn,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ )
return execution_arn
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
+ if event is None or event["status"] != "success":
+ raise AirflowException(f"Trigger error: event is {event}")
+
+ self.log.info("State Machine execution completed successfully")
+ return event["execution_arn"]
+
class StepFunctionGetExecutionOutputOperator(BaseOperator):
"""
diff --git a/airflow/providers/amazon/aws/triggers/step_function.py
b/airflow/providers/amazon/aws/triggers/step_function.py
new file mode 100644
index 0000000000..c4875f078f
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/step_function.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+
+
+class StepFunctionsExecutionCompleteTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger to poll for the completion of a Step Functions execution.
+
+ :param execution_arn: ARN of the state machine to poll
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ *,
+ execution_arn: str,
+ waiter_delay: int = 60,
+ waiter_max_attempts: int = 30,
+ aws_conn_id: str | None = None,
+ region_name: str | None = None,
+ ) -> None:
+
+ super().__init__(
+ serialized_fields={"execution_arn": execution_arn, "region_name":
region_name},
+ waiter_name="step_function_succeeded",
+ waiter_args={"executionArn": execution_arn},
+ failure_message="Step function failed",
+ status_message="Status of step function execution is",
+ status_queries=["status", "error", "cause"],
+ return_key="execution_arn",
+ return_value=execution_arn,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return StepFunctionHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
diff --git a/airflow/providers/amazon/aws/waiters/stepfunctions.json
b/airflow/providers/amazon/aws/waiters/stepfunctions.json
new file mode 100644
index 0000000000..7a7af36786
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/stepfunctions.json
@@ -0,0 +1,36 @@
+{
+ "version": 2,
+ "waiters": {
+ "step_function_succeeded": {
+ "operation": "DescribeExecution",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "SUCCEEDED",
+ "state": "success"
+ },
+ {
+ "matcher": "error",
+ "argument": "status",
+ "expected": "RUNNING",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "FAILED",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "ABORTED",
+ "state": "failure"
+ }
+ ]
+ }
+ }
+}
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index ff7ed91525..66d331b5cd 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -566,6 +566,9 @@ triggers:
- integration-name: Amazon RDS
python-modules:
- airflow.providers.amazon.aws.triggers.rds
+ - integration-name: AWS Step Functions
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.step_function
transfers:
- source-integration-name: Amazon DynamoDB
diff --git a/docs/apache-airflow-providers-amazon/operators/step_functions.rst
b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
index 1f207b4576..7736fa9b16 100644
--- a/docs/apache-airflow-providers-amazon/operators/step_functions.rst
+++ b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
@@ -38,6 +38,7 @@ Start an AWS Step Functions state machine execution
To start a new AWS Step Functions state machine execution you can use
:class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionStartExecutionOperator`.
+You can also run this operator in deferrable mode by setting ``deferrable``
param to ``True``.
.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_step_functions.py
:language: python
diff --git a/tests/providers/amazon/aws/operators/test_step_function.py
b/tests/providers/amazon/aws/operators/test_step_function.py
index 566e134a86..91ccebf7c6 100644
--- a/tests/providers/amazon/aws/operators/test_step_function.py
+++ b/tests/providers/amazon/aws/operators/test_step_function.py
@@ -22,6 +22,8 @@ from unittest.mock import MagicMock
import pytest
+from airflow.exceptions import TaskDeferred
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
from airflow.providers.amazon.aws.operators.step_function import (
StepFunctionGetExecutionOutputOperator,
StepFunctionStartExecutionOperator,
@@ -132,3 +134,18 @@ class TestStepFunctionStartExecutionOperator:
# Then
assert hook_response == result
+
+ @mock.patch.object(StepFunctionHook, "start_execution")
+ def test_step_function_start_execution_deferrable(self,
mock_start_execution):
+ mock_start_execution.return_value = "test-execution-arn"
+ operator = StepFunctionStartExecutionOperator(
+ task_id=self.TASK_ID,
+ state_machine_arn=STATE_MACHINE_ARN,
+ name=NAME,
+ state_machine_input=INPUT,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME,
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred):
+ operator.execute(None)
diff --git a/tests/providers/amazon/aws/triggers/test_step_function.py
b/tests/providers/amazon/aws/triggers/test_step_function.py
new file mode 100644
index 0000000000..d0c25e096f
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_step_function.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.step_function import
StepFunctionsExecutionCompleteTrigger
+
+TEST_EXECUTION_ARN = "test-execution-arn"
+TEST_WAITER_DELAY = 10
+TEST_WAITER_MAX_ATTEMPTS = 10
+TEST_AWS_CONN_ID = "test-conn-id"
+TEST_REGION_NAME = "test-region-name"
+
+
+class TestStepFunctionsTriggers:
+ @pytest.mark.parametrize(
+ "trigger",
+ [
+ StepFunctionsExecutionCompleteTrigger(
+ execution_arn=TEST_EXECUTION_ARN,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ region_name=TEST_REGION_NAME,
+ )
+ ],
+ )
+ def test_serialize_recreate(self, trigger):
+ class_path, args = trigger.serialize()
+
+ class_name = class_path.split(".")[-1]
+ clazz = globals()[class_name]
+ instance = clazz(**args)
+
+ class_path2, args2 = instance.serialize()
+
+ assert class_path == class_path2
+ assert args == args2