SameerMesiah97 commented on code in PR #67524:
URL: https://github.com/apache/airflow/pull/67524#discussion_r3329167919
##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/dms.py:
##########
@@ -219,3 +222,79 @@ def hook(self) -> AwsGenericHook:
verify=self.verify,
config=self.botocore_config,
)
+
+
+class DmsTaskModifyCompleteTrigger(BaseTrigger):
+ """
+ Trigger that polls until a DMS classic replication task exits the
``modifying`` state.
+
+ The boto3 ``replication_task_stopped`` waiter treats ``modifying`` as a
terminal failure,
+ so a custom polling loop is required here.
+
+ :param replication_task_arn: The ARN of the replication task.
+ :param waiter_delay: Seconds between polls.
+ :param waiter_max_attempts: Maximum number of poll attempts before giving
up.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param verify: Whether or not to verify SSL certificates.
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client.
+ """
+
+ def __init__(
+ self,
+ replication_task_arn: str,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
+ aws_conn_id: str | None = "aws_default",
+ verify: bool | str | None = None,
+ botocore_config: dict | None = None,
+ ) -> None:
+ super().__init__()
+ self.replication_task_arn = replication_task_arn
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.aws_conn_id = aws_conn_id
+ self.verify = verify
+ self.botocore_config = botocore_config
+
+ def serialize(self) -> tuple[str, dict]:
+ return (
+
"airflow.providers.amazon.aws.triggers.dms.DmsTaskModifyCompleteTrigger",
+ {
+ "replication_task_arn": self.replication_task_arn,
+ "waiter_delay": self.waiter_delay,
+ "waiter_max_attempts": self.waiter_max_attempts,
+ "aws_conn_id": self.aws_conn_id,
+ "verify": self.verify,
+ "botocore_config": self.botocore_config,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ hook = DmsHook(aws_conn_id=self.aws_conn_id, verify=self.verify,
config=self.botocore_config)
+ try:
+ for _ in range(self.waiter_max_attempts):
+ status = await
hook.get_task_status_async(self.replication_task_arn)
+ if status != "modifying":
Review Comment:
I would use the `ENUM` for the task state instead of raw strings to be
consistent with the operator. If you follow my suggestion in the other comment,
this line will become:
`if status != DmsTaskWaiterStatus.MODIFYING`
Also, besides this there is another more significant issue. It appears you
are collapsing all non-'MODIFYING' statuses and yielding a success event. I
think we should define `SUCCESS` and `FAILURE` states and handle TriggerEvent
emission explicitly for both. Unknown/unexpected states and max waiter attempts
can yield their own error events. And errors unrelated to the state machine can
yield error events with the error message (you do this already).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]