This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 869f3a93a8 Remove ability to specify arbitrary hook params in AWS RDS
trigger (#32386)
869f3a93a8 is described below
commit 869f3a93a8873381a57382f8a0ab88879ca43f9a
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Fri Jul 7 11:53:54 2023 -0700
Remove ability to specify arbitrary hook params in AWS RDS trigger (#32386)
---
airflow/providers/amazon/aws/operators/rds.py | 23 ++++++++++++++++-------
airflow/providers/amazon/aws/triggers/rds.py | 8 ++++----
tests/providers/amazon/aws/triggers/test_rds.py | 13 +++++++------
3 files changed, 27 insertions(+), 17 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/rds.py
b/airflow/providers/amazon/aws/operators/rds.py
index c58961db2e..2630098bc4 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -43,18 +43,27 @@ class RdsBaseOperator(BaseOperator):
ui_color = "#eeaa88"
ui_fgcolor = "#ffffff"
- def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params:
dict | None = None, **kwargs):
+ def __init__(
+ self,
+ *args,
+ aws_conn_id: str = "aws_conn_id",
+ region_name: str | None = None,
+ hook_params: dict | None = None,
+ **kwargs,
+ ):
if hook_params is not None:
warnings.warn(
"The parameter hook_params is deprecated and will be removed. "
- "If you were using it, please get in touch either on airflow
slack, "
- "or by opening a github issue on the project. "
+ "Note that it is also incompatible with deferrable mode. "
+ "You can use the region_name parameter to specify the region. "
+ "If you were using hook_params for other purposes, please get
in touch either on "
+ "airflow slack, or by opening a github issue on the project. "
"You can mention https://github.com/apache/airflow/pull/32352",
AirflowProviderDeprecationWarning,
stacklevel=3, # 2 is in the operator's init, 3 is in the user
code creating the operator
)
- self.hook_params = hook_params or {}
- self.hook = RdsHook(aws_conn_id=aws_conn_id, **self.hook_params)
+ self.region_name = region_name
+ self.hook = RdsHook(aws_conn_id=aws_conn_id, region_name=region_name,
**(hook_params or {}))
super().__init__(*args, **kwargs)
self._await_interval = 60 # seconds
@@ -588,7 +597,7 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
- hook_params=self.hook_params,
+ region_name=self.region_name,
waiter_name="db_instance_available",
# ignoring type because create_db_instance is a dict
response=create_db_instance, # type: ignore[arg-type]
@@ -674,7 +683,7 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
- hook_params=self.hook_params,
+ region_name=self.region_name,
waiter_name="db_instance_deleted",
# ignoring type because delete_db_instance is a dict
response=delete_db_instance, # type: ignore[arg-type]
diff --git a/airflow/providers/amazon/aws/triggers/rds.py
b/airflow/providers/amazon/aws/triggers/rds.py
index 0897f764be..0551d67591 100644
--- a/airflow/providers/amazon/aws/triggers/rds.py
+++ b/airflow/providers/amazon/aws/triggers/rds.py
@@ -47,14 +47,14 @@ class RdsDbInstanceTrigger(BaseTrigger):
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str,
- hook_params: dict[str, Any],
+ region_name: str | None,
response: dict[str, Any],
):
self.db_instance_identifier = db_instance_identifier
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
- self.hook_params = hook_params
+ self.region_name = region_name
self.waiter_name = waiter_name
self.response = response
@@ -67,14 +67,14 @@ class RdsDbInstanceTrigger(BaseTrigger):
"waiter_delay": str(self.waiter_delay),
"waiter_max_attempts": str(self.waiter_max_attempts),
"aws_conn_id": self.aws_conn_id,
- "hook_params": self.hook_params,
+ "region_name": self.region_name,
"waiter_name": self.waiter_name,
"response": self.response,
},
)
async def run(self):
- self.hook = RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)
+ self.hook = RdsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
async with self.hook.async_conn as client:
waiter = client.get_waiter(self.waiter_name)
await async_wait(
diff --git a/tests/providers/amazon/aws/triggers/test_rds.py
b/tests/providers/amazon/aws/triggers/test_rds.py
index 5ae64b83c3..9c518c8eee 100644
--- a/tests/providers/amazon/aws/triggers/test_rds.py
+++ b/tests/providers/amazon/aws/triggers/test_rds.py
@@ -31,6 +31,7 @@ TEST_DB_INSTANCE_IDENTIFIER = "test-db-instance-identifier"
TEST_WAITER_DELAY = 10
TEST_WAITER_MAX_ATTEMPTS = 10
TEST_AWS_CONN_ID = "test-aws-id"
+TEST_REGION = "sa-east-1"
TEST_RESPONSE = {
"DBInstance": {
"DBInstanceIdentifier": "test-db-instance-identifier",
@@ -47,7 +48,7 @@ class TestRdsDbInstanceTrigger:
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
aws_conn_id=TEST_AWS_CONN_ID,
- hook_params={},
+ region_name=TEST_REGION,
response=TEST_RESPONSE,
)
class_path, args = rds_db_instance_trigger.serialize()
@@ -58,7 +59,7 @@ class TestRdsDbInstanceTrigger:
assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
assert args["aws_conn_id"] == TEST_AWS_CONN_ID
- assert args["hook_params"] == {}
+ assert args["region_name"] == TEST_REGION
assert args["response"] == TEST_RESPONSE
@pytest.mark.asyncio
@@ -75,7 +76,7 @@ class TestRdsDbInstanceTrigger:
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
aws_conn_id=TEST_AWS_CONN_ID,
- hook_params={},
+ region_name=TEST_REGION,
response=TEST_RESPONSE,
)
@@ -104,7 +105,7 @@ class TestRdsDbInstanceTrigger:
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
aws_conn_id=TEST_AWS_CONN_ID,
- hook_params={},
+ region_name=TEST_REGION,
response=TEST_RESPONSE,
)
@@ -135,7 +136,7 @@ class TestRdsDbInstanceTrigger:
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=2,
aws_conn_id=TEST_AWS_CONN_ID,
- hook_params={},
+ region_name=TEST_REGION,
response=TEST_RESPONSE,
)
@@ -173,7 +174,7 @@ class TestRdsDbInstanceTrigger:
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
aws_conn_id=TEST_AWS_CONN_ID,
- hook_params={},
+ region_name=TEST_REGION,
response=TEST_RESPONSE,
)