This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f6962a929b Support IAM authentication for Redshift serverless (#35897)
f6962a929b is described below
commit f6962a929b839215613d1b6f99f43511759c1e5b
Author: Hussein Awala <[email protected]>
AuthorDate: Tue Nov 28 19:31:24 2023 +0200
Support IAM authentication for Redshift serverless (#35897)
* Support IAM authentication for Redshift serverless
* Comments from review
---
airflow/providers/amazon/aws/hooks/redshift_sql.py | 53 +++++++++++++++-------
.../connections/redshift.rst | 20 ++++++++
.../amazon/aws/hooks/test_redshift_sql.py | 45 ++++++++++++++++++
3 files changed, 102 insertions(+), 16 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py
b/airflow/providers/amazon/aws/hooks/redshift_sql.py
index 0b1c26fff5..87bc6be478 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_sql.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py
@@ -102,22 +102,43 @@ class RedshiftSQLHook(DbApiHook):
Port is required. If none is provided, default is used for each
service.
"""
port = conn.port or 5439
- # Pull the custer-identifier from the beginning of the Redshift URL
- # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns
my-cluster
- cluster_identifier = conn.extra_dejson.get("cluster_identifier")
- if not cluster_identifier:
- if conn.host:
- cluster_identifier = conn.host.split(".", 1)[0]
- else:
- raise AirflowException("Please set cluster_identifier or host
in redshift connection.")
- redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id,
client_type="redshift").conn
- #
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
- cluster_creds = redshift_client.get_cluster_credentials(
- DbUser=conn.login,
- DbName=conn.schema,
- ClusterIdentifier=cluster_identifier,
- AutoCreate=False,
- )
+ is_serverless = conn.extra_dejson.get("is_serverless", False)
+ if is_serverless:
+ serverless_work_group =
conn.extra_dejson.get("serverless_work_group")
+ if not serverless_work_group:
+ raise AirflowException(
+ "Please set serverless_work_group in redshift connection
to use IAM with"
+ " Redshift Serverless."
+ )
+ serverless_token_duration_seconds = conn.extra_dejson.get(
+ "serverless_token_duration_seconds", 3600
+ )
+ redshift_client = AwsBaseHook(
+ aws_conn_id=self.aws_conn_id, client_type="redshift-serverless"
+ ).conn
+ #
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-serverless/client/get_credentials.html#get-credentials
+ cluster_creds = redshift_client.get_cluster_credentials(
+ DbName=conn.schema,
+ workgroupName=serverless_work_group,
+ durationSeconds=serverless_token_duration_seconds,
+ )
+ else:
+ # Pull the custer-identifier from the beginning of the Redshift URL
+ # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com
returns my-cluster
+ cluster_identifier = conn.extra_dejson.get("cluster_identifier")
+ if not cluster_identifier:
+ if conn.host:
+ cluster_identifier = conn.host.split(".", 1)[0]
+ else:
+ raise AirflowException("Please set cluster_identifier or
host in redshift connection.")
+ redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id,
client_type="redshift").conn
+ #
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
+ cluster_creds = redshift_client.get_cluster_credentials(
+ DbUser=conn.login,
+ DbName=conn.schema,
+ ClusterIdentifier=cluster_identifier,
+ AutoCreate=False,
+ )
token = cluster_creds["DbPassword"]
login = cluster_creds["DbUser"]
return login, token, port
diff --git a/docs/apache-airflow-providers-amazon/connections/redshift.rst
b/docs/apache-airflow-providers-amazon/connections/redshift.rst
index ccd48805c4..24817e8d38 100644
--- a/docs/apache-airflow-providers-amazon/connections/redshift.rst
+++ b/docs/apache-airflow-providers-amazon/connections/redshift.rst
@@ -97,3 +97,23 @@ inferred by the **Host** field in Connection.
"database": "dev",
"profile": "default"
}
+
+If you want to use IAM with Amazon Redshift Serverless, you need to set
**is_serverless** to true and provide
+**serverless_work_group**. You can also set
**serverless_token_duration_seconds** to specify the number of seconds
+until the returned temporary password expires; the minimum is 900 seconds, the
maximum is 3600 seconds and by default
+it's 3600 seconds.
+
+* **Extra**:
+
+.. code-block:: json
+
+ {
+ "iam": true,
+ "is_serverless": true,
+ "serverless_work_group": "default",
+ "serverless_token_duration_seconds": 3600,
+ "port": 5439,
+ "region": "us-east-1",
+ "database": "dev",
+ "profile": "default"
+ }
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
index 5316372173..cab6053eff 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
@@ -120,6 +120,51 @@ class TestRedshiftSQLHookConn:
iam=True,
)
+ @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
+ @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+ def test_get_conn_iam_serverless_redshift(self, mock_connect,
mock_aws_hook_conn, aws_conn_id):
+ mock_work_group = "my-test-workgroup"
+ mock_conn_extra = {
+ "iam": True,
+ "is_serverless": True,
+ "profile": "default",
+ "serverless_work_group": mock_work_group,
+ }
+ if aws_conn_id is not NOTSET:
+ self.db_hook.aws_conn_id = aws_conn_id
+ self.connection.extra = json.dumps(mock_conn_extra)
+
+ mock_db_user = f"IAM:{self.connection.login}"
+ mock_db_pass = "aws_token"
+
+ # Mock AWS Connection
+ mock_aws_hook_conn.get_cluster_credentials.return_value = {
+ "DbPassword": mock_db_pass,
+ "DbUser": mock_db_user,
+ }
+
+ self.db_hook.get_conn()
+
+ # Check boto3 'redshift' client method `get_cluster_credentials` call
args
+ mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
+ DbName=LOGIN_SCHEMA,
+ workgroupName=mock_work_group,
+ durationSeconds=3600,
+ )
+
+ mock_connect.assert_called_once_with(
+ user=mock_db_user,
+ password=mock_db_pass,
+ host=LOGIN_HOST,
+ port=LOGIN_PORT,
+ serverless_work_group=mock_work_group,
+ profile="default",
+ database=LOGIN_SCHEMA,
+ iam=True,
+ is_serverless=True,
+ )
+
@pytest.mark.parametrize(
"conn_params, conn_extra, expected_call_args",
[