SameerMesiah97 commented on code in PR #67524:
URL: https://github.com/apache/airflow/pull/67524#discussion_r3329167919


##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/dms.py:
##########
@@ -219,3 +222,79 @@ def hook(self) -> AwsGenericHook:
             verify=self.verify,
             config=self.botocore_config,
         )
+
+
+class DmsTaskModifyCompleteTrigger(BaseTrigger):
+    """
+    Trigger that polls until a DMS classic replication task exits the 
``modifying`` state.
+
+    The boto3 ``replication_task_stopped`` waiter treats ``modifying`` as a 
terminal failure,
+    so a custom polling loop is required here.
+
+    :param replication_task_arn: The ARN of the replication task.
+    :param waiter_delay: Seconds between polls.
+    :param waiter_max_attempts: Maximum number of poll attempts before giving 
up.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param verify: Whether or not to verify SSL certificates.
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client.
+    """
+
+    def __init__(
+        self,
+        replication_task_arn: str,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 60,
+        aws_conn_id: str | None = "aws_default",
+        verify: bool | str | None = None,
+        botocore_config: dict | None = None,
+    ) -> None:
+        super().__init__()
+        self.replication_task_arn = replication_task_arn
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.aws_conn_id = aws_conn_id
+        self.verify = verify
+        self.botocore_config = botocore_config
+
+    def serialize(self) -> tuple[str, dict]:
+        return (
+            
"airflow.providers.amazon.aws.triggers.dms.DmsTaskModifyCompleteTrigger",
+            {
+                "replication_task_arn": self.replication_task_arn,
+                "waiter_delay": self.waiter_delay,
+                "waiter_max_attempts": self.waiter_max_attempts,
+                "aws_conn_id": self.aws_conn_id,
+                "verify": self.verify,
+                "botocore_config": self.botocore_config,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        hook = DmsHook(aws_conn_id=self.aws_conn_id, verify=self.verify, 
config=self.botocore_config)
+        try:
+            for _ in range(self.waiter_max_attempts):
+                status = await 
hook.get_task_status_async(self.replication_task_arn)
+                if status != "modifying":

Review Comment:
   I would use the `ENUM` for the task state instead of raw strings to be 
consistent with the operator.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to