ferruzzi commented on code in PR #32274:
URL: https://github.com/apache/airflow/pull/32274#discussion_r1253590001
##########
airflow/providers/amazon/aws/triggers/ecs.py:
##########
@@ -22,68 +22,89 @@
from botocore.exceptions import ClientError, WaiterError
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.triggers.base_trigger import
AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
-from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
-class ClusterWaiterTrigger(BaseTrigger):
+class ClusterActiveTrigger(AwsBaseWaiterTrigger):
"""
- Polls the status of a cluster using a given waiter. Can be used to poll
for an active or inactive cluster.
+ Polls the status of a cluster until it's ready.
- :param waiter_name: Name of the waiter to use, for instance
'cluster_active' or 'cluster_inactive'
:param cluster_arn: ARN of the cluster to watch.
:param waiter_delay: The amount of time in seconds to wait between
attempts.
:param waiter_max_attempts: The number of times to ping for status.
Will fail after that many unsuccessful attempts.
:param aws_conn_id: The Airflow connection used for AWS credentials.
- :param region: The AWS region where the cluster is located.
+ :param region_name: The AWS region where the cluster is located.
"""
def __init__(
self,
- waiter_name: str,
cluster_arn: str,
- waiter_delay: int | None,
- waiter_max_attempts: int | None,
+ waiter_delay: int,
+ waiter_max_attempts: int,
aws_conn_id: str | None,
- region: str | None,
+ region_name: str | None,
):
- self.cluster_arn = cluster_arn
- self.waiter_name = waiter_name
- self.waiter_delay = waiter_delay if waiter_delay is not None else 15
# written like this to allow 0
- self.attempts = waiter_max_attempts or 999999999
- self.aws_conn_id = aws_conn_id
- self.region = region
+ super().__init__(
+ serialized_fields={"cluster_arn": cluster_arn},
+ waiter_name="cluster_active",
+ waiter_args={"clusters": [cluster_arn]},
+ failure_message="Failure while waiting for cluster to be
available",
+ status_message="Cluster is not ready yet",
+ status_queries=["clusters[].status", "failures"],
+ return_key="arn",
+ return_value=cluster_arn,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ region_name=region_name,
+ )
- def serialize(self) -> tuple[str, dict[str, Any]]:
- return (
- self.__class__.__module__ + "." + self.__class__.__qualname__,
- {
- "waiter_name": self.waiter_name,
- "cluster_arn": self.cluster_arn,
- "waiter_delay": self.waiter_delay,
- "waiter_max_attempts": self.attempts,
- "aws_conn_id": self.aws_conn_id,
- "region": self.region,
- },
+ def hook(self) -> AwsGenericHook:
+ return EcsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+
+class ClusterInactiveTrigger(AwsBaseWaiterTrigger):
Review Comment:
I may be over-complicating this, but it looooks like the only difference
between `ClusterActiveTrigger` and `ClusterInactiveTrigger` is the
`waiter_name` value and the two message. In which case, why not have them both
inherit a `ClusterStatusTrigger` which accepts those three values and drop all
the repetition?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]