AlejandroMorgante commented on code in PR #67524:
URL: https://github.com/apache/airflow/pull/67524#discussion_r3329271826


##########
providers/amazon/tests/system/amazon/aws/example_dms.py:
##########
@@ -374,6 +375,24 @@ def delete_security_group(security_group_id: str, 
security_group_name: str):
     )
     # [END howto_operator_dms_stop_task]
 
+    # [START howto_operator_dms_modify_task]
+    modify_task = DmsModifyTaskOperator(
+        task_id="modify_task",
+        replication_task_arn=task_arn,
+        table_mappings={
+            "rules": [
+                {
+                    "rule-type": "selection",
+                    "rule-id": "1",
+                    "rule-name": "1",
+                    "object-locator": {"schema-name": "%", "table-name": "%"},
+                    "rule-action": "include",
+                }
+            ]
+        },
+    )
+    # [END howto_operator_dms_modify_task]

Review Comment:
   Done, reordered to `stop_task -> await_task_stop -> modify_task`.



##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
 
+class TestDmsModifyTaskOperator:
+    TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+    TABLE_MAPPINGS = {
+        "rules": [
+            {
+                "rule-type": "selection",
+                "rule-id": "1",
+                "rule-name": "1",
+                "object-locator": {"schema-name": "myschema", "table-name": 
"mytable"},
+                "rule-action": "include",
+            }
+        ]
+    }
+
+    def _stopped_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+    def _running_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+    def _modifying_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+    def test_init(self):
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+            table_mappings=self.TABLE_MAPPINGS,
+            migration_type="cdc",
+            aws_conn_id="fake-conn-id",
+            region_name="us-east-1",
+        )
+        assert op.replication_task_arn == self.TASK_ARN
+        assert op.table_mappings == self.TABLE_MAPPINGS
+        assert op.migration_type == "cdc"
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-east-1"
+
+    def test_init_defaults(self):
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        assert op.table_mappings is None
+        assert op.migration_type is None
+        assert op.replication_task_settings is None
+        assert op.cdc_start_time is None
+        assert op.cdc_start_position is None
+        assert op.cdc_stop_position is None
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+

Review Comment:
   Removed both.



##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
 
+class TestDmsModifyTaskOperator:
+    TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+    TABLE_MAPPINGS = {
+        "rules": [
+            {
+                "rule-type": "selection",
+                "rule-id": "1",
+                "rule-name": "1",
+                "object-locator": {"schema-name": "myschema", "table-name": 
"mytable"},
+                "rule-action": "include",
+            }
+        ]
+    }
+
+    def _stopped_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+    def _running_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+    def _modifying_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+    def test_init(self):
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+            table_mappings=self.TABLE_MAPPINGS,
+            migration_type="cdc",
+            aws_conn_id="fake-conn-id",
+            region_name="us-east-1",
+        )
+        assert op.replication_task_arn == self.TASK_ARN
+        assert op.table_mappings == self.TABLE_MAPPINGS
+        assert op.migration_type == "cdc"
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-east-1"
+
+    def test_init_defaults(self):
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        assert op.table_mappings is None
+        assert op.migration_type is None
+        assert op.replication_task_settings is None
+        assert op.cdc_start_time is None
+        assert op.cdc_start_position is None
+        assert op.cdc_stop_position is None
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+
+    def test_init_raises_if_both_cdc_start_params_provided(self):
+        from datetime import datetime
+
+        with pytest.raises(ValueError, match="Only one of"):
+            DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                cdc_start_time=datetime(2024, 1, 1),
+                cdc_start_position="mysql-bin.000001:4",
+            )
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_already_stopped(self, mock_conn, mock_find):
+        mock_find.return_value = self._stopped_task()
+        expected = {"ReplicationTaskArn": self.TASK_ARN}
+        with mock.patch.object(DmsHook, "modify_replication_task", 
return_value=expected) as mock_modify:
+            op = DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                wait_for_completion=False,
+            )
+            result = op.execute(None)
+
+            mock_modify.assert_called_once_with(
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                migration_type=None,
+                replication_task_settings=None,
+                cdc_start_time=None,
+                cdc_start_position=None,
+                cdc_stop_position=None,
+            )
+            assert result == expected
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_running(self, mock_conn, mock_find):
+        mock_find.return_value = self._running_task()
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+        )
+        with pytest.raises(RuntimeError, match="must be stopped"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_not_found(self, mock_conn, mock_find):
+        mock_find.return_value = []
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        with pytest.raises(ValueError, match="not found"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_defers_for_completion(self, mock_conn, mock_find):
+        from airflow.exceptions import TaskDeferred

Review Comment:
   Done.



##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
 
+class TestDmsModifyTaskOperator:
+    TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+    TABLE_MAPPINGS = {
+        "rules": [
+            {
+                "rule-type": "selection",
+                "rule-id": "1",
+                "rule-name": "1",
+                "object-locator": {"schema-name": "myschema", "table-name": 
"mytable"},
+                "rule-action": "include",
+            }
+        ]
+    }
+
+    def _stopped_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+    def _running_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+    def _modifying_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+    def test_init(self):
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+            table_mappings=self.TABLE_MAPPINGS,
+            migration_type="cdc",
+            aws_conn_id="fake-conn-id",
+            region_name="us-east-1",
+        )
+        assert op.replication_task_arn == self.TASK_ARN
+        assert op.table_mappings == self.TABLE_MAPPINGS
+        assert op.migration_type == "cdc"
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-east-1"
+
+    def test_init_defaults(self):
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        assert op.table_mappings is None
+        assert op.migration_type is None
+        assert op.replication_task_settings is None
+        assert op.cdc_start_time is None
+        assert op.cdc_start_position is None
+        assert op.cdc_stop_position is None
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+
+    def test_init_raises_if_both_cdc_start_params_provided(self):
+        from datetime import datetime
+
+        with pytest.raises(ValueError, match="Only one of"):
+            DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                cdc_start_time=datetime(2024, 1, 1),
+                cdc_start_position="mysql-bin.000001:4",
+            )
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_already_stopped(self, mock_conn, mock_find):
+        mock_find.return_value = self._stopped_task()
+        expected = {"ReplicationTaskArn": self.TASK_ARN}
+        with mock.patch.object(DmsHook, "modify_replication_task", 
return_value=expected) as mock_modify:
+            op = DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                wait_for_completion=False,
+            )
+            result = op.execute(None)
+
+            mock_modify.assert_called_once_with(
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                migration_type=None,
+                replication_task_settings=None,
+                cdc_start_time=None,
+                cdc_start_position=None,
+                cdc_stop_position=None,
+            )
+            assert result == expected
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_running(self, mock_conn, mock_find):
+        mock_find.return_value = self._running_task()
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+        )
+        with pytest.raises(RuntimeError, match="must be stopped"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_not_found(self, mock_conn, mock_find):
+        mock_find.return_value = []
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        with pytest.raises(ValueError, match="not found"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_defers_for_completion(self, mock_conn, mock_find):
+        from airflow.exceptions import TaskDeferred
+        from airflow.providers.amazon.aws.triggers.dms import 
DmsTaskModifyCompleteTrigger

Review Comment:
   Done.



##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
 
+class TestDmsModifyTaskOperator:
+    TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+    TABLE_MAPPINGS = {
+        "rules": [
+            {
+                "rule-type": "selection",
+                "rule-id": "1",
+                "rule-name": "1",
+                "object-locator": {"schema-name": "myschema", "table-name": 
"mytable"},
+                "rule-action": "include",
+            }
+        ]
+    }
+
+    def _stopped_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+    def _running_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+    def _modifying_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+    def test_init(self):
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+            table_mappings=self.TABLE_MAPPINGS,
+            migration_type="cdc",
+            aws_conn_id="fake-conn-id",
+            region_name="us-east-1",
+        )
+        assert op.replication_task_arn == self.TASK_ARN
+        assert op.table_mappings == self.TABLE_MAPPINGS
+        assert op.migration_type == "cdc"
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-east-1"
+
+    def test_init_defaults(self):
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        assert op.table_mappings is None
+        assert op.migration_type is None
+        assert op.replication_task_settings is None
+        assert op.cdc_start_time is None
+        assert op.cdc_start_position is None
+        assert op.cdc_stop_position is None
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+
+    def test_init_raises_if_both_cdc_start_params_provided(self):
+        from datetime import datetime
+
+        with pytest.raises(ValueError, match="Only one of"):
+            DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                cdc_start_time=datetime(2024, 1, 1),
+                cdc_start_position="mysql-bin.000001:4",
+            )
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_already_stopped(self, mock_conn, mock_find):
+        mock_find.return_value = self._stopped_task()
+        expected = {"ReplicationTaskArn": self.TASK_ARN}
+        with mock.patch.object(DmsHook, "modify_replication_task", 
return_value=expected) as mock_modify:
+            op = DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                wait_for_completion=False,
+            )
+            result = op.execute(None)
+
+            mock_modify.assert_called_once_with(
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                migration_type=None,
+                replication_task_settings=None,
+                cdc_start_time=None,
+                cdc_start_position=None,
+                cdc_stop_position=None,
+            )
+            assert result == expected
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_running(self, mock_conn, mock_find):
+        mock_find.return_value = self._running_task()
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+        )
+        with pytest.raises(RuntimeError, match="must be stopped"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_not_found(self, mock_conn, mock_find):
+        mock_find.return_value = []
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        with pytest.raises(ValueError, match="not found"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_defers_for_completion(self, mock_conn, mock_find):
+        from airflow.exceptions import TaskDeferred
+        from airflow.providers.amazon.aws.triggers.dms import 
DmsTaskModifyCompleteTrigger
+
+        mock_find.return_value = self._stopped_task()
+        expected = {"ReplicationTaskArn": self.TASK_ARN}
+        with mock.patch.object(DmsHook, "modify_replication_task", 
return_value=expected):
+            op = DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                deferrable=True,
+                wait_for_completion=True,
+            )
+            with pytest.raises(TaskDeferred) as exc_info:
+                op.execute(None)
+
+        assert isinstance(exc_info.value.trigger, DmsTaskModifyCompleteTrigger)
+        assert exc_info.value.method_name == "execute_complete"
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_with_all_params(self, mock_conn, mock_find):
+        from datetime import datetime

Review Comment:
   Done.



##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
 
+class TestDmsModifyTaskOperator:
+    TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+    TABLE_MAPPINGS = {
+        "rules": [
+            {
+                "rule-type": "selection",
+                "rule-id": "1",
+                "rule-name": "1",
+                "object-locator": {"schema-name": "myschema", "table-name": 
"mytable"},
+                "rule-action": "include",
+            }
+        ]
+    }
+
+    def _stopped_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+    def _running_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+    def _modifying_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+    def test_init(self):
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+            table_mappings=self.TABLE_MAPPINGS,
+            migration_type="cdc",
+            aws_conn_id="fake-conn-id",
+            region_name="us-east-1",
+        )
+        assert op.replication_task_arn == self.TASK_ARN
+        assert op.table_mappings == self.TABLE_MAPPINGS
+        assert op.migration_type == "cdc"
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-east-1"
+
+    def test_init_defaults(self):
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        assert op.table_mappings is None
+        assert op.migration_type is None
+        assert op.replication_task_settings is None
+        assert op.cdc_start_time is None
+        assert op.cdc_start_position is None
+        assert op.cdc_stop_position is None
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+
+    def test_init_raises_if_both_cdc_start_params_provided(self):
+        from datetime import datetime

Review Comment:
   Done.



##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
 
+class TestDmsModifyTaskOperator:
+    TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+    TABLE_MAPPINGS = {
+        "rules": [
+            {
+                "rule-type": "selection",
+                "rule-id": "1",
+                "rule-name": "1",
+                "object-locator": {"schema-name": "myschema", "table-name": 
"mytable"},
+                "rule-action": "include",
+            }
+        ]
+    }
+
+    def _stopped_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+    def _running_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+    def _modifying_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+    def test_init(self):
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+            table_mappings=self.TABLE_MAPPINGS,
+            migration_type="cdc",
+            aws_conn_id="fake-conn-id",
+            region_name="us-east-1",
+        )
+        assert op.replication_task_arn == self.TASK_ARN
+        assert op.table_mappings == self.TABLE_MAPPINGS
+        assert op.migration_type == "cdc"
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-east-1"
+
+    def test_init_defaults(self):
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        assert op.table_mappings is None
+        assert op.migration_type is None
+        assert op.replication_task_settings is None
+        assert op.cdc_start_time is None
+        assert op.cdc_start_position is None
+        assert op.cdc_stop_position is None
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+
+    def test_init_raises_if_both_cdc_start_params_provided(self):
+        from datetime import datetime
+
+        with pytest.raises(ValueError, match="Only one of"):
+            DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                cdc_start_time=datetime(2024, 1, 1),
+                cdc_start_position="mysql-bin.000001:4",
+            )
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_already_stopped(self, mock_conn, mock_find):
+        mock_find.return_value = self._stopped_task()
+        expected = {"ReplicationTaskArn": self.TASK_ARN}
+        with mock.patch.object(DmsHook, "modify_replication_task", 
return_value=expected) as mock_modify:
+            op = DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                wait_for_completion=False,
+            )
+            result = op.execute(None)
+
+            mock_modify.assert_called_once_with(
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                migration_type=None,
+                replication_task_settings=None,
+                cdc_start_time=None,
+                cdc_start_position=None,
+                cdc_stop_position=None,
+            )
+            assert result == expected
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_running(self, mock_conn, mock_find):
+        mock_find.return_value = self._running_task()
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+        )
+        with pytest.raises(RuntimeError, match="must be stopped"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_not_found(self, mock_conn, mock_find):
+        mock_find.return_value = []
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        with pytest.raises(ValueError, match="not found"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_defers_for_completion(self, mock_conn, mock_find):
+        from airflow.exceptions import TaskDeferred
+        from airflow.providers.amazon.aws.triggers.dms import 
DmsTaskModifyCompleteTrigger
+
+        mock_find.return_value = self._stopped_task()
+        expected = {"ReplicationTaskArn": self.TASK_ARN}
+        with mock.patch.object(DmsHook, "modify_replication_task", 
return_value=expected):
+            op = DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                deferrable=True,
+                wait_for_completion=True,
+            )
+            with pytest.raises(TaskDeferred) as exc_info:
+                op.execute(None)
+
+        assert isinstance(exc_info.value.trigger, DmsTaskModifyCompleteTrigger)
+        assert exc_info.value.method_name == "execute_complete"
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_with_all_params(self, mock_conn, mock_find):
+        from datetime import datetime
+
+        mock_find.return_value = self._stopped_task()
+        cdc_start = datetime(2024, 1, 1)
+        expected = {"ReplicationTaskArn": self.TASK_ARN}
+        with mock.patch.object(DmsHook, "modify_replication_task", 
return_value=expected) as mock_modify:
+            op = DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                migration_type="full-load-and-cdc",
+                replication_task_settings={"TargetMetadata": {}},
+                cdc_start_time=cdc_start,
+                cdc_stop_position="2024-01-31:00:00:00",
+            )
+            op.execute(None)
+
+            mock_modify.assert_called_once_with(
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                migration_type="full-load-and-cdc",
+                replication_task_settings={"TargetMetadata": {}},
+                cdc_start_time=cdc_start,
+                cdc_start_position=None,
+                cdc_stop_position="2024-01-31:00:00:00",
+            )
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_waits_if_modifying(self, mock_conn, mock_find):
+        stopped_task = self._stopped_task()
+        expected = {"ReplicationTaskArn": self.TASK_ARN}
+        # execute() sees 'modifying' → _wait_until_not_modifying polls once → 
'stopped';
+        # modify is called; wait_for_completion=True → 
_wait_until_not_modifying polls once → 'stopped'.
+        mock_find.side_effect = [
+            self._modifying_task(),
+            stopped_task,
+            stopped_task,
+        ]
+        with mock.patch.object(DmsHook, "modify_replication_task", 
return_value=expected) as mock_modify:
+            op = DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                wait_for_completion=True,
+                waiter_delay=0,
+            )
+            op.execute(None)
+
+            mock_modify.assert_called_once()
+
+    def test_execute_complete_success(self):

Review Comment:
   Added `TestDmsTaskModifyCompleteTrigger` in 
`tests/unit/amazon/aws/triggers/test_dms.py` covering serialization, success, 
task not found, timeout, unexpected state, and exception.



##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
 
+class TestDmsModifyTaskOperator:
+    TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+    TABLE_MAPPINGS = {
+        "rules": [
+            {
+                "rule-type": "selection",
+                "rule-id": "1",
+                "rule-name": "1",
+                "object-locator": {"schema-name": "myschema", "table-name": 
"mytable"},
+                "rule-action": "include",
+            }
+        ]
+    }
+
+    def _stopped_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+    def _running_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+    def _modifying_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+    def test_init(self):
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+            table_mappings=self.TABLE_MAPPINGS,
+            migration_type="cdc",
+            aws_conn_id="fake-conn-id",
+            region_name="us-east-1",
+        )
+        assert op.replication_task_arn == self.TASK_ARN
+        assert op.table_mappings == self.TABLE_MAPPINGS
+        assert op.migration_type == "cdc"
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-east-1"
+
+    def test_init_defaults(self):
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        assert op.table_mappings is None
+        assert op.migration_type is None
+        assert op.replication_task_settings is None
+        assert op.cdc_start_time is None
+        assert op.cdc_start_position is None
+        assert op.cdc_stop_position is None
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+
+    def test_init_raises_if_both_cdc_start_params_provided(self):
+        from datetime import datetime
+
+        with pytest.raises(ValueError, match="Only one of"):
+            DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                cdc_start_time=datetime(2024, 1, 1),
+                cdc_start_position="mysql-bin.000001:4",
+            )
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_already_stopped(self, mock_conn, mock_find):

Review Comment:
   Done, generalized to `test_modify_task_modifiable_states` parametrized over 
stopped, ready, and failed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to