This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b3382f638 Add null check for host in Amazon Redshift connection 
(#31567)
5b3382f638 is described below

commit 5b3382f63898e497d482870636ed156ce861afbc
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Tue May 30 23:48:02 2023 +0530

    Add null check for host in Amazon Redshift connection (#31567)
    
    * Add null check for the redshift connection if cluster_identifier in 
get_iam_token
    
    * Add null check for the redshift connection if cluster_identifier in 
get_iam_token
    
    * Add exception and test cases
    
    * Remove this check for conn.host as its checked already
    
    * Update airflow/providers/amazon/aws/hooks/redshift_sql.py
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    
    * Change the conditional checks
    
    * refatco unit tests
    
    * reorder imports
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    Co-authored-by: Hussein Awala <[email protected]>
---
 airflow/providers/amazon/aws/hooks/redshift_sql.py |  8 ++-
 .../amazon/aws/hooks/test_redshift_sql.py          | 58 ++++++++++++++++++++++
 2 files changed, 65 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py 
b/airflow/providers/amazon/aws/hooks/redshift_sql.py
index 11c7dbce26..61832e67aa 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_sql.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py
@@ -24,6 +24,7 @@ from redshift_connector import Connection as 
RedshiftConnection
 from sqlalchemy import create_engine
 from sqlalchemy.engine.url import URL
 
+from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 
@@ -104,7 +105,12 @@ class RedshiftSQLHook(DbApiHook):
         port = conn.port or 5439
         # Pull the custer-identifier from the beginning of the Redshift URL
         # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns 
my-cluster
-        cluster_identifier = conn.extra_dejson.get("cluster_identifier", 
conn.host.split(".")[0])
+        cluster_identifier = conn.extra_dejson.get("cluster_identifier")
+        if not cluster_identifier:
+            if conn.host:
+                cluster_identifier = conn.host.split(".", 1)[0]
+            else:
+                raise AirflowException("Please set cluster_identifier or host 
in redshift connection.")
         redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id, 
client_type="redshift").conn
         # 
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
         cluster_creds = redshift_client.get_cluster_credentials(
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py 
b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
index 335c2f28f4..e3c0ea80e8 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
@@ -21,6 +21,7 @@ from unittest import mock
 
 import pytest
 
+from airflow import AirflowException
 from airflow.models import Connection
 from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
 from airflow.utils.types import NOTSET
@@ -137,3 +138,60 @@ class TestRedshiftSQLHookConn:
         ):
             self.db_hook.get_conn()
             mock_connect.assert_called_once_with(**expected_call_args)
+
+    @pytest.mark.parametrize(
+        "connection_host, connection_extra, expected_cluster_identifier, 
expected_exception_msg",
+        [
+            # test without a connection host and without a cluster_identifier 
in connection extra
+            (None, {"iam": True}, None, "Please set cluster_identifier or host 
in redshift connection."),
+            # test without a connection host but with a cluster_identifier in 
connection extra
+            (
+                None,
+                {"iam": True, "cluster_identifier": 
"cluster_identifier_from_extra"},
+                "cluster_identifier_from_extra",
+                None,
+            ),
+            # test with a connection host and without a cluster_identifier in 
connection extra
+            ("cluster_identifier_from_host.x.y", {"iam": True}, 
"cluster_identifier_from_host", None),
+            # test with both connection host and cluster_identifier in 
connection extra
+            (
+                "cluster_identifier_from_host.x.y",
+                {"iam": True, "cluster_identifier": 
"cluster_identifier_from_extra"},
+                "cluster_identifier_from_extra",
+                None,
+            ),
+        ],
+    )
+    @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+    def test_get_iam_token(
+        self,
+        mock_connect,
+        mock_aws_hook_conn,
+        connection_host,
+        connection_extra,
+        expected_cluster_identifier,
+        expected_exception_msg,
+    ):
+        self.connection.host = connection_host
+        self.connection.extra = json.dumps(connection_extra)
+
+        mock_db_user = f"IAM:{self.connection.login}"
+        mock_db_pass = "aws_token"
+
+        # Mock AWS Connection
+        mock_aws_hook_conn.get_cluster_credentials.return_value = {
+            "DbPassword": mock_db_pass,
+            "DbUser": mock_db_user,
+        }
+        if expected_exception_msg is not None:
+            with pytest.raises(AirflowException, match=expected_exception_msg):
+                self.db_hook.get_conn()
+        else:
+            self.db_hook.get_conn()
+            mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
+                DbUser=LOGIN_USER,
+                DbName=LOGIN_SCHEMA,
+                ClusterIdentifier=expected_cluster_identifier,
+                AutoCreate=False,
+            )

Reply via email to