This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c1526a2888 Convert Step Functions Example DAG to System Test (AIP-47) 
(#24643)
c1526a2888 is described below

commit c1526a28889d73d2fe33752904524bd133067a75
Author: Niko <[email protected]>
AuthorDate: Thu Jul 7 13:56:49 2022 -0700

    Convert Step Functions Example DAG to System Test (AIP-47) (#24643)
    
    * Add Amazon System Test Context Builder
    
    This factory class builds a task that will generate the environment id
    and fetch any variables at runtime and store them in xcom
    
    * Convert step functions example dag to system test
    
    Moved example_step_functions example dag AWS system test dir.
    Convert to AIP-47 standard.
---
 .../aws/example_dags/example_step_functions.py     |  56 ----------
 .../operators/step_functions.rst                   |   6 +-
 .../providers/amazon/aws/example_step_functions.py | 119 +++++++++++++++++++++
 .../system/providers/amazon/aws/utils/__init__.py  |  55 +++++++++-
 4 files changed, 172 insertions(+), 64 deletions(-)

diff --git 
a/airflow/providers/amazon/aws/example_dags/example_step_functions.py 
b/airflow/providers/amazon/aws/example_dags/example_step_functions.py
deleted file mode 100644
index 02763e3ea1..0000000000
--- a/airflow/providers/amazon/aws/example_dags/example_step_functions.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from os import environ
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.step_function import (
-    StepFunctionGetExecutionOutputOperator,
-    StepFunctionStartExecutionOperator,
-)
-from airflow.providers.amazon.aws.sensors.step_function import 
StepFunctionExecutionSensor
-
-STEP_FUNCTIONS_STATE_MACHINE_ARN = 
environ.get('STEP_FUNCTIONS_STATE_MACHINE_ARN', 'state_machine_arn')
-
-with DAG(
-    dag_id='example_step_functions',
-    schedule_interval=None,
-    start_date=datetime(2021, 1, 1),
-    tags=['example'],
-    catchup=False,
-) as dag:
-
-    # [START howto_operator_step_function_start_execution]
-    start_execution = StepFunctionStartExecutionOperator(
-        task_id='start_execution', 
state_machine_arn=STEP_FUNCTIONS_STATE_MACHINE_ARN
-    )
-    # [END howto_operator_step_function_start_execution]
-
-    # [START howto_sensor_step_function_execution]
-    wait_for_execution = StepFunctionExecutionSensor(
-        task_id='wait_for_execution', execution_arn=start_execution.output
-    )
-    # [END howto_sensor_step_function_execution]
-
-    # [START howto_operator_step_function_get_execution_output]
-    get_execution_output = StepFunctionGetExecutionOutputOperator(
-        task_id='get_execution_output', execution_arn=start_execution.output
-    )
-    # [END howto_operator_step_function_get_execution_output]
-
-    chain(start_execution, wait_for_execution, get_execution_output)
diff --git a/docs/apache-airflow-providers-amazon/operators/step_functions.rst 
b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
index e29d3194f6..91984dc598 100644
--- a/docs/apache-airflow-providers-amazon/operators/step_functions.rst
+++ b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
@@ -39,7 +39,7 @@ Start an AWS Step Functions state machine execution
 To start a new AWS Step Functions state machine execution you can use
 
:class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionStartExecutionOperator`.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_step_functions.py
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_step_functions.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_step_function_start_execution]
@@ -53,7 +53,7 @@ Get an AWS Step Functions execution output
 To fetch the output from an AWS Step Function state machine execution you can
 use 
:class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionGetExecutionOutputOperator`.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_step_functions.py
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_step_functions.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_step_function_get_execution_output]
@@ -70,7 +70,7 @@ Wait on an AWS Step Functions state machine execution state
 To wait on the state of an AWS Step Function state machine execution until it 
reaches a terminal state you can
 use 
:class:`~airflow.providers.amazon.aws.sensors.step_function.StepFunctionExecutionSensor`.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_step_functions.py
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_step_functions.py
     :language: python
     :dedent: 4
     :start-after: [START howto_sensor_step_function_execution]
diff --git a/tests/system/providers/amazon/aws/example_step_functions.py 
b/tests/system/providers/amazon/aws/example_step_functions.py
new file mode 100644
index 0000000000..98d6b7a743
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_step_functions.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+from datetime import datetime
+
+from airflow import DAG
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.providers.amazon.aws.operators.step_function import (
+    StepFunctionGetExecutionOutputOperator,
+    StepFunctionStartExecutionOperator,
+)
+from airflow.providers.amazon.aws.sensors.step_function import 
StepFunctionExecutionSensor
+from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, 
SystemTestContextBuilder
+
+DAG_ID = 'example_step_functions'
+
+# Externally fetched variables:
+ROLE_ARN_KEY = 'ROLE_ARN'
+
+sys_test_context_task = 
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+STATE_MACHINE_DEFINITION = {
+    "StartAt": "Wait",
+    "States": {"Wait": {"Type": "Wait", "Seconds": 7, "Next": "Success"}, 
"Success": {"Type": "Succeed"}},
+}
+
+
+@task
+def create_state_machine(env_id, role_arn):
+    # Create a Step Functions State Machine and return the ARN for use by
+    # downstream tasks.
+    return (
+        StepFunctionHook()
+        .get_conn()
+        .create_state_machine(
+            name=f'{DAG_ID}_{env_id}',
+            definition=json.dumps(STATE_MACHINE_DEFINITION),
+            roleArn=role_arn,
+        )['stateMachineArn']
+    )
+
+
+@task
+def delete_state_machine(state_machine_arn):
+    
StepFunctionHook().get_conn().delete_state_machine(stateMachineArn=state_machine_arn)
+
+
+with DAG(
+    dag_id=DAG_ID,
+    schedule_interval='@once',
+    start_date=datetime(2021, 1, 1),
+    tags=['example'],
+    catchup=False,
+) as dag:
+
+    # This context contains the ENV_ID and any env variables requested when the
+    # task was built above. Access the info as you would any other TaskFlow 
task.
+    test_context = sys_test_context_task()
+    env_id = test_context[ENV_ID_KEY]
+    role_arn = test_context[ROLE_ARN_KEY]
+
+    state_machine_arn = create_state_machine(env_id, role_arn)
+
+    # [START howto_operator_step_function_start_execution]
+    start_execution = StepFunctionStartExecutionOperator(
+        task_id='start_execution', state_machine_arn=state_machine_arn
+    )
+    # [END howto_operator_step_function_start_execution]
+
+    # [START howto_sensor_step_function_execution]
+    wait_for_execution = StepFunctionExecutionSensor(
+        task_id='wait_for_execution', execution_arn=start_execution.output
+    )
+    # [END howto_sensor_step_function_execution]
+
+    # [START howto_operator_step_function_get_execution_output]
+    get_execution_output = StepFunctionGetExecutionOutputOperator(
+        task_id='get_execution_output', execution_arn=start_execution.output
+    )
+    # [END howto_operator_step_function_get_execution_output]
+
+    chain(
+        # TEST SETUP
+        test_context,
+        state_machine_arn,
+        # TEST BODY
+        start_execution,
+        wait_for_execution,
+        get_execution_output,
+        # TEST TEARDOWN
+        delete_state_machine(state_machine_arn),
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/amazon/aws/utils/__init__.py 
b/tests/system/providers/amazon/aws/utils/__init__.py
index 09b7fff0a1..38f606d0fb 100644
--- a/tests/system/providers/amazon/aws/utils/__init__.py
+++ b/tests/system/providers/amazon/aws/utils/__init__.py
@@ -26,7 +26,10 @@ import boto3
 from botocore.client import BaseClient
 from botocore.exceptions import NoCredentialsError
 
+from airflow.decorators import task
+
 ENV_ID_ENVIRON_KEY: str = 'SYSTEM_TESTS_ENV_ID'
+ENV_ID_KEY: str = 'ENV_ID'
 DEFAULT_ENV_ID_PREFIX: str = 'env'
 DEFAULT_ENV_ID_LEN: int = 8
 DEFAULT_ENV_ID: str = 
f'{DEFAULT_ENV_ID_PREFIX}{str(uuid4())[:DEFAULT_ENV_ID_LEN]}'
@@ -76,19 +79,19 @@ def _validate_env_id(env_id: str) -> str:
     return env_id.lower()
 
 
-def _fetch_from_ssm(key: str) -> str:
+def _fetch_from_ssm(key: str, test_name: Optional[str] = None) -> str:
     """
     Test values are stored in the SSM Value as a JSON-encoded dict of 
key/value pairs.
 
     :param key: The key to search for within the returned Parameter Value.
     :return: The value of the provided key from SSM
     """
-    test_name: str = _get_test_name()
+    _test_name: str = test_name if test_name else _get_test_name()
     ssm_client: BaseClient = boto3.client('ssm')
     value: str = ''
 
     try:
-        value = 
json.loads(ssm_client.get_parameter(Name=test_name)['Parameter']['Value'])[key]
+        value = 
json.loads(ssm_client.get_parameter(Name=_test_name)['Parameter']['Value'])[key]
     # Since a default value after the SSM check is allowed, these exceptions 
should not stop execution.
     except NoCredentialsError:
         # No boto credentials found.
@@ -102,7 +105,49 @@ def _fetch_from_ssm(key: str) -> str:
     return value
 
 
-def fetch_variable(key: str, default_value: Optional[str] = None) -> str:
+class SystemTestContextBuilder:
+    """This builder class ultimately constructs a TaskFlow task which is run at
+    runtime (task execution time). This task generates and stores the test 
ENV_ID as well
+    as any external resources requested (e.g.g IAM Roles, VPC, etc)"""
+
+    def __init__(self):
+        self.variables = []
+        self.variable_defaults = {}
+        self.test_name = _get_test_name()
+        self.env_id = set_env_id()
+
+    def add_variable(self, variable_name: str, **kwargs):
+        """Register a variable to fetch from environment or cloud parameter 
store"""
+        self.variables.append(variable_name)
+        # default_value is accepted via kwargs so that it is completely 
optional and no
+        # default value needs to be provided in the method stub (otherwise we 
wouldn't
+        # be able to tell the difference between our default value and one 
provided by
+        # the caller)
+        if 'default_value' in kwargs:
+            self.variable_defaults[variable_name] = kwargs['default_value']
+
+        return self  # Builder recipe; returning self allows chaining
+
+    def build(self):
+        """Build and return a TaskFlow task which will create an env_id and
+        fetch requested variables. Storing everything in xcom for downstream
+        tasks to use."""
+
+        @task
+        def variable_fetcher(**kwargs):
+            ti = kwargs['ti']
+            for variable in self.variables:
+                default_value = self.variable_defaults.get(variable, None)
+                value = fetch_variable(variable, default_value, 
test_name=self.test_name)
+                ti.xcom_push(variable, value)
+
+            # Fetch/generate ENV_ID and store it in XCOM
+            ti.xcom_push(ENV_ID_KEY, self.env_id)
+
+        return variable_fetcher
+
+
+def fetch_variable(key: str, default_value: Optional[str] = None, test_name: 
Optional[str] = None) -> str:
     """
     Given a Parameter name: first check for an existing Environment Variable,
     then check SSM for a value. If neither are available, fall back on the
@@ -113,7 +158,7 @@ def fetch_variable(key: str, default_value: Optional[str] = 
None) -> str:
     :return: The value of the parameter.
     """
 
-    value: Optional[str] = os.getenv(key, _fetch_from_ssm(key)) or 
default_value
+    value: Optional[str] = os.getenv(key, _fetch_from_ssm(key, 
test_name=test_name)) or default_value
     if not value:
         raise ValueError(NO_VALUE_MSG.format(key=key))
     return value

Reply via email to