This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f6962a929b Support IAM authentication for Redshift serverless (#35897)
f6962a929b is described below

commit f6962a929b839215613d1b6f99f43511759c1e5b
Author: Hussein Awala <[email protected]>
AuthorDate: Tue Nov 28 19:31:24 2023 +0200

    Support IAM authentication for Redshift serverless (#35897)
    
    * Support IAM authentication for Redshift serverless
    
    * Comments from review
---
 airflow/providers/amazon/aws/hooks/redshift_sql.py | 53 +++++++++++++++-------
 .../connections/redshift.rst                       | 20 ++++++++
 .../amazon/aws/hooks/test_redshift_sql.py          | 45 ++++++++++++++++++
 3 files changed, 102 insertions(+), 16 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py 
b/airflow/providers/amazon/aws/hooks/redshift_sql.py
index 0b1c26fff5..87bc6be478 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_sql.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py
@@ -102,22 +102,43 @@ class RedshiftSQLHook(DbApiHook):
         Port is required. If none is provided, default is used for each 
service.
         """
         port = conn.port or 5439
-        # Pull the custer-identifier from the beginning of the Redshift URL
-        # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns 
my-cluster
-        cluster_identifier = conn.extra_dejson.get("cluster_identifier")
-        if not cluster_identifier:
-            if conn.host:
-                cluster_identifier = conn.host.split(".", 1)[0]
-            else:
-                raise AirflowException("Please set cluster_identifier or host 
in redshift connection.")
-        redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id, 
client_type="redshift").conn
-        # 
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
-        cluster_creds = redshift_client.get_cluster_credentials(
-            DbUser=conn.login,
-            DbName=conn.schema,
-            ClusterIdentifier=cluster_identifier,
-            AutoCreate=False,
-        )
+        is_serverless = conn.extra_dejson.get("is_serverless", False)
+        if is_serverless:
+            serverless_work_group = 
conn.extra_dejson.get("serverless_work_group")
+            if not serverless_work_group:
+                raise AirflowException(
+                    "Please set serverless_work_group in redshift connection 
to use IAM with"
+                    " Redshift Serverless."
+                )
+            serverless_token_duration_seconds = conn.extra_dejson.get(
+                "serverless_token_duration_seconds", 3600
+            )
+            redshift_client = AwsBaseHook(
+                aws_conn_id=self.aws_conn_id, client_type="redshift-serverless"
+            ).conn
+            # 
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-serverless/client/get_credentials.html#get-credentials
+            cluster_creds = redshift_client.get_cluster_credentials(
+                DbName=conn.schema,
+                workgroupName=serverless_work_group,
+                durationSeconds=serverless_token_duration_seconds,
+            )
+        else:
+            # Pull the custer-identifier from the beginning of the Redshift URL
+            # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com 
returns my-cluster
+            cluster_identifier = conn.extra_dejson.get("cluster_identifier")
+            if not cluster_identifier:
+                if conn.host:
+                    cluster_identifier = conn.host.split(".", 1)[0]
+                else:
+                    raise AirflowException("Please set cluster_identifier or 
host in redshift connection.")
+            redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id, 
client_type="redshift").conn
+            # 
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
+            cluster_creds = redshift_client.get_cluster_credentials(
+                DbUser=conn.login,
+                DbName=conn.schema,
+                ClusterIdentifier=cluster_identifier,
+                AutoCreate=False,
+            )
         token = cluster_creds["DbPassword"]
         login = cluster_creds["DbUser"]
         return login, token, port
diff --git a/docs/apache-airflow-providers-amazon/connections/redshift.rst 
b/docs/apache-airflow-providers-amazon/connections/redshift.rst
index ccd48805c4..24817e8d38 100644
--- a/docs/apache-airflow-providers-amazon/connections/redshift.rst
+++ b/docs/apache-airflow-providers-amazon/connections/redshift.rst
@@ -97,3 +97,23 @@ inferred by the **Host** field in Connection.
       "database": "dev",
       "profile": "default"
     }
+
+If you want to use IAM with Amazon Redshift Serverless, you need to set 
**is_serverless** to true and provide
+**serverless_work_group**. You can also set 
**serverless_token_duration_seconds** to specify the number of seconds
+until the returned temporary password expires; the minimum is 900 seconds, the 
maximum is 3600 seconds and by default
+it's 3600 seconds.
+
+* **Extra**:
+
+.. code-block:: json
+
+    {
+      "iam": true,
+      "is_serverless": true,
+      "serverless_work_group": "default",
+      "serverless_token_duration_seconds": 3600,
+      "port": 5439,
+      "region": "us-east-1",
+      "database": "dev",
+      "profile": "default"
+    }
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py 
b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
index 5316372173..cab6053eff 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
@@ -120,6 +120,51 @@ class TestRedshiftSQLHookConn:
             iam=True,
         )
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+    @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+    def test_get_conn_iam_serverless_redshift(self, mock_connect, 
mock_aws_hook_conn, aws_conn_id):
+        mock_work_group = "my-test-workgroup"
+        mock_conn_extra = {
+            "iam": True,
+            "is_serverless": True,
+            "profile": "default",
+            "serverless_work_group": mock_work_group,
+        }
+        if aws_conn_id is not NOTSET:
+            self.db_hook.aws_conn_id = aws_conn_id
+        self.connection.extra = json.dumps(mock_conn_extra)
+
+        mock_db_user = f"IAM:{self.connection.login}"
+        mock_db_pass = "aws_token"
+
+        # Mock AWS Connection
+        mock_aws_hook_conn.get_cluster_credentials.return_value = {
+            "DbPassword": mock_db_pass,
+            "DbUser": mock_db_user,
+        }
+
+        self.db_hook.get_conn()
+
+        # Check boto3 'redshift' client method `get_cluster_credentials` call 
args
+        mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
+            DbName=LOGIN_SCHEMA,
+            workgroupName=mock_work_group,
+            durationSeconds=3600,
+        )
+
+        mock_connect.assert_called_once_with(
+            user=mock_db_user,
+            password=mock_db_pass,
+            host=LOGIN_HOST,
+            port=LOGIN_PORT,
+            serverless_work_group=mock_work_group,
+            profile="default",
+            database=LOGIN_SCHEMA,
+            iam=True,
+            is_serverless=True,
+        )
+
     @pytest.mark.parametrize(
         "conn_params, conn_extra, expected_call_args",
         [

Reply via email to