This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8012c9fce6 Add support for querying Redshift Serverless clusters
(#32785)
8012c9fce6 is described below
commit 8012c9fce64f152b006f88497d65ea81d29571b8
Author: Ivica Kolenkaš <[email protected]>
AuthorDate: Mon Jul 24 19:09:44 2023 +0200
Add support for querying Redshift Serverless clusters (#32785)
---
.../providers/amazon/aws/hooks/redshift_data.py | 10 ++++
.../amazon/aws/operators/redshift_data.py | 7 +++
.../amazon/aws/hooks/test_redshift_data.py | 67 +++++++++++++++++++++-
.../amazon/aws/operators/test_redshift_data.py | 42 ++++++++++++++
4 files changed, 125 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py
b/airflow/providers/amazon/aws/hooks/redshift_data.py
index fddd42bd61..a522d3e8c9 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -60,6 +60,7 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
with_event: bool = False,
wait_for_completion: bool = True,
poll_interval: int = 10,
+ workgroup_name: str | None = None,
) -> str:
"""
Execute a statement against Amazon Redshift.
@@ -74,6 +75,9 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
:param with_event: indicates whether to send an event to EventBridge
:param wait_for_completion: indicates whether to wait for a result, if
True wait, if False don't wait
:param poll_interval: how often in seconds to check the query status
+ :param workgroup_name: name of the Redshift Serverless workgroup.
Mutually exclusive with
+ `cluster_identifier`. Specify this parameter to query Redshift
Serverless. More info
+
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:returns statement_id: str, the UUID of the statement
"""
@@ -85,6 +89,7 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
"WithEvent": with_event,
"SecretArn": secret_arn,
"StatementName": statement_name,
+ "WorkgroupName": workgroup_name,
}
if isinstance(sql, list):
kwargs["Sqls"] = sql
@@ -95,6 +100,9 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
statement_id = resp["Id"]
+ if bool(cluster_identifier) is bool(workgroup_name):
+ raise ValueError("Either 'cluster_identifier' or 'workgroup_name'
must be specified.")
+
if wait_for_completion:
self.wait_for_results(statement_id, poll_interval=poll_interval)
@@ -127,6 +135,7 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
database: str,
schema: str | None = "public",
cluster_identifier: str | None = None,
+ workgroup_name: str | None = None,
db_user: str | None = None,
secret_arn: str | None = None,
statement_name: str | None = None,
@@ -168,6 +177,7 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
sql=sql,
database=database,
cluster_identifier=cluster_identifier,
+ workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py
b/airflow/providers/amazon/aws/operators/redshift_data.py
index bf560c9973..126c585e3f 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -51,6 +51,9 @@ class RedshiftDataOperator(BaseOperator):
if False (default) will return statement ID
:param aws_conn_id: aws connection to use
:param region: aws region to use
+ :param workgroup_name: name of the Redshift Serverless workgroup. Mutually
exclusive with
+ `cluster_identifier`. Specify this parameter to query Redshift
Serverless. More info
+
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
"""
template_fields = (
@@ -62,6 +65,7 @@ class RedshiftDataOperator(BaseOperator):
"statement_name",
"aws_conn_id",
"region",
+ "workgroup_name",
)
template_ext = (".sql",)
template_fields_renderers = {"sql": "sql"}
@@ -82,12 +86,14 @@ class RedshiftDataOperator(BaseOperator):
return_sql_result: bool = False,
aws_conn_id: str = "aws_default",
region: str | None = None,
+ workgroup_name: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.database = database
self.sql = sql
self.cluster_identifier = cluster_identifier
+ self.workgroup_name = workgroup_name
self.db_user = db_user
self.parameters = parameters
self.secret_arn = secret_arn
@@ -119,6 +125,7 @@ class RedshiftDataOperator(BaseOperator):
database=self.database,
sql=self.sql,
cluster_identifier=self.cluster_identifier,
+ workgroup_name=self.workgroup_name,
db_user=self.db_user,
parameters=self.parameters,
secret_arn=self.secret_arn,
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py
b/tests/providers/amazon/aws/hooks/test_redshift_data.py
index 92920f7042..cc174a872c 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_data.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py
@@ -20,6 +20,8 @@ from __future__ import annotations
import logging
from unittest import mock
+import pytest
+
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
SQL = "sql"
@@ -39,22 +41,50 @@ class TestRedshiftDataHook:
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_without_waiting(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ cluster_identifier = "cluster_identifier"
hook = RedshiftDataHook()
hook.execute_query(
database=DATABASE,
+ cluster_identifier=cluster_identifier,
sql=SQL,
wait_for_completion=False,
)
mock_conn.execute_statement.assert_called_once_with(
Database=DATABASE,
+ ClusterIdentifier=cluster_identifier,
Sql=SQL,
WithEvent=False,
)
mock_conn.describe_statement.assert_not_called()
+ @pytest.mark.parametrize(
+ "cluster_identifier, workgroup_name",
+ [
+ (None, None),
+ ("some_cluster", "some_workgroup"),
+ ],
+ )
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
- def test_execute_with_all_parameters(self, mock_conn):
+ def test_execute_requires_either_cluster_identifier_or_workgroup_name(
+ self, mock_conn, cluster_identifier, workgroup_name
+ ):
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ cluster_identifier = "cluster_identifier"
+ workgroup_name = "workgroup_name"
+
+ with pytest.raises(ValueError):
+ hook = RedshiftDataHook()
+ hook.execute_query(
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ workgroup_name=workgroup_name,
+ sql=SQL,
+ wait_for_completion=False,
+ )
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_execute_with_all_parameters_cluster_identifier(self, mock_conn):
cluster_identifier = "cluster_identifier"
db_user = "db_user"
secret_arn = "secret_arn"
@@ -88,6 +118,41 @@ class TestRedshiftDataHook:
Id=STATEMENT_ID,
)
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_execute_with_all_parameters_workgroup_name(self, mock_conn):
+ workgroup_name = "workgroup_name"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ parameters = [{"name": "id", "value": "1"}]
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+
+ hook = RedshiftDataHook()
+ hook.execute_query(
+ sql=SQL,
+ database=DATABASE,
+ workgroup_name=workgroup_name,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ )
+
+ mock_conn.execute_statement.assert_called_once_with(
+ Database=DATABASE,
+ Sql=SQL,
+ WorkgroupName=workgroup_name,
+ DbUser=db_user,
+ SecretArn=secret_arn,
+ StatementName=statement_name,
+ Parameters=parameters,
+ WithEvent=False,
+ )
+ mock_conn.describe_statement.assert_called_once_with(
+ Id=STATEMENT_ID,
+ )
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_batch_execute(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py
b/tests/providers/amazon/aws/operators/test_redshift_data.py
index be77f96c30..e5a851fe73 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -32,6 +32,7 @@ class TestRedshiftDataOperator:
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute(self, mock_exec_query):
cluster_identifier = "cluster_identifier"
+ workgroup_name = None
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
@@ -57,6 +58,46 @@ class TestRedshiftDataOperator:
sql=SQL,
database=DATABASE,
cluster_identifier=cluster_identifier,
+ workgroup_name=workgroup_name,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ with_event=False,
+ wait_for_completion=wait_for_completion,
+ poll_interval=poll_interval,
+ )
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+ def test_execute_with_workgroup_name(self, mock_exec_query):
+ cluster_identifier = None
+ workgroup_name = "workgroup_name"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ parameters = [{"name": "id", "value": "1"}]
+ poll_interval = 5
+ wait_for_completion = True
+
+ operator = RedshiftDataOperator(
+ aws_conn_id=CONN_ID,
+ task_id=TASK_ID,
+ sql=SQL,
+ database=DATABASE,
+ workgroup_name=workgroup_name,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ wait_for_completion=True,
+ poll_interval=poll_interval,
+ )
+ operator.execute(None)
+ mock_exec_query.assert_called_once_with(
+ sql=SQL,
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
@@ -85,6 +126,7 @@ class TestRedshiftDataOperator:
operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
task_id=TASK_ID,
+ cluster_identifier="cluster_identifier",
sql=SQL,
database=DATABASE,
wait_for_completion=False,