vandonr-amz commented on code in PR #32274:
URL: https://github.com/apache/airflow/pull/32274#discussion_r1253446469
##########
airflow/providers/amazon/aws/triggers/athena.py:
##########
@@ -16,61 +16,43 @@
# under the License.
from __future__ import annotations
-from typing import Any
-
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
-from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
-from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.triggers.base_trigger import
AwsBaseWaiterTrigger
-class AthenaTrigger(BaseTrigger):
+class AthenaTrigger(AwsBaseWaiterTrigger):
"""
Trigger for RedshiftCreateClusterOperator.
The trigger will asynchronously poll the boto3 API and wait for the
Redshift cluster to be in the `available` state.
:param query_execution_id: ID of the Athena query execution to watch
- :param poll_interval: The amount of time in seconds to wait between
attempts.
- :param max_attempt: The maximum number of attempts to be made.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""
def __init__(
self,
query_execution_id: str,
- poll_interval: int,
- max_attempt: int,
+ waiter_delay: int,
+ waiter_max_attempts: int,
Review Comment:
this trigger was added in #32186 merged on June 27th, last provider release
was on June 20th, so this breaking change is OK.
##########
airflow/providers/amazon/aws/triggers/ecs.py:
##########
@@ -22,68 +22,89 @@
from botocore.exceptions import ClientError, WaiterError
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.triggers.base_trigger import
AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
-from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
-class ClusterWaiterTrigger(BaseTrigger):
+class ClusterActiveTrigger(AwsBaseWaiterTrigger):
Review Comment:
this trigger was added in #31881 merged on June 23rd, last provider release
was on June 20th, so this breaking change is OK.
##########
airflow/providers/amazon/aws/triggers/batch.py:
##########
@@ -189,56 +193,78 @@ async def run(self):
)
-class BatchCreateComputeEnvironmentTrigger(BaseTrigger):
+class BatchJobTrigger(AwsBaseWaiterTrigger):
+ """
+ Checks for the status of a submitted job_id to AWS Batch until it reaches
a failure or a success state.
+
+ :param job_id: the job ID, to poll for job completion or not
+ :param region_name: AWS region name to use
+ Override the region_name in connection (if provided)
+ :param aws_conn_id: connection id of AWS credentials / region name. If
None,
+ credential boto3 strategy will be used
+ :param waiter_delay: polling period in seconds to check for the status of
the job
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ """
+
+ def __init__(
+ self,
+ job_id: str | None,
+ region_name: str | None,
+ aws_conn_id: str | None = "aws_default",
+ waiter_delay: int = 5,
+ waiter_max_attempts: int = 720,
+ ):
+ super().__init__(
+ serialized_fields={"job_id": job_id},
+ waiter_name="batch_job_complete",
+ waiter_args={"jobs": [job_id]},
+ failure_message=f"Failure while running batch job {job_id}",
+ status_message=f"Batch job {job_id} not ready yet",
+ status_queries=["jobs[].status",
"computeEnvironments[].statusReason"],
+ return_key="job_id",
+ return_value=job_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ region_name=region_name,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return BatchClientHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+
+class BatchCreateComputeEnvironmentTrigger(AwsBaseWaiterTrigger):
"""
Asynchronously poll the boto3 API and wait for the compute environment to
be ready.
- :param job_id: A unique identifier for the cluster.
- :param max_retries: The maximum number of attempts to be made.
+ :param compute_env_arn: The ARN of the compute env.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: region name to use in AWS Hook
- :param poll_interval: The amount of time in seconds to wait between
attempts.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
"""
def __init__(
self,
- compute_env_arn: str | None = None,
- poll_interval: int = 30,
- max_retries: int = 10,
+ compute_env_arn: str,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 10,
Review Comment:
this trigger was added in #32036 merged on June 27th, last provider release
was on June 20th, so this breaking change is OK.
##########
tests/providers/amazon/aws/triggers/test_athena.py:
##########
@@ -16,38 +16,20 @@
# under the License.
from __future__ import annotations
-from unittest import mock
-from unittest.mock import AsyncMock
-
-import pytest
-from botocore.exceptions import WaiterError
-
-from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
class TestAthenaTrigger:
- @pytest.mark.asyncio
- @mock.patch.object(AthenaHook, "get_waiter")
- @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI
fails without this
- async def test_run_with_error(self, conn_mock, waiter_mock):
- waiter_mock.side_effect = WaiterError("name", "reason", {})
-
- trigger = AthenaTrigger("query_id", 0, 5, None)
+ def test_serialize_recreate(self):
Review Comment:
This is probably best viewed in side-by-side diff. I removed existing tests
because there is no logic anymore in individual triggers.
Instead, I'm testing the only thing that can be broken, which is the
serialization/deserialization.
To do that, I do a cycle of serialize-deserialize-reserialize and I compare
the serialized data. It'd probably be better to compare the instances, but at
least comparing the serialized output can be done with a simple `==`
I copy-pasted the same test for all triggers inheriting from the base,
because I think it's better to have it in their respective files ? It could
also be a parametrized test with many cases in test_base_trigger to avoid the
code duplication, open to hear opinion about it.
##########
airflow/providers/amazon/aws/triggers/eks.py:
##########
@@ -115,59 +90,37 @@ def __init__(
waiter_max_attempts: int,
aws_conn_id: str,
region: str | None = None,
+ region_name: str | None = None,
):
- self.cluster_name = cluster_name
- self.fargate_profile_name = fargate_profile_name
- self.waiter_delay = waiter_delay
- self.waiter_max_attempts = waiter_max_attempts
- self.aws_conn_id = aws_conn_id
- self.region = region
-
- def serialize(self) -> tuple[str, dict[str, Any]]:
- return (
- self.__class__.__module__ + "." + self.__class__.__qualname__,
- {
- "cluster_name": self.cluster_name,
- "fargate_profile_name": self.fargate_profile_name,
- "waiter_delay": str(self.waiter_delay),
- "waiter_max_attempts": str(self.waiter_max_attempts),
- "aws_conn_id": self.aws_conn_id,
- "region": self.region,
- },
+ if region is not None:
+ warnings.warn(
+ "please use region_name param instead of region",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ region_name = region
+
+ super().__init__(
+ serialized_fields={"cluster_name": cluster_name,
"fargate_profile_name": fargate_profile_name},
+ waiter_name="fargate_profile_deleted",
+ waiter_args={"clusterName": cluster_name, "fargateProfileName":
fargate_profile_name},
+ failure_message="Failure while deleting Fargate profile",
+ status_message="Fargate profile not deleted yet",
+ status_queries=["fargateProfile.status"],
+ return_value=None,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ region_name=region_name,
)
- async def run(self):
- self.hook = EksHook(aws_conn_id=self.aws_conn_id,
region_name=self.region)
- async with self.hook.async_conn as client:
- attempt = 0
- waiter = client.get_waiter("fargate_profile_deleted")
- while attempt < int(self.waiter_max_attempts):
- attempt += 1
- try:
- await waiter.wait(
- clusterName=self.cluster_name,
- fargateProfileName=self.fargate_profile_name,
- WaiterConfig={"Delay": int(self.waiter_delay),
"MaxAttempts": 1},
- )
- break
- except WaiterError as error:
- if "terminal failure" in str(error):
- raise AirflowException(f"Delete Fargate Profile
failed: {error}")
- self.log.info(
- "Status of fargate profile is %s",
error.last_response["fargateProfile"]["status"]
- )
- await asyncio.sleep(int(self.waiter_delay))
- if attempt >= int(self.waiter_max_attempts):
- raise AirflowException(
- f"Delete Fargate Profile failed - max attempts reached:
{self.waiter_max_attempts}"
- )
- else:
- yield TriggerEvent({"status": "success", "message": "Fargate
Profile Deleted"})
+ def hook(self) -> AwsGenericHook:
+ return EksHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
-class EksNodegroupTrigger(BaseTrigger):
+class EksCreateNodegroupTrigger(AwsBaseWaiterTrigger):
Review Comment:
this trigger was added in #32165 merged on June 26th, last provider release
was on June 20th, so this breaking change is OK.
##########
airflow/providers/amazon/aws/hooks/athena.py:
##########
@@ -253,7 +253,7 @@ def poll_query_status(
try:
wait(
waiter=self.get_waiter("query_complete"),
- waiter_delay=sleep_time or self.sleep_time,
+ waiter_delay=self.sleep_time if sleep_time is None else
sleep_time,
Review Comment:
this is a somewhat unrelated fix that allows specifying a sleep time of 0 in
unit tests. Without this, athena unit tests were taking 30s each
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]