sunank200 commented on code in PR #31567:
URL: https://github.com/apache/airflow/pull/31567#discussion_r1210135927


##########
tests/providers/amazon/aws/hooks/test_redshift_sql.py:
##########
@@ -137,3 +137,130 @@ def test_get_conn_overrides_correctly(self, mock_connect, 
conn_params, conn_extr
         ):
             self.db_hook.get_conn()
             mock_connect.assert_called_once_with(**expected_call_args)
+
+
+class TestRedshiftSQLHookConnectionScenario:
+    def test_get_iam_token_without_both_cluster_identifier_and_host(self):
+        """
+        Tests if it raises exception when both cluster_identifier and host are 
not set in redshift connection.
+        """
+        self.connection = Connection(
+            conn_type="redshift",
+            login=LOGIN_USER,
+            password=LOGIN_PASSWORD,
+            port=LOGIN_PORT,
+            schema=LOGIN_SCHEMA,
+        )
+        self.connection.extra = json.dumps(
+            {
+                "iam": True,
+            }
+        )
+        self.db_hook = RedshiftSQLHook()
+        self.db_hook.get_connection = mock.Mock()
+        self.db_hook.get_connection.return_value = self.connection
+        with pytest.raises(Exception, match="Please set cluster_identifier or 
host in redshift connection."):
+            self.db_hook.get_uri()
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+    @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+    def test_get_conn_iam_with_cluster_identifier_without_host(
+        self, mock_connect, mock_aws_hook_conn, aws_conn_id
+    ):
+        self.connection = Connection(
+            conn_type="redshift",
+            login=LOGIN_USER,
+            password=LOGIN_PASSWORD,
+            port=LOGIN_PORT,
+            schema=LOGIN_SCHEMA,
+        )
+
+        self.db_hook = RedshiftSQLHook()
+        self.db_hook.get_connection = mock.Mock()
+        self.db_hook.get_connection.return_value = self.connection
+        mock_conn_extra = {"iam": True, "profile": "default", 
"cluster_identifier": "my-test-cluster"}
+        if aws_conn_id is not NOTSET:
+            self.db_hook.aws_conn_id = aws_conn_id
+        self.connection.extra = json.dumps(mock_conn_extra)
+
+        mock_db_user = f"IAM:{self.connection.login}"
+        mock_db_pass = "aws_token"
+
+        # Mock AWS Connection
+        mock_aws_hook_conn.get_cluster_credentials.return_value = {
+            "DbPassword": mock_db_pass,
+            "DbUser": mock_db_user,
+        }
+
+        self.db_hook.get_conn()
+
+        # Check boto3 'redshift' client method `get_cluster_credentials` call 
args
+        mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
+            DbUser=LOGIN_USER,
+            DbName=LOGIN_SCHEMA,
+            ClusterIdentifier="my-test-cluster",
+            AutoCreate=False,
+        )
+
+        mock_connect.assert_called_once_with(
+            user=mock_db_user,
+            password=mock_db_pass,
+            port=LOGIN_PORT,
+            cluster_identifier="my-test-cluster",
+            profile="default",
+            database=LOGIN_SCHEMA,
+            iam=True,
+        )
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+    @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+    def test_get_conn_iam_without_cluster_identifier_with_host(
+        self, mock_connect, mock_aws_hook_conn, aws_conn_id
+    ):
+        self.connection = Connection(
+            conn_type="redshift",
+            login=LOGIN_USER,
+            password=LOGIN_PASSWORD,
+            host=LOGIN_HOST,
+            port=LOGIN_PORT,
+            schema=LOGIN_SCHEMA,
+        )
+
+        self.db_hook = RedshiftSQLHook()
+        self.db_hook.get_connection = mock.Mock()
+        self.db_hook.get_connection.return_value = self.connection
+        mock_conn_extra = {"iam": True, "profile": "default"}
+        if aws_conn_id is not NOTSET:
+            self.db_hook.aws_conn_id = aws_conn_id
+        self.connection.extra = json.dumps(mock_conn_extra)
+
+        mock_db_user = f"IAM:{self.connection.login}"
+        mock_db_pass = "aws_token"
+
+        # Mock AWS Connection
+        mock_aws_hook_conn.get_cluster_credentials.return_value = {
+            "DbPassword": mock_db_pass,
+            "DbUser": mock_db_user,
+        }
+
+        self.db_hook.get_conn()
+
+        # Check boto3 'redshift' client method `get_cluster_credentials` call 
args
+        mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
+            DbUser=LOGIN_USER,
+            DbName=LOGIN_SCHEMA,
+            ClusterIdentifier="host",
+            AutoCreate=False,
+        )
+
+        mock_connect.assert_called_once_with(
+            user=mock_db_user,
+            password=mock_db_pass,
+            port=LOGIN_PORT,
+            host=LOGIN_HOST,
+            profile="default",
+            database=LOGIN_SCHEMA,
+            iam=True,
+        )

Review Comment:
   @hussein-awala all the scenarios are there already:
   -  without cluster identifier without host: 
`TestRedshiftSQLHookConnectionScenario::test_get_iam_token_without_both_cluster_identifier_and_host`
   - `with cluster identifier with host` : 
`TestRedshiftSQLHookConn::test_get_conn_iam`
   
   These are the missing test mentioned above. 
`TestRedshiftSQLHookConn::test_get_conn_iam` is already the part of the test 
previously as well. 
`TestRedshiftSQLHookConnectionScenario::test_get_iam_token_without_both_cluster_identifier_and_host`
 is added in the PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to