SameerMesiah97 commented on code in PR #67524:
URL: https://github.com/apache/airflow/pull/67524#discussion_r3329216476


##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
 
+class TestDmsModifyTaskOperator:
+    TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+    TABLE_MAPPINGS = {
+        "rules": [
+            {
+                "rule-type": "selection",
+                "rule-id": "1",
+                "rule-name": "1",
+                "object-locator": {"schema-name": "myschema", "table-name": 
"mytable"},
+                "rule-action": "include",
+            }
+        ]
+    }
+
+    def _stopped_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+    def _running_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+    def _modifying_task(self):
+        return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+    def test_init(self):
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+            table_mappings=self.TABLE_MAPPINGS,
+            migration_type="cdc",
+            aws_conn_id="fake-conn-id",
+            region_name="us-east-1",
+        )
+        assert op.replication_task_arn == self.TASK_ARN
+        assert op.table_mappings == self.TABLE_MAPPINGS
+        assert op.migration_type == "cdc"
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-east-1"
+
+    def test_init_defaults(self):
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        assert op.table_mappings is None
+        assert op.migration_type is None
+        assert op.replication_task_settings is None
+        assert op.cdc_start_time is None
+        assert op.cdc_start_position is None
+        assert op.cdc_stop_position is None
+        assert op.wait_for_completion is True
+        assert op.deferrable is False
+
+    def test_init_raises_if_both_cdc_start_params_provided(self):
+        from datetime import datetime
+
+        with pytest.raises(ValueError, match="Only one of"):
+            DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                cdc_start_time=datetime(2024, 1, 1),
+                cdc_start_position="mysql-bin.000001:4",
+            )
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_already_stopped(self, mock_conn, mock_find):
+        mock_find.return_value = self._stopped_task()
+        expected = {"ReplicationTaskArn": self.TASK_ARN}
+        with mock.patch.object(DmsHook, "modify_replication_task", 
return_value=expected) as mock_modify:
+            op = DmsModifyTaskOperator(
+                task_id="modify_task",
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                wait_for_completion=False,
+            )
+            result = op.execute(None)
+
+            mock_modify.assert_called_once_with(
+                replication_task_arn=self.TASK_ARN,
+                table_mappings=self.TABLE_MAPPINGS,
+                migration_type=None,
+                replication_task_settings=None,
+                cdc_start_time=None,
+                cdc_start_position=None,
+                cdc_stop_position=None,
+            )
+            assert result == expected
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_running(self, mock_conn, mock_find):
+        mock_find.return_value = self._running_task()
+        op = DmsModifyTaskOperator(
+            task_id="modify_task",
+            replication_task_arn=self.TASK_ARN,
+        )
+        with pytest.raises(RuntimeError, match="must be stopped"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_raises_if_not_found(self, mock_conn, mock_find):
+        mock_find.return_value = []
+        op = DmsModifyTaskOperator(task_id="modify_task", 
replication_task_arn=self.TASK_ARN)
+        with pytest.raises(ValueError, match="not found"):
+            op.execute(None)
+
+    @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+    @mock.patch.object(DmsHook, "get_conn")
+    def test_modify_task_defers_for_completion(self, mock_conn, mock_find):
+        from airflow.exceptions import TaskDeferred

Review Comment:
   Why re-import TaskDeferred here? You already have this at the top of the 
file:
   
   `from airflow.providers.common.compat.sdk import AirflowException, 
TaskDeferred`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to