hussein-awala commented on code in PR #31567:
URL: https://github.com/apache/airflow/pull/31567#discussion_r1209189059


##########
airflow/providers/amazon/aws/hooks/redshift_sql.py:
##########
@@ -104,7 +104,11 @@ def get_iam_token(self, conn: Connection) -> tuple[str, 
str, int]:
         port = conn.port or 5439
         # Pull the custer-identifier from the beginning of the Redshift URL
         # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns 
my-cluster
-        cluster_identifier = conn.extra_dejson.get("cluster_identifier", 
conn.host.split(".")[0])
+        cluster_identifier = conn.extra_dejson.get("cluster_identifier")
+        if not cluster_identifier and conn.host:
+            cluster_identifier = conn.host.split(".", 1)[0]
+        if not cluster_identifier:
+            raise Exception("Please set cluster_identifier or host in redshift 
connection.")

Review Comment:
   Nit
   ```suggestion
           if not cluster_identifier:
               if conn.host:
                   cluster_identifier = conn.host.split(".", 1)[0]
               else:
                   raise Exception("Please set cluster_identifier or host in 
redshift connection.")
   ```



##########
tests/providers/amazon/aws/hooks/test_redshift_sql.py:
##########
@@ -137,3 +137,130 @@ def test_get_conn_overrides_correctly(self, mock_connect, 
conn_params, conn_extr
         ):
             self.db_hook.get_conn()
             mock_connect.assert_called_once_with(**expected_call_args)
+
+
+class TestRedshiftSQLHookConnectionScenario:
+    def test_get_iam_token_without_both_cluster_identifier_and_host(self):
+        """
+        Tests if it raises exception when both cluster_identifier and host are 
not set in redshift connection.
+        """
+        self.connection = Connection(
+            conn_type="redshift",
+            login=LOGIN_USER,
+            password=LOGIN_PASSWORD,
+            port=LOGIN_PORT,
+            schema=LOGIN_SCHEMA,
+        )
+        self.connection.extra = json.dumps(
+            {
+                "iam": True,
+            }
+        )
+        self.db_hook = RedshiftSQLHook()
+        self.db_hook.get_connection = mock.Mock()
+        self.db_hook.get_connection.return_value = self.connection
+        with pytest.raises(Exception, match="Please set cluster_identifier or 
host in redshift connection."):
+            self.db_hook.get_uri()
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+    @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+    def test_get_conn_iam_with_cluster_identifier_without_host(
+        self, mock_connect, mock_aws_hook_conn, aws_conn_id
+    ):
+        self.connection = Connection(
+            conn_type="redshift",
+            login=LOGIN_USER,
+            password=LOGIN_PASSWORD,
+            port=LOGIN_PORT,
+            schema=LOGIN_SCHEMA,
+        )
+
+        self.db_hook = RedshiftSQLHook()
+        self.db_hook.get_connection = mock.Mock()
+        self.db_hook.get_connection.return_value = self.connection
+        mock_conn_extra = {"iam": True, "profile": "default", 
"cluster_identifier": "my-test-cluster"}
+        if aws_conn_id is not NOTSET:
+            self.db_hook.aws_conn_id = aws_conn_id
+        self.connection.extra = json.dumps(mock_conn_extra)
+
+        mock_db_user = f"IAM:{self.connection.login}"
+        mock_db_pass = "aws_token"
+
+        # Mock AWS Connection
+        mock_aws_hook_conn.get_cluster_credentials.return_value = {
+            "DbPassword": mock_db_pass,
+            "DbUser": mock_db_user,
+        }
+
+        self.db_hook.get_conn()
+
+        # Check boto3 'redshift' client method `get_cluster_credentials` call 
args
+        mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
+            DbUser=LOGIN_USER,
+            DbName=LOGIN_SCHEMA,
+            ClusterIdentifier="my-test-cluster",
+            AutoCreate=False,
+        )
+
+        mock_connect.assert_called_once_with(
+            user=mock_db_user,
+            password=mock_db_pass,
+            port=LOGIN_PORT,
+            cluster_identifier="my-test-cluster",
+            profile="default",
+            database=LOGIN_SCHEMA,
+            iam=True,
+        )
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+    @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+    def test_get_conn_iam_without_cluster_identifier_with_host(
+        self, mock_connect, mock_aws_hook_conn, aws_conn_id
+    ):
+        self.connection = Connection(
+            conn_type="redshift",
+            login=LOGIN_USER,
+            password=LOGIN_PASSWORD,
+            host=LOGIN_HOST,
+            port=LOGIN_PORT,
+            schema=LOGIN_SCHEMA,
+        )
+
+        self.db_hook = RedshiftSQLHook()
+        self.db_hook.get_connection = mock.Mock()
+        self.db_hook.get_connection.return_value = self.connection
+        mock_conn_extra = {"iam": True, "profile": "default"}
+        if aws_conn_id is not NOTSET:
+            self.db_hook.aws_conn_id = aws_conn_id
+        self.connection.extra = json.dumps(mock_conn_extra)
+
+        mock_db_user = f"IAM:{self.connection.login}"
+        mock_db_pass = "aws_token"
+
+        # Mock AWS Connection
+        mock_aws_hook_conn.get_cluster_credentials.return_value = {
+            "DbPassword": mock_db_pass,
+            "DbUser": mock_db_user,
+        }
+
+        self.db_hook.get_conn()
+
+        # Check boto3 'redshift' client method `get_cluster_credentials` call 
args
+        mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
+            DbUser=LOGIN_USER,
+            DbName=LOGIN_SCHEMA,
+            ClusterIdentifier="host",
+            AutoCreate=False,
+        )
+
+        mock_connect.assert_called_once_with(
+            user=mock_db_user,
+            password=mock_db_pass,
+            port=LOGIN_PORT,
+            host=LOGIN_HOST,
+            profile="default",
+            database=LOGIN_SCHEMA,
+            iam=True,
+        )

Review Comment:
   These two tests can be refactored in a single parametrize test.
   
   Also we need to cover the 4 different cases:
   - with cluster identifier without host (already exist)
   - without cluster identifier with host (already exist)
   - with cluster identifier with host (missing)
   - without cluster identifier without host (missing)
   
   The main goal of these tests is to avoid breaking the functionality of this 
feature when someone add a new feature or update this code in the future.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to