This is an automated email from the ASF dual-hosted git repository.

turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 564192c  Add AWS StepFunctions integrations to the aws provider (#8749)
564192c is described below

commit 564192c1625a552456cebb3751978c08eebdb2a1
Author: chamcca <[email protected]>
AuthorDate: Wed Jul 8 03:25:16 2020 -0600

    Add AWS StepFunctions integrations to the aws provider (#8749)
---
 .../providers/amazon/aws/hooks/step_function.py    |  79 +++++++++++++++
 .../step_function_get_execution_output.py          |  58 +++++++++++
 .../aws/operators/step_function_start_execution.py |  72 ++++++++++++++
 .../amazon/aws/sensors/step_function_execution.py  |  77 +++++++++++++++
 docs/operators-and-hooks-ref.rst                   |   7 ++
 .../amazon/aws/hooks/test_step_function.py         |  63 ++++++++++++
 .../test_step_function_get_execution_output.py     |  76 +++++++++++++++
 .../test_step_function_start_execution.py          |  82 ++++++++++++++++
 .../aws/sensors/test_step_function_execution.py    | 107 +++++++++++++++++++++
 9 files changed, 621 insertions(+)

diff --git a/airflow/providers/amazon/aws/hooks/step_function.py 
b/airflow/providers/amazon/aws/hooks/step_function.py
new file mode 100644
index 0000000..f0e1040
--- /dev/null
+++ b/airflow/providers/amazon/aws/hooks/step_function.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+from typing import Optional, Union
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+
+class StepFunctionHook(AwsBaseHook):
+    """
+    Interact with an AWS Step Functions State Machine.
+
+    Additional arguments (such as ``aws_conn_id``) may be specified and
+    are passed down to the underlying AwsBaseHook.
+
+    .. seealso::
+        :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+    """
+
+    def __init__(self, region_name=None, *args, **kwargs):
+        super().__init__(client_type='stepfunctions', *args, **kwargs)
+
+    def start_execution(self, state_machine_arn: str, name: Optional[str] = 
None,
+                        state_machine_input: Union[dict, str, None] = None) -> 
str:
+        """
+        Start Execution of the State Machine.
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution
+
+        :param state_machine_arn: AWS Step Function State Machine ARN
+        :type state_machine_arn: str
+        :param name: The name of the execution.
+        :type name: Optional[str]
+        :param state_machine_input: JSON data input to pass to the State 
Machine
+        :type state_machine_input: Union[Dict[str, any], str, None]
+        :return: Execution ARN
+        :rtype: str
+        """
+        execution_args = {
+            'stateMachineArn': state_machine_arn
+        }
+        if name is not None:
+            execution_args['name'] = name
+        if state_machine_input is not None:
+            if isinstance(state_machine_input, str):
+                execution_args['input'] = state_machine_input
+            elif isinstance(state_machine_input, dict):
+                execution_args['input'] = json.dumps(state_machine_input)
+
+        self.log.info('Executing Step Function State Machine: %s', 
state_machine_arn)
+
+        response = self.conn.start_execution(**execution_args)
+        return response.get('executionArn', None)
+
+    def describe_execution(self, execution_arn: str) -> dict:
+        """
+        Describes a State Machine Execution
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.describe_execution
+
+        :param execution_arn: ARN of the State Machine Execution
+        :type execution_arn: str
+        :return: Dict with Execution details
+        :rtype: dict
+        """
+        return self.get_conn().describe_execution(executionArn=execution_arn)
diff --git 
a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py 
b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
new file mode 100644
index 0000000..2ef531c
--- /dev/null
+++ 
b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.utils.decorators import apply_defaults
+
+
+class StepFunctionGetExecutionOutputOperator(BaseOperator):
+    """
+    An Operator that begins execution of an Step Function State Machine
+
+    Additional arguments may be specified and are passed down to the 
underlying BaseOperator.
+
+    .. seealso::
+        :class:`~airflow.models.BaseOperator`
+
+    :param execution_arn: ARN of the Step Function State Machine Execution
+    :type execution_arn: str
+    :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+    :type aws_conn_id: str
+    """
+    template_fields = ['execution_arn']
+    template_ext = ()
+    ui_color = '#f9c915'
+
+    @apply_defaults
+    def __init__(self, execution_arn: str, aws_conn_id='aws_default', 
region_name=None, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.execution_arn = execution_arn
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+
+    def execute(self, context):
+        hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+
+        execution_status = hook.describe_execution(self.execution_arn)
+        execution_output = json.loads(execution_status['output']) if 'output' 
in execution_status else None
+
+        self.log.info('Got State Machine Execution output for %s', 
self.execution_arn)
+
+        return execution_output
diff --git 
a/airflow/providers/amazon/aws/operators/step_function_start_execution.py 
b/airflow/providers/amazon/aws/operators/step_function_start_execution.py
new file mode 100644
index 0000000..f5ea75c
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.utils.decorators import apply_defaults
+
+
+class StepFunctionStartExecutionOperator(BaseOperator):
+    """
+    An Operator that begins execution of an Step Function State Machine
+
+    Additional arguments may be specified and are passed down to the 
underlying BaseOperator.
+
+    .. seealso::
+        :class:`~airflow.models.BaseOperator`
+
+    :param state_machine_arn: ARN of the Step Function State Machine
+    :type state_machine_arn: str
+    :param name: The name of the execution.
+    :type name: Optional[str]
+    :param state_machine_input: JSON data input to pass to the State Machine
+    :type state_machine_input: Union[Dict[str, any], str, None]
+    :param aws_conn_id: aws connection to uses
+    :type aws_conn_id: str
+    :param do_xcom_push: if True, execution_arn is pushed to XCom with key 
execution_arn.
+    :type do_xcom_push: bool
+    """
+    template_fields = ['state_machine_arn', 'name', 'input']
+    template_ext = ()
+    ui_color = '#f9c915'
+
+    @apply_defaults
+    def __init__(self, state_machine_arn: str, name: Optional[str] = None,
+                 state_machine_input: Union[dict, str, None] = None,
+                 aws_conn_id='aws_default', region_name=None,
+                 *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.state_machine_arn = state_machine_arn
+        self.name = name
+        self.input = state_machine_input
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+
+    def execute(self, context):
+        hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+
+        execution_arn = hook.start_execution(self.state_machine_arn, 
self.name, self.input)
+
+        if execution_arn is None:
+            raise AirflowException(f'Failed to start State Machine execution 
for: {self.state_machine_arn}')
+
+        self.log.info('Started State Machine execution for %s: %s', 
self.state_machine_arn, execution_arn)
+
+        return execution_arn
diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py 
b/airflow/providers/amazon/aws/sensors/step_function_execution.py
new file mode 100644
index 0000000..0cc3caf
--- /dev/null
+++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class StepFunctionExecutionSensor(BaseSensorOperator):
+    """
+    Asks for the state of the Step Function State Machine Execution until it
+    reaches a failure state or success state.
+    If it fails, failing the task.
+
+    On successful completion of the Execution the Sensor will do an XCom Push
+    of the State Machine's output to `output`
+
+    :param execution_arn: execution_arn to check the state of
+    :type execution_arn: str
+    :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+    :type aws_conn_id: str
+    """
+
+    INTERMEDIATE_STATES = ('RUNNING',)
+    FAILURE_STATES = ('FAILED', 'TIMED_OUT', 'ABORTED',)
+    SUCCESS_STATES = ('SUCCEEDED',)
+
+    template_fields = ['execution_arn']
+    template_ext = ()
+    ui_color = '#66c3ff'
+
+    @apply_defaults
+    def __init__(self, execution_arn: str, aws_conn_id='aws_default', 
region_name=None,
+                 *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.execution_arn = execution_arn
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.hook = None
+
+    def poke(self, context):
+        execution_status = 
self.get_hook().describe_execution(self.execution_arn)
+        state = execution_status['status']
+        output = json.loads(execution_status['output']) if 'output' in 
execution_status else None
+
+        if state in self.FAILURE_STATES:
+            raise AirflowException(f'Step Function sensor failed. State 
Machine Output: {output}')
+
+        if state in self.INTERMEDIATE_STATES:
+            return False
+
+        self.log.info('Doing xcom_push of output')
+        self.xcom_push(context, 'output', output)
+        return True
+
+    def get_hook(self):
+        """Create and return a StepFunctionHook"""
+        if not self.hook:
+            self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+        return self.hook
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 9c791d3..5c25882 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -495,6 +495,13 @@ These integrations allow you to perform various operations 
within the Amazon Web
      - :mod:`airflow.providers.amazon.aws.sensors.s3_key`,
        :mod:`airflow.providers.amazon.aws.sensors.s3_prefix`
 
+   * - `AWS Step Functions <https://aws.amazon.com/step-functions/>`__
+     -
+     - :mod:`airflow.providers.amazon.aws.hooks.step_function`
+     - 
:mod:`airflow.providers.amazon.aws.operators.step_function_start_execution`,
+       
:mod:`airflow.providers.amazon.aws.operators.step_function_get_execution_output`,
+     - :mod:`airflow.providers.amazon.aws.sensors.step_function_execution`,
+
 Transfer operators and hooks
 ''''''''''''''''''''''''''''
 
diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py 
b/tests/providers/amazon/aws/hooks/test_step_function.py
new file mode 100644
index 0000000..679d2e4
--- /dev/null
+++ b/tests/providers/amazon/aws/hooks/test_step_function.py
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+
+try:
+    from moto import mock_stepfunctions
+except ImportError:
+    mock_stepfunctions = None
+
+
[email protected](mock_stepfunctions is None, 'moto package not present')
+class TestStepFunctionHook(unittest.TestCase):
+
+    @mock_stepfunctions
+    def test_get_conn_returns_a_boto3_connection(self):
+        hook = StepFunctionHook(aws_conn_id='aws_default')
+        self.assertEqual('stepfunctions', 
hook.get_conn().meta.service_model.service_name)
+
+    @mock_stepfunctions
+    def test_start_execution(self):
+        hook = StepFunctionHook(aws_conn_id='aws_default', 
region_name='us-east-1')
+        state_machine = hook.get_conn().create_state_machine(
+            name='pseudo-state-machine', definition='{}', 
roleArn='arn:aws:iam::000000000000:role/Role')
+
+        state_machine_arn = state_machine.get('stateMachineArn', None)
+
+        execution_arn = hook.start_execution(
+            state_machine_arn=state_machine_arn, name=None, 
state_machine_input={})
+
+        assert execution_arn is not None
+
+    @mock_stepfunctions
+    def test_describe_execution(self):
+        hook = StepFunctionHook(aws_conn_id='aws_default', 
region_name='us-east-1')
+        state_machine = hook.get_conn().create_state_machine(
+            name='pseudo-state-machine', definition='{}', 
roleArn='arn:aws:iam::000000000000:role/Role')
+
+        state_machine_arn = state_machine.get('stateMachineArn', None)
+
+        execution_arn = hook.start_execution(
+            state_machine_arn=state_machine_arn, name=None, 
state_machine_input={})
+        response = hook.describe_execution(execution_arn)
+
+        assert 'input' in response
diff --git 
a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py
 
b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py
new file mode 100644
index 0000000..8997df9
--- /dev/null
+++ 
b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
+
+from airflow.providers.amazon.aws.operators.step_function_get_execution_output 
import (
+    StepFunctionGetExecutionOutputOperator,
+)
+
+TASK_ID = 'step_function_get_execution_output'
+EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\
+                'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934'
+AWS_CONN_ID = 'aws_non_default'
+REGION_NAME = 'us-west-2'
+
+
+class TestStepFunctionGetExecutionOutputOperator(unittest.TestCase):
+
+    def setUp(self):
+        self.mock_context = MagicMock()
+
+    def test_init(self):
+        # Given / When
+        operator = StepFunctionGetExecutionOutputOperator(
+            task_id=TASK_ID,
+            execution_arn=EXECUTION_ARN,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME
+        )
+
+        # Then
+        self.assertEqual(TASK_ID, operator.task_id)
+        self.assertEqual(EXECUTION_ARN, operator.execution_arn)
+        self.assertEqual(AWS_CONN_ID, operator.aws_conn_id)
+        self.assertEqual(REGION_NAME, operator.region_name)
+
+    
@mock.patch('airflow.providers.amazon.aws.operators.step_function_get_execution_output.StepFunctionHook')
+    def test_execute(self, mock_hook):
+        # Given
+        hook_response = {
+            'output': '{}'
+        }
+
+        hook_instance = mock_hook.return_value
+        hook_instance.describe_execution.return_value = hook_response
+
+        operator = StepFunctionGetExecutionOutputOperator(
+            task_id=TASK_ID,
+            execution_arn=EXECUTION_ARN,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME
+        )
+
+        # When
+        result = operator.execute(self.mock_context)
+
+        # Then
+        self.assertEqual({}, result)
diff --git 
a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py 
b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py
new file mode 100644
index 0000000..5f6c336
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
+
+from airflow.providers.amazon.aws.operators.step_function_start_execution 
import (
+    StepFunctionStartExecutionOperator,
+)
+
+TASK_ID = 'step_function_start_execution_task'
+STATE_MACHINE_ARN = 
'arn:aws:states:us-east-1:000000000000:stateMachine:pseudo-state-machine'
+NAME = 'NAME'
+INPUT = '{}'
+AWS_CONN_ID = 'aws_non_default'
+REGION_NAME = 'us-west-2'
+
+
+class TestStepFunctionStartExecutionOperator(unittest.TestCase):
+
+    def setUp(self):
+        self.mock_context = MagicMock()
+
+    def test_init(self):
+        # Given / When
+        operator = StepFunctionStartExecutionOperator(
+            task_id=TASK_ID,
+            state_machine_arn=STATE_MACHINE_ARN,
+            name=NAME,
+            state_machine_input=INPUT,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME
+        )
+
+        # Then
+        self.assertEqual(TASK_ID, operator.task_id)
+        self.assertEqual(STATE_MACHINE_ARN, operator.state_machine_arn)
+        self.assertEqual(NAME, operator.name)
+        self.assertEqual(INPUT, operator.input)
+        self.assertEqual(AWS_CONN_ID, operator.aws_conn_id)
+        self.assertEqual(REGION_NAME, operator.region_name)
+
+    
@mock.patch('airflow.providers.amazon.aws.operators.step_function_start_execution.StepFunctionHook')
+    def test_execute(self, mock_hook):
+        # Given
+        hook_response = 'arn:aws:states:us-east-1:123456789012:execution:'\
+                        
'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934'
+
+        hook_instance = mock_hook.return_value
+        hook_instance.start_execution.return_value = hook_response
+
+        operator = StepFunctionStartExecutionOperator(
+            task_id=TASK_ID,
+            state_machine_arn=STATE_MACHINE_ARN,
+            name=NAME,
+            state_machine_input=INPUT,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME
+        )
+
+        # When
+        result = operator.execute(self.mock_context)
+
+        # Then
+        self.assertEqual(hook_response, result)
diff --git a/tests/providers/amazon/aws/sensors/test_step_function_execution.py 
b/tests/providers/amazon/aws/sensors/test_step_function_execution.py
new file mode 100644
index 0000000..237f8ef
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_step_function_execution.py
@@ -0,0 +1,107 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
+
+from parameterized import parameterized
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.sensors.step_function_execution import 
StepFunctionExecutionSensor
+
+TASK_ID = 'step_function_execution_sensor'
+EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\
+                'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934'
+AWS_CONN_ID = 'aws_non_default'
+REGION_NAME = 'us-west-2'
+
+
+class TestStepFunctionExecutionSensor(unittest.TestCase):
+
+    def setUp(self):
+        self.mock_context = MagicMock()
+
+    def test_init(self):
+        sensor = StepFunctionExecutionSensor(
+            task_id=TASK_ID,
+            execution_arn=EXECUTION_ARN,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME
+        )
+
+        self.assertEqual(TASK_ID, sensor.task_id)
+        self.assertEqual(EXECUTION_ARN, sensor.execution_arn)
+        self.assertEqual(AWS_CONN_ID, sensor.aws_conn_id)
+        self.assertEqual(REGION_NAME, sensor.region_name)
+
+    @parameterized.expand([('FAILED',), ('TIMED_OUT',), ('ABORTED',)])
+    
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
+    def test_exceptions(self, mock_status, mock_hook):
+        hook_response = {
+            'status': mock_status
+        }
+
+        hook_instance = mock_hook.return_value
+        hook_instance.describe_execution.return_value = hook_response
+
+        sensor = StepFunctionExecutionSensor(
+            task_id=TASK_ID,
+            execution_arn=EXECUTION_ARN,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME
+        )
+
+        with self.assertRaises(AirflowException):
+            sensor.poke(self.mock_context)
+
+    
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
+    def test_running(self, mock_hook):
+        hook_response = {
+            'status': 'RUNNING'
+        }
+
+        hook_instance = mock_hook.return_value
+        hook_instance.describe_execution.return_value = hook_response
+
+        sensor = StepFunctionExecutionSensor(
+            task_id=TASK_ID,
+            execution_arn=EXECUTION_ARN,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME
+        )
+
+        self.assertFalse(sensor.poke(self.mock_context))
+
+    
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
+    def test_succeeded(self, mock_hook):
+        hook_response = {
+            'status': 'SUCCEEDED'
+        }
+
+        hook_instance = mock_hook.return_value
+        hook_instance.describe_execution.return_value = hook_response
+
+        sensor = StepFunctionExecutionSensor(
+            task_id=TASK_ID,
+            execution_arn=EXECUTION_ARN,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME
+        )
+
+        self.assertTrue(sensor.poke(self.mock_context))

Reply via email to