AlejandroMorgante commented on code in PR #67524:
URL: https://github.com/apache/airflow/pull/67524#discussion_r3329271854
##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
assert op.hook.aws_conn_id == DEFAULT_CONN
+class TestDmsModifyTaskOperator:
+ TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+ TABLE_MAPPINGS = {
+ "rules": [
+ {
+ "rule-type": "selection",
+ "rule-id": "1",
+ "rule-name": "1",
+ "object-locator": {"schema-name": "myschema", "table-name":
"mytable"},
+ "rule-action": "include",
+ }
+ ]
+ }
+
+ def _stopped_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+ def _running_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+ def _modifying_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+ def test_init(self):
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ migration_type="cdc",
+ aws_conn_id="fake-conn-id",
+ region_name="us-east-1",
+ )
+ assert op.replication_task_arn == self.TASK_ARN
+ assert op.table_mappings == self.TABLE_MAPPINGS
+ assert op.migration_type == "cdc"
+ assert op.wait_for_completion is True
+ assert op.deferrable is False
+ assert op.hook.aws_conn_id == "fake-conn-id"
+ assert op.hook._region_name == "us-east-1"
+
+ def test_init_defaults(self):
+ op = DmsModifyTaskOperator(task_id="modify_task",
replication_task_arn=self.TASK_ARN)
+ assert op.table_mappings is None
+ assert op.migration_type is None
+ assert op.replication_task_settings is None
+ assert op.cdc_start_time is None
+ assert op.cdc_start_position is None
+ assert op.cdc_stop_position is None
+ assert op.wait_for_completion is True
+ assert op.deferrable is False
+
+ def test_init_raises_if_both_cdc_start_params_provided(self):
+ from datetime import datetime
+
+ with pytest.raises(ValueError, match="Only one of"):
+ DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ cdc_start_time=datetime(2024, 1, 1),
+ cdc_start_position="mysql-bin.000001:4",
+ )
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_already_stopped(self, mock_conn, mock_find):
+ mock_find.return_value = self._stopped_task()
+ expected = {"ReplicationTaskArn": self.TASK_ARN}
+ with mock.patch.object(DmsHook, "modify_replication_task",
return_value=expected) as mock_modify:
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ wait_for_completion=False,
+ )
+ result = op.execute(None)
+
+ mock_modify.assert_called_once_with(
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ migration_type=None,
+ replication_task_settings=None,
+ cdc_start_time=None,
+ cdc_start_position=None,
+ cdc_stop_position=None,
+ )
+ assert result == expected
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_raises_if_running(self, mock_conn, mock_find):
+ mock_find.return_value = self._running_task()
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ )
+ with pytest.raises(RuntimeError, match="must be stopped"):
+ op.execute(None)
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_raises_if_not_found(self, mock_conn, mock_find):
+ mock_find.return_value = []
+ op = DmsModifyTaskOperator(task_id="modify_task",
replication_task_arn=self.TASK_ARN)
+ with pytest.raises(ValueError, match="not found"):
+ op.execute(None)
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_defers_for_completion(self, mock_conn, mock_find):
+ from airflow.exceptions import TaskDeferred
Review Comment:
Removed the local import, using the `TaskDeferred` already imported at the
top of the file.
##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
assert op.hook.aws_conn_id == DEFAULT_CONN
+class TestDmsModifyTaskOperator:
+ TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+ TABLE_MAPPINGS = {
+ "rules": [
+ {
+ "rule-type": "selection",
+ "rule-id": "1",
+ "rule-name": "1",
+ "object-locator": {"schema-name": "myschema", "table-name":
"mytable"},
+ "rule-action": "include",
+ }
+ ]
+ }
+
+ def _stopped_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+ def _running_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+ def _modifying_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+ def test_init(self):
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ migration_type="cdc",
+ aws_conn_id="fake-conn-id",
+ region_name="us-east-1",
+ )
+ assert op.replication_task_arn == self.TASK_ARN
+ assert op.table_mappings == self.TABLE_MAPPINGS
+ assert op.migration_type == "cdc"
+ assert op.wait_for_completion is True
+ assert op.deferrable is False
+ assert op.hook.aws_conn_id == "fake-conn-id"
+ assert op.hook._region_name == "us-east-1"
+
+ def test_init_defaults(self):
+ op = DmsModifyTaskOperator(task_id="modify_task",
replication_task_arn=self.TASK_ARN)
+ assert op.table_mappings is None
+ assert op.migration_type is None
+ assert op.replication_task_settings is None
+ assert op.cdc_start_time is None
+ assert op.cdc_start_position is None
+ assert op.cdc_stop_position is None
+ assert op.wait_for_completion is True
+ assert op.deferrable is False
+
+ def test_init_raises_if_both_cdc_start_params_provided(self):
+ from datetime import datetime
+
+ with pytest.raises(ValueError, match="Only one of"):
+ DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ cdc_start_time=datetime(2024, 1, 1),
+ cdc_start_position="mysql-bin.000001:4",
+ )
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_already_stopped(self, mock_conn, mock_find):
+ mock_find.return_value = self._stopped_task()
+ expected = {"ReplicationTaskArn": self.TASK_ARN}
+ with mock.patch.object(DmsHook, "modify_replication_task",
return_value=expected) as mock_modify:
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ wait_for_completion=False,
+ )
+ result = op.execute(None)
+
+ mock_modify.assert_called_once_with(
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ migration_type=None,
+ replication_task_settings=None,
+ cdc_start_time=None,
+ cdc_start_position=None,
+ cdc_stop_position=None,
+ )
+ assert result == expected
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_raises_if_running(self, mock_conn, mock_find):
+ mock_find.return_value = self._running_task()
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ )
+ with pytest.raises(RuntimeError, match="must be stopped"):
+ op.execute(None)
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_raises_if_not_found(self, mock_conn, mock_find):
+ mock_find.return_value = []
+ op = DmsModifyTaskOperator(task_id="modify_task",
replication_task_arn=self.TASK_ARN)
+ with pytest.raises(ValueError, match="not found"):
+ op.execute(None)
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_defers_for_completion(self, mock_conn, mock_find):
+ from airflow.exceptions import TaskDeferred
+ from airflow.providers.amazon.aws.triggers.dms import
DmsTaskModifyCompleteTrigger
Review Comment:
Moved `DmsTaskModifyCompleteTrigger` to the top-level imports block
alongside the other trigger imports.
##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
assert op.hook.aws_conn_id == DEFAULT_CONN
+class TestDmsModifyTaskOperator:
+ TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+ TABLE_MAPPINGS = {
+ "rules": [
+ {
+ "rule-type": "selection",
+ "rule-id": "1",
+ "rule-name": "1",
+ "object-locator": {"schema-name": "myschema", "table-name":
"mytable"},
+ "rule-action": "include",
+ }
+ ]
+ }
+
+ def _stopped_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+ def _running_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+ def _modifying_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+ def test_init(self):
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ migration_type="cdc",
+ aws_conn_id="fake-conn-id",
+ region_name="us-east-1",
+ )
+ assert op.replication_task_arn == self.TASK_ARN
+ assert op.table_mappings == self.TABLE_MAPPINGS
+ assert op.migration_type == "cdc"
+ assert op.wait_for_completion is True
+ assert op.deferrable is False
+ assert op.hook.aws_conn_id == "fake-conn-id"
+ assert op.hook._region_name == "us-east-1"
+
+ def test_init_defaults(self):
+ op = DmsModifyTaskOperator(task_id="modify_task",
replication_task_arn=self.TASK_ARN)
+ assert op.table_mappings is None
+ assert op.migration_type is None
+ assert op.replication_task_settings is None
+ assert op.cdc_start_time is None
+ assert op.cdc_start_position is None
+ assert op.cdc_stop_position is None
+ assert op.wait_for_completion is True
+ assert op.deferrable is False
+
+ def test_init_raises_if_both_cdc_start_params_provided(self):
+ from datetime import datetime
+
+ with pytest.raises(ValueError, match="Only one of"):
+ DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ cdc_start_time=datetime(2024, 1, 1),
+ cdc_start_position="mysql-bin.000001:4",
+ )
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_already_stopped(self, mock_conn, mock_find):
+ mock_find.return_value = self._stopped_task()
+ expected = {"ReplicationTaskArn": self.TASK_ARN}
+ with mock.patch.object(DmsHook, "modify_replication_task",
return_value=expected) as mock_modify:
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ wait_for_completion=False,
+ )
+ result = op.execute(None)
+
+ mock_modify.assert_called_once_with(
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ migration_type=None,
+ replication_task_settings=None,
+ cdc_start_time=None,
+ cdc_start_position=None,
+ cdc_stop_position=None,
+ )
+ assert result == expected
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_raises_if_running(self, mock_conn, mock_find):
+ mock_find.return_value = self._running_task()
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ )
+ with pytest.raises(RuntimeError, match="must be stopped"):
+ op.execute(None)
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_raises_if_not_found(self, mock_conn, mock_find):
+ mock_find.return_value = []
+ op = DmsModifyTaskOperator(task_id="modify_task",
replication_task_arn=self.TASK_ARN)
+ with pytest.raises(ValueError, match="not found"):
+ op.execute(None)
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_defers_for_completion(self, mock_conn, mock_find):
+ from airflow.exceptions import TaskDeferred
+ from airflow.providers.amazon.aws.triggers.dms import
DmsTaskModifyCompleteTrigger
+
+ mock_find.return_value = self._stopped_task()
+ expected = {"ReplicationTaskArn": self.TASK_ARN}
+ with mock.patch.object(DmsHook, "modify_replication_task",
return_value=expected):
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ deferrable=True,
+ wait_for_completion=True,
+ )
+ with pytest.raises(TaskDeferred) as exc_info:
+ op.execute(None)
+
+ assert isinstance(exc_info.value.trigger, DmsTaskModifyCompleteTrigger)
+ assert exc_info.value.method_name == "execute_complete"
+
+ @mock.patch.object(DmsHook, "find_replication_tasks_by_arn")
+ @mock.patch.object(DmsHook, "get_conn")
+ def test_modify_task_with_all_params(self, mock_conn, mock_find):
+ from datetime import datetime
Review Comment:
Moved `datetime` to the top of the file.
##########
providers/amazon/tests/unit/amazon/aws/operators/test_dms.py:
##########
@@ -180,6 +181,214 @@ def test_default_conn_passed_to_hook(self):
assert op.hook.aws_conn_id == DEFAULT_CONN
+class TestDmsModifyTaskOperator:
+ TASK_ARN = "arn:aws:dms:us-east-1:123456789012:task:EXAMPLE"
+ TABLE_MAPPINGS = {
+ "rules": [
+ {
+ "rule-type": "selection",
+ "rule-id": "1",
+ "rule-name": "1",
+ "object-locator": {"schema-name": "myschema", "table-name":
"mytable"},
+ "rule-action": "include",
+ }
+ ]
+ }
+
+ def _stopped_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "stopped"}]
+
+ def _running_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "running"}]
+
+ def _modifying_task(self):
+ return [{"ReplicationTaskArn": self.TASK_ARN, "Status": "modifying"}]
+
+ def test_init(self):
+ op = DmsModifyTaskOperator(
+ task_id="modify_task",
+ replication_task_arn=self.TASK_ARN,
+ table_mappings=self.TABLE_MAPPINGS,
+ migration_type="cdc",
+ aws_conn_id="fake-conn-id",
+ region_name="us-east-1",
+ )
+ assert op.replication_task_arn == self.TASK_ARN
+ assert op.table_mappings == self.TABLE_MAPPINGS
+ assert op.migration_type == "cdc"
+ assert op.wait_for_completion is True
+ assert op.deferrable is False
+ assert op.hook.aws_conn_id == "fake-conn-id"
+ assert op.hook._region_name == "us-east-1"
+
+ def test_init_defaults(self):
+ op = DmsModifyTaskOperator(task_id="modify_task",
replication_task_arn=self.TASK_ARN)
+ assert op.table_mappings is None
+ assert op.migration_type is None
+ assert op.replication_task_settings is None
+ assert op.cdc_start_time is None
+ assert op.cdc_start_position is None
+ assert op.cdc_stop_position is None
+ assert op.wait_for_completion is True
+ assert op.deferrable is False
+
+ def test_init_raises_if_both_cdc_start_params_provided(self):
+ from datetime import datetime
Review Comment:
Moved `datetime` to the top of the file.
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py:
##########
@@ -118,6 +120,178 @@ def execute(self, context: Context):
return task_arn
+class DmsModifyTaskOperator(AwsBaseOperator[DmsHook]):
+ """
+ Modifies an existing AWS DMS replication task.
+
+ The task must already be stopped before modification. Use
:class:`DmsStopTaskOperator`
+ upstream in the Dag to stop it, and :class:`DmsStartTaskOperator`
downstream to restart
+ it afterwards if needed.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:DmsModifyTaskOperator`
+
+ :param replication_task_arn: Replication task ARN
+ :param table_mappings: New table mappings. If not provided, existing
mappings are kept.
+ :param migration_type: Migration type
('full-load'|'cdc'|'full-load-and-cdc').
+ If not provided, existing type is kept.
+ :param replication_task_settings: Task settings dict. If not provided,
existing settings are kept.
+ :param cdc_start_time: Start time for CDC.
+ :param cdc_start_position: Indicates when to start CDC (checkpoint or
LSN/SCN format).
+ Mutually exclusive with cdc_start_time.
+ :param cdc_stop_position: Indicates when to stop CDC.
+ :param wait_for_completion: If True, wait for the modification to finish
before returning.
+ In deferrable mode the operator defers rather than blocking. Defaults
to True.
+ :param deferrable: Run the operator in deferrable mode. Defaults to False.
+ :param waiter_delay: Seconds between waiter polls (default: 30).
+ :param waiter_max_attempts: Maximum waiter poll attempts (default: 60).
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ MODIFIABLE_STATES = frozenset({DmsTaskState.STOPPED, DmsTaskState.READY,
DmsTaskState.FAILED})
+
+ aws_hook_class = DmsHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "replication_task_arn",
+ "table_mappings",
+ "migration_type",
+ "replication_task_settings",
+ "cdc_start_time",
+ "cdc_start_position",
+ "cdc_stop_position",
+ )
+ template_fields_renderers: ClassVar[dict] = {
+ "table_mappings": "json",
+ "replication_task_settings": "json",
+ }
+
+ def __init__(
+ self,
+ *,
+ replication_task_arn: str,
+ table_mappings: dict | None = None,
+ migration_type: str | None = None,
+ replication_task_settings: dict | None = None,
+ cdc_start_time: datetime | None = None,
+ cdc_start_position: str | None = None,
+ cdc_stop_position: str | None = None,
+ wait_for_completion: bool = True,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if cdc_start_time and cdc_start_position:
+ raise ValueError("Only one of cdc_start_time or cdc_start_position
can be provided.")
+ self.replication_task_arn = replication_task_arn
+ self.table_mappings = table_mappings
+ self.migration_type = migration_type
+ self.replication_task_settings = replication_task_settings
+ self.cdc_start_time = cdc_start_time
+ self.cdc_start_position = cdc_start_position
+ self.cdc_stop_position = cdc_stop_position
+ self.wait_for_completion = wait_for_completion
+ self.deferrable = deferrable
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+
+ def execute(self, context: Context) -> dict:
+ tasks = self.hook.find_replication_tasks_by_arn(
+ replication_task_arn=self.replication_task_arn,
without_settings=True
+ )
+ if not tasks:
+ raise ValueError(f"Replication task {self.replication_task_arn}
not found.")
+
+ current_status = tasks[0].get("Status", "").lower()
+ self.log.info(
+ "Current status of replication task(%s) is '%s'.",
self.replication_task_arn, current_status
+ )
+
+ if current_status == DmsTaskState.MODIFYING:
+ # boto3 stopped/ready waiters treat 'modifying' as a terminal
failure — use poll loop.
+ self._wait_until_not_modifying()
+ elif current_status not in self.MODIFIABLE_STATES:
+ raise RuntimeError(
+ f"Replication task {self.replication_task_arn} is in state
'{current_status}' "
+ f"and must be stopped before modification. "
+ f"Use DmsStopTaskOperator to stop it first."
+ )
+
+ result = self.hook.modify_replication_task(
+ replication_task_arn=self.replication_task_arn,
+ table_mappings=self.table_mappings,
+ migration_type=self.migration_type,
+ replication_task_settings=self.replication_task_settings,
+ cdc_start_time=self.cdc_start_time,
+ cdc_start_position=self.cdc_start_position,
+ cdc_stop_position=self.cdc_stop_position,
+ )
+ self.log.info("DMS replication task(%s) has been modified.",
self.replication_task_arn)
+
+ if self.wait_for_completion:
+ if self.deferrable:
+ self.defer(
+ trigger=DmsTaskModifyCompleteTrigger(
+ replication_task_arn=self.replication_task_arn,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
Review Comment:
Added the `else` so `_wait_until_not_modifying()` only runs on the
non-deferrable path.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]