SameerMesiah97 commented on code in PR #67524:
URL: https://github.com/apache/airflow/pull/67524#discussion_r3329198947


##########
providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py:
##########
@@ -118,6 +120,178 @@ def execute(self, context: Context):
         return task_arn
 
 
+class DmsModifyTaskOperator(AwsBaseOperator[DmsHook]):
+    """
+    Modifies an existing AWS DMS replication task.
+
+    The task must already be stopped before modification. Use 
:class:`DmsStopTaskOperator`
+    upstream in the Dag to stop it, and :class:`DmsStartTaskOperator` 
downstream to restart
+    it afterwards if needed.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DmsModifyTaskOperator`
+
+    :param replication_task_arn: Replication task ARN
+    :param table_mappings: New table mappings. If not provided, existing 
mappings are kept.
+    :param migration_type: Migration type 
('full-load'|'cdc'|'full-load-and-cdc').
+        If not provided, existing type is kept.
+    :param replication_task_settings: Task settings dict. If not provided, 
existing settings are kept.
+    :param cdc_start_time: Start time for CDC.
+    :param cdc_start_position: Indicates when to start CDC (checkpoint or 
LSN/SCN format).
+        Mutually exclusive with cdc_start_time.
+    :param cdc_stop_position: Indicates when to stop CDC.
+    :param wait_for_completion: If True, wait for the modification to finish 
before returning.
+        In deferrable mode the operator defers rather than blocking. Defaults 
to True.
+    :param deferrable: Run the operator in deferrable mode. Defaults to False.
+    :param waiter_delay: Seconds between waiter polls (default: 30).
+    :param waiter_max_attempts: Maximum waiter poll attempts (default: 60).
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+    """
+
+    MODIFIABLE_STATES = frozenset({DmsTaskState.STOPPED, DmsTaskState.READY, 
DmsTaskState.FAILED})
+
+    aws_hook_class = DmsHook
+    template_fields: Sequence[str] = aws_template_fields(
+        "replication_task_arn",
+        "table_mappings",
+        "migration_type",
+        "replication_task_settings",
+        "cdc_start_time",
+        "cdc_start_position",
+        "cdc_stop_position",
+    )
+    template_fields_renderers: ClassVar[dict] = {
+        "table_mappings": "json",
+        "replication_task_settings": "json",
+    }
+
+    def __init__(
+        self,
+        *,
+        replication_task_arn: str,
+        table_mappings: dict | None = None,
+        migration_type: str | None = None,
+        replication_task_settings: dict | None = None,
+        cdc_start_time: datetime | None = None,
+        cdc_start_position: str | None = None,
+        cdc_stop_position: str | None = None,
+        wait_for_completion: bool = True,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 60,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        if cdc_start_time and cdc_start_position:
+            raise ValueError("Only one of cdc_start_time or cdc_start_position 
can be provided.")
+        self.replication_task_arn = replication_task_arn
+        self.table_mappings = table_mappings
+        self.migration_type = migration_type
+        self.replication_task_settings = replication_task_settings
+        self.cdc_start_time = cdc_start_time
+        self.cdc_start_position = cdc_start_position
+        self.cdc_stop_position = cdc_stop_position
+        self.wait_for_completion = wait_for_completion
+        self.deferrable = deferrable
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+
+    def execute(self, context: Context) -> dict:
+        tasks = self.hook.find_replication_tasks_by_arn(
+            replication_task_arn=self.replication_task_arn, 
without_settings=True
+        )
+        if not tasks:
+            raise ValueError(f"Replication task {self.replication_task_arn} 
not found.")
+
+        current_status = tasks[0].get("Status", "").lower()
+        self.log.info(
+            "Current status of replication task(%s) is '%s'.", 
self.replication_task_arn, current_status
+        )
+
+        if current_status == DmsTaskState.MODIFYING:
+            # boto3 stopped/ready waiters treat 'modifying' as a terminal 
failure — use poll loop.
+            self._wait_until_not_modifying()
+        elif current_status not in self.MODIFIABLE_STATES:
+            raise RuntimeError(
+                f"Replication task {self.replication_task_arn} is in state 
'{current_status}' "
+                f"and must be stopped before modification. "
+                f"Use DmsStopTaskOperator to stop it first."

Review Comment:
   Is this error message accurate? `MODIFIABLE_STATES` includes `READY` and 
`FAILED` in addition to `STOPPED`. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to