sunank200 commented on code in PR #31567:
URL: https://github.com/apache/airflow/pull/31567#discussion_r1210142541
##########
tests/providers/amazon/aws/hooks/test_redshift_sql.py:
##########
@@ -137,3 +137,130 @@ def test_get_conn_overrides_correctly(self, mock_connect,
conn_params, conn_extr
):
self.db_hook.get_conn()
mock_connect.assert_called_once_with(**expected_call_args)
+
+
+class TestRedshiftSQLHookConnectionScenario:
+ def test_get_iam_token_without_both_cluster_identifier_and_host(self):
+ """
+ Tests if it raises exception when both cluster_identifier and host are
not set in redshift connection.
+ """
+ self.connection = Connection(
+ conn_type="redshift",
+ login=LOGIN_USER,
+ password=LOGIN_PASSWORD,
+ port=LOGIN_PORT,
+ schema=LOGIN_SCHEMA,
+ )
+ self.connection.extra = json.dumps(
+ {
+ "iam": True,
+ }
+ )
+ self.db_hook = RedshiftSQLHook()
+ self.db_hook.get_connection = mock.Mock()
+ self.db_hook.get_connection.return_value = self.connection
+ with pytest.raises(Exception, match="Please set cluster_identifier or
host in redshift connection."):
+ self.db_hook.get_uri()
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+ @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+ def test_get_conn_iam_with_cluster_identifier_without_host(
+ self, mock_connect, mock_aws_hook_conn, aws_conn_id
+ ):
+ self.connection = Connection(
+ conn_type="redshift",
+ login=LOGIN_USER,
+ password=LOGIN_PASSWORD,
+ port=LOGIN_PORT,
+ schema=LOGIN_SCHEMA,
+ )
+
+ self.db_hook = RedshiftSQLHook()
+ self.db_hook.get_connection = mock.Mock()
+ self.db_hook.get_connection.return_value = self.connection
+ mock_conn_extra = {"iam": True, "profile": "default",
"cluster_identifier": "my-test-cluster"}
+ if aws_conn_id is not NOTSET:
+ self.db_hook.aws_conn_id = aws_conn_id
+ self.connection.extra = json.dumps(mock_conn_extra)
+
+ mock_db_user = f"IAM:{self.connection.login}"
+ mock_db_pass = "aws_token"
+
+ # Mock AWS Connection
+ mock_aws_hook_conn.get_cluster_credentials.return_value = {
+ "DbPassword": mock_db_pass,
+ "DbUser": mock_db_user,
+ }
+
+ self.db_hook.get_conn()
+
+ # Check boto3 'redshift' client method `get_cluster_credentials` call
args
+ mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
+ DbUser=LOGIN_USER,
+ DbName=LOGIN_SCHEMA,
+ ClusterIdentifier="my-test-cluster",
+ AutoCreate=False,
+ )
+
+ mock_connect.assert_called_once_with(
+ user=mock_db_user,
+ password=mock_db_pass,
+ port=LOGIN_PORT,
+ cluster_identifier="my-test-cluster",
+ profile="default",
+ database=LOGIN_SCHEMA,
+ iam=True,
+ )
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+ @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+ def test_get_conn_iam_without_cluster_identifier_with_host(
+ self, mock_connect, mock_aws_hook_conn, aws_conn_id
+ ):
+ self.connection = Connection(
Review Comment:
Changed it
[here](https://github.com/apache/airflow/pull/31567/commits/0b106d459f353195f8dd0df6ed5954e54adeea48)
##########
tests/providers/amazon/aws/hooks/test_redshift_sql.py:
##########
@@ -137,3 +137,130 @@ def test_get_conn_overrides_correctly(self, mock_connect,
conn_params, conn_extr
):
self.db_hook.get_conn()
mock_connect.assert_called_once_with(**expected_call_args)
+
+
+class TestRedshiftSQLHookConnectionScenario:
+ def test_get_iam_token_without_both_cluster_identifier_and_host(self):
+ """
+ Tests if it raises exception when both cluster_identifier and host are
not set in redshift connection.
+ """
+ self.connection = Connection(
+ conn_type="redshift",
+ login=LOGIN_USER,
+ password=LOGIN_PASSWORD,
+ port=LOGIN_PORT,
+ schema=LOGIN_SCHEMA,
+ )
+ self.connection.extra = json.dumps(
+ {
+ "iam": True,
+ }
+ )
+ self.db_hook = RedshiftSQLHook()
+ self.db_hook.get_connection = mock.Mock()
+ self.db_hook.get_connection.return_value = self.connection
+ with pytest.raises(Exception, match="Please set cluster_identifier or
host in redshift connection."):
+ self.db_hook.get_uri()
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+ @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+ def test_get_conn_iam_with_cluster_identifier_without_host(
+ self, mock_connect, mock_aws_hook_conn, aws_conn_id
+ ):
+ self.connection = Connection(
Review Comment:
Changed it
[here](https://github.com/apache/airflow/pull/31567/commits/0b106d459f353195f8dd0df6ed5954e54adeea48)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]