This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 869f3a93a8 Remove ability to specify arbitrary hook params in AWS RDS 
trigger (#32386)
869f3a93a8 is described below

commit 869f3a93a8873381a57382f8a0ab88879ca43f9a
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Fri Jul 7 11:53:54 2023 -0700

    Remove ability to specify arbitrary hook params in AWS RDS trigger (#32386)
---
 airflow/providers/amazon/aws/operators/rds.py   | 23 ++++++++++++++++-------
 airflow/providers/amazon/aws/triggers/rds.py    |  8 ++++----
 tests/providers/amazon/aws/triggers/test_rds.py | 13 +++++++------
 3 files changed, 27 insertions(+), 17 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/rds.py 
b/airflow/providers/amazon/aws/operators/rds.py
index c58961db2e..2630098bc4 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -43,18 +43,27 @@ class RdsBaseOperator(BaseOperator):
     ui_color = "#eeaa88"
     ui_fgcolor = "#ffffff"
 
-    def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: 
dict | None = None, **kwargs):
+    def __init__(
+        self,
+        *args,
+        aws_conn_id: str = "aws_conn_id",
+        region_name: str | None = None,
+        hook_params: dict | None = None,
+        **kwargs,
+    ):
         if hook_params is not None:
             warnings.warn(
                 "The parameter hook_params is deprecated and will be removed. "
-                "If you were using it, please get in touch either on airflow 
slack, "
-                "or by opening a github issue on the project. "
+                "Note that it is also incompatible with deferrable mode. "
+                "You can use the region_name parameter to specify the region. "
+                "If you were using hook_params for other purposes, please get 
in touch either on "
+                "airflow slack, or by opening a github issue on the project. "
                 "You can mention https://github.com/apache/airflow/pull/32352";,
                 AirflowProviderDeprecationWarning,
                 stacklevel=3,  # 2 is in the operator's init, 3 is in the user 
code creating the operator
             )
-        self.hook_params = hook_params or {}
-        self.hook = RdsHook(aws_conn_id=aws_conn_id, **self.hook_params)
+        self.region_name = region_name
+        self.hook = RdsHook(aws_conn_id=aws_conn_id, region_name=region_name, 
**(hook_params or {}))
         super().__init__(*args, **kwargs)
 
         self._await_interval = 60  # seconds
@@ -588,7 +597,7 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
                     waiter_delay=self.waiter_delay,
                     waiter_max_attempts=self.waiter_max_attempts,
                     aws_conn_id=self.aws_conn_id,
-                    hook_params=self.hook_params,
+                    region_name=self.region_name,
                     waiter_name="db_instance_available",
                     # ignoring type because create_db_instance is a dict
                     response=create_db_instance,  # type: ignore[arg-type]
@@ -674,7 +683,7 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
                     waiter_delay=self.waiter_delay,
                     waiter_max_attempts=self.waiter_max_attempts,
                     aws_conn_id=self.aws_conn_id,
-                    hook_params=self.hook_params,
+                    region_name=self.region_name,
                     waiter_name="db_instance_deleted",
                     # ignoring type because delete_db_instance is a dict
                     response=delete_db_instance,  # type: ignore[arg-type]
diff --git a/airflow/providers/amazon/aws/triggers/rds.py 
b/airflow/providers/amazon/aws/triggers/rds.py
index 0897f764be..0551d67591 100644
--- a/airflow/providers/amazon/aws/triggers/rds.py
+++ b/airflow/providers/amazon/aws/triggers/rds.py
@@ -47,14 +47,14 @@ class RdsDbInstanceTrigger(BaseTrigger):
         waiter_delay: int,
         waiter_max_attempts: int,
         aws_conn_id: str,
-        hook_params: dict[str, Any],
+        region_name: str | None,
         response: dict[str, Any],
     ):
         self.db_instance_identifier = db_instance_identifier
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
         self.aws_conn_id = aws_conn_id
-        self.hook_params = hook_params
+        self.region_name = region_name
         self.waiter_name = waiter_name
         self.response = response
 
@@ -67,14 +67,14 @@ class RdsDbInstanceTrigger(BaseTrigger):
                 "waiter_delay": str(self.waiter_delay),
                 "waiter_max_attempts": str(self.waiter_max_attempts),
                 "aws_conn_id": self.aws_conn_id,
-                "hook_params": self.hook_params,
+                "region_name": self.region_name,
                 "waiter_name": self.waiter_name,
                 "response": self.response,
             },
         )
 
     async def run(self):
-        self.hook = RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)
+        self.hook = RdsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
         async with self.hook.async_conn as client:
             waiter = client.get_waiter(self.waiter_name)
             await async_wait(
diff --git a/tests/providers/amazon/aws/triggers/test_rds.py 
b/tests/providers/amazon/aws/triggers/test_rds.py
index 5ae64b83c3..9c518c8eee 100644
--- a/tests/providers/amazon/aws/triggers/test_rds.py
+++ b/tests/providers/amazon/aws/triggers/test_rds.py
@@ -31,6 +31,7 @@ TEST_DB_INSTANCE_IDENTIFIER = "test-db-instance-identifier"
 TEST_WAITER_DELAY = 10
 TEST_WAITER_MAX_ATTEMPTS = 10
 TEST_AWS_CONN_ID = "test-aws-id"
+TEST_REGION = "sa-east-1"
 TEST_RESPONSE = {
     "DBInstance": {
         "DBInstanceIdentifier": "test-db-instance-identifier",
@@ -47,7 +48,7 @@ class TestRdsDbInstanceTrigger:
             waiter_delay=TEST_WAITER_DELAY,
             waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
             aws_conn_id=TEST_AWS_CONN_ID,
-            hook_params={},
+            region_name=TEST_REGION,
             response=TEST_RESPONSE,
         )
         class_path, args = rds_db_instance_trigger.serialize()
@@ -58,7 +59,7 @@ class TestRdsDbInstanceTrigger:
         assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
         assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
         assert args["aws_conn_id"] == TEST_AWS_CONN_ID
-        assert args["hook_params"] == {}
+        assert args["region_name"] == TEST_REGION
         assert args["response"] == TEST_RESPONSE
 
     @pytest.mark.asyncio
@@ -75,7 +76,7 @@ class TestRdsDbInstanceTrigger:
             waiter_delay=TEST_WAITER_DELAY,
             waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
             aws_conn_id=TEST_AWS_CONN_ID,
-            hook_params={},
+            region_name=TEST_REGION,
             response=TEST_RESPONSE,
         )
 
@@ -104,7 +105,7 @@ class TestRdsDbInstanceTrigger:
             waiter_delay=TEST_WAITER_DELAY,
             waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
             aws_conn_id=TEST_AWS_CONN_ID,
-            hook_params={},
+            region_name=TEST_REGION,
             response=TEST_RESPONSE,
         )
 
@@ -135,7 +136,7 @@ class TestRdsDbInstanceTrigger:
             waiter_delay=TEST_WAITER_DELAY,
             waiter_max_attempts=2,
             aws_conn_id=TEST_AWS_CONN_ID,
-            hook_params={},
+            region_name=TEST_REGION,
             response=TEST_RESPONSE,
         )
 
@@ -173,7 +174,7 @@ class TestRdsDbInstanceTrigger:
             waiter_delay=TEST_WAITER_DELAY,
             waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
             aws_conn_id=TEST_AWS_CONN_ID,
-            hook_params={},
+            region_name=TEST_REGION,
             response=TEST_RESPONSE,
         )
 

Reply via email to