This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 239440df23 System Test for EMR (AIP-47) (#27286)
239440df23 is described below

commit 239440df23210355dc02f71b6c5aea7734651055
Author: Syed Hussaain <[email protected]>
AuthorDate: Thu Nov 17 09:30:20 2022 -0800

    System Test for EMR (AIP-47) (#27286)
---
 airflow/providers/amazon/aws/hooks/emr.py          | 29 ++++++++
 airflow/providers/amazon/aws/operators/emr.py      | 13 ++--
 airflow/providers/amazon/aws/sensors/emr.py        |  3 +-
 .../operators/emr.rst                              | 21 ++----
 tests/always/test_project_structure.py             |  2 +
 tests/providers/amazon/aws/hooks/test_emr.py       | 56 +++++++++++++++
 .../system/providers/amazon/aws}/example_emr.py    | 84 +++++++++++++---------
 7 files changed, 148 insertions(+), 60 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/emr.py 
b/airflow/providers/amazon/aws/hooks/emr.py
index 1365c1705d..5423dd1af8 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -124,6 +124,35 @@ class EmrHook(AwsBaseHook):
 
         return response
 
+    def add_job_flow_steps(
+        self, job_flow_id: str, steps: list[dict] | str | None = None, 
wait_for_completion: bool = False
+    ) -> list[str]:
+        """
+        Add new steps to a running cluster.
+
+        :param job_flow_id: The id of the job flow to which the steps are 
being added
+        :param steps: A list of the steps to be executed by the job flow
+        :param wait_for_completion: If True, wait for the steps to be 
completed. Default is False
+        """
+        response = self.get_conn().add_job_flow_steps(JobFlowId=job_flow_id, 
Steps=steps)
+
+        if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+            raise AirflowException(f"Adding steps failed: {response}")
+
+        self.log.info("Steps %s added to JobFlow", response["StepIds"])
+        if wait_for_completion:
+            waiter = self.get_conn().get_waiter("step_complete")
+            for step_id in response["StepIds"]:
+                waiter.wait(
+                    ClusterId=job_flow_id,
+                    StepId=step_id,
+                    WaiterConfig={
+                        "Delay": 5,
+                        "MaxAttempts": 100,
+                    },
+                )
+        return response["StepIds"]
+
     def test_connection(self):
         """
         Return failed state for test Amazon Elastic MapReduce Connection 
(untestable).
diff --git a/airflow/providers/amazon/aws/operators/emr.py 
b/airflow/providers/amazon/aws/operators/emr.py
index 0935ed2422..63659e8188 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -50,6 +50,7 @@ class EmrAddStepsOperator(BaseOperator):
     :param aws_conn_id: aws connection to uses
     :param steps: boto3 style steps or reference to a steps file (must be 
'.json') to
         be added to the jobflow. (templated)
+    :param wait_for_completion: If True, the operator will wait for all the 
steps to be completed.
     :param do_xcom_push: if True, job_flow_id is pushed to XCom with key 
job_flow_id.
     """
 
@@ -67,6 +68,7 @@ class EmrAddStepsOperator(BaseOperator):
         cluster_states: list[str] | None = None,
         aws_conn_id: str = "aws_default",
         steps: list[dict] | str | None = None,
+        wait_for_completion: bool = False,
         **kwargs,
     ):
         if not (job_flow_id is None) ^ (job_flow_name is None):
@@ -79,12 +81,11 @@ class EmrAddStepsOperator(BaseOperator):
         self.job_flow_name = job_flow_name
         self.cluster_states = cluster_states
         self.steps = steps
+        self.wait_for_completion = wait_for_completion
 
     def execute(self, context: Context) -> list[str]:
         emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
 
-        emr = emr_hook.get_conn()
-
         job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(
             str(self.job_flow_name), self.cluster_states
         )
@@ -111,13 +112,7 @@ class EmrAddStepsOperator(BaseOperator):
         if isinstance(steps, str):
             steps = ast.literal_eval(steps)
 
-        response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps)
-
-        if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
-            raise AirflowException(f"Adding steps failed: {response}")
-        else:
-            self.log.info("Steps %s added to JobFlow", response["StepIds"])
-            return response["StepIds"]
+        return emr_hook.add_job_flow_steps(job_flow_id=job_flow_id, 
steps=steps, wait_for_completion=True)
 
 
 class EmrEksCreateClusterOperator(BaseOperator):
diff --git a/airflow/providers/amazon/aws/sensors/emr.py 
b/airflow/providers/amazon/aws/sensors/emr.py
index 0659554762..a3684fa249 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Sequence
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, 
EmrServerlessHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.sensors.base import BaseSensorOperator, poke_mode_only
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -378,6 +378,7 @@ class EmrJobFlowSensor(EmrBaseSensor):
         return None
 
 
+@poke_mode_only
 class EmrStepSensor(EmrBaseSensor):
     """
     Asks for the state of the step until it reaches any of the target states.
diff --git a/docs/apache-airflow-providers-amazon/operators/emr.rst 
b/docs/apache-airflow-providers-amazon/operators/emr.rst
index 6bb16bd494..a9c434777d 100644
--- a/docs/apache-airflow-providers-amazon/operators/emr.rst
+++ b/docs/apache-airflow-providers-amazon/operators/emr.rst
@@ -53,7 +53,7 @@ JobFlow configuration
 
 To create a job flow on EMR, you need to specify the configuration for the EMR 
cluster:
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_emr.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py
     :language: python
     :start-after: [START howto_operator_emr_steps_config]
     :end-before: [END howto_operator_emr_steps_config]
@@ -76,7 +76,7 @@ Create the Job Flow
 
 In the following code we are creating a new job flow using the configuration 
as explained above.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_emr.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_emr_create_job_flow]
@@ -90,7 +90,7 @@ Add Steps to an EMR job flow
 To add steps to an existing EMR Job flow you can use
 :class:`~airflow.providers.amazon.aws.operators.emr.EmrAddStepsOperator`.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_emr.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_emr_add_steps]
@@ -104,7 +104,7 @@ Terminate an EMR job flow
 To terminate an EMR Job Flow you can use
 
:class:`~airflow.providers.amazon.aws.operators.emr.EmrTerminateJobFlowOperator`.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_emr.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_emr_terminate_job_flow]
@@ -118,7 +118,7 @@ Modify Amazon EMR container
 To modify an existing EMR container you can use
 :class:`~airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor`.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_emr.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_emr_modify_cluster]
@@ -135,7 +135,7 @@ Wait on an Amazon EMR job flow state
 To monitor the state of an EMR job flow you can use
 :class:`~airflow.providers.amazon.aws.sensors.emr.EmrJobFlowSensor`.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_emr.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py
     :language: python
     :dedent: 4
     :start-after: [START howto_sensor_emr_job_flow]
@@ -146,15 +146,6 @@ To monitor the state of an EMR job flow you can use
 Wait on an Amazon EMR step state
 ================================
 
-To monitor the state of a step running an existing EMR Job flow you can use
-:class:`~airflow.providers.amazon.aws.sensors.emr.EmrStepSensor`.
-
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_emr.py
-    :language: python
-    :dedent: 4
-    :start-after: [START howto_sensor_emr_step]
-    :end-before: [END howto_sensor_emr_step]
-
 Reference
 ---------
 
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index 2f215d2a6f..438ff2f215 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -400,6 +400,8 @@ class 
TestAmazonProviderProjectStructure(ExampleCoverageTest):
         
"airflow.providers.amazon.aws.transfers.exasol_to_s3.ExasolToS3Operator",
         # Glue Catalog sensor difficult to test
         
"airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor",
+        # EMR Step sensor difficult to test, see: 
https://github.com/apache/airflow/pull/27286
+        "airflow.providers.amazon.aws.sensors.emr.EmrStepSensor",
     }
 
 
diff --git a/tests/providers/amazon/aws/hooks/test_emr.py 
b/tests/providers/amazon/aws/hooks/test_emr.py
index 8e91aba5e6..ea3cf73eeb 100644
--- a/tests/providers/amazon/aws/hooks/test_emr.py
+++ b/tests/providers/amazon/aws/hooks/test_emr.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import re
 from unittest import mock
 
 import boto3
@@ -43,6 +44,61 @@ class TestEmrHook:
 
         assert client.list_clusters()["Clusters"][0]["Id"] == 
cluster["JobFlowId"]
 
+    @mock_emr
+    @pytest.mark.parametrize("num_steps", [1, 2, 3, 4])
+    def test_add_job_flow_steps_one_step(self, num_steps):
+        hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default", 
region_name="us-east-1")
+        cluster = hook.create_job_flow(
+            {"Name": "test_cluster", "Instances": 
{"KeepJobFlowAliveWhenNoSteps": False}}
+        )
+        steps = [
+            {
+                "ActionOnFailure": "test_step",
+                "HadoopJarStep": {
+                    "Args": ["test args"],
+                    "Jar": "test.jar",
+                },
+                "Name": f"step_{i}",
+            }
+            for i in range(num_steps)
+        ]
+        response = hook.add_job_flow_steps(job_flow_id=cluster["JobFlowId"], 
steps=steps)
+
+        assert len(response) == num_steps
+        for step_id in response:
+            assert re.match("s-[A-Z0-9]{13}$", step_id)
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn")
+    def test_add_job_flow_steps_wait_for_completion(self, mock_conn):
+        hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default", 
region_name="us-east-1")
+        mock_conn.run_job_flow.return_value = {
+            "JobFlowId": "job_flow_id",
+            "ClusterArn": "cluster_arn",
+        }
+        mock_conn.add_job_flow_steps.return_value = {
+            "StepIds": [
+                "step_id",
+            ],
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+
+        hook.create_job_flow({"Name": "test_cluster", "Instances": 
{"KeepJobFlowAliveWhenNoSteps": False}})
+
+        steps = [
+            {
+                "ActionOnFailure": "test_step",
+                "HadoopJarStep": {
+                    "Args": ["test args"],
+                    "Jar": "test.jar",
+                },
+                "Name": "step_1",
+            }
+        ]
+
+        hook.add_job_flow_steps(job_flow_id="job_flow_id", steps=steps, 
wait_for_completion=True)
+
+        mock_conn.get_waiter.assert_called_once_with("step_complete")
+
     @mock_emr
     def test_create_job_flow_extra_args(self):
         """
diff --git a/airflow/providers/amazon/aws/example_dags/example_emr.py 
b/tests/system/providers/amazon/aws/example_emr.py
similarity index 63%
rename from airflow/providers/amazon/aws/example_dags/example_emr.py
rename to tests/system/providers/amazon/aws/example_emr.py
index 2f5e2d0d5e..a39ef92835 100644
--- a/airflow/providers/amazon/aws/example_dags/example_emr.py
+++ b/tests/system/providers/amazon/aws/example_emr.py
@@ -15,9 +15,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 from __future__ import annotations
 
-import os
 from datetime import datetime
 
 from airflow import DAG
@@ -28,10 +28,11 @@ from airflow.providers.amazon.aws.operators.emr import (
     EmrModifyClusterOperator,
     EmrTerminateJobFlowOperator,
 )
-from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor, 
EmrStepSensor
+from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, 
SystemTestContextBuilder
 
-JOB_FLOW_ROLE = os.getenv("EMR_JOB_FLOW_ROLE", "EMR_EC2_DefaultRole")
-SERVICE_ROLE = os.getenv("EMR_SERVICE_ROLE", "EMR_DefaultRole")
+DAG_ID = "example_emr"
 
 # [START howto_operator_emr_steps_config]
 SPARK_STEPS = [
@@ -63,64 +64,77 @@ JOB_FLOW_OVERRIDES = {
         "TerminationProtected": False,
     },
     "Steps": SPARK_STEPS,
-    "JobFlowRole": JOB_FLOW_ROLE,
-    "ServiceRole": SERVICE_ROLE,
+    "JobFlowRole": "EMR_EC2_DefaultRole",
+    "ServiceRole": "EMR_DefaultRole",
 }
 # [END howto_operator_emr_steps_config]
 
+sys_test_context_task = SystemTestContextBuilder().build()
+
 with DAG(
-    dag_id="example_emr",
+    dag_id=DAG_ID,
     start_date=datetime(2021, 1, 1),
-    tags=["example"],
+    schedule="@once",
     catchup=False,
+    tags=["example"],
 ) as dag:
+    test_context = sys_test_context_task()
+    env_id = test_context[ENV_ID_KEY]
+
     # [START howto_operator_emr_create_job_flow]
-    job_flow_creator = EmrCreateJobFlowOperator(
+    create_job_flow = EmrCreateJobFlowOperator(
         task_id="create_job_flow",
         job_flow_overrides=JOB_FLOW_OVERRIDES,
     )
     # [END howto_operator_emr_create_job_flow]
 
-    job_flow_id = job_flow_creator.output
-
-    # [START howto_sensor_emr_job_flow]
-    job_sensor = EmrJobFlowSensor(task_id="check_job_flow", 
job_flow_id=job_flow_id)
-    # [END howto_sensor_emr_job_flow]
-
     # [START howto_operator_emr_modify_cluster]
-    cluster_modifier = EmrModifyClusterOperator(
-        task_id="modify_cluster", cluster_id=job_flow_id, 
step_concurrency_level=1
+    modify_cluster = EmrModifyClusterOperator(
+        task_id="modify_cluster", cluster_id=create_job_flow.output, 
step_concurrency_level=1
     )
     # [END howto_operator_emr_modify_cluster]
 
     # [START howto_operator_emr_add_steps]
-    step_adder = EmrAddStepsOperator(
+    add_steps = EmrAddStepsOperator(
         task_id="add_steps",
-        job_flow_id=job_flow_id,
+        job_flow_id=create_job_flow.output,
         steps=SPARK_STEPS,
+        wait_for_completion=True,
     )
     # [END howto_operator_emr_add_steps]
 
-    # [START howto_sensor_emr_step]
-    step_checker = EmrStepSensor(
-        task_id="watch_step",
-        job_flow_id=job_flow_id,
-        step_id="{{ task_instance.xcom_pull(task_ids='add_steps', 
key='return_value')[0] }}",
-    )
-    # [END howto_sensor_emr_step]
-
     # [START howto_operator_emr_terminate_job_flow]
-    cluster_remover = EmrTerminateJobFlowOperator(
+    remove_cluster = EmrTerminateJobFlowOperator(
         task_id="remove_cluster",
-        job_flow_id=job_flow_id,
+        job_flow_id=create_job_flow.output,
     )
     # [END howto_operator_emr_terminate_job_flow]
+    remove_cluster.trigger_rule = TriggerRule.ALL_DONE
+
+    # [START howto_sensor_emr_job_flow]
+    check_job_flow = EmrJobFlowSensor(task_id="check_job_flow", 
job_flow_id=create_job_flow.output)
+    # [END howto_sensor_emr_job_flow]
 
     chain(
-        job_flow_creator,
-        job_sensor,
-        cluster_modifier,
-        step_adder,
-        step_checker,
-        cluster_remover,
+        # TEST SETUP
+        test_context,
+        # TEST BODY
+        create_job_flow,
+        modify_cluster,
+        add_steps,
+        # TEST TEARDOWN
+        remove_cluster,
+        check_job_flow,
     )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to