Lee-W commented on code in PR #32355:
URL: https://github.com/apache/airflow/pull/32355#discussion_r1262179951
##########
airflow/providers/amazon/aws/triggers/eks.py:
##########
@@ -17,11 +17,174 @@
from __future__ import annotations
import warnings
+from typing import Any
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.triggers.base import TriggerEvent
+
+
+class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger for EksCreateClusterOperator.
+
+ The trigger will asynchronously wait for the cluster to be created.
+
+ :param cluster_name: The name of the EKS cluster
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region_name: Which AWS region the connection should use.
+ If this is None or empty then the default boto3 behaviour is used.
+ """
+
+ def __init__(
+ self,
+ cluster_name: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str,
+ region_name: str | None,
+ ):
+ super().__init__(
+ serialized_fields={"cluster_name": cluster_name, "region_name":
region_name},
+ waiter_name="cluster_active",
+ waiter_args={"name": cluster_name},
+ failure_message="Error checking Eks cluster",
+ status_message="Eks cluster status is",
+ status_queries=["cluster.status"],
+ return_value=None,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ region_name=region_name,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return EksHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+
+class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger for EksDeleteClusterOperator.
+
+ The trigger will asynchronously wait for the cluster to be deleted. If
there are
+ any nodegroups or fargate profiles associated with the cluster, they will
be deleted
+ before the cluster is deleted.
+
+ :param cluster_name: The name of the EKS cluster
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region_name: Which AWS region the connection should use.
+ If this is None or empty then the default boto3 behaviour is used.
+ :param force_delete_compute: If True, any nodegroups or fargate profiles
associated
+ with the cluster will be deleted before the cluster is deleted.
+ """
+
+ def __init__(
+ self,
+ cluster_name,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str,
+ region_name: str | None,
+ force_delete_compute: bool,
+ ):
+ self.cluster_name = cluster_name
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.force_delete_compute = force_delete_compute
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "cluster_name": self.cluster_name,
+ "waiter_delay": str(self.waiter_delay),
+ "waiter_max_attempts": str(self.waiter_max_attempts),
+ "aws_conn_id": self.aws_conn_id,
+ "region_name": self.region_name,
+ "force_delete_compute": self.force_delete_compute,
+ },
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return EksHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+ async def run(self):
+ async with self.hook.async_conn as client:
+ waiter = client.get_waiter("cluster_deleted")
+ if self.force_delete_compute:
+ await self.delete_any_nodegroups(client=client)
+ await self.delete_any_fargate_profiles(client=client)
+ await client.delete_cluster(name=self.cluster_name)
+ await async_wait(
+ waiter=waiter,
+ waiter_delay=int(self.waiter_delay),
+ waiter_max_attempts=int(self.waiter_max_attempts),
+ args={"name": self.cluster_name},
+ failure_message="Error deleting cluster",
+ status_message="Status of cluster is",
+ status_args=["cluster.status"],
+ )
+
+ yield TriggerEvent({"status": "deleted"})
+
+ async def delete_any_nodegroups(self, client):
Review Comment:
nitpick
```suggestion
async def delete_any_nodegroups(self, client) -> None:
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]