This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9be22e4cc0 Add deferrable option to EmrTerminateJobFlowOperator 
(#31646)
9be22e4cc0 is described below

commit 9be22e4cc09faba5db6432ccac8d2193114d95ee
Author: Syed Hussaain <[email protected]>
AuthorDate: Thu Jun 15 09:21:36 2023 -0700

    Add deferrable option to EmrTerminateJobFlowOperator (#31646)
    
    * Add job_flow_terminated to list of custom waiters
    
    * Fix test_service_waiters so that order doesn't matter
    
    * Add documentation to explain availability of deferrable mode for 
EmrTerminateJobFlowOperator
    Add EmrTerminateJobFlowTrigger to provider.yaml
---
 airflow/providers/amazon/aws/operators/emr.py      |  48 ++++++-
 airflow/providers/amazon/aws/triggers/emr.py       |  73 ++++++++++
 airflow/providers/amazon/aws/waiters/emr.json      |  19 +++
 .../operators/emr/emr.rst                          |   4 +
 tests/providers/amazon/aws/hooks/test_emr.py       |   4 +-
 .../aws/operators/test_emr_terminate_job_flow.py   |  23 +++
 tests/providers/amazon/aws/triggers/test_emr.py    | 154 ++++++++++++++++++++-
 7 files changed, 319 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/emr.py 
b/airflow/providers/amazon/aws/operators/emr.py
index 15ef379aad..b8ca53226e 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -28,7 +28,11 @@ from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarni
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, 
EmrServerlessHook
 from airflow.providers.amazon.aws.links.emr import EmrClusterLink, 
EmrLogsLink, get_log_uri
-from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger, 
EmrCreateJobFlowTrigger
+from airflow.providers.amazon.aws.triggers.emr import (
+    EmrAddStepsTrigger,
+    EmrCreateJobFlowTrigger,
+    EmrTerminateJobFlowTrigger,
+)
 from airflow.providers.amazon.aws.utils.waiter import waiter
 from airflow.utils.helpers import exactly_one, prune_dict
 from airflow.utils.types import NOTSET, ArgNotSet
@@ -842,6 +846,11 @@ class EmrTerminateJobFlowOperator(BaseOperator):
 
     :param job_flow_id: id of the JobFlow to terminate. (templated)
     :param aws_conn_id: aws connection to uses
+    :param waiter_delay: Time (in seconds) to wait between two consecutive 
calls to check JobFlow status
+    :param waiter_max_attempts: The maximum number of times to poll for 
JobFlow status.
+    :param deferrable: If True, the operator will wait asynchronously for the 
crawl to complete.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False)
     """
 
     template_fields: Sequence[str] = ("job_flow_id",)
@@ -852,10 +861,22 @@ class EmrTerminateJobFlowOperator(BaseOperator):
         EmrLogsLink(),
     )
 
-    def __init__(self, *, job_flow_id: str, aws_conn_id: str = "aws_default", 
**kwargs):
+    def __init__(
+        self,
+        *,
+        job_flow_id: str,
+        aws_conn_id: str = "aws_default",
+        waiter_delay: int = 60,
+        waiter_max_attempts: int = 20,
+        deferrable: bool = False,
+        **kwargs,
+    ):
         super().__init__(**kwargs)
         self.job_flow_id = job_flow_id
         self.aws_conn_id = aws_conn_id
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
 
     def execute(self, context: Context) -> None:
         emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
@@ -883,7 +904,28 @@ class EmrTerminateJobFlowOperator(BaseOperator):
         if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
             raise AirflowException(f"JobFlow termination failed: {response}")
         else:
-            self.log.info("JobFlow with id %s terminated", self.job_flow_id)
+            self.log.info("Terminating JobFlow with id %s", self.job_flow_id)
+
+        if self.deferrable:
+            self.defer(
+                trigger=EmrTerminateJobFlowTrigger(
+                    job_flow_id=self.job_flow_id,
+                    poll_interval=self.waiter_delay,
+                    max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+                # timeout is set to ensure that if a trigger dies, the timeout 
does not restart
+                # 60 seconds is added to allow the trigger to exit gracefully 
(i.e. yield TriggerEvent)
+                timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay + 60),
+            )
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error terminating JobFlow: {event}")
+        else:
+            self.log.info("Jobflow terminated successfully.")
+        return
 
 
 class EmrServerlessCreateApplicationOperator(BaseOperator):
diff --git a/airflow/providers/amazon/aws/triggers/emr.py 
b/airflow/providers/amazon/aws/triggers/emr.py
index 76ee47bc8b..1c3c8bb833 100644
--- a/airflow/providers/amazon/aws/triggers/emr.py
+++ b/airflow/providers/amazon/aws/triggers/emr.py
@@ -173,3 +173,76 @@ class EmrCreateJobFlowTrigger(BaseTrigger):
                     "job_flow_id": self.job_flow_id,
                 }
             )
+
+
+class EmrTerminateJobFlowTrigger(BaseTrigger):
+    """
+    Trigger that terminates a running EMR Job Flow.
+    The trigger will asynchronously poll the boto3 API and wait for the
+    JobFlow to finish terminating.
+
+    :param job_flow_id: ID of the EMR Job Flow to terminate
+    :param poll_interval: The amount of time in seconds to wait between 
attempts.
+    :param max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    """
+
+    def __init__(
+        self,
+        job_flow_id: str,
+        poll_interval: int,
+        max_attempts: int,
+        aws_conn_id: str,
+    ):
+        self.job_flow_id = job_flow_id
+        self.poll_interval = poll_interval
+        self.max_attempts = max_attempts
+        self.aws_conn_id = aws_conn_id
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "job_flow_id": self.job_flow_id,
+                "poll_interval": str(self.poll_interval),
+                "max_attempts": str(self.max_attempts),
+                "aws_conn_id": self.aws_conn_id,
+            },
+        )
+
+    async def run(self):
+        self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
+        async with self.hook.async_conn as client:
+            attempt = 0
+            waiter = self.hook.get_waiter("job_flow_terminated", 
deferrable=True, client=client)
+            while attempt < int(self.max_attempts):
+                attempt = attempt + 1
+                try:
+                    await waiter.wait(
+                        ClusterId=self.job_flow_id,
+                        WaiterConfig=prune_dict(
+                            {
+                                "Delay": self.poll_interval,
+                                "MaxAttempts": 1,
+                            }
+                        ),
+                    )
+                    break
+                except WaiterError as error:
+                    if "terminal failure" in str(error):
+                        raise AirflowException(f"JobFlow termination failed: 
{error}")
+                    self.log.info(
+                        "Status of jobflow is %s - %s",
+                        error.last_response["Cluster"]["Status"]["State"],
+                        
error.last_response["Cluster"]["Status"]["StateChangeReason"],
+                    )
+                    await asyncio.sleep(int(self.poll_interval))
+        if attempt >= int(self.max_attempts):
+            raise AirflowException(f"JobFlow termination failed - max attempts 
reached: {self.max_attempts}")
+        else:
+            yield TriggerEvent(
+                {
+                    "status": "success",
+                    "message": "JobFlow terminated successfully",
+                }
+            )
diff --git a/airflow/providers/amazon/aws/waiters/emr.json 
b/airflow/providers/amazon/aws/waiters/emr.json
index 13bc5857e3..d27cd08e00 100644
--- a/airflow/providers/amazon/aws/waiters/emr.json
+++ b/airflow/providers/amazon/aws/waiters/emr.json
@@ -75,6 +75,25 @@
                     "state": "failure"
                 }
             ]
+        },
+        "job_flow_terminated": {
+            "operation": "DescribeCluster",
+            "delay": 30,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "Cluster.Status.State",
+                    "expected": "TERMINATED",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "Cluster.Status.State",
+                    "expected": "TERMINATED_WITH_ERRORS",
+                    "state": "failure"
+                }
+            ]
         }
     }
 }
diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst 
b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst
index d26a427a64..8a2255ddbf 100644
--- a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst
+++ b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst
@@ -111,6 +111,10 @@ Terminate an EMR job flow
 
 To terminate an EMR Job Flow you can use
 
:class:`~airflow.providers.amazon.aws.operators.emr.EmrTerminateJobFlowOperator`.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as 
a parameter.
+Using ``deferrable`` mode will release worker slots and leads to efficient 
utilization of
+resources within Airflow cluster.However this mode will need the Airflow 
triggerer to be
+available in your deployment.
 
 .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py
     :language: python
diff --git a/tests/providers/amazon/aws/hooks/test_emr.py 
b/tests/providers/amazon/aws/hooks/test_emr.py
index 918710e8ce..8ff9bc5a36 100644
--- a/tests/providers/amazon/aws/hooks/test_emr.py
+++ b/tests/providers/amazon/aws/hooks/test_emr.py
@@ -32,9 +32,9 @@ class TestEmrHook:
     def test_service_waiters(self):
         hook = EmrHook(aws_conn_id=None)
         official_waiters = hook.conn.waiter_names
-        custom_waiters = ["job_flow_waiting", "notebook_running", 
"notebook_stopped"]
+        custom_waiters = ["job_flow_waiting", "job_flow_terminated", 
"notebook_running", "notebook_stopped"]
 
-        assert hook.list_waiters() == [*official_waiters, *custom_waiters]
+        assert sorted(hook.list_waiters()) == sorted([*official_waiters, 
*custom_waiters])
 
     @mock_emr
     def test_get_conn_returns_a_boto3_connection(self):
diff --git 
a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py 
b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
index 9443402e4e..509dcfee0c 100644
--- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
+++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
@@ -19,8 +19,12 @@ from __future__ import annotations
 
 from unittest.mock import MagicMock, patch
 
+import pytest
+
+from airflow.exceptions import TaskDeferred
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.amazon.aws.operators.emr import 
EmrTerminateJobFlowOperator
+from airflow.providers.amazon.aws.triggers.emr import 
EmrTerminateJobFlowTrigger
 
 TERMINATE_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}}
 
@@ -48,3 +52,22 @@ class TestEmrTerminateJobFlowOperator:
             )
 
             operator.execute(MagicMock())
+
+    @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
+    def test_create_job_flow_deferrable(self, _):
+        with patch("boto3.session.Session", self.boto3_session_mock), patch(
+            "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
+        ) as mock_isinstance:
+            mock_isinstance.return_value = True
+            operator = EmrTerminateJobFlowOperator(
+                task_id="test_task",
+                job_flow_id="j-8989898989",
+                aws_conn_id="aws_default",
+                deferrable=True,
+            )
+            with pytest.raises(TaskDeferred) as exc:
+                operator.execute(MagicMock())
+
+        assert isinstance(
+            exc.value.trigger, EmrTerminateJobFlowTrigger
+        ), "Trigger is not a EmrTerminateJobFlowTrigger"
diff --git a/tests/providers/amazon/aws/triggers/test_emr.py 
b/tests/providers/amazon/aws/triggers/test_emr.py
index c749c4ee9a..5d599801a2 100644
--- a/tests/providers/amazon/aws/triggers/test_emr.py
+++ b/tests/providers/amazon/aws/triggers/test_emr.py
@@ -24,7 +24,7 @@ from botocore.exceptions import WaiterError
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.emr import EmrHook
-from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
+from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger, 
EmrTerminateJobFlowTrigger
 from airflow.triggers.base import TriggerEvent
 
 TEST_JOB_FLOW_ID = "test-job-flow-id"
@@ -198,3 +198,155 @@ class TestEmrCreateJobFlowTrigger:
 
         assert str(exc.value) == f"JobFlow creation failed: {error_failed}"
         assert mock_get_waiter().wait.call_count == 3
+
+
+class TestEmrTerminateJobFlowTrigger:
+    def test_emr_terminate_job_flow_trigger_serialize(self):
+        emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+            job_flow_id=TEST_JOB_FLOW_ID,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=TEST_MAX_ATTEMPTS,
+        )
+        class_path, args = emr_terminate_job_flow_trigger.serialize()
+        assert class_path == 
"airflow.providers.amazon.aws.triggers.emr.EmrTerminateJobFlowTrigger"
+        assert args["job_flow_id"] == TEST_JOB_FLOW_ID
+        assert args["aws_conn_id"] == TEST_AWS_CONN_ID
+        assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
+        assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS)
+
+    @pytest.mark.asyncio
+    @mock.patch.object(EmrHook, "get_waiter")
+    @mock.patch.object(EmrHook, "async_conn")
+    async def test_emr_terminate_job_flow_trigger_run(self, mock_async_conn, 
mock_get_waiter):
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+
+        mock_get_waiter().wait = AsyncMock()
+
+        emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+            job_flow_id=TEST_JOB_FLOW_ID,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=TEST_MAX_ATTEMPTS,
+        )
+
+        generator = emr_terminate_job_flow_trigger.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent(
+            {
+                "status": "success",
+                "message": "JobFlow terminated successfully",
+            }
+        )
+
+    @pytest.mark.asyncio
+    @mock.patch("asyncio.sleep")
+    @mock.patch.object(EmrHook, "get_waiter")
+    @mock.patch.object(EmrHook, "async_conn")
+    async def test_emr_terminate_job_flow_trigger_run_multiple_attempts(
+        self, mock_async_conn, mock_get_waiter, mock_sleep
+    ):
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+        error = WaiterError(
+            name="test_name",
+            reason="test_reason",
+            last_response={
+                "Cluster": {"Status": {"State": "TERMINATING", 
"StateChangeReason": "test-reason"}}
+            },
+        )
+        mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, 
error, True])
+        mock_sleep.return_value = True
+
+        emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+            job_flow_id=TEST_JOB_FLOW_ID,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=TEST_MAX_ATTEMPTS,
+        )
+
+        generator = emr_terminate_job_flow_trigger.run()
+        response = await generator.asend(None)
+
+        assert mock_get_waiter().wait.call_count == 3
+        assert response == TriggerEvent(
+            {
+                "status": "success",
+                "message": "JobFlow terminated successfully",
+            }
+        )
+
+    @pytest.mark.asyncio
+    @mock.patch("asyncio.sleep")
+    @mock.patch.object(EmrHook, "get_waiter")
+    @mock.patch.object(EmrHook, "async_conn")
+    async def test_emr_terminate_job_flow_trigger_run_attempts_exceeded(
+        self, mock_async_conn, mock_get_waiter, mock_sleep
+    ):
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+        error = WaiterError(
+            name="test_name",
+            reason="test_reason",
+            last_response={
+                "Cluster": {"Status": {"State": "TERMINATING", 
"StateChangeReason": "test-reason"}}
+            },
+        )
+        mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, 
error, True])
+        mock_sleep.return_value = True
+
+        emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+            job_flow_id=TEST_JOB_FLOW_ID,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=2,
+        )
+        with pytest.raises(AirflowException) as exc:
+            generator = emr_terminate_job_flow_trigger.run()
+            await generator.asend(None)
+
+        assert str(exc.value) == "JobFlow termination failed - max attempts 
reached: 2"
+        assert mock_get_waiter().wait.call_count == 2
+
+    @pytest.mark.asyncio
+    @mock.patch("asyncio.sleep")
+    @mock.patch.object(EmrHook, "get_waiter")
+    @mock.patch.object(EmrHook, "async_conn")
+    async def test_emr_terminate_job_flow_trigger_run_attempts_failed(
+        self, mock_async_conn, mock_get_waiter, mock_sleep
+    ):
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+        error_starting = WaiterError(
+            name="test_name",
+            reason="test_reason",
+            last_response={
+                "Cluster": {"Status": {"State": "TERMINATING", 
"StateChangeReason": "test-reason"}}
+            },
+        )
+        error_failed = WaiterError(
+            name="test_name",
+            reason="Waiter encountered a terminal failure state:",
+            last_response={
+                "Cluster": {"Status": {"State": "TERMINATED_WITH_ERRORS", 
"StateChangeReason": "test-reason"}}
+            },
+        )
+        mock_get_waiter().wait.side_effect = AsyncMock(
+            side_effect=[error_starting, error_starting, error_failed]
+        )
+        mock_sleep.return_value = True
+
+        emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+            job_flow_id=TEST_JOB_FLOW_ID,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=TEST_MAX_ATTEMPTS,
+        )
+        with pytest.raises(AirflowException) as exc:
+            generator = emr_terminate_job_flow_trigger.run()
+            await generator.asend(None)
+
+        assert str(exc.value) == f"JobFlow termination failed: {error_failed}"
+        assert mock_get_waiter().wait.call_count == 3

Reply via email to