AlejandroMorgante commented on code in PR #67524:
URL: https://github.com/apache/airflow/pull/67524#discussion_r3314647660


##########
providers/amazon/docs/changelog.rst:
##########
@@ -32,6 +32,7 @@ Changelog
 Features
 ~~~~~~~~
 
+* ``Add DmsModifyTaskOperator to modify DMS replication tasks with 
stop/restart lifecycle (#67524)``

Review Comment:
   Good catch, reverted.



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py:
##########
@@ -118,6 +121,274 @@ def execute(self, context: Context):
         return task_arn
 
 
+class DmsModifyTaskOperator(AwsBaseOperator[DmsHook]):
+    """
+    Modifies an existing AWS DMS replication task.
+
+    If the task is not already stopped, set ``stop_task_before=True`` to stop 
it first.
+    To restart the task after modification, set ``restart_task_after=True``.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DmsModifyTaskOperator`
+
+    :param replication_task_arn: Replication task ARN
+    :param table_mappings: New table mappings. If not provided, existing 
mappings are kept.
+    :param migration_type: Migration type 
('full-load'|'cdc'|'full-load-and-cdc').
+        If not provided, existing type is kept.
+    :param replication_task_settings: Task settings dict. If not provided, 
existing settings are kept.
+    :param cdc_start_time: Start time for CDC.
+    :param cdc_start_position: Indicates when to start CDC (checkpoint or 
LSN/SCN format).
+        Mutually exclusive with cdc_start_time.
+    :param cdc_stop_position: Indicates when to stop CDC.
+    :param stop_task_before: If True, stop the task before modifying if it is 
not already stopped.
+    :param restart_task_after: If True, restart the task after modifying.
+    :param start_replication_task_type: Start type used when restarting the 
task.
+        One of 'start-replication', 'resume-processing', or 'reload-target'.
+        Defaults to 'resume-processing'. Only used when 
``restart_task_after=True``.
+    :param wait_for_completion: Only applies when the task is already in 
``modifying`` state
+        when ``execute()`` is called. If True, wait for the modification to 
finish before
+        proceeding. If False, raises immediately instead of waiting.
+    :param deferrable: Run the operator in deferrable mode.
+    :param waiter_delay: Seconds between waiter polls (default: 30).
+    :param waiter_max_attempts: Maximum waiter poll attempts (default: 60).
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+    """
+
+    STOPPED_STATES = ("stopped", "ready", "failed", "created")
+    TERMINAL_STATES = frozenset({"failed", "stopped", "ready", "created", 
"deleting"})
+
+    aws_hook_class = DmsHook
+    template_fields: Sequence[str] = aws_template_fields(
+        "replication_task_arn",
+        "table_mappings",
+        "migration_type",
+        "replication_task_settings",
+        "cdc_start_time",
+        "cdc_start_position",
+        "cdc_stop_position",
+        "start_replication_task_type",
+    )
+    template_fields_renderers: ClassVar[dict] = {
+        "table_mappings": "json",
+        "replication_task_settings": "json",
+    }
+
+    def __init__(
+        self,
+        *,
+        replication_task_arn: str,
+        table_mappings: dict | None = None,
+        migration_type: str | None = None,
+        replication_task_settings: dict | None = None,
+        cdc_start_time: datetime | None = None,
+        cdc_start_position: str | None = None,
+        cdc_stop_position: str | None = None,
+        stop_task_before: bool = False,
+        restart_task_after: bool = False,
+        start_replication_task_type: str = "resume-processing",
+        wait_for_completion: bool = True,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 60,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        if cdc_start_time and cdc_start_position:
+            raise ValueError("Only one of cdc_start_time or cdc_start_position 
can be provided.")
+        self.replication_task_arn = replication_task_arn
+        self.table_mappings = table_mappings
+        self.migration_type = migration_type
+        self.replication_task_settings = replication_task_settings
+        self.cdc_start_time = cdc_start_time
+        self.cdc_start_position = cdc_start_position
+        self.cdc_stop_position = cdc_stop_position
+        self.stop_task_before = stop_task_before
+        self.restart_task_after = restart_task_after
+        self.start_replication_task_type = start_replication_task_type
+        self.wait_for_completion = wait_for_completion
+        self.deferrable = deferrable
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+
+    def execute(self, context: Context) -> dict:
+        tasks = self.hook.find_replication_tasks_by_arn(
+            replication_task_arn=self.replication_task_arn, 
without_settings=True
+        )
+        if not tasks:
+            raise AirflowException(f"Replication task 
{self.replication_task_arn} not found.")
+
+        current_status = tasks[0].get("Status", "").lower()
+        self.log.info(
+            "Current status of replication task(%s) is '%s'.", 
self.replication_task_arn, current_status
+        )
+
+        if current_status == "modifying":
+            # boto3 stopped/ready waiters treat 'modifying' as a terminal 
failure — use poll loop.
+            if not self.wait_for_completion:
+                raise AirflowException(

Review Comment:
   Done — ValueError for not-found, RuntimeError for state errors and timeouts.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to