This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1706f05858 EMR serverless Create/Start/Stop/Delete Application
deferrable mode (#32513)
1706f05858 is described below
commit 1706f058582a0668555eee874bcf4ccdc248acbb
Author: Syed Hussain <[email protected]>
AuthorDate: Mon Jul 24 10:30:37 2023 -0700
EMR serverless Create/Start/Stop/Delete Application deferrable mode (#32513)
* Minor code refactoring
* Add type annotations
* update doc string about default value of deferrable
---
airflow/providers/amazon/aws/hooks/emr.py | 26 ++--
airflow/providers/amazon/aws/operators/emr.py | 172 +++++++++++++++++++--
airflow/providers/amazon/aws/triggers/emr.py | 144 ++++++++++++++++-
.../operators/emr/emr_serverless.rst | 6 +
.../amazon/aws/operators/test_emr_serverless.py | 73 ++++++++-
.../amazon/aws/triggers/test_emr_serverless.py | 29 ++++
6 files changed, 417 insertions(+), 33 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/emr.py
b/airflow/providers/amazon/aws/hooks/emr.py
index 2fb35a34a6..dea61c0858 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -256,9 +256,14 @@ class EmrServerlessHook(AwsBaseHook):
kwargs["client_type"] = "emr-serverless"
super().__init__(*args, **kwargs)
- def cancel_running_jobs(self, application_id: str, waiter_config: dict =
{}):
+ def cancel_running_jobs(
+ self, application_id: str, waiter_config: dict | None = None,
wait_for_completion: bool = True
+ ) -> int:
"""
- List all jobs in an intermediate state, cancel them, then wait for
those jobs to reach terminal state.
+ Cancel jobs in an intermediate state, and return the number of
cancelled jobs.
+
+ If wait_for_completion is True, then the method will wait until all
jobs are
+ cancelled before returning.
Note: if new jobs are triggered while this operation is ongoing,
it's going to time out and return an error.
@@ -284,13 +289,16 @@ class EmrServerlessHook(AwsBaseHook):
)
for job_id in job_ids:
self.conn.cancel_job_run(applicationId=application_id,
jobRunId=job_id)
- if count > 0:
- self.log.info("now waiting for the %s cancelled job(s) to
terminate", count)
- self.get_waiter("no_job_running").wait(
- applicationId=application_id,
-
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
- WaiterConfig=waiter_config,
- )
+ if wait_for_completion:
+ if count > 0:
+ self.log.info("now waiting for the %s cancelled job(s) to
terminate", count)
+ self.get_waiter("no_job_running").wait(
+ applicationId=application_id,
+
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
+ WaiterConfig=waiter_config or {},
+ )
+
+ return count
class EmrContainerHook(AwsBaseHook):
diff --git a/airflow/providers/amazon/aws/operators/emr.py
b/airflow/providers/amazon/aws/operators/emr.py
index c01dbbbe91..f75b12327f 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -33,8 +33,12 @@ from airflow.providers.amazon.aws.triggers.emr import (
EmrAddStepsTrigger,
EmrContainerTrigger,
EmrCreateJobFlowTrigger,
+ EmrServerlessCancelJobsTrigger,
+ EmrServerlessCreateApplicationTrigger,
+ EmrServerlessDeleteApplicationTrigger,
EmrServerlessStartApplicationTrigger,
EmrServerlessStartJobTrigger,
+ EmrServerlessStopApplicationTrigger,
EmrTerminateJobFlowTrigger,
)
from airflow.providers.amazon.aws.utils.waiter import waiter
@@ -974,7 +978,7 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
:param release_label: The EMR release version associated with the
application.
:param job_type: The type of application you want to start, such as Spark
or Hive.
:param wait_for_completion: If true, wait for the Application to start
before returning. Default to True.
- If set to False, ``waiter_countdown`` and
``waiter_check_interval_seconds`` will only be applied when
+ If set to False, ``waiter_max_attempts`` and ``waiter_delay`` will
only be applied when
waiting for the application to be in the ``CREATED`` state.
:param client_request_token: The client idempotency token of the
application to create.
Its value must be unique for each request.
@@ -987,6 +991,9 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
:waiter_max_attempts: Number of times the waiter should poll the
application to check the state.
If not set, the waiter will use its default value.
:param waiter_delay: Number of seconds between polling the state of the
application.
+ :param deferrable: If True, the operator will wait asynchronously for
application to be created.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
"""
def __init__(
@@ -1001,6 +1008,7 @@ class
EmrServerlessCreateApplicationOperator(BaseOperator):
waiter_check_interval_seconds: int | ArgNotSet = NOTSET,
waiter_max_attempts: int | ArgNotSet = NOTSET,
waiter_delay: int | ArgNotSet = NOTSET,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
):
if waiter_check_interval_seconds is NOTSET:
@@ -1032,6 +1040,7 @@ class
EmrServerlessCreateApplicationOperator(BaseOperator):
self.config = config or {}
self.waiter_max_attempts = int(waiter_max_attempts) # type:
ignore[arg-type]
self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
+ self.deferrable = deferrable
super().__init__(**kwargs)
self.client_request_token = client_request_token or str(uuid4())
@@ -1054,8 +1063,19 @@ class
EmrServerlessCreateApplicationOperator(BaseOperator):
raise AirflowException(f"Application Creation failed: {response}")
self.log.info("EMR serverless application created: %s", application_id)
- waiter = self.hook.get_waiter("serverless_app_created")
+ if self.deferrable:
+ self.defer(
+ trigger=EmrServerlessCreateApplicationTrigger(
+ application_id=application_id,
+ aws_conn_id=self.aws_conn_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ ),
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ method_name="start_application_deferred",
+ )
+ waiter = self.hook.get_waiter("serverless_app_created")
wait(
waiter=waiter,
waiter_delay=self.waiter_delay,
@@ -1081,6 +1101,32 @@ class
EmrServerlessCreateApplicationOperator(BaseOperator):
)
return application_id
+ def start_application_deferred(self, context: Context, event: dict[str,
Any] | None = None) -> None:
+ if event is None:
+ self.log.error("Trigger error: event is None")
+ raise AirflowException("Trigger error: event is None")
+ elif event["status"] != "success":
+ raise AirflowException(f"Application {event['application_id']}
failed to create")
+ self.log.info("Starting application %s", event["application_id"])
+ self.hook.conn.start_application(applicationId=event["application_id"])
+ self.defer(
+ trigger=EmrServerlessStartApplicationTrigger(
+ application_id=event["application_id"],
+ aws_conn_id=self.aws_conn_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ ),
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
+ if event is None or event["status"] != "success":
+ raise AirflowException(f"Trigger error: Application failed to
start, event is {event}")
+
+ self.log.info("Application %s started", event["application_id"])
+ return event["application_id"]
+
class EmrServerlessStartJobOperator(BaseOperator):
"""
@@ -1312,14 +1358,21 @@ class
EmrServerlessStopApplicationOperator(BaseOperator):
:param application_id: ID of the EMR Serverless application to stop.
:param wait_for_completion: If true, wait for the Application to stop
before returning. Default to True
:param aws_conn_id: AWS connection to use
- :param waiter_countdown: Total amount of time, in seconds, the operator
will wait for
+ :param waiter_countdown: (deprecated) Total amount of time, in seconds,
the operator will wait for
the application be stopped. Defaults to 5 minutes.
- :param waiter_check_interval_seconds: Number of seconds between polling
the state of the application.
- Defaults to 30 seconds.
+ :param waiter_check_interval_seconds: (deprecated) Number of seconds
between polling the state of the
+ application. Defaults to 60 seconds.
:param force_stop: If set to True, any job for that app that is not in a
terminal state will be cancelled.
Otherwise, trying to stop an app with running jobs will return an
error.
If you want to wait for the jobs to finish gracefully, use
:class:`airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor`
+ :waiter_max_attempts: Number of times the waiter should poll the
application to check the state.
+ Default is 25.
+ :param waiter_delay: Number of seconds between polling the state of the
application.
+ Default is 60 seconds.
+ :param deferrable: If True, the operator will wait asynchronously for the
application to stop.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
"""
template_fields: Sequence[str] = ("application_id",)
@@ -1334,6 +1387,7 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
waiter_max_attempts: int | ArgNotSet = NOTSET,
waiter_delay: int | ArgNotSet = NOTSET,
force_stop: bool = False,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
):
if waiter_check_interval_seconds is NOTSET:
@@ -1359,10 +1413,11 @@ class
EmrServerlessStopApplicationOperator(BaseOperator):
)
self.aws_conn_id = aws_conn_id
self.application_id = application_id
- self.wait_for_completion = wait_for_completion
+ self.wait_for_completion = False if deferrable else wait_for_completion
self.waiter_max_attempts = int(waiter_max_attempts) # type:
ignore[arg-type]
self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
self.force_stop = force_stop
+ self.deferrable = deferrable
super().__init__(**kwargs)
@cached_property
@@ -1374,16 +1429,46 @@ class
EmrServerlessStopApplicationOperator(BaseOperator):
self.log.info("Stopping application: %s", self.application_id)
if self.force_stop:
- self.hook.cancel_running_jobs(
- self.application_id,
- waiter_config={
- "Delay": self.waiter_delay,
- "MaxAttempts": self.waiter_max_attempts,
- },
+ count = self.hook.cancel_running_jobs(
+ application_id=self.application_id,
+ wait_for_completion=False,
)
+ if count > 0:
+ self.log.info("now waiting for the %s cancelled job(s) to
terminate", count)
+ if self.deferrable:
+ self.defer(
+ trigger=EmrServerlessCancelJobsTrigger(
+ application_id=self.application_id,
+ aws_conn_id=self.aws_conn_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ ),
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ method_name="stop_application",
+ )
+ self.hook.get_waiter("no_job_running").wait(
+ applicationId=self.application_id,
+
states=list(self.hook.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
+ WaiterConfig={
+ "Delay": self.waiter_delay,
+ "MaxAttempts": self.waiter_max_attempts,
+ },
+ )
+ else:
+ self.log.info("no running jobs found with application ID %s",
self.application_id)
self.hook.conn.stop_application(applicationId=self.application_id)
-
+ if self.deferrable:
+ self.defer(
+ trigger=EmrServerlessStopApplicationTrigger(
+ application_id=self.application_id,
+ aws_conn_id=self.aws_conn_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ ),
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ method_name="execute_complete",
+ )
if self.wait_for_completion:
waiter = self.hook.get_waiter("serverless_app_stopped")
wait(
@@ -1397,6 +1482,30 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
)
self.log.info("EMR serverless application %s stopped
successfully", self.application_id)
+ def stop_application(self, context: Context, event: dict[str, Any] | None
= None) -> None:
+ if event is None:
+ self.log.error("Trigger error: event is None")
+ raise AirflowException("Trigger error: event is None")
+ elif event["status"] == "success":
+ self.hook.conn.stop_application(applicationId=self.application_id)
+ self.defer(
+ trigger=EmrServerlessStopApplicationTrigger(
+ application_id=self.application_id,
+ aws_conn_id=self.aws_conn_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ ),
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
+ if event is None:
+ self.log.error("Trigger error: event is None")
+ raise AirflowException("Trigger error: event is None")
+ elif event["status"] == "success":
+ self.log.info("EMR serverless application %s stopped
successfully", self.application_id)
+
class
EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperator):
"""
@@ -1410,10 +1519,17 @@ class
EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
:param wait_for_completion: If true, wait for the Application to be
deleted before returning.
Defaults to True. Note that this operator will always wait for the
application to be STOPPED first.
:param aws_conn_id: AWS connection to use
- :param waiter_countdown: Total amount of time, in seconds, the operator
will wait for each step of first,
- the application to be stopped, and then deleted. Defaults to 25
minutes.
- :param waiter_check_interval_seconds: Number of seconds between polling
the state of the application.
+ :param waiter_countdown: (deprecated) Total amount of time, in seconds,
the operator will wait for each
+ step of first,the application to be stopped, and then deleted.
Defaults to 25 minutes.
+ :param waiter_check_interval_seconds: (deprecated) Number of seconds
between polling the state
+ of the application. Defaults to 60 seconds.
+ :waiter_max_attempts: Number of times the waiter should poll the
application to check the state.
+ Defaults to 25.
+ :param waiter_delay: Number of seconds between polling the state of the
application.
Defaults to 60 seconds.
+ :param deferrable: If True, the operator will wait asynchronously for
application to be deleted.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
:param force_stop: If set to True, any job for that app that is not in a
terminal state will be cancelled.
Otherwise, trying to delete an app with running jobs will return an
error.
If you want to wait for the jobs to finish gracefully, use
@@ -1432,6 +1548,7 @@ class
EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
waiter_max_attempts: int | ArgNotSet = NOTSET,
waiter_delay: int | ArgNotSet = NOTSET,
force_stop: bool = False,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
):
if waiter_check_interval_seconds is NOTSET:
@@ -1467,6 +1584,8 @@ class
EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
force_stop=force_stop,
**kwargs,
)
+ self.deferrable = deferrable
+ self.wait_for_delete_completion = False if deferrable else
wait_for_completion
def execute(self, context: Context) -> None:
# super stops the app (or makes sure it's already stopped)
@@ -1478,7 +1597,19 @@ class
EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Application deletion failed: {response}")
- if self.wait_for_delete_completion:
+ if self.deferrable:
+ self.defer(
+ trigger=EmrServerlessDeleteApplicationTrigger(
+ application_id=self.application_id,
+ aws_conn_id=self.aws_conn_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ ),
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ method_name="execute_complete",
+ )
+
+ elif self.wait_for_delete_completion:
waiter = self.hook.get_waiter("serverless_app_terminated")
wait(
@@ -1492,3 +1623,10 @@ class
EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
)
self.log.info("EMR serverless application deleted")
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
+ if event is None:
+ self.log.error("Trigger error: event is None")
+ raise AirflowException("Trigger error: event is None")
+ elif event["status"] == "success":
+ self.log.info("EMR serverless application %s deleted
successfully", self.application_id)
diff --git a/airflow/providers/amazon/aws/triggers/emr.py
b/airflow/providers/amazon/aws/triggers/emr.py
index 471a8a747a..d7d9844af3 100644
--- a/airflow/providers/amazon/aws/triggers/emr.py
+++ b/airflow/providers/amazon/aws/triggers/emr.py
@@ -285,6 +285,41 @@ class EmrStepSensorTrigger(AwsBaseWaiterTrigger):
return EmrHook(self.aws_conn_id)
+class EmrServerlessCreateApplicationTrigger(AwsBaseWaiterTrigger):
+ """
+ Poll an Emr Serverless application and wait for it to be created.
+
+ :param application_id: The ID of the application being polled.
+ :waiter_delay: polling period in seconds to check for the status
+ :param waiter_max_attempts: The maximum number of attempts to be made
+ :param aws_conn_id: Reference to AWS connection id
+ """
+
+ def __init__(
+ self,
+ application_id: str,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
+ aws_conn_id: str = "aws_default",
+ ) -> None:
+ super().__init__(
+ serialized_fields={"application_id": application_id},
+ waiter_name="serverless_app_created",
+ waiter_args={"applicationId": application_id},
+ failure_message="Application creation failed",
+ status_message="Application status is",
+ status_queries=["application.state", "application.stateDetails"],
+ return_key="application_id",
+ return_value=application_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return EmrServerlessHook(self.aws_conn_id)
+
+
class EmrServerlessStartApplicationTrigger(AwsBaseWaiterTrigger):
"""
Poll an Emr Serverless application and wait for it to be started.
@@ -301,7 +336,7 @@ class
EmrServerlessStartApplicationTrigger(AwsBaseWaiterTrigger):
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
aws_conn_id: str = "aws_default",
- ):
+ ) -> None:
super().__init__(
serialized_fields={"application_id": application_id},
waiter_name="serverless_app_started",
@@ -320,6 +355,41 @@ class
EmrServerlessStartApplicationTrigger(AwsBaseWaiterTrigger):
return EmrServerlessHook(self.aws_conn_id)
+class EmrServerlessStopApplicationTrigger(AwsBaseWaiterTrigger):
+ """
+ Poll an Emr Serverless application and wait for it to be stopped.
+
+ :param application_id: The ID of the application being polled.
+ :waiter_delay: polling period in seconds to check for the status
+ :param waiter_max_attempts: The maximum number of attempts to be made
+ :param aws_conn_id: Reference to AWS connection id.
+ """
+
+ def __init__(
+ self,
+ application_id: str,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
+ aws_conn_id: str = "aws_default",
+ ) -> None:
+ super().__init__(
+ serialized_fields={"application_id": application_id},
+ waiter_name="serverless_app_stopped",
+ waiter_args={"applicationId": application_id},
+ failure_message="Application failed to start",
+ status_message="Application status is",
+ status_queries=["application.state", "application.stateDetails"],
+ return_key="application_id",
+ return_value=application_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return EmrServerlessHook(self.aws_conn_id)
+
+
class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):
"""
Poll an Emr Serverless job run and wait for it to be completed.
@@ -355,3 +425,75 @@ class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):
def hook(self) -> AwsGenericHook:
return EmrServerlessHook(self.aws_conn_id)
+
+
+class EmrServerlessDeleteApplicationTrigger(AwsBaseWaiterTrigger):
+ """
+ Poll an Emr Serverless application and wait for it to be deleted.
+
+ :param application_id: The ID of the application being polled.
+ :waiter_delay: polling period in seconds to check for the status
+ :param waiter_max_attempts: The maximum number of attempts to be made
+ :param aws_conn_id: Reference to AWS connection id
+ """
+
+ def __init__(
+ self,
+ application_id: str,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
+ aws_conn_id: str = "aws_default",
+ ) -> None:
+ super().__init__(
+ serialized_fields={"application_id": application_id},
+ waiter_name="serverless_app_terminated",
+ waiter_args={"applicationId": application_id},
+ failure_message="Application failed to start",
+ status_message="Application status is",
+ status_queries=["application.state", "application.stateDetails"],
+ return_key="application_id",
+ return_value=application_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return EmrServerlessHook(self.aws_conn_id)
+
+
+class EmrServerlessCancelJobsTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger for canceling a list of jobs in an EMR Serverless application.
+
+ :param application_id: EMR Serverless application ID
+ :param aws_conn_id: Reference to AWS connection id
+ :param waiter_delay: Delay in seconds between each attempt to check the
status
+ :param waiter_max_attempts: Maximum number of attempts to check the status
+ """
+
+ def __init__(
+ self,
+ application_id: str,
+ aws_conn_id: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ ) -> None:
+ self.hook_instance = EmrServerlessHook(aws_conn_id)
+ states =
list(self.hook_instance.JOB_INTERMEDIATE_STATES.union({"CANCELLING"}))
+ super().__init__(
+ serialized_fields={"application_id": application_id},
+ waiter_name="no_job_running",
+ waiter_args={"applicationId": application_id, "states": states},
+ failure_message="Error while waiting for jobs to cancel",
+ status_message="Currently running jobs",
+ status_queries=["jobRuns[*].applicationId", "jobRuns[*].state"],
+ return_key="application_id",
+ return_value=application_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return self.hook_instance
diff --git
a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst
b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst
index 76638815b8..bcd5995e5c 100644
--- a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst
+++ b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst
@@ -40,6 +40,8 @@ Create an EMR Serverless Application
You can use
:class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessCreateApplicationOperator`
to
create a new EMR Serverless Application.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as
a parameter. This requires
+the aiobotocore module to be installed.
.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_emr_serverless.py
:language: python
@@ -70,6 +72,8 @@ Stop an EMR Serverless Application
You can use
:class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessStopApplicationOperator`
to
stop an EMR Serverless Application.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as
a parameter. This requires
+the aiobotocore module to be installed.
.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_emr_serverless.py
:language: python
@@ -84,6 +88,8 @@ Delete an EMR Serverless Application
You can use
:class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessDeleteApplicationOperator`
to
delete an EMR Serverless Application.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as
a parameter. This requires
+the aiobotocore module to be installed.
.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_emr_serverless.py
:language: python
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py
b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index 2d9f37830f..5e16313d9b 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from unittest import mock
-from unittest.mock import MagicMock, PropertyMock
+from unittest.mock import MagicMock
from uuid import UUID
import pytest
@@ -333,6 +333,24 @@ class TestEmrServerlessCreateApplicationOperator:
assert operator.waiter_delay == expected[0]
assert operator.waiter_max_attempts == expected[1]
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_create_application_deferrable(self, mock_conn):
+ mock_conn.create_application.return_value = {
+ "applicationId": application_id,
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ operator = EmrServerlessCreateApplicationOperator(
+ task_id=task_id,
+ release_label=release_label,
+ job_type=job_type,
+ client_request_token=client_request_token,
+ config=config,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred):
+ operator.execute(None)
+
class TestEmrServerlessStartJobOperator:
@mock.patch.object(EmrServerlessHook, "get_waiter")
@@ -851,6 +869,18 @@ class TestEmrServerlessDeleteOperator:
assert operator.waiter_delay == expected[0]
assert operator.waiter_max_attempts == expected[1]
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_delete_application_deferrable(self, mock_conn):
+ mock_conn.delete_application.return_value = {"ResponseMetadata":
{"HTTPStatusCode": 200}}
+
+ operator = EmrServerlessDeleteApplicationOperator(
+ task_id=task_id,
+ application_id=application_id,
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred):
+ operator.execute(None)
+
class TestEmrServerlessStopOperator:
@mock.patch.object(EmrServerlessHook, "get_waiter")
@@ -876,14 +906,45 @@ class TestEmrServerlessStopOperator:
mock_get_waiter().wait.assert_not_called()
mock_conn.stop_application.assert_called_once()
- @mock.patch.object(EmrServerlessStopApplicationOperator, "hook",
new_callable=PropertyMock)
- def test_force_stop(self, mock_hook: MagicMock):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ @mock.patch.object(EmrServerlessHook, "cancel_running_jobs")
+ def test_force_stop(self, mock_cancel_running_jobs, mock_conn,
mock_get_waiter):
+ mock_cancel_running_jobs.return_value = 0
+ mock_conn.stop_application.return_value = {}
+ mock_get_waiter().wait.return_value = True
+
operator = EmrServerlessStopApplicationOperator(
task_id=task_id, application_id="test", force_stop=True
)
operator.execute(None)
- mock_hook().cancel_running_jobs.assert_called_once()
- mock_hook().conn.stop_application.assert_called_once()
- mock_hook().get_waiter().wait.assert_called_once()
+ mock_cancel_running_jobs.assert_called_once()
+ mock_conn.stop_application.assert_called_once()
+ mock_get_waiter().wait.assert_called_once()
+
+ @mock.patch.object(EmrServerlessHook, "cancel_running_jobs")
+ def test_stop_application_deferrable_with_force_stop(self,
mock_cancel_running_jobs, caplog):
+ mock_cancel_running_jobs.return_value = 2
+ operator = EmrServerlessStopApplicationOperator(
+ task_id=task_id, application_id="test", deferrable=True,
force_stop=True
+ )
+ with pytest.raises(TaskDeferred):
+ operator.execute(None)
+ assert "now waiting for the 2 cancelled job(s) to terminate" in
caplog.messages
+
+ @mock.patch.object(EmrServerlessHook, "conn")
+ @mock.patch.object(EmrServerlessHook, "cancel_running_jobs")
+ def test_stop_application_deferrable_without_force_stop(
+ self, mock_cancel_running_jobs, mock_conn, caplog
+ ):
+ mock_conn.stop_application.return_value = {}
+ mock_cancel_running_jobs.return_value = 0
+ operator = EmrServerlessStopApplicationOperator(
+ task_id=task_id, application_id="test", deferrable=True,
force_stop=True
+ )
+ with pytest.raises(TaskDeferred):
+ operator.execute(None)
+
+ assert "no running jobs found with application ID test" in
caplog.messages
diff --git a/tests/providers/amazon/aws/triggers/test_emr_serverless.py
b/tests/providers/amazon/aws/triggers/test_emr_serverless.py
index 029fc4ccbf..39f1a9a0ab 100644
--- a/tests/providers/amazon/aws/triggers/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/triggers/test_emr_serverless.py
@@ -14,13 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
import pytest
from airflow.providers.amazon.aws.triggers.emr import (
+ EmrServerlessCancelJobsTrigger,
+ EmrServerlessCreateApplicationTrigger,
+ EmrServerlessDeleteApplicationTrigger,
EmrServerlessStartApplicationTrigger,
EmrServerlessStartJobTrigger,
+ EmrServerlessStopApplicationTrigger,
)
TEST_APPLICATION_ID = "test-application-id"
@@ -35,12 +40,36 @@ class TestEmrTriggers:
@pytest.mark.parametrize(
"trigger",
[
+ EmrServerlessCreateApplicationTrigger(
+ application_id=TEST_APPLICATION_ID,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ ),
EmrServerlessStartApplicationTrigger(
application_id=TEST_APPLICATION_ID,
aws_conn_id=TEST_AWS_CONN_ID,
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
),
+ EmrServerlessStopApplicationTrigger(
+ application_id=TEST_APPLICATION_ID,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ ),
+ EmrServerlessDeleteApplicationTrigger(
+ application_id=TEST_APPLICATION_ID,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ ),
+ EmrServerlessCancelJobsTrigger(
+ application_id=TEST_APPLICATION_ID,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ ),
EmrServerlessStartJobTrigger(
application_id=TEST_APPLICATION_ID,
job_id=TEST_JOB_ID,