This is an automated email from the ASF dual-hosted git repository.
turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 564192c Add AWS StepFunctions integrations to the aws provider (#8749)
564192c is described below
commit 564192c1625a552456cebb3751978c08eebdb2a1
Author: chamcca <[email protected]>
AuthorDate: Wed Jul 8 03:25:16 2020 -0600
Add AWS StepFunctions integrations to the aws provider (#8749)
---
.../providers/amazon/aws/hooks/step_function.py | 79 +++++++++++++++
.../step_function_get_execution_output.py | 58 +++++++++++
.../aws/operators/step_function_start_execution.py | 72 ++++++++++++++
.../amazon/aws/sensors/step_function_execution.py | 77 +++++++++++++++
docs/operators-and-hooks-ref.rst | 7 ++
.../amazon/aws/hooks/test_step_function.py | 63 ++++++++++++
.../test_step_function_get_execution_output.py | 76 +++++++++++++++
.../test_step_function_start_execution.py | 82 ++++++++++++++++
.../aws/sensors/test_step_function_execution.py | 107 +++++++++++++++++++++
9 files changed, 621 insertions(+)
diff --git a/airflow/providers/amazon/aws/hooks/step_function.py
b/airflow/providers/amazon/aws/hooks/step_function.py
new file mode 100644
index 0000000..f0e1040
--- /dev/null
+++ b/airflow/providers/amazon/aws/hooks/step_function.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+from typing import Optional, Union
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+
+class StepFunctionHook(AwsBaseHook):
+ """
+ Interact with an AWS Step Functions State Machine.
+
+ Additional arguments (such as ``aws_conn_id``) may be specified and
+ are passed down to the underlying AwsBaseHook.
+
+ .. seealso::
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+ """
+
+ def __init__(self, region_name=None, *args, **kwargs):
+ super().__init__(client_type='stepfunctions', *args, **kwargs)
+
+ def start_execution(self, state_machine_arn: str, name: Optional[str] =
None,
+ state_machine_input: Union[dict, str, None] = None) ->
str:
+ """
+ Start Execution of the State Machine.
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution
+
+ :param state_machine_arn: AWS Step Function State Machine ARN
+ :type state_machine_arn: str
+ :param name: The name of the execution.
+ :type name: Optional[str]
+ :param state_machine_input: JSON data input to pass to the State
Machine
+ :type state_machine_input: Union[Dict[str, any], str, None]
+ :return: Execution ARN
+ :rtype: str
+ """
+ execution_args = {
+ 'stateMachineArn': state_machine_arn
+ }
+ if name is not None:
+ execution_args['name'] = name
+ if state_machine_input is not None:
+ if isinstance(state_machine_input, str):
+ execution_args['input'] = state_machine_input
+ elif isinstance(state_machine_input, dict):
+ execution_args['input'] = json.dumps(state_machine_input)
+
+ self.log.info('Executing Step Function State Machine: %s',
state_machine_arn)
+
+ response = self.conn.start_execution(**execution_args)
+ return response.get('executionArn', None)
+
+ def describe_execution(self, execution_arn: str) -> dict:
+ """
+ Describes a State Machine Execution
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.describe_execution
+
+ :param execution_arn: ARN of the State Machine Execution
+ :type execution_arn: str
+ :return: Dict with Execution details
+ :rtype: dict
+ """
+ return self.get_conn().describe_execution(executionArn=execution_arn)
diff --git
a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
new file mode 100644
index 0000000..2ef531c
--- /dev/null
+++
b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.utils.decorators import apply_defaults
+
+
+class StepFunctionGetExecutionOutputOperator(BaseOperator):
+ """
+ An Operator that begins execution of an Step Function State Machine
+
+ Additional arguments may be specified and are passed down to the
underlying BaseOperator.
+
+ .. seealso::
+ :class:`~airflow.models.BaseOperator`
+
+ :param execution_arn: ARN of the Step Function State Machine Execution
+ :type execution_arn: str
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ :type aws_conn_id: str
+ """
+ template_fields = ['execution_arn']
+ template_ext = ()
+ ui_color = '#f9c915'
+
+ @apply_defaults
+ def __init__(self, execution_arn: str, aws_conn_id='aws_default',
region_name=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.execution_arn = execution_arn
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+
+ def execute(self, context):
+ hook = StepFunctionHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+ execution_status = hook.describe_execution(self.execution_arn)
+ execution_output = json.loads(execution_status['output']) if 'output'
in execution_status else None
+
+ self.log.info('Got State Machine Execution output for %s',
self.execution_arn)
+
+ return execution_output
diff --git
a/airflow/providers/amazon/aws/operators/step_function_start_execution.py
b/airflow/providers/amazon/aws/operators/step_function_start_execution.py
new file mode 100644
index 0000000..f5ea75c
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.utils.decorators import apply_defaults
+
+
+class StepFunctionStartExecutionOperator(BaseOperator):
+ """
+ An Operator that begins execution of an Step Function State Machine
+
+ Additional arguments may be specified and are passed down to the
underlying BaseOperator.
+
+ .. seealso::
+ :class:`~airflow.models.BaseOperator`
+
+ :param state_machine_arn: ARN of the Step Function State Machine
+ :type state_machine_arn: str
+ :param name: The name of the execution.
+ :type name: Optional[str]
+ :param state_machine_input: JSON data input to pass to the State Machine
+ :type state_machine_input: Union[Dict[str, any], str, None]
+ :param aws_conn_id: aws connection to uses
+ :type aws_conn_id: str
+ :param do_xcom_push: if True, execution_arn is pushed to XCom with key
execution_arn.
+ :type do_xcom_push: bool
+ """
+ template_fields = ['state_machine_arn', 'name', 'input']
+ template_ext = ()
+ ui_color = '#f9c915'
+
+ @apply_defaults
+ def __init__(self, state_machine_arn: str, name: Optional[str] = None,
+ state_machine_input: Union[dict, str, None] = None,
+ aws_conn_id='aws_default', region_name=None,
+ *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.state_machine_arn = state_machine_arn
+ self.name = name
+ self.input = state_machine_input
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+
+ def execute(self, context):
+ hook = StepFunctionHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+ execution_arn = hook.start_execution(self.state_machine_arn,
self.name, self.input)
+
+ if execution_arn is None:
+ raise AirflowException(f'Failed to start State Machine execution
for: {self.state_machine_arn}')
+
+ self.log.info('Started State Machine execution for %s: %s',
self.state_machine_arn, execution_arn)
+
+ return execution_arn
diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py
b/airflow/providers/amazon/aws/sensors/step_function_execution.py
new file mode 100644
index 0000000..0cc3caf
--- /dev/null
+++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class StepFunctionExecutionSensor(BaseSensorOperator):
+ """
+ Asks for the state of the Step Function State Machine Execution until it
+ reaches a failure state or success state.
+ If it fails, failing the task.
+
+ On successful completion of the Execution the Sensor will do an XCom Push
+ of the State Machine's output to `output`
+
+ :param execution_arn: execution_arn to check the state of
+ :type execution_arn: str
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ :type aws_conn_id: str
+ """
+
+ INTERMEDIATE_STATES = ('RUNNING',)
+ FAILURE_STATES = ('FAILED', 'TIMED_OUT', 'ABORTED',)
+ SUCCESS_STATES = ('SUCCEEDED',)
+
+ template_fields = ['execution_arn']
+ template_ext = ()
+ ui_color = '#66c3ff'
+
+ @apply_defaults
+ def __init__(self, execution_arn: str, aws_conn_id='aws_default',
region_name=None,
+ *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.execution_arn = execution_arn
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.hook = None
+
+ def poke(self, context):
+ execution_status =
self.get_hook().describe_execution(self.execution_arn)
+ state = execution_status['status']
+ output = json.loads(execution_status['output']) if 'output' in
execution_status else None
+
+ if state in self.FAILURE_STATES:
+ raise AirflowException(f'Step Function sensor failed. State
Machine Output: {output}')
+
+ if state in self.INTERMEDIATE_STATES:
+ return False
+
+ self.log.info('Doing xcom_push of output')
+ self.xcom_push(context, 'output', output)
+ return True
+
+ def get_hook(self):
+ """Create and return a StepFunctionHook"""
+ if not self.hook:
+ self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+ return self.hook
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 9c791d3..5c25882 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -495,6 +495,13 @@ These integrations allow you to perform various operations
within the Amazon Web
- :mod:`airflow.providers.amazon.aws.sensors.s3_key`,
:mod:`airflow.providers.amazon.aws.sensors.s3_prefix`
+ * - `AWS Step Functions <https://aws.amazon.com/step-functions/>`__
+ -
+ - :mod:`airflow.providers.amazon.aws.hooks.step_function`
+ -
:mod:`airflow.providers.amazon.aws.operators.step_function_start_execution`,
+
:mod:`airflow.providers.amazon.aws.operators.step_function_get_execution_output`,
+ - :mod:`airflow.providers.amazon.aws.sensors.step_function_execution`,
+
Transfer operators and hooks
''''''''''''''''''''''''''''
diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py
b/tests/providers/amazon/aws/hooks/test_step_function.py
new file mode 100644
index 0000000..679d2e4
--- /dev/null
+++ b/tests/providers/amazon/aws/hooks/test_step_function.py
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+
+try:
+ from moto import mock_stepfunctions
+except ImportError:
+ mock_stepfunctions = None
+
+
[email protected](mock_stepfunctions is None, 'moto package not present')
+class TestStepFunctionHook(unittest.TestCase):
+
+ @mock_stepfunctions
+ def test_get_conn_returns_a_boto3_connection(self):
+ hook = StepFunctionHook(aws_conn_id='aws_default')
+ self.assertEqual('stepfunctions',
hook.get_conn().meta.service_model.service_name)
+
+ @mock_stepfunctions
+ def test_start_execution(self):
+ hook = StepFunctionHook(aws_conn_id='aws_default',
region_name='us-east-1')
+ state_machine = hook.get_conn().create_state_machine(
+ name='pseudo-state-machine', definition='{}',
roleArn='arn:aws:iam::000000000000:role/Role')
+
+ state_machine_arn = state_machine.get('stateMachineArn', None)
+
+ execution_arn = hook.start_execution(
+ state_machine_arn=state_machine_arn, name=None,
state_machine_input={})
+
+ assert execution_arn is not None
+
+ @mock_stepfunctions
+ def test_describe_execution(self):
+ hook = StepFunctionHook(aws_conn_id='aws_default',
region_name='us-east-1')
+ state_machine = hook.get_conn().create_state_machine(
+ name='pseudo-state-machine', definition='{}',
roleArn='arn:aws:iam::000000000000:role/Role')
+
+ state_machine_arn = state_machine.get('stateMachineArn', None)
+
+ execution_arn = hook.start_execution(
+ state_machine_arn=state_machine_arn, name=None,
state_machine_input={})
+ response = hook.describe_execution(execution_arn)
+
+ assert 'input' in response
diff --git
a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py
b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py
new file mode 100644
index 0000000..8997df9
--- /dev/null
+++
b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
+
+from airflow.providers.amazon.aws.operators.step_function_get_execution_output
import (
+ StepFunctionGetExecutionOutputOperator,
+)
+
+TASK_ID = 'step_function_get_execution_output'
+EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\
+ 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934'
+AWS_CONN_ID = 'aws_non_default'
+REGION_NAME = 'us-west-2'
+
+
+class TestStepFunctionGetExecutionOutputOperator(unittest.TestCase):
+
+ def setUp(self):
+ self.mock_context = MagicMock()
+
+ def test_init(self):
+ # Given / When
+ operator = StepFunctionGetExecutionOutputOperator(
+ task_id=TASK_ID,
+ execution_arn=EXECUTION_ARN,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME
+ )
+
+ # Then
+ self.assertEqual(TASK_ID, operator.task_id)
+ self.assertEqual(EXECUTION_ARN, operator.execution_arn)
+ self.assertEqual(AWS_CONN_ID, operator.aws_conn_id)
+ self.assertEqual(REGION_NAME, operator.region_name)
+
+
@mock.patch('airflow.providers.amazon.aws.operators.step_function_get_execution_output.StepFunctionHook')
+ def test_execute(self, mock_hook):
+ # Given
+ hook_response = {
+ 'output': '{}'
+ }
+
+ hook_instance = mock_hook.return_value
+ hook_instance.describe_execution.return_value = hook_response
+
+ operator = StepFunctionGetExecutionOutputOperator(
+ task_id=TASK_ID,
+ execution_arn=EXECUTION_ARN,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME
+ )
+
+ # When
+ result = operator.execute(self.mock_context)
+
+ # Then
+ self.assertEqual({}, result)
diff --git
a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py
b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py
new file mode 100644
index 0000000..5f6c336
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
+
+from airflow.providers.amazon.aws.operators.step_function_start_execution
import (
+ StepFunctionStartExecutionOperator,
+)
+
+TASK_ID = 'step_function_start_execution_task'
+STATE_MACHINE_ARN =
'arn:aws:states:us-east-1:000000000000:stateMachine:pseudo-state-machine'
+NAME = 'NAME'
+INPUT = '{}'
+AWS_CONN_ID = 'aws_non_default'
+REGION_NAME = 'us-west-2'
+
+
+class TestStepFunctionStartExecutionOperator(unittest.TestCase):
+
+ def setUp(self):
+ self.mock_context = MagicMock()
+
+ def test_init(self):
+ # Given / When
+ operator = StepFunctionStartExecutionOperator(
+ task_id=TASK_ID,
+ state_machine_arn=STATE_MACHINE_ARN,
+ name=NAME,
+ state_machine_input=INPUT,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME
+ )
+
+ # Then
+ self.assertEqual(TASK_ID, operator.task_id)
+ self.assertEqual(STATE_MACHINE_ARN, operator.state_machine_arn)
+ self.assertEqual(NAME, operator.name)
+ self.assertEqual(INPUT, operator.input)
+ self.assertEqual(AWS_CONN_ID, operator.aws_conn_id)
+ self.assertEqual(REGION_NAME, operator.region_name)
+
+
@mock.patch('airflow.providers.amazon.aws.operators.step_function_start_execution.StepFunctionHook')
+ def test_execute(self, mock_hook):
+ # Given
+ hook_response = 'arn:aws:states:us-east-1:123456789012:execution:'\
+
'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934'
+
+ hook_instance = mock_hook.return_value
+ hook_instance.start_execution.return_value = hook_response
+
+ operator = StepFunctionStartExecutionOperator(
+ task_id=TASK_ID,
+ state_machine_arn=STATE_MACHINE_ARN,
+ name=NAME,
+ state_machine_input=INPUT,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME
+ )
+
+ # When
+ result = operator.execute(self.mock_context)
+
+ # Then
+ self.assertEqual(hook_response, result)
diff --git a/tests/providers/amazon/aws/sensors/test_step_function_execution.py
b/tests/providers/amazon/aws/sensors/test_step_function_execution.py
new file mode 100644
index 0000000..237f8ef
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_step_function_execution.py
@@ -0,0 +1,107 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
+
+from parameterized import parameterized
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.sensors.step_function_execution import
StepFunctionExecutionSensor
+
+TASK_ID = 'step_function_execution_sensor'
+EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\
+ 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934'
+AWS_CONN_ID = 'aws_non_default'
+REGION_NAME = 'us-west-2'
+
+
+class TestStepFunctionExecutionSensor(unittest.TestCase):
+
+ def setUp(self):
+ self.mock_context = MagicMock()
+
+ def test_init(self):
+ sensor = StepFunctionExecutionSensor(
+ task_id=TASK_ID,
+ execution_arn=EXECUTION_ARN,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME
+ )
+
+ self.assertEqual(TASK_ID, sensor.task_id)
+ self.assertEqual(EXECUTION_ARN, sensor.execution_arn)
+ self.assertEqual(AWS_CONN_ID, sensor.aws_conn_id)
+ self.assertEqual(REGION_NAME, sensor.region_name)
+
+ @parameterized.expand([('FAILED',), ('TIMED_OUT',), ('ABORTED',)])
+
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
+ def test_exceptions(self, mock_status, mock_hook):
+ hook_response = {
+ 'status': mock_status
+ }
+
+ hook_instance = mock_hook.return_value
+ hook_instance.describe_execution.return_value = hook_response
+
+ sensor = StepFunctionExecutionSensor(
+ task_id=TASK_ID,
+ execution_arn=EXECUTION_ARN,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME
+ )
+
+ with self.assertRaises(AirflowException):
+ sensor.poke(self.mock_context)
+
+
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
+ def test_running(self, mock_hook):
+ hook_response = {
+ 'status': 'RUNNING'
+ }
+
+ hook_instance = mock_hook.return_value
+ hook_instance.describe_execution.return_value = hook_response
+
+ sensor = StepFunctionExecutionSensor(
+ task_id=TASK_ID,
+ execution_arn=EXECUTION_ARN,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME
+ )
+
+ self.assertFalse(sensor.poke(self.mock_context))
+
+
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
+ def test_succeeded(self, mock_hook):
+ hook_response = {
+ 'status': 'SUCCEEDED'
+ }
+
+ hook_instance = mock_hook.return_value
+ hook_instance.describe_execution.return_value = hook_response
+
+ sensor = StepFunctionExecutionSensor(
+ task_id=TASK_ID,
+ execution_arn=EXECUTION_ARN,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME
+ )
+
+ self.assertTrue(sensor.poke(self.mock_context))