This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5b3382f638 Add null check for host in Amazon Redshift connection
(#31567)
5b3382f638 is described below
commit 5b3382f63898e497d482870636ed156ce861afbc
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Tue May 30 23:48:02 2023 +0530
Add null check for host in Amazon Redshift connection (#31567)
* Add null check for the redshift connection if cluster_identifier in
get_iam_token
* Add null check for the redshift connection if cluster_identifier in
get_iam_token
* Add exception and test cases
* Remove this check for conn.host as its checked already
* Update airflow/providers/amazon/aws/hooks/redshift_sql.py
Co-authored-by: Tzu-ping Chung <[email protected]>
* Change the conditional checks
* refatco unit tests
* reorder imports
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
Co-authored-by: Hussein Awala <[email protected]>
---
airflow/providers/amazon/aws/hooks/redshift_sql.py | 8 ++-
.../amazon/aws/hooks/test_redshift_sql.py | 58 ++++++++++++++++++++++
2 files changed, 65 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py
b/airflow/providers/amazon/aws/hooks/redshift_sql.py
index 11c7dbce26..61832e67aa 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_sql.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py
@@ -24,6 +24,7 @@ from redshift_connector import Connection as
RedshiftConnection
from sqlalchemy import create_engine
from sqlalchemy.engine.url import URL
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
@@ -104,7 +105,12 @@ class RedshiftSQLHook(DbApiHook):
port = conn.port or 5439
# Pull the custer-identifier from the beginning of the Redshift URL
# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns
my-cluster
- cluster_identifier = conn.extra_dejson.get("cluster_identifier",
conn.host.split(".")[0])
+ cluster_identifier = conn.extra_dejson.get("cluster_identifier")
+ if not cluster_identifier:
+ if conn.host:
+ cluster_identifier = conn.host.split(".", 1)[0]
+ else:
+ raise AirflowException("Please set cluster_identifier or host
in redshift connection.")
redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id,
client_type="redshift").conn
#
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
cluster_creds = redshift_client.get_cluster_credentials(
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
index 335c2f28f4..e3c0ea80e8 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
@@ -21,6 +21,7 @@ from unittest import mock
import pytest
+from airflow import AirflowException
from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.utils.types import NOTSET
@@ -137,3 +138,60 @@ class TestRedshiftSQLHookConn:
):
self.db_hook.get_conn()
mock_connect.assert_called_once_with(**expected_call_args)
+
+ @pytest.mark.parametrize(
+ "connection_host, connection_extra, expected_cluster_identifier,
expected_exception_msg",
+ [
+ # test without a connection host and without a cluster_identifier
in connection extra
+ (None, {"iam": True}, None, "Please set cluster_identifier or host
in redshift connection."),
+ # test without a connection host but with a cluster_identifier in
connection extra
+ (
+ None,
+ {"iam": True, "cluster_identifier":
"cluster_identifier_from_extra"},
+ "cluster_identifier_from_extra",
+ None,
+ ),
+ # test with a connection host and without a cluster_identifier in
connection extra
+ ("cluster_identifier_from_host.x.y", {"iam": True},
"cluster_identifier_from_host", None),
+ # test with both connection host and cluster_identifier in
connection extra
+ (
+ "cluster_identifier_from_host.x.y",
+ {"iam": True, "cluster_identifier":
"cluster_identifier_from_extra"},
+ "cluster_identifier_from_extra",
+ None,
+ ),
+ ],
+ )
+ @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+ def test_get_iam_token(
+ self,
+ mock_connect,
+ mock_aws_hook_conn,
+ connection_host,
+ connection_extra,
+ expected_cluster_identifier,
+ expected_exception_msg,
+ ):
+ self.connection.host = connection_host
+ self.connection.extra = json.dumps(connection_extra)
+
+ mock_db_user = f"IAM:{self.connection.login}"
+ mock_db_pass = "aws_token"
+
+ # Mock AWS Connection
+ mock_aws_hook_conn.get_cluster_credentials.return_value = {
+ "DbPassword": mock_db_pass,
+ "DbUser": mock_db_user,
+ }
+ if expected_exception_msg is not None:
+ with pytest.raises(AirflowException, match=expected_exception_msg):
+ self.db_hook.get_conn()
+ else:
+ self.db_hook.get_conn()
+ mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
+ DbUser=LOGIN_USER,
+ DbName=LOGIN_SCHEMA,
+ ClusterIdentifier=expected_cluster_identifier,
+ AutoCreate=False,
+ )