This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f91c93ccfe Add retry configuration in `EmrContainerOperator` (#37426)
f91c93ccfe is described below
commit f91c93ccfeedc27e54493e95d3088e8478cdf08c
Author: Vincent <[email protected]>
AuthorDate: Wed Feb 14 14:23:18 2024 -0800
Add retry configuration in `EmrContainerOperator` (#37426)
---
airflow/providers/amazon/aws/hooks/emr.py | 6 ++++++
airflow/providers/amazon/aws/operators/emr.py | 5 +++++
tests/providers/amazon/aws/operators/test_emr_containers.py | 2 +-
tests/system/providers/amazon/aws/example_emr_eks.py | 1 +
4 files changed, 13 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/amazon/aws/hooks/emr.py
b/airflow/providers/amazon/aws/hooks/emr.py
index c6dc88e4e8..e2fb960355 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -383,6 +383,7 @@ class EmrContainerHook(AwsBaseHook):
configuration_overrides: dict | None = None,
client_request_token: str | None = None,
tags: dict | None = None,
+ retry_max_attempts: int | None = None,
) -> str:
"""
Submit a job to the EMR Containers API and return the job ID.
@@ -402,6 +403,7 @@ class EmrContainerHook(AwsBaseHook):
:param client_request_token: The client idempotency token of the job
run request.
Use this if you want to specify a unique ID to prevent two jobs
from getting started.
:param tags: The tags assigned to job runs.
+ :param retry_max_attempts: The maximum number of attempts on the job's
driver.
:return: The ID of the job run request.
"""
params = {
@@ -415,6 +417,10 @@ class EmrContainerHook(AwsBaseHook):
}
if client_request_token:
params["clientToken"] = client_request_token
+ if retry_max_attempts:
+ params["retryPolicyConfiguration"] = {
+ "maxAttempts": retry_max_attempts,
+ }
response = self.conn.start_job_run(**params)
diff --git a/airflow/providers/amazon/aws/operators/emr.py
b/airflow/providers/amazon/aws/operators/emr.py
index 628490b342..68e1c90296 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -503,6 +503,8 @@ class EmrContainerOperator(BaseOperator):
:param max_tries: Deprecated - use max_polling_attempts instead.
:param max_polling_attempts: Maximum number of times to wait for the job
run to finish.
Defaults to None, which will poll until the job is *not* in a pending,
submitted, or running state.
+ :param job_retry_max_attempts: Maximum number of times to retry when the
EMR job fails.
+ Defaults to None, which disable the retry.
:param tags: The tags assigned to job runs.
Defaults to None
:param deferrable: Run operator in the deferrable mode.
@@ -534,6 +536,7 @@ class EmrContainerOperator(BaseOperator):
max_tries: int | None = None,
tags: dict | None = None,
max_polling_attempts: int | None = None,
+ job_retry_max_attempts: int | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs: Any,
) -> None:
@@ -549,6 +552,7 @@ class EmrContainerOperator(BaseOperator):
self.wait_for_completion = wait_for_completion
self.poll_interval = poll_interval
self.max_polling_attempts = max_polling_attempts
+ self.job_retry_max_attempts = job_retry_max_attempts
self.tags = tags
self.job_id: str | None = None
self.deferrable = deferrable
@@ -583,6 +587,7 @@ class EmrContainerOperator(BaseOperator):
self.configuration_overrides,
self.client_request_token,
self.tags,
+ self.job_retry_max_attempts,
)
if self.deferrable:
query_status = self.hook.check_query_status(job_id=self.job_id)
diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py
b/tests/providers/amazon/aws/operators/test_emr_containers.py
index 368f85d395..8e94e744d9 100644
--- a/tests/providers/amazon/aws/operators/test_emr_containers.py
+++ b/tests/providers/amazon/aws/operators/test_emr_containers.py
@@ -71,7 +71,7 @@ class TestEmrContainerOperator:
self.emr_container.execute(None)
mock_submit_job.assert_called_once_with(
- "test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {},
GENERATED_UUID, {}
+ "test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {},
GENERATED_UUID, {}, None
)
mock_check_query_status.assert_called_once_with("jobid_123456")
assert self.emr_container.release_label == "6.3.0-latest"
diff --git a/tests/system/providers/amazon/aws/example_emr_eks.py
b/tests/system/providers/amazon/aws/example_emr_eks.py
index 28dc7ac3c2..428c16f450 100644
--- a/tests/system/providers/amazon/aws/example_emr_eks.py
+++ b/tests/system/providers/amazon/aws/example_emr_eks.py
@@ -282,6 +282,7 @@ with DAG(
)
# [END howto_operator_emr_container]
job_starter.wait_for_completion = False
+ job_starter.job_retry_max_attempts = 5
# [START howto_sensor_emr_container]
job_waiter = EmrContainerSensor(