This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c519920661 Make EMR Container Trigger max attempts retries match the
Operator (#41008)
c519920661 is described below
commit c519920661133a06e917a781e73caeac111b26f5
Author: Niko Oliveira <[email protected]>
AuthorDate: Fri Jul 26 13:52:37 2024 -0700
Make EMR Container Trigger max attempts retries match the Operator (#41008)
The EMR Container Operator will wait indefinitely by default (on the
wait for completion path) however when it is deferred the Trigger has a
default timeout of 600s which does not match the user's expectations
when using the operator.
Update the Trigger to have an infinite try count by default to match the
Operator behaviour.
---
airflow/providers/amazon/aws/triggers/emr.py | 4 +++-
tests/providers/amazon/aws/triggers/test_emr.py | 24 ++++++++++++++++++++++++
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/amazon/aws/triggers/emr.py
b/airflow/providers/amazon/aws/triggers/emr.py
index 8b64d84f63..9abfe120d2 100644
--- a/airflow/providers/amazon/aws/triggers/emr.py
+++ b/airflow/providers/amazon/aws/triggers/emr.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import sys
import warnings
from typing import TYPE_CHECKING
@@ -174,6 +175,7 @@ class EmrContainerTrigger(AwsBaseWaiterTrigger):
:param job_id: job_id to check the state
:param aws_conn_id: Reference to AWS connection id
:param waiter_delay: polling period in seconds to check for the status
+ :param waiter_max_attempts: The maximum number of attempts to be made.
Defaults to an infinite wait.
"""
def __init__(
@@ -183,7 +185,7 @@ class EmrContainerTrigger(AwsBaseWaiterTrigger):
aws_conn_id: str | None = "aws_default",
poll_interval: int | None = None, # deprecated
waiter_delay: int = 30,
- waiter_max_attempts: int = 600,
+ waiter_max_attempts: int = sys.maxsize,
):
if poll_interval is not None:
warnings.warn(
diff --git a/tests/providers/amazon/aws/triggers/test_emr.py
b/tests/providers/amazon/aws/triggers/test_emr.py
index 92fd08857d..3469ee4c13 100644
--- a/tests/providers/amazon/aws/triggers/test_emr.py
+++ b/tests/providers/amazon/aws/triggers/test_emr.py
@@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations
+import sys
+
from airflow.providers.amazon.aws.triggers.emr import (
EmrAddStepsTrigger,
EmrContainerTrigger,
@@ -209,6 +211,28 @@ class TestEmrContainerTrigger:
"aws_conn_id": "aws_default",
}
+ def test_serialization_default_max_attempts(self):
+ virtual_cluster_id = "test_virtual_cluster_id"
+ job_id = "test_job_id"
+ waiter_delay = 30
+ aws_conn_id = "aws_default"
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id=virtual_cluster_id,
+ job_id=job_id,
+ waiter_delay=waiter_delay,
+ aws_conn_id=aws_conn_id,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+ assert kwargs == {
+ "virtual_cluster_id": "test_virtual_cluster_id",
+ "job_id": "test_job_id",
+ "waiter_delay": 30,
+ "waiter_max_attempts": sys.maxsize,
+ "aws_conn_id": "aws_default",
+ }
+
class TestEmrStepSensorTrigger:
def test_serialization(self):