This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9be22e4cc0 Add deferrable option to EmrTerminateJobFlowOperator
(#31646)
9be22e4cc0 is described below
commit 9be22e4cc09faba5db6432ccac8d2193114d95ee
Author: Syed Hussaain <[email protected]>
AuthorDate: Thu Jun 15 09:21:36 2023 -0700
Add deferrable option to EmrTerminateJobFlowOperator (#31646)
* Add job_flow_terminated to list of custom waiters
* Fix test_service_waiters so that order doesn't matter
* Add documentation to explain availability of deferrable mode for
EmrTerminateJobFlowOperator
Add EmrTerminateJobFlowTrigger to provider.yaml
---
airflow/providers/amazon/aws/operators/emr.py | 48 ++++++-
airflow/providers/amazon/aws/triggers/emr.py | 73 ++++++++++
airflow/providers/amazon/aws/waiters/emr.json | 19 +++
.../operators/emr/emr.rst | 4 +
tests/providers/amazon/aws/hooks/test_emr.py | 4 +-
.../aws/operators/test_emr_terminate_job_flow.py | 23 +++
tests/providers/amazon/aws/triggers/test_emr.py | 154 ++++++++++++++++++++-
7 files changed, 319 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/emr.py
b/airflow/providers/amazon/aws/operators/emr.py
index 15ef379aad..b8ca53226e 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -28,7 +28,11 @@ from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarni
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook,
EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink,
EmrLogsLink, get_log_uri
-from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger,
EmrCreateJobFlowTrigger
+from airflow.providers.amazon.aws.triggers.emr import (
+ EmrAddStepsTrigger,
+ EmrCreateJobFlowTrigger,
+ EmrTerminateJobFlowTrigger,
+)
from airflow.providers.amazon.aws.utils.waiter import waiter
from airflow.utils.helpers import exactly_one, prune_dict
from airflow.utils.types import NOTSET, ArgNotSet
@@ -842,6 +846,11 @@ class EmrTerminateJobFlowOperator(BaseOperator):
:param job_flow_id: id of the JobFlow to terminate. (templated)
:param aws_conn_id: aws connection to uses
+ :param waiter_delay: Time (in seconds) to wait between two consecutive
calls to check JobFlow status
+ :param waiter_max_attempts: The maximum number of times to poll for
JobFlow status.
+ :param deferrable: If True, the operator will wait asynchronously for the
crawl to complete.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
"""
template_fields: Sequence[str] = ("job_flow_id",)
@@ -852,10 +861,22 @@ class EmrTerminateJobFlowOperator(BaseOperator):
EmrLogsLink(),
)
- def __init__(self, *, job_flow_id: str, aws_conn_id: str = "aws_default",
**kwargs):
+ def __init__(
+ self,
+ *,
+ job_flow_id: str,
+ aws_conn_id: str = "aws_default",
+ waiter_delay: int = 60,
+ waiter_max_attempts: int = 20,
+ deferrable: bool = False,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
def execute(self, context: Context) -> None:
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
@@ -883,7 +904,28 @@ class EmrTerminateJobFlowOperator(BaseOperator):
if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
raise AirflowException(f"JobFlow termination failed: {response}")
else:
- self.log.info("JobFlow with id %s terminated", self.job_flow_id)
+ self.log.info("Terminating JobFlow with id %s", self.job_flow_id)
+
+ if self.deferrable:
+ self.defer(
+ trigger=EmrTerminateJobFlowTrigger(
+ job_flow_id=self.job_flow_id,
+ poll_interval=self.waiter_delay,
+ max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ # timeout is set to ensure that if a trigger dies, the timeout
does not restart
+ # 60 seconds is added to allow the trigger to exit gracefully
(i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay + 60),
+ )
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error terminating JobFlow: {event}")
+ else:
+ self.log.info("Jobflow terminated successfully.")
+ return
class EmrServerlessCreateApplicationOperator(BaseOperator):
diff --git a/airflow/providers/amazon/aws/triggers/emr.py
b/airflow/providers/amazon/aws/triggers/emr.py
index 76ee47bc8b..1c3c8bb833 100644
--- a/airflow/providers/amazon/aws/triggers/emr.py
+++ b/airflow/providers/amazon/aws/triggers/emr.py
@@ -173,3 +173,76 @@ class EmrCreateJobFlowTrigger(BaseTrigger):
"job_flow_id": self.job_flow_id,
}
)
+
+
+class EmrTerminateJobFlowTrigger(BaseTrigger):
+ """
+ Trigger that terminates a running EMR Job Flow.
+ The trigger will asynchronously poll the boto3 API and wait for the
+ JobFlow to finish terminating.
+
+ :param job_flow_id: ID of the EMR Job Flow to terminate
+ :param poll_interval: The amount of time in seconds to wait between
attempts.
+ :param max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ job_flow_id: str,
+ poll_interval: int,
+ max_attempts: int,
+ aws_conn_id: str,
+ ):
+ self.job_flow_id = job_flow_id
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "job_flow_id": self.job_flow_id,
+ "poll_interval": str(self.poll_interval),
+ "max_attempts": str(self.max_attempts),
+ "aws_conn_id": self.aws_conn_id,
+ },
+ )
+
+ async def run(self):
+ self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
+ async with self.hook.async_conn as client:
+ attempt = 0
+ waiter = self.hook.get_waiter("job_flow_terminated",
deferrable=True, client=client)
+ while attempt < int(self.max_attempts):
+ attempt = attempt + 1
+ try:
+ await waiter.wait(
+ ClusterId=self.job_flow_id,
+ WaiterConfig=prune_dict(
+ {
+ "Delay": self.poll_interval,
+ "MaxAttempts": 1,
+ }
+ ),
+ )
+ break
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ raise AirflowException(f"JobFlow termination failed:
{error}")
+ self.log.info(
+ "Status of jobflow is %s - %s",
+ error.last_response["Cluster"]["Status"]["State"],
+
error.last_response["Cluster"]["Status"]["StateChangeReason"],
+ )
+ await asyncio.sleep(int(self.poll_interval))
+ if attempt >= int(self.max_attempts):
+ raise AirflowException(f"JobFlow termination failed - max attempts
reached: {self.max_attempts}")
+ else:
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "JobFlow terminated successfully",
+ }
+ )
diff --git a/airflow/providers/amazon/aws/waiters/emr.json
b/airflow/providers/amazon/aws/waiters/emr.json
index 13bc5857e3..d27cd08e00 100644
--- a/airflow/providers/amazon/aws/waiters/emr.json
+++ b/airflow/providers/amazon/aws/waiters/emr.json
@@ -75,6 +75,25 @@
"state": "failure"
}
]
+ },
+ "job_flow_terminated": {
+ "operation": "DescribeCluster",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "Cluster.Status.State",
+ "expected": "TERMINATED",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "Cluster.Status.State",
+ "expected": "TERMINATED_WITH_ERRORS",
+ "state": "failure"
+ }
+ ]
}
}
}
diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst
b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst
index d26a427a64..8a2255ddbf 100644
--- a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst
+++ b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst
@@ -111,6 +111,10 @@ Terminate an EMR job flow
To terminate an EMR Job Flow you can use
:class:`~airflow.providers.amazon.aws.operators.emr.EmrTerminateJobFlowOperator`.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as
a parameter.
+Using ``deferrable`` mode will release worker slots and leads to efficient
utilization of
+resources within Airflow cluster.However this mode will need the Airflow
triggerer to be
+available in your deployment.
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py
:language: python
diff --git a/tests/providers/amazon/aws/hooks/test_emr.py
b/tests/providers/amazon/aws/hooks/test_emr.py
index 918710e8ce..8ff9bc5a36 100644
--- a/tests/providers/amazon/aws/hooks/test_emr.py
+++ b/tests/providers/amazon/aws/hooks/test_emr.py
@@ -32,9 +32,9 @@ class TestEmrHook:
def test_service_waiters(self):
hook = EmrHook(aws_conn_id=None)
official_waiters = hook.conn.waiter_names
- custom_waiters = ["job_flow_waiting", "notebook_running",
"notebook_stopped"]
+ custom_waiters = ["job_flow_waiting", "job_flow_terminated",
"notebook_running", "notebook_stopped"]
- assert hook.list_waiters() == [*official_waiters, *custom_waiters]
+ assert sorted(hook.list_waiters()) == sorted([*official_waiters,
*custom_waiters])
@mock_emr
def test_get_conn_returns_a_boto3_connection(self):
diff --git
a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
index 9443402e4e..509dcfee0c 100644
--- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
+++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
@@ -19,8 +19,12 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
+import pytest
+
+from airflow.exceptions import TaskDeferred
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import
EmrTerminateJobFlowOperator
+from airflow.providers.amazon.aws.triggers.emr import
EmrTerminateJobFlowTrigger
TERMINATE_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}}
@@ -48,3 +52,22 @@ class TestEmrTerminateJobFlowOperator:
)
operator.execute(MagicMock())
+
+ @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
+ def test_create_job_flow_deferrable(self, _):
+ with patch("boto3.session.Session", self.boto3_session_mock), patch(
+ "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
+ ) as mock_isinstance:
+ mock_isinstance.return_value = True
+ operator = EmrTerminateJobFlowOperator(
+ task_id="test_task",
+ job_flow_id="j-8989898989",
+ aws_conn_id="aws_default",
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred) as exc:
+ operator.execute(MagicMock())
+
+ assert isinstance(
+ exc.value.trigger, EmrTerminateJobFlowTrigger
+ ), "Trigger is not a EmrTerminateJobFlowTrigger"
diff --git a/tests/providers/amazon/aws/triggers/test_emr.py
b/tests/providers/amazon/aws/triggers/test_emr.py
index c749c4ee9a..5d599801a2 100644
--- a/tests/providers/amazon/aws/triggers/test_emr.py
+++ b/tests/providers/amazon/aws/triggers/test_emr.py
@@ -24,7 +24,7 @@ from botocore.exceptions import WaiterError
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrHook
-from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
+from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger,
EmrTerminateJobFlowTrigger
from airflow.triggers.base import TriggerEvent
TEST_JOB_FLOW_ID = "test-job-flow-id"
@@ -198,3 +198,155 @@ class TestEmrCreateJobFlowTrigger:
assert str(exc.value) == f"JobFlow creation failed: {error_failed}"
assert mock_get_waiter().wait.call_count == 3
+
+
+class TestEmrTerminateJobFlowTrigger:
+ def test_emr_terminate_job_flow_trigger_serialize(self):
+ emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPTS,
+ )
+ class_path, args = emr_terminate_job_flow_trigger.serialize()
+ assert class_path ==
"airflow.providers.amazon.aws.triggers.emr.EmrTerminateJobFlowTrigger"
+ assert args["job_flow_id"] == TEST_JOB_FLOW_ID
+ assert args["aws_conn_id"] == TEST_AWS_CONN_ID
+ assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
+ assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS)
+
+ @pytest.mark.asyncio
+ @mock.patch.object(EmrHook, "get_waiter")
+ @mock.patch.object(EmrHook, "async_conn")
+ async def test_emr_terminate_job_flow_trigger_run(self, mock_async_conn,
mock_get_waiter):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+
+ mock_get_waiter().wait = AsyncMock()
+
+ emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPTS,
+ )
+
+ generator = emr_terminate_job_flow_trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent(
+ {
+ "status": "success",
+ "message": "JobFlow terminated successfully",
+ }
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+ @mock.patch.object(EmrHook, "get_waiter")
+ @mock.patch.object(EmrHook, "async_conn")
+ async def test_emr_terminate_job_flow_trigger_run_multiple_attempts(
+ self, mock_async_conn, mock_get_waiter, mock_sleep
+ ):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+ error = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={
+ "Cluster": {"Status": {"State": "TERMINATING",
"StateChangeReason": "test-reason"}}
+ },
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error,
error, True])
+ mock_sleep.return_value = True
+
+ emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPTS,
+ )
+
+ generator = emr_terminate_job_flow_trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_waiter().wait.call_count == 3
+ assert response == TriggerEvent(
+ {
+ "status": "success",
+ "message": "JobFlow terminated successfully",
+ }
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+ @mock.patch.object(EmrHook, "get_waiter")
+ @mock.patch.object(EmrHook, "async_conn")
+ async def test_emr_terminate_job_flow_trigger_run_attempts_exceeded(
+ self, mock_async_conn, mock_get_waiter, mock_sleep
+ ):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+ error = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={
+ "Cluster": {"Status": {"State": "TERMINATING",
"StateChangeReason": "test-reason"}}
+ },
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error,
error, True])
+ mock_sleep.return_value = True
+
+ emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=2,
+ )
+ with pytest.raises(AirflowException) as exc:
+ generator = emr_terminate_job_flow_trigger.run()
+ await generator.asend(None)
+
+ assert str(exc.value) == "JobFlow termination failed - max attempts
reached: 2"
+ assert mock_get_waiter().wait.call_count == 2
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+ @mock.patch.object(EmrHook, "get_waiter")
+ @mock.patch.object(EmrHook, "async_conn")
+ async def test_emr_terminate_job_flow_trigger_run_attempts_failed(
+ self, mock_async_conn, mock_get_waiter, mock_sleep
+ ):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+ error_starting = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={
+ "Cluster": {"Status": {"State": "TERMINATING",
"StateChangeReason": "test-reason"}}
+ },
+ )
+ error_failed = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={
+ "Cluster": {"Status": {"State": "TERMINATED_WITH_ERRORS",
"StateChangeReason": "test-reason"}}
+ },
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(
+ side_effect=[error_starting, error_starting, error_failed]
+ )
+ mock_sleep.return_value = True
+
+ emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPTS,
+ )
+ with pytest.raises(AirflowException) as exc:
+ generator = emr_terminate_job_flow_trigger.run()
+ await generator.asend(None)
+
+ assert str(exc.value) == f"JobFlow termination failed: {error_failed}"
+ assert mock_get_waiter().wait.call_count == 3