This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 815655101b Add Deferrable mode to StepFunctionStartExecutionOperator 
(#32563)
815655101b is described below

commit 815655101b7457d60be08648e6cd02af30e0d695
Author: Syed Hussain <[email protected]>
AuthorDate: Fri Jul 21 10:53:17 2023 -0700

    Add Deferrable mode to StepFunctionStartExecutionOperator (#32563)
    
    * Add stepfunction trigger to provider.yaml
    * Update docstring about default value of deferrable
    * Change file name from test_stepfunction.py to test_step_function.py to 
match other tests
---
 .../amazon/aws/operators/step_function.py          | 36 ++++++++++++-
 .../providers/amazon/aws/triggers/step_function.py | 59 ++++++++++++++++++++++
 .../amazon/aws/waiters/stepfunctions.json          | 36 +++++++++++++
 airflow/providers/amazon/provider.yaml             |  3 ++
 .../operators/step_functions.rst                   |  1 +
 .../amazon/aws/operators/test_step_function.py     | 17 +++++++
 .../amazon/aws/triggers/test_step_function.py      | 53 +++++++++++++++++++
 7 files changed, 203 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/step_function.py 
b/airflow/providers/amazon/aws/operators/step_function.py
index 2aa8bdd8e2..68324df731 100644
--- a/airflow/providers/amazon/aws/operators/step_function.py
+++ b/airflow/providers/amazon/aws/operators/step_function.py
@@ -17,11 +17,14 @@
 from __future__ import annotations
 
 import json
-from typing import TYPE_CHECKING, Sequence
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Sequence
 
+from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.providers.amazon.aws.triggers.step_function import 
StepFunctionsExecutionCompleteTrigger
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -42,6 +45,11 @@ class StepFunctionStartExecutionOperator(BaseOperator):
     :param state_machine_input: JSON data input to pass to the State Machine
     :param aws_conn_id: aws connection to uses
     :param do_xcom_push: if True, execution_arn is pushed to XCom with key 
execution_arn.
+    :param waiter_max_attempts: Maximum number of attempts to poll the 
execution.
+    :param waiter_delay: Number of seconds between polling the state of the 
execution.
+    :param deferrable: If True, the operator will wait asynchronously for the 
job to complete.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False, but can be overridden in config file by setting 
default_deferrable to True)
     """
 
     template_fields: Sequence[str] = ("state_machine_arn", "name", "input")
@@ -56,6 +64,9 @@ class StepFunctionStartExecutionOperator(BaseOperator):
         state_machine_input: dict | str | None = None,
         aws_conn_id: str = "aws_default",
         region_name: str | None = None,
+        waiter_max_attempts: int = 30,
+        waiter_delay: int = 60,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -64,6 +75,9 @@ class StepFunctionStartExecutionOperator(BaseOperator):
         self.input = state_machine_input
         self.aws_conn_id = aws_conn_id
         self.region_name = region_name
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
 
     def execute(self, context: Context):
         hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
@@ -74,9 +88,27 @@ class StepFunctionStartExecutionOperator(BaseOperator):
             raise AirflowException(f"Failed to start State Machine execution 
for: {self.state_machine_arn}")
 
         self.log.info("Started State Machine execution for %s: %s", 
self.state_machine_arn, execution_arn)
-
+        if self.deferrable:
+            self.defer(
+                trigger=StepFunctionsExecutionCompleteTrigger(
+                    execution_arn=execution_arn,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    region_name=self.region_name,
+                ),
+                method_name="execute_complete",
+                timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay),
+            )
         return execution_arn
 
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        if event is None or event["status"] != "success":
+            raise AirflowException(f"Trigger error: event is {event}")
+
+        self.log.info("State Machine execution completed successfully")
+        return event["execution_arn"]
+
 
 class StepFunctionGetExecutionOutputOperator(BaseOperator):
     """
diff --git a/airflow/providers/amazon/aws/triggers/step_function.py 
b/airflow/providers/amazon/aws/triggers/step_function.py
new file mode 100644
index 0000000000..c4875f078f
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/step_function.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+
+
+class StepFunctionsExecutionCompleteTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to poll for the completion of a Step Functions execution.
+
+    :param execution_arn: ARN of the state machine to poll
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    """
+
+    def __init__(
+        self,
+        *,
+        execution_arn: str,
+        waiter_delay: int = 60,
+        waiter_max_attempts: int = 30,
+        aws_conn_id: str | None = None,
+        region_name: str | None = None,
+    ) -> None:
+
+        super().__init__(
+            serialized_fields={"execution_arn": execution_arn, "region_name": 
region_name},
+            waiter_name="step_function_succeeded",
+            waiter_args={"executionArn": execution_arn},
+            failure_message="Step function failed",
+            status_message="Status of step function execution is",
+            status_queries=["status", "error", "cause"],
+            return_key="execution_arn",
+            return_value=execution_arn,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return StepFunctionHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
diff --git a/airflow/providers/amazon/aws/waiters/stepfunctions.json 
b/airflow/providers/amazon/aws/waiters/stepfunctions.json
new file mode 100644
index 0000000000..7a7af36786
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/stepfunctions.json
@@ -0,0 +1,36 @@
+{
+    "version": 2,
+    "waiters": {
+        "step_function_succeeded": {
+            "operation": "DescribeExecution",
+            "delay": 30,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "SUCCEEDED",
+                    "state": "success"
+                },
+                {
+                    "matcher": "error",
+                    "argument": "status",
+                    "expected": "RUNNING",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "FAILED",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "status",
+                    "expected": "ABORTED",
+                    "state": "failure"
+                }
+            ]
+        }
+    }
+}
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index ff7ed91525..66d331b5cd 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -566,6 +566,9 @@ triggers:
   - integration-name: Amazon RDS
     python-modules:
       - airflow.providers.amazon.aws.triggers.rds
+  - integration-name: AWS Step Functions
+    python-modules:
+      - airflow.providers.amazon.aws.triggers.step_function
 
 transfers:
   - source-integration-name: Amazon DynamoDB
diff --git a/docs/apache-airflow-providers-amazon/operators/step_functions.rst 
b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
index 1f207b4576..7736fa9b16 100644
--- a/docs/apache-airflow-providers-amazon/operators/step_functions.rst
+++ b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
@@ -38,6 +38,7 @@ Start an AWS Step Functions state machine execution
 
 To start a new AWS Step Functions state machine execution you can use
 
:class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionStartExecutionOperator`.
+You can also run this operator in deferrable mode by setting ``deferrable`` 
param to ``True``.
 
 .. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_step_functions.py
     :language: python
diff --git a/tests/providers/amazon/aws/operators/test_step_function.py 
b/tests/providers/amazon/aws/operators/test_step_function.py
index 566e134a86..91ccebf7c6 100644
--- a/tests/providers/amazon/aws/operators/test_step_function.py
+++ b/tests/providers/amazon/aws/operators/test_step_function.py
@@ -22,6 +22,8 @@ from unittest.mock import MagicMock
 
 import pytest
 
+from airflow.exceptions import TaskDeferred
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
 from airflow.providers.amazon.aws.operators.step_function import (
     StepFunctionGetExecutionOutputOperator,
     StepFunctionStartExecutionOperator,
@@ -132,3 +134,18 @@ class TestStepFunctionStartExecutionOperator:
 
         # Then
         assert hook_response == result
+
+    @mock.patch.object(StepFunctionHook, "start_execution")
+    def test_step_function_start_execution_deferrable(self, 
mock_start_execution):
+        mock_start_execution.return_value = "test-execution-arn"
+        operator = StepFunctionStartExecutionOperator(
+            task_id=self.TASK_ID,
+            state_machine_arn=STATE_MACHINE_ARN,
+            name=NAME,
+            state_machine_input=INPUT,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME,
+            deferrable=True,
+        )
+        with pytest.raises(TaskDeferred):
+            operator.execute(None)
diff --git a/tests/providers/amazon/aws/triggers/test_step_function.py 
b/tests/providers/amazon/aws/triggers/test_step_function.py
new file mode 100644
index 0000000000..d0c25e096f
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_step_function.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.step_function import 
StepFunctionsExecutionCompleteTrigger
+
+TEST_EXECUTION_ARN = "test-execution-arn"
+TEST_WAITER_DELAY = 10
+TEST_WAITER_MAX_ATTEMPTS = 10
+TEST_AWS_CONN_ID = "test-conn-id"
+TEST_REGION_NAME = "test-region-name"
+
+
+class TestStepFunctionsTriggers:
+    @pytest.mark.parametrize(
+        "trigger",
+        [
+            StepFunctionsExecutionCompleteTrigger(
+                execution_arn=TEST_EXECUTION_ARN,
+                aws_conn_id=TEST_AWS_CONN_ID,
+                waiter_delay=TEST_WAITER_DELAY,
+                waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+                region_name=TEST_REGION_NAME,
+            )
+        ],
+    )
+    def test_serialize_recreate(self, trigger):
+        class_path, args = trigger.serialize()
+
+        class_name = class_path.split(".")[-1]
+        clazz = globals()[class_name]
+        instance = clazz(**args)
+
+        class_path2, args2 = instance.serialize()
+
+        assert class_path == class_path2
+        assert args == args2

Reply via email to