This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f91c93ccfe Add retry configuration in `EmrContainerOperator` (#37426)
f91c93ccfe is described below

commit f91c93ccfeedc27e54493e95d3088e8478cdf08c
Author: Vincent <[email protected]>
AuthorDate: Wed Feb 14 14:23:18 2024 -0800

    Add retry configuration in `EmrContainerOperator` (#37426)
---
 airflow/providers/amazon/aws/hooks/emr.py                   | 6 ++++++
 airflow/providers/amazon/aws/operators/emr.py               | 5 +++++
 tests/providers/amazon/aws/operators/test_emr_containers.py | 2 +-
 tests/system/providers/amazon/aws/example_emr_eks.py        | 1 +
 4 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/amazon/aws/hooks/emr.py 
b/airflow/providers/amazon/aws/hooks/emr.py
index c6dc88e4e8..e2fb960355 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -383,6 +383,7 @@ class EmrContainerHook(AwsBaseHook):
         configuration_overrides: dict | None = None,
         client_request_token: str | None = None,
         tags: dict | None = None,
+        retry_max_attempts: int | None = None,
     ) -> str:
         """
         Submit a job to the EMR Containers API and return the job ID.
@@ -402,6 +403,7 @@ class EmrContainerHook(AwsBaseHook):
         :param client_request_token: The client idempotency token of the job 
run request.
             Use this if you want to specify a unique ID to prevent two jobs 
from getting started.
         :param tags: The tags assigned to job runs.
+        :param retry_max_attempts: The maximum number of attempts on the job's 
driver.
         :return: The ID of the job run request.
         """
         params = {
@@ -415,6 +417,10 @@ class EmrContainerHook(AwsBaseHook):
         }
         if client_request_token:
             params["clientToken"] = client_request_token
+        if retry_max_attempts:
+            params["retryPolicyConfiguration"] = {
+                "maxAttempts": retry_max_attempts,
+            }
 
         response = self.conn.start_job_run(**params)
 
diff --git a/airflow/providers/amazon/aws/operators/emr.py 
b/airflow/providers/amazon/aws/operators/emr.py
index 628490b342..68e1c90296 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -503,6 +503,8 @@ class EmrContainerOperator(BaseOperator):
     :param max_tries: Deprecated - use max_polling_attempts instead.
     :param max_polling_attempts: Maximum number of times to wait for the job 
run to finish.
         Defaults to None, which will poll until the job is *not* in a pending, 
submitted, or running state.
+    :param job_retry_max_attempts: Maximum number of times to retry when the 
EMR job fails.
+        Defaults to None, which disable the retry.
     :param tags: The tags assigned to job runs.
         Defaults to None
     :param deferrable: Run operator in the deferrable mode.
@@ -534,6 +536,7 @@ class EmrContainerOperator(BaseOperator):
         max_tries: int | None = None,
         tags: dict | None = None,
         max_polling_attempts: int | None = None,
+        job_retry_max_attempts: int | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs: Any,
     ) -> None:
@@ -549,6 +552,7 @@ class EmrContainerOperator(BaseOperator):
         self.wait_for_completion = wait_for_completion
         self.poll_interval = poll_interval
         self.max_polling_attempts = max_polling_attempts
+        self.job_retry_max_attempts = job_retry_max_attempts
         self.tags = tags
         self.job_id: str | None = None
         self.deferrable = deferrable
@@ -583,6 +587,7 @@ class EmrContainerOperator(BaseOperator):
             self.configuration_overrides,
             self.client_request_token,
             self.tags,
+            self.job_retry_max_attempts,
         )
         if self.deferrable:
             query_status = self.hook.check_query_status(job_id=self.job_id)
diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py 
b/tests/providers/amazon/aws/operators/test_emr_containers.py
index 368f85d395..8e94e744d9 100644
--- a/tests/providers/amazon/aws/operators/test_emr_containers.py
+++ b/tests/providers/amazon/aws/operators/test_emr_containers.py
@@ -71,7 +71,7 @@ class TestEmrContainerOperator:
         self.emr_container.execute(None)
 
         mock_submit_job.assert_called_once_with(
-            "test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {}, 
GENERATED_UUID, {}
+            "test_emr_job", "arn:aws:somerole", "6.3.0-latest", {}, {}, 
GENERATED_UUID, {}, None
         )
         mock_check_query_status.assert_called_once_with("jobid_123456")
         assert self.emr_container.release_label == "6.3.0-latest"
diff --git a/tests/system/providers/amazon/aws/example_emr_eks.py 
b/tests/system/providers/amazon/aws/example_emr_eks.py
index 28dc7ac3c2..428c16f450 100644
--- a/tests/system/providers/amazon/aws/example_emr_eks.py
+++ b/tests/system/providers/amazon/aws/example_emr_eks.py
@@ -282,6 +282,7 @@ with DAG(
     )
     # [END howto_operator_emr_container]
     job_starter.wait_for_completion = False
+    job_starter.job_retry_max_attempts = 5
 
     # [START howto_sensor_emr_container]
     job_waiter = EmrContainerSensor(

Reply via email to