This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5c887988b0 Refactor Eks Create Cluster Operator code (#31960)
5c887988b0 is described below
commit 5c887988b02b02e60f693c9341013592a291ee27
Author: Syed Hussaain <[email protected]>
AuthorDate: Fri Jun 23 14:18:13 2023 -0700
Refactor Eks Create Cluster Operator code (#31960)
* Refactor EksCreateClusterOperator to reuse code being used in multiple
places
* Update create_compute method to pass tests
Add waiter params to EksCreateClusterOperator and EksCreateNodegroupOperator
Update EksCreateFargateProfileTrigger and EksDeleteFargateProfileTrigger to
use more consistent waiter names
Update unit tests for triggers and operators
---
airflow/providers/amazon/aws/operators/eks.py | 249 +++++++++++++++--------
airflow/providers/amazon/aws/triggers/eks.py | 62 +++---
tests/providers/amazon/aws/operators/test_eks.py | 32 ++-
tests/providers/amazon/aws/triggers/test_eks.py | 52 ++---
4 files changed, 247 insertions(+), 148 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/eks.py
b/airflow/providers/amazon/aws/operators/eks.py
index 8131be4f65..e280da5e5a 100644
--- a/airflow/providers/amazon/aws/operators/eks.py
+++ b/airflow/providers/amazon/aws/operators/eks.py
@@ -17,10 +17,11 @@
"""This module contains Amazon EKS operators."""
from __future__ import annotations
+import logging
import warnings
from ast import literal_eval
from datetime import timedelta
-from typing import TYPE_CHECKING, Any, List, Sequence, cast
+from typing import TYPE_CHECKING, List, Sequence, cast
from botocore.exceptions import ClientError, WaiterError
@@ -31,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.eks import (
EksCreateFargateProfileTrigger,
EksDeleteFargateProfileTrigger,
)
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
try:
from airflow.providers.cncf.kubernetes.operators.pod import
KubernetesPodOperator
@@ -59,6 +61,75 @@ NODEGROUP_FULL_NAME = "Amazon EKS managed node groups"
FARGATE_FULL_NAME = "AWS Fargate profiles"
+def _create_compute(
+ compute: str | None,
+ cluster_name: str,
+ aws_conn_id: str,
+ region: str | None,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ wait_for_completion: bool = False,
+ nodegroup_name: str | None = None,
+ nodegroup_role_arn: str | None = None,
+ create_nodegroup_kwargs: dict | None = None,
+ fargate_profile_name: str | None = None,
+ fargate_pod_execution_role_arn: str | None = None,
+ fargate_selectors: list | None = None,
+ create_fargate_profile_kwargs: dict | None = None,
+ subnets: list[str] | None = None,
+):
+ log = logging.getLogger(__name__)
+ eks_hook = EksHook(aws_conn_id=aws_conn_id, region_name=region)
+ if compute == "nodegroup" and nodegroup_name:
+
+ # this is to satisfy mypy
+ subnets = subnets or []
+ create_nodegroup_kwargs = create_nodegroup_kwargs or {}
+
+ eks_hook.create_nodegroup(
+ clusterName=cluster_name,
+ nodegroupName=nodegroup_name,
+ subnets=subnets,
+ nodeRole=nodegroup_role_arn,
+ **create_nodegroup_kwargs,
+ )
+ if wait_for_completion:
+ log.info("Waiting for nodegroup to provision. This will take some
time.")
+ wait(
+ waiter=eks_hook.conn.get_waiter("nodegroup_active"),
+ waiter_delay=waiter_delay,
+ max_attempts=waiter_max_attempts,
+ args={"clusterName": cluster_name, "nodegroupName":
nodegroup_name},
+ failure_message="Nodegroup creation failed",
+ status_message="Nodegroup status is",
+ status_args=["nodegroup.status"],
+ )
+ elif compute == "fargate" and fargate_profile_name:
+
+ # this is to satisfy mypy
+ create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
+ fargate_selectors = fargate_selectors or []
+
+ eks_hook.create_fargate_profile(
+ clusterName=cluster_name,
+ fargateProfileName=fargate_profile_name,
+ podExecutionRoleArn=fargate_pod_execution_role_arn,
+ selectors=fargate_selectors,
+ **create_fargate_profile_kwargs,
+ )
+ if wait_for_completion:
+ log.info("Waiting for Fargate profile to provision. This will
take some time.")
+ wait(
+ waiter=eks_hook.conn.get_waiter("fargate_profile_active"),
+ waiter_delay=waiter_delay,
+ max_attempts=waiter_max_attempts,
+ args={"clusterName": cluster_name, "fargateProfileName":
fargate_profile_name},
+ failure_message="Fargate profile creation failed",
+ status_message="Fargate profile status is",
+ status_args=["fargateProfile.status"],
+ )
+
+
class EksCreateClusterOperator(BaseOperator):
"""
Creates an Amazon EKS Cluster control plane.
@@ -112,6 +183,8 @@ class EksCreateClusterOperator(BaseOperator):
:param fargate_selectors: The selectors to match for pods to use this AWS
Fargate profile. (templated)
:param create_fargate_profile_kwargs: Optional parameters to pass to the
CreateFargateProfile API
(templated)
+ :param waiter_delay: Time (in seconds) to wait between two consecutive
calls to check cluster status
+ :param waiter_max_attempts: The maximum number of attempts to check the
status of the cluster.
"""
@@ -137,7 +210,7 @@ class EksCreateClusterOperator(BaseOperator):
self,
cluster_name: str,
cluster_role_arn: str,
- resources_vpc_config: dict[str, Any],
+ resources_vpc_config: dict,
compute: str | None = DEFAULT_COMPUTE_TYPE,
create_cluster_kwargs: dict | None = None,
nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
@@ -150,6 +223,8 @@ class EksCreateClusterOperator(BaseOperator):
wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 40,
**kwargs,
) -> None:
self.compute = compute
@@ -157,17 +232,21 @@ class EksCreateClusterOperator(BaseOperator):
self.cluster_role_arn = cluster_role_arn
self.resources_vpc_config = resources_vpc_config
self.create_cluster_kwargs = create_cluster_kwargs or {}
- self.nodegroup_name = nodegroup_name
self.nodegroup_role_arn = nodegroup_role_arn
- self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
- self.fargate_profile_name = fargate_profile_name
self.fargate_pod_execution_role_arn = fargate_pod_execution_role_arn
- self.fargate_selectors = fargate_selectors or [{"namespace":
DEFAULT_NAMESPACE_NAME}]
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or
{}
self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
self.region = region
- super().__init__(**kwargs)
+ self.nodegroup_name = nodegroup_name
+ self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
+ self.fargate_selectors = fargate_selectors or [{"namespace":
DEFAULT_NAMESPACE_NAME}]
+ self.fargate_profile_name = fargate_profile_name
+ super().__init__(
+ **kwargs,
+ )
def execute(self, context: Context):
if self.compute:
@@ -183,13 +262,8 @@ class EksCreateClusterOperator(BaseOperator):
compute=FARGATE_FULL_NAME,
requirement="fargate_pod_execution_role_arn"
)
)
-
- eks_hook = EksHook(
- aws_conn_id=self.aws_conn_id,
- region_name=self.region,
- )
-
- eks_hook.create_cluster(
+ self.eks_hook = EksHook(aws_conn_id=self.aws_conn_id,
region_name=self.region)
+ self.eks_hook.create_cluster(
name=self.cluster_name,
roleArn=self.cluster_role_arn,
resourcesVpcConfig=self.resources_vpc_config,
@@ -202,44 +276,38 @@ class EksCreateClusterOperator(BaseOperator):
return None
self.log.info("Waiting for EKS Cluster to provision. This will take
some time.")
- client = eks_hook.conn
+ client = self.eks_hook.conn
try:
- client.get_waiter("cluster_active").wait(name=self.cluster_name)
+ client.get_waiter("cluster_active").wait(
+ name=self.cluster_name,
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts":
self.waiter_max_attempts},
+ )
except (ClientError, WaiterError) as e:
self.log.error("Cluster failed to start and will be torn down.\n
%s", e)
- eks_hook.delete_cluster(name=self.cluster_name)
- client.get_waiter("cluster_deleted").wait(name=self.cluster_name)
- raise
-
- if self.compute == "nodegroup":
- eks_hook.create_nodegroup(
- clusterName=self.cluster_name,
- nodegroupName=self.nodegroup_name,
- subnets=cast(List[str],
self.resources_vpc_config.get("subnetIds")),
- nodeRole=self.nodegroup_role_arn,
- **self.create_nodegroup_kwargs,
- )
- if self.wait_for_completion:
- self.log.info("Waiting for nodegroup to provision. This will
take some time.")
- client.get_waiter("nodegroup_active").wait(
- clusterName=self.cluster_name,
- nodegroupName=self.nodegroup_name,
- )
- elif self.compute == "fargate":
- eks_hook.create_fargate_profile(
- clusterName=self.cluster_name,
- fargateProfileName=self.fargate_profile_name,
- podExecutionRoleArn=self.fargate_pod_execution_role_arn,
- selectors=self.fargate_selectors,
- **self.create_fargate_profile_kwargs,
+ self.eks_hook.delete_cluster(name=self.cluster_name)
+ client.get_waiter("cluster_deleted").wait(
+ name=self.cluster_name,
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts":
self.waiter_max_attempts},
)
- if self.wait_for_completion:
- self.log.info("Waiting for Fargate profile to provision. This
will take some time.")
- client.get_waiter("fargate_profile_active").wait(
- clusterName=self.cluster_name,
- fargateProfileName=self.fargate_profile_name,
- )
+ raise
+ _create_compute(
+ compute=self.compute,
+ cluster_name=self.cluster_name,
+ aws_conn_id=self.aws_conn_id,
+ region=self.region,
+ wait_for_completion=self.wait_for_completion,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ nodegroup_name=self.nodegroup_name,
+ nodegroup_role_arn=self.nodegroup_role_arn,
+ create_nodegroup_kwargs=self.create_nodegroup_kwargs,
+ fargate_profile_name=self.fargate_profile_name,
+ fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
+ fargate_selectors=self.fargate_selectors,
+ create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
+ subnets=cast(List[str],
self.resources_vpc_config.get("subnetIds")),
+ )
class EksCreateNodegroupOperator(BaseOperator):
@@ -265,6 +333,8 @@ class EksCreateNodegroupOperator(BaseOperator):
maintained on each worker node).
:param region: Which AWS region the connection should use. (templated)
If this is None or empty then the default boto3 behaviour is used.
+ :param waiter_delay: Time (in seconds) to wait between two consecutive
calls to check nodegroup status
+ :param waiter_max_attempts: The maximum number of attempts to check the
status of the nodegroup.
"""
@@ -289,8 +359,12 @@ class EksCreateNodegroupOperator(BaseOperator):
wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 80,
**kwargs,
) -> None:
+ self.nodegroup_subnets = nodegroup_subnets
+ self.compute = "nodegroup"
self.cluster_name = cluster_name
self.nodegroup_role_arn = nodegroup_role_arn
self.nodegroup_name = nodegroup_name
@@ -298,10 +372,15 @@ class EksCreateNodegroupOperator(BaseOperator):
self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
- self.nodegroup_subnets = nodegroup_subnets
- super().__init__(**kwargs)
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+
+ super().__init__(
+ **kwargs,
+ )
def execute(self, context: Context):
+ self.log.info(self.task_id)
if isinstance(self.nodegroup_subnets, str):
nodegroup_subnets_list: list[str] = []
if self.nodegroup_subnets != "":
@@ -314,25 +393,20 @@ class EksCreateNodegroupOperator(BaseOperator):
self.nodegroup_subnets,
)
self.nodegroup_subnets = nodegroup_subnets_list
-
- eks_hook = EksHook(
+ _create_compute(
+ compute=self.compute,
+ cluster_name=self.cluster_name,
aws_conn_id=self.aws_conn_id,
- region_name=self.region,
- )
- eks_hook.create_nodegroup(
- clusterName=self.cluster_name,
- nodegroupName=self.nodegroup_name,
+ region=self.region,
+ wait_for_completion=self.wait_for_completion,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ nodegroup_name=self.nodegroup_name,
+ nodegroup_role_arn=self.nodegroup_role_arn,
+ create_nodegroup_kwargs=self.create_nodegroup_kwargs,
subnets=self.nodegroup_subnets,
- nodeRole=self.nodegroup_role_arn,
- **self.create_nodegroup_kwargs,
)
- if self.wait_for_completion:
- self.log.info("Waiting for nodegroup to provision. This will take
some time.")
- eks_hook.conn.get_waiter("nodegroup_active").wait(
- clusterName=self.cluster_name,
nodegroupName=self.nodegroup_name
- )
-
class EksCreateFargateProfileOperator(BaseOperator):
"""
@@ -392,30 +466,34 @@ class EksCreateFargateProfileOperator(BaseOperator):
**kwargs,
) -> None:
self.cluster_name = cluster_name
- self.pod_execution_role_arn = pod_execution_role_arn
self.selectors = selectors
+ self.pod_execution_role_arn = pod_execution_role_arn
self.fargate_profile_name = fargate_profile_name
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or
{}
- self.wait_for_completion = wait_for_completion
+ self.wait_for_completion = False if deferrable else wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
- super().__init__(**kwargs)
+ self.compute = "fargate"
+ super().__init__(
+ **kwargs,
+ )
def execute(self, context: Context):
- eks_hook = EksHook(
+ _create_compute(
+ compute=self.compute,
+ cluster_name=self.cluster_name,
aws_conn_id=self.aws_conn_id,
- region_name=self.region,
- )
-
- eks_hook.create_fargate_profile(
- clusterName=self.cluster_name,
- fargateProfileName=self.fargate_profile_name,
- podExecutionRoleArn=self.pod_execution_role_arn,
- selectors=self.selectors,
- **self.create_fargate_profile_kwargs,
+ region=self.region,
+ wait_for_completion=self.wait_for_completion,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ fargate_profile_name=self.fargate_profile_name,
+ fargate_pod_execution_role_arn=self.pod_execution_role_arn,
+ fargate_selectors=self.selectors,
+ create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
)
if self.deferrable:
self.defer(
@@ -423,21 +501,15 @@ class EksCreateFargateProfileOperator(BaseOperator):
cluster_name=self.cluster_name,
fargate_profile_name=self.fargate_profile_name,
aws_conn_id=self.aws_conn_id,
- poll_interval=self.waiter_delay,
- max_attempts=self.waiter_max_attempts,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ region=self.region,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout
does not restart
# 60 seconds is added to allow the trigger to exit gracefully
(i.e. yield TriggerEvent)
timeout=timedelta(seconds=(self.waiter_max_attempts *
self.waiter_delay + 60)),
)
- elif self.wait_for_completion:
- self.log.info("Waiting for Fargate profile to provision. This
will take some time.")
- eks_hook.conn.get_waiter("fargate_profile_active").wait(
- clusterName=self.cluster_name,
- fargateProfileName=self.fargate_profile_name,
- WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts":
self.waiter_max_attempts},
- )
def execute_complete(self, context, event=None):
if event["status"] != "success":
@@ -677,8 +749,9 @@ class EksDeleteFargateProfileOperator(BaseOperator):
cluster_name=self.cluster_name,
fargate_profile_name=self.fargate_profile_name,
aws_conn_id=self.aws_conn_id,
- poll_interval=self.waiter_delay,
- max_attempts=self.waiter_max_attempts,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ region=self.region,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout
does not restart
diff --git a/airflow/providers/amazon/aws/triggers/eks.py
b/airflow/providers/amazon/aws/triggers/eks.py
index dddab74b30..8ccd88167c 100644
--- a/airflow/providers/amazon/aws/triggers/eks.py
+++ b/airflow/providers/amazon/aws/triggers/eks.py
@@ -33,8 +33,8 @@ class EksCreateFargateProfileTrigger(BaseTrigger):
:param cluster_name: The name of the EKS cluster
:param fargate_profile_name: The name of the fargate profile
- :param poll_interval: The amount of time in seconds to wait between
attempts.
- :param max_attempts: The maximum number of attempts to be made.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""
@@ -42,15 +42,17 @@ class EksCreateFargateProfileTrigger(BaseTrigger):
self,
cluster_name: str,
fargate_profile_name: str,
- poll_interval: int,
- max_attempts: int,
+ waiter_delay: int,
+ waiter_max_attempts: int,
aws_conn_id: str,
+ region: str | None = None,
):
self.cluster_name = cluster_name
self.fargate_profile_name = fargate_profile_name
- self.poll_interval = poll_interval
- self.max_attempts = max_attempts
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
+ self.region = region
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
@@ -58,24 +60,25 @@ class EksCreateFargateProfileTrigger(BaseTrigger):
{
"cluster_name": self.cluster_name,
"fargate_profile_name": self.fargate_profile_name,
- "poll_interval": str(self.poll_interval),
- "max_attempts": str(self.max_attempts),
+ "waiter_delay": str(self.waiter_delay),
+ "waiter_max_attempts": str(self.waiter_max_attempts),
"aws_conn_id": self.aws_conn_id,
+ "region": self.region,
},
)
async def run(self):
- self.hook = EksHook(aws_conn_id=self.aws_conn_id)
+ self.hook = EksHook(aws_conn_id=self.aws_conn_id,
region_name=self.region)
async with self.hook.async_conn as client:
attempt = 0
waiter = client.get_waiter("fargate_profile_active")
- while attempt < int(self.max_attempts):
+ while attempt < int(self.waiter_max_attempts):
attempt += 1
try:
await waiter.wait(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
- WaiterConfig={"Delay": int(self.poll_interval),
"MaxAttempts": 1},
+ WaiterConfig={"Delay": int(self.waiter_delay),
"MaxAttempts": 1},
)
break
except WaiterError as error:
@@ -84,10 +87,10 @@ class EksCreateFargateProfileTrigger(BaseTrigger):
self.log.info(
"Status of fargate profile is %s",
error.last_response["fargateProfile"]["status"]
)
- await asyncio.sleep(int(self.poll_interval))
- if attempt >= int(self.max_attempts):
+ await asyncio.sleep(int(self.waiter_delay))
+ if attempt >= int(self.waiter_max_attempts):
raise AirflowException(
- f"Create Fargate Profile failed - max attempts reached:
{self.max_attempts}"
+ f"Create Fargate Profile failed - max attempts reached:
{self.waiter_max_attempts}"
)
else:
yield TriggerEvent({"status": "success", "message": "Fargate
Profile Created"})
@@ -100,8 +103,8 @@ class EksDeleteFargateProfileTrigger(BaseTrigger):
:param cluster_name: The name of the EKS cluster
:param fargate_profile_name: The name of the fargate profile
- :param poll_interval: The amount of time in seconds to wait between
attempts.
- :param max_attempts: The maximum number of attempts to be made.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""
@@ -109,15 +112,17 @@ class EksDeleteFargateProfileTrigger(BaseTrigger):
self,
cluster_name: str,
fargate_profile_name: str,
- poll_interval: int,
- max_attempts: int,
+ waiter_delay: int,
+ waiter_max_attempts: int,
aws_conn_id: str,
+ region: str | None = None,
):
self.cluster_name = cluster_name
self.fargate_profile_name = fargate_profile_name
- self.poll_interval = poll_interval
- self.max_attempts = max_attempts
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
+ self.region = region
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
@@ -125,24 +130,25 @@ class EksDeleteFargateProfileTrigger(BaseTrigger):
{
"cluster_name": self.cluster_name,
"fargate_profile_name": self.fargate_profile_name,
- "poll_interval": str(self.poll_interval),
- "max_attempts": str(self.max_attempts),
+ "waiter_delay": str(self.waiter_delay),
+ "waiter_max_attempts": str(self.waiter_max_attempts),
"aws_conn_id": self.aws_conn_id,
+ "region": self.region,
},
)
async def run(self):
- self.hook = EksHook(aws_conn_id=self.aws_conn_id)
+ self.hook = EksHook(aws_conn_id=self.aws_conn_id,
region_name=self.region)
async with self.hook.async_conn as client:
attempt = 0
waiter = client.get_waiter("fargate_profile_deleted")
- while attempt < int(self.max_attempts):
+ while attempt < int(self.waiter_max_attempts):
attempt += 1
try:
await waiter.wait(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
- WaiterConfig={"Delay": int(self.poll_interval),
"MaxAttempts": 1},
+ WaiterConfig={"Delay": int(self.waiter_delay),
"MaxAttempts": 1},
)
break
except WaiterError as error:
@@ -151,10 +157,10 @@ class EksDeleteFargateProfileTrigger(BaseTrigger):
self.log.info(
"Status of fargate profile is %s",
error.last_response["fargateProfile"]["status"]
)
- await asyncio.sleep(int(self.poll_interval))
- if attempt >= int(self.max_attempts):
+ await asyncio.sleep(int(self.waiter_delay))
+ if attempt >= int(self.waiter_max_attempts):
raise AirflowException(
- f"Delete Fargate Profile failed - max attempts reached:
{self.max_attempts}"
+ f"Delete Fargate Profile failed - max attempts reached:
{self.waiter_max_attempts}"
)
else:
yield TriggerEvent({"status": "success", "message": "Fargate
Profile Deleted"})
diff --git a/tests/providers/amazon/aws/operators/test_eks.py
b/tests/providers/amazon/aws/operators/test_eks.py
index 089aef1704..311aad972d 100644
--- a/tests/providers/amazon/aws/operators/test_eks.py
+++ b/tests/providers/amazon/aws/operators/test_eks.py
@@ -200,7 +200,11 @@ class TestEksCreateClusterOperator:
operator.execute({})
mock_create_cluster.assert_called_with(**convert_keys(parameters))
mock_create_nodegroup.assert_not_called()
- mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME)
+ mock_waiter.assert_called_once_with(
+ mock.ANY,
+ name=CLUSTER_NAME,
+ WaiterConfig={"Delay": mock.ANY, "MaxAttempts": mock.ANY},
+ )
assert_expected_waiter_type(mock_waiter, "ClusterActive")
@mock.patch.object(Waiter, "wait")
@@ -216,7 +220,11 @@ class TestEksCreateClusterOperator:
mock_create_cluster.assert_called_once_with(**convert_keys(self.create_cluster_params))
mock_create_nodegroup.assert_called_once_with(**convert_keys(self.create_nodegroup_params))
- mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME)
+ mock_waiter.assert_called_once_with(
+ mock.ANY,
+ name=CLUSTER_NAME,
+ WaiterConfig={"Delay": mock.ANY, "MaxAttempts": mock.ANY},
+ )
assert_expected_waiter_type(mock_waiter, "ClusterActive")
@mock.patch.object(Waiter, "wait")
@@ -235,7 +243,12 @@ class TestEksCreateClusterOperator:
mock_create_nodegroup.assert_called_once_with(**convert_keys(self.create_nodegroup_params))
# Calls waiter once for the cluster and once for the nodegroup.
assert mock_waiter.call_count == 2
- mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME,
nodegroupName=NODEGROUP_NAME)
+ mock_waiter.assert_called_with(
+ mock.ANY,
+ clusterName=CLUSTER_NAME,
+ nodegroupName=NODEGROUP_NAME,
+ WaiterConfig={"MaxAttempts": mock.ANY},
+ )
assert_expected_waiter_type(mock_waiter, "NodegroupActive")
@mock.patch.object(Waiter, "wait")
@@ -253,7 +266,11 @@ class TestEksCreateClusterOperator:
mock_create_fargate_profile.assert_called_once_with(
**convert_keys(self.create_fargate_profile_params)
)
- mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME)
+ mock_waiter.assert_called_once_with(
+ mock.ANY,
+ name=CLUSTER_NAME,
+ WaiterConfig={"Delay": mock.ANY, "MaxAttempts": mock.ANY},
+ )
assert_expected_waiter_type(mock_waiter, "ClusterActive")
@mock.patch.object(Waiter, "wait")
@@ -275,7 +292,10 @@ class TestEksCreateClusterOperator:
# Calls waiter once for the cluster and once for the nodegroup.
assert mock_waiter.call_count == 2
mock_waiter.assert_called_with(
- mock.ANY, clusterName=CLUSTER_NAME,
fargateProfileName=FARGATE_PROFILE_NAME
+ mock.ANY,
+ clusterName=CLUSTER_NAME,
+ fargateProfileName=FARGATE_PROFILE_NAME,
+ WaiterConfig={"MaxAttempts": mock.ANY},
)
assert_expected_waiter_type(mock_waiter, "FargateProfileActive")
@@ -377,7 +397,7 @@ class TestEksCreateFargateProfileOperator:
mock.ANY,
clusterName=CLUSTER_NAME,
fargateProfileName=FARGATE_PROFILE_NAME,
- WaiterConfig={"Delay": 10, "MaxAttempts": 60},
+ WaiterConfig={"MaxAttempts": mock.ANY},
)
assert_expected_waiter_type(mock_waiter, "FargateProfileActive")
diff --git a/tests/providers/amazon/aws/triggers/test_eks.py
b/tests/providers/amazon/aws/triggers/test_eks.py
index abab121d24..dbc71e7296 100644
--- a/tests/providers/amazon/aws/triggers/test_eks.py
+++ b/tests/providers/amazon/aws/triggers/test_eks.py
@@ -32,8 +32,8 @@ from airflow.triggers.base import TriggerEvent
TEST_CLUSTER_IDENTIFIER = "test-cluster"
TEST_FARGATE_PROFILE_NAME = "test-fargate-profile"
-TEST_POLL_INTERVAL = 10
-TEST_MAX_ATTEMPTS = 10
+TEST_WAITER_DELAY = 10
+TEST_WAITER_MAX_ATTEMPTS = 10
TEST_AWS_CONN_ID = "test-aws-id"
@@ -43,8 +43,8 @@ class TestEksCreateFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
)
class_path, args = eks_create_fargate_profile_trigger.serialize()
@@ -52,8 +52,8 @@ class TestEksCreateFargateProfileTrigger:
assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER
assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME
assert args["aws_conn_id"] == TEST_AWS_CONN_ID
- assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
- assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS)
+ assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
+ assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
@pytest.mark.asyncio
@mock.patch.object(EksHook, "async_conn")
@@ -67,8 +67,8 @@ class TestEksCreateFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
)
generator = eks_create_fargate_profile_trigger.run()
@@ -96,8 +96,8 @@ class TestEksCreateFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
)
generator = eks_create_fargate_profile_trigger.run()
@@ -126,8 +126,8 @@ class TestEksCreateFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=2,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=2,
)
with pytest.raises(AirflowException) as exc:
generator = eks_create_fargate_profile_trigger.run()
@@ -158,8 +158,8 @@ class TestEksCreateFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
)
with pytest.raises(AirflowException) as exc:
@@ -175,8 +175,8 @@ class TestEksDeleteFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
)
class_path, args = eks_delete_fargate_profile_trigger.serialize()
@@ -184,8 +184,8 @@ class TestEksDeleteFargateProfileTrigger:
assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER
assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME
assert args["aws_conn_id"] == TEST_AWS_CONN_ID
- assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
- assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS)
+ assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
+ assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
@pytest.mark.asyncio
@mock.patch.object(EksHook, "async_conn")
@@ -199,8 +199,8 @@ class TestEksDeleteFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
)
generator = eks_delete_fargate_profile_trigger.run()
@@ -228,8 +228,8 @@ class TestEksDeleteFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
)
generator = eks_delete_fargate_profile_trigger.run()
@@ -257,8 +257,8 @@ class TestEksDeleteFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=2,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=2,
)
with pytest.raises(AirflowException) as exc:
generator = eks_delete_fargate_profile_trigger.run()
@@ -289,8 +289,8 @@ class TestEksDeleteFargateProfileTrigger:
cluster_name=TEST_CLUSTER_IDENTIFIER,
fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
)
with pytest.raises(AirflowException) as exc:
generator = eks_delete_fargate_profile_trigger.run()