vandonr-amz commented on code in PR #31657:
URL: https://github.com/apache/airflow/pull/31657#discussion_r1221944662
##########
airflow/providers/amazon/aws/operators/eks.py:
##########
@@ -401,13 +417,31 @@ def execute(self, context: Context):
selectors=self.selectors,
**self.create_fargate_profile_kwargs,
)
-
- if self.wait_for_completion:
+ if self.deferrable:
+ self.defer(
+ trigger=EksCreateFargateProfileTrigger(
+ cluster_name=self.cluster_name,
+ fargate_profile_name=self.fargate_profile_name,
+ aws_conn_id=self.aws_conn_id,
+ poll_interval=self.waiter_delay,
+ max_attempts=self.waiter_max_attempts,
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=(self.waiter_max_attempts *
self.waiter_delay + 60)),
Review Comment:
why 60 ? Maybe create a named constant ?
##########
airflow/providers/amazon/aws/triggers/eks.py:
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from functools import cached_property
+from typing import Any
+
+from botocore.exceptions import WaiterError
+
+from airflow.providers.amazon.aws.hooks.eks import EksHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class EksCreateFargateProfileTrigger(BaseTrigger):
+ """
+ Trigger for EksCreateFargateProfileOperator.
+ The trigger will asynchronously wait for the fargate profile to be created.
+
+ :param cluster_name: The name of the EKS cluster
+ :param fargate_profile_name: The name of the fargate profile
+ :param poll_interval: The amount of time in seconds to wait between
attempts.
+ :param max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ cluster_name: str,
+ fargate_profile_name: str,
+ poll_interval: int,
+ max_attempts: int,
+ aws_conn_id: str,
+ ):
+ self.cluster_name = cluster_name
+ self.fargate_profile_name = fargate_profile_name
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "cluster_name": self.cluster_name,
+ "fargate_profile_name": self.fargate_profile_name,
+ "poll_interval": str(self.poll_interval),
+ "max_attempts": str(self.max_attempts),
+ "aws_conn_id": self.aws_conn_id,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> EksHook:
+ return EksHook(aws_conn_id=self.aws_conn_id)
Review Comment:
mentioned it earlier, but I think caching this is useless
##########
airflow/providers/amazon/aws/operators/eks.py:
##########
@@ -622,12 +667,31 @@ def execute(self, context: Context):
eks_hook.delete_fargate_profile(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name
)
- if self.wait_for_completion:
Review Comment:
same comments as above to this code
##########
airflow/providers/amazon/aws/operators/eks.py:
##########
@@ -401,13 +417,31 @@ def execute(self, context: Context):
selectors=self.selectors,
**self.create_fargate_profile_kwargs,
)
-
- if self.wait_for_completion:
+ if self.deferrable:
+ self.defer(
+ trigger=EksCreateFargateProfileTrigger(
+ cluster_name=self.cluster_name,
+ fargate_profile_name=self.fargate_profile_name,
+ aws_conn_id=self.aws_conn_id,
+ poll_interval=self.waiter_delay,
+ max_attempts=self.waiter_max_attempts,
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=(self.waiter_max_attempts *
self.waiter_delay + 60)),
+ )
+ elif self.wait_for_completion:
self.log.info("Waiting for Fargate profile to provision. This
will take some time.")
eks_hook.conn.get_waiter("fargate_profile_active").wait(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name
)
Review Comment:
shall we pass the delay & max attempts to this waiter as well while we're at
it ?
##########
airflow/providers/amazon/aws/triggers/eks.py:
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from functools import cached_property
+from typing import Any
+
+from botocore.exceptions import WaiterError
+
+from airflow.providers.amazon.aws.hooks.eks import EksHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class EksCreateFargateProfileTrigger(BaseTrigger):
+ """
+ Trigger for EksCreateFargateProfileOperator.
+ The trigger will asynchronously wait for the fargate profile to be created.
+
+ :param cluster_name: The name of the EKS cluster
+ :param fargate_profile_name: The name of the fargate profile
+ :param poll_interval: The amount of time in seconds to wait between
attempts.
+ :param max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ cluster_name: str,
+ fargate_profile_name: str,
+ poll_interval: int,
+ max_attempts: int,
+ aws_conn_id: str,
+ ):
+ self.cluster_name = cluster_name
+ self.fargate_profile_name = fargate_profile_name
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "cluster_name": self.cluster_name,
+ "fargate_profile_name": self.fargate_profile_name,
+ "poll_interval": str(self.poll_interval),
+ "max_attempts": str(self.max_attempts),
+ "aws_conn_id": self.aws_conn_id,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> EksHook:
+ return EksHook(aws_conn_id=self.aws_conn_id)
+
+ async def run(self):
+ async with self.hook.async_conn as client:
+ attempt = 0
+ waiter = client.get_waiter("fargate_profile_active")
+ while attempt < int(self.max_attempts):
+ attempt += 1
+ try:
+ await waiter.wait(
+ clusterName=self.cluster_name,
+ fargateProfileName=self.fargate_profile_name,
+ WaiterConfig={"Delay": int(self.poll_interval),
"MaxAttempts": 1},
+ )
+ break
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ yield TriggerEvent(
+ {"status": "failure", "message": f"Create Fargate
Profile failed: {error}"}
+ )
Review Comment:
you should raise directly here instead of returning the error to the operator
see https://apache-airflow.slack.com/archives/CCPRP7943/p1685663968275419
##########
airflow/providers/amazon/aws/triggers/eks.py:
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from functools import cached_property
+from typing import Any
+
+from botocore.exceptions import WaiterError
+
+from airflow.providers.amazon.aws.hooks.eks import EksHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class EksCreateFargateProfileTrigger(BaseTrigger):
+ """
+ Trigger for EksCreateFargateProfileOperator.
+ The trigger will asynchronously wait for the fargate profile to be created.
+
+ :param cluster_name: The name of the EKS cluster
+ :param fargate_profile_name: The name of the fargate profile
+ :param poll_interval: The amount of time in seconds to wait between
attempts.
+ :param max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ cluster_name: str,
+ fargate_profile_name: str,
+ poll_interval: int,
+ max_attempts: int,
+ aws_conn_id: str,
+ ):
+ self.cluster_name = cluster_name
+ self.fargate_profile_name = fargate_profile_name
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "cluster_name": self.cluster_name,
+ "fargate_profile_name": self.fargate_profile_name,
+ "poll_interval": str(self.poll_interval),
+ "max_attempts": str(self.max_attempts),
+ "aws_conn_id": self.aws_conn_id,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> EksHook:
+ return EksHook(aws_conn_id=self.aws_conn_id)
+
+ async def run(self):
+ async with self.hook.async_conn as client:
+ attempt = 0
+ waiter = client.get_waiter("fargate_profile_active")
+ while attempt < int(self.max_attempts):
+ attempt += 1
+ try:
+ await waiter.wait(
+ clusterName=self.cluster_name,
+ fargateProfileName=self.fargate_profile_name,
+ WaiterConfig={"Delay": int(self.poll_interval),
"MaxAttempts": 1},
+ )
+ break
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ yield TriggerEvent(
+ {"status": "failure", "message": f"Create Fargate
Profile failed: {error}"}
+ )
+ break
+ self.log.info(
+ "Status of fargate profile is %s",
error.last_response["fargateProfile"]["status"]
+ )
+ await asyncio.sleep(int(self.poll_interval))
+ if attempt >= int(self.max_attempts):
+ yield TriggerEvent(
+ {
+ "status": "failure",
+ "message": "Create Fargate profile Failed - max attempts
reached.",
+ }
+ )
+ else:
+ yield TriggerEvent({"status": "success", "message": "Fargate
Profile Created"})
+
+
+class EksDeleteFargateProfileTrigger(BaseTrigger):
Review Comment:
same comments as above
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]