ferruzzi commented on code in PR #32355:
URL: https://github.com/apache/airflow/pull/32355#discussion_r1253508551
##########
airflow/providers/amazon/aws/triggers/eks.py:
##########
@@ -27,6 +28,185 @@
from airflow.triggers.base import BaseTrigger, TriggerEvent
+class EksCreateClusterTrigger(BaseTrigger):
+ """
+ Trigger for EksCreateClusterOperator.
+ The trigger will asynchronously wait for the cluster to be created.
Review Comment:
```suggestion
Trigger for EksCreateClusterOperator.
The trigger will asynchronously wait for the cluster to be created.
```
I'm working on enabling D205 style checks which requires a space after the
first "summary" line, if you don't mind.
##########
airflow/providers/amazon/aws/operators/eks.py:
##########
@@ -273,12 +280,28 @@ def execute(self, context: Context):
# Short circuit early if we don't need to wait to attach compute
# and the caller hasn't requested to wait for the cluster either.
- if not self.compute and not self.wait_for_completion:
+ # if not self.compute and not self.wait_for_completion and not
self.deferrable:
+ # return None
+ if not any([self.compute, self.wait_for_completion, self.deferrable]):
return None
self.log.info("Waiting for EKS Cluster to provision. This will take
some time.")
client = self.eks_hook.conn
+ if self.deferrable:
+ self.defer(
+ trigger=EksCreateClusterTrigger(
+ waiter_name="cluster_active",
+ cluster_name=self.cluster_name,
+ aws_conn_id=self.aws_conn_id,
+ region=self.region,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
Review Comment:
Non-blocking nit/suggestion. Here and in the other Operators below:
One way you could reduce repetition is setting
```
self.waiter_config = {
cluster_name=self.cluster_name,
aws_conn_id=self.aws_conn_id,
region=self.region,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
}
```
up in the class init and unpacking it in each of these method calls (here
and on lines 342, 360, 375, 388, etc.) like:
```trigger=EksCreateClusterTrigger(waiter_name="cluster_active",
**self.waiter_config)```
##########
airflow/providers/amazon/aws/operators/eks.py:
##########
@@ -273,12 +280,28 @@ def execute(self, context: Context):
# Short circuit early if we don't need to wait to attach compute
# and the caller hasn't requested to wait for the cluster either.
- if not self.compute and not self.wait_for_completion:
+ # if not self.compute and not self.wait_for_completion and not
self.deferrable:
+ # return None
Review Comment:
Guessing this is meant to be removed?
##########
tests/providers/amazon/aws/triggers/test_eks.py:
##########
@@ -453,3 +455,403 @@ async def
test_eks_nodegroup_trigger_run_attempts_failed(self, mock_async_conn,
assert "Error checking nodegroup" in str(exc.value)
assert a_mock.get_waiter().wait.call_count == 3
+
+
+class TestEksCreateClusterTrigger:
+ def test_eks_create_cluster_trigger_serialize(self):
+ eks_create_cluster_trigger = EksCreateClusterTrigger(
+ waiter_name="test_waiter_name",
+ cluster_name=TEST_CLUSTER_IDENTIFIER,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ region=TEST_REGION,
+ )
+
+ class_path, args = eks_create_cluster_trigger.serialize()
+ assert class_path ==
"airflow.providers.amazon.aws.triggers.eks.EksCreateClusterTrigger"
+ assert args["waiter_name"] == "test_waiter_name"
+ assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER
+ assert args["aws_conn_id"] == TEST_AWS_CONN_ID
+ assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
+ assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
+ assert args["region"] == TEST_REGION
+
+ @pytest.mark.asyncio
+ @mock.patch.object(EksHook, "async_conn")
+ async def test_eks_create_cluster_trigger_run(self, mock_async_conn):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+
+ a_mock.get_waiter().wait = AsyncMock()
+
+ eks_create_cluster_trigger = EksCreateClusterTrigger(
+ waiter_name="test_waiter_name",
+ cluster_name=TEST_CLUSTER_IDENTIFIER,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ region=TEST_REGION,
+ )
+
+ generator = eks_create_cluster_trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent(
+ {
+ "status": "success",
+ }
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+ @mock.patch.object(EksHook, "async_conn")
+ async def test_eks_create_cluster_trigger_run_multiple_attempts(self,
mock_async_conn, mock_sleep):
+ mock_sleep.return_value = True
Review Comment:
Is there a reason not to set the return type in the patch line directly like
```
@mock.patch("asyncio.sleep", return_value=True)
```
##########
airflow/providers/amazon/aws/triggers/eks.py:
##########
@@ -27,6 +28,185 @@
from airflow.triggers.base import BaseTrigger, TriggerEvent
+class EksCreateClusterTrigger(BaseTrigger):
+ """
+ Trigger for EksCreateClusterOperator.
+ The trigger will asynchronously wait for the cluster to be created.
+
+ :param waiter_name: The name of the waiter to use.
+ :param cluster_name: The name of the EKS cluster
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region: Which AWS region the connection should use.
+ If this is None or empty then the default boto3 behaviour is used.
+ """
+
+ def __init__(
+ self,
+ waiter_name: str,
+ cluster_name: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str,
+ region: str | None,
+ ):
+ self.waiter_name = waiter_name
+ self.cluster_name = cluster_name
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.aws_conn_id = aws_conn_id
+ self.region = region
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "waiter_name": self.waiter_name,
+ "cluster_name": self.cluster_name,
+ "waiter_delay": str(self.waiter_delay),
+ "waiter_max_attempts": str(self.waiter_max_attempts),
+ "aws_conn_id": self.aws_conn_id,
+ "region": self.region,
+ },
+ )
+
+ async def run(self):
+ failure_message = "Error checking Eks cluster"
+ self.hook = EksHook(aws_conn_id=self.aws_conn_id,
region_name=self.region)
+ async with self.hook.async_conn as client:
+ waiter = client.get_waiter(self.waiter_name)
+ try:
+ await async_wait(
+ waiter=waiter,
+ waiter_max_attempts=int(self.waiter_max_attempts),
+ waiter_delay=int(self.waiter_delay),
+ args={"name": self.cluster_name},
+ failure_message=failure_message,
+ status_message="Eks cluster status is",
+ status_args=["cluster.status"],
+ )
+ except AirflowException as exc:
+ if failure_message in str(exc):
+ yield TriggerEvent({"status": "failed", "exception": exc})
+ raise
+ yield TriggerEvent({"status": "success"})
+
+
+class EksDeleteClusterTrigger(BaseTrigger):
+ """
+ Trigger for EksDeleteClusterOperator.
+ The trigger will asynchronously wait for the cluster to be deleted. If
there are
Review Comment:
```suggestion
Trigger for EksDeleteClusterOperator.
The trigger will asynchronously wait for the cluster to be deleted. If
there are
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]