This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8012c9fce6 Add support for querying Redshift Serverless clusters 
(#32785)
8012c9fce6 is described below

commit 8012c9fce64f152b006f88497d65ea81d29571b8
Author: Ivica Kolenkaš <[email protected]>
AuthorDate: Mon Jul 24 19:09:44 2023 +0200

    Add support for querying Redshift Serverless clusters (#32785)
---
 .../providers/amazon/aws/hooks/redshift_data.py    | 10 ++++
 .../amazon/aws/operators/redshift_data.py          |  7 +++
 .../amazon/aws/hooks/test_redshift_data.py         | 67 +++++++++++++++++++++-
 .../amazon/aws/operators/test_redshift_data.py     | 42 ++++++++++++++
 4 files changed, 125 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py 
b/airflow/providers/amazon/aws/hooks/redshift_data.py
index fddd42bd61..a522d3e8c9 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -60,6 +60,7 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
         with_event: bool = False,
         wait_for_completion: bool = True,
         poll_interval: int = 10,
+        workgroup_name: str | None = None,
     ) -> str:
         """
         Execute a statement against Amazon Redshift.
@@ -74,6 +75,9 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
         :param with_event: indicates whether to send an event to EventBridge
         :param wait_for_completion: indicates whether to wait for a result, if 
True wait, if False don't wait
         :param poll_interval: how often in seconds to check the query status
+        :param workgroup_name: name of the Redshift Serverless workgroup. 
Mutually exclusive with
+            `cluster_identifier`. Specify this parameter to query Redshift 
Serverless. More info
+            
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
 
         :returns statement_id: str, the UUID of the statement
         """
@@ -85,6 +89,7 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
             "WithEvent": with_event,
             "SecretArn": secret_arn,
             "StatementName": statement_name,
+            "WorkgroupName": workgroup_name,
         }
         if isinstance(sql, list):
             kwargs["Sqls"] = sql
@@ -95,6 +100,9 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
 
         statement_id = resp["Id"]
 
+        if bool(cluster_identifier) is bool(workgroup_name):
+            raise ValueError("Either 'cluster_identifier' or 'workgroup_name' 
must be specified.")
+
         if wait_for_completion:
             self.wait_for_results(statement_id, poll_interval=poll_interval)
 
@@ -127,6 +135,7 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
         database: str,
         schema: str | None = "public",
         cluster_identifier: str | None = None,
+        workgroup_name: str | None = None,
         db_user: str | None = None,
         secret_arn: str | None = None,
         statement_name: str | None = None,
@@ -168,6 +177,7 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
             sql=sql,
             database=database,
             cluster_identifier=cluster_identifier,
+            workgroup_name=workgroup_name,
             db_user=db_user,
             secret_arn=secret_arn,
             statement_name=statement_name,
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py 
b/airflow/providers/amazon/aws/operators/redshift_data.py
index bf560c9973..126c585e3f 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -51,6 +51,9 @@ class RedshiftDataOperator(BaseOperator):
         if False (default) will return statement ID
     :param aws_conn_id: aws connection to use
     :param region: aws region to use
+    :param workgroup_name: name of the Redshift Serverless workgroup. Mutually 
exclusive with
+        `cluster_identifier`. Specify this parameter to query Redshift 
Serverless. More info
+        
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
     """
 
     template_fields = (
@@ -62,6 +65,7 @@ class RedshiftDataOperator(BaseOperator):
         "statement_name",
         "aws_conn_id",
         "region",
+        "workgroup_name",
     )
     template_ext = (".sql",)
     template_fields_renderers = {"sql": "sql"}
@@ -82,12 +86,14 @@ class RedshiftDataOperator(BaseOperator):
         return_sql_result: bool = False,
         aws_conn_id: str = "aws_default",
         region: str | None = None,
+        workgroup_name: str | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
         self.database = database
         self.sql = sql
         self.cluster_identifier = cluster_identifier
+        self.workgroup_name = workgroup_name
         self.db_user = db_user
         self.parameters = parameters
         self.secret_arn = secret_arn
@@ -119,6 +125,7 @@ class RedshiftDataOperator(BaseOperator):
             database=self.database,
             sql=self.sql,
             cluster_identifier=self.cluster_identifier,
+            workgroup_name=self.workgroup_name,
             db_user=self.db_user,
             parameters=self.parameters,
             secret_arn=self.secret_arn,
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py 
b/tests/providers/amazon/aws/hooks/test_redshift_data.py
index 92920f7042..cc174a872c 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_data.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py
@@ -20,6 +20,8 @@ from __future__ import annotations
 import logging
 from unittest import mock
 
+import pytest
+
 from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
 
 SQL = "sql"
@@ -39,22 +41,50 @@ class TestRedshiftDataHook:
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
     def test_execute_without_waiting(self, mock_conn):
         mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        cluster_identifier = "cluster_identifier"
 
         hook = RedshiftDataHook()
         hook.execute_query(
             database=DATABASE,
+            cluster_identifier=cluster_identifier,
             sql=SQL,
             wait_for_completion=False,
         )
         mock_conn.execute_statement.assert_called_once_with(
             Database=DATABASE,
+            ClusterIdentifier=cluster_identifier,
             Sql=SQL,
             WithEvent=False,
         )
         mock_conn.describe_statement.assert_not_called()
 
+    @pytest.mark.parametrize(
+        "cluster_identifier, workgroup_name",
+        [
+            (None, None),
+            ("some_cluster", "some_workgroup"),
+        ],
+    )
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
-    def test_execute_with_all_parameters(self, mock_conn):
+    def test_execute_requires_either_cluster_identifier_or_workgroup_name(
+        self, mock_conn, cluster_identifier, workgroup_name
+    ):
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        cluster_identifier = "cluster_identifier"
+        workgroup_name = "workgroup_name"
+
+        with pytest.raises(ValueError):
+            hook = RedshiftDataHook()
+            hook.execute_query(
+                database=DATABASE,
+                cluster_identifier=cluster_identifier,
+                workgroup_name=workgroup_name,
+                sql=SQL,
+                wait_for_completion=False,
+            )
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_execute_with_all_parameters_cluster_identifier(self, mock_conn):
         cluster_identifier = "cluster_identifier"
         db_user = "db_user"
         secret_arn = "secret_arn"
@@ -88,6 +118,41 @@ class TestRedshiftDataHook:
             Id=STATEMENT_ID,
         )
 
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_execute_with_all_parameters_workgroup_name(self, mock_conn):
+        workgroup_name = "workgroup_name"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        parameters = [{"name": "id", "value": "1"}]
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+
+        hook = RedshiftDataHook()
+        hook.execute_query(
+            sql=SQL,
+            database=DATABASE,
+            workgroup_name=workgroup_name,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+        )
+
+        mock_conn.execute_statement.assert_called_once_with(
+            Database=DATABASE,
+            Sql=SQL,
+            WorkgroupName=workgroup_name,
+            DbUser=db_user,
+            SecretArn=secret_arn,
+            StatementName=statement_name,
+            Parameters=parameters,
+            WithEvent=False,
+        )
+        mock_conn.describe_statement.assert_called_once_with(
+            Id=STATEMENT_ID,
+        )
+
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
     def test_batch_execute(self, mock_conn):
         mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py 
b/tests/providers/amazon/aws/operators/test_redshift_data.py
index be77f96c30..e5a851fe73 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -32,6 +32,7 @@ class TestRedshiftDataOperator:
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
     def test_execute(self, mock_exec_query):
         cluster_identifier = "cluster_identifier"
+        workgroup_name = None
         db_user = "db_user"
         secret_arn = "secret_arn"
         statement_name = "statement_name"
@@ -57,6 +58,46 @@ class TestRedshiftDataOperator:
             sql=SQL,
             database=DATABASE,
             cluster_identifier=cluster_identifier,
+            workgroup_name=workgroup_name,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+            with_event=False,
+            wait_for_completion=wait_for_completion,
+            poll_interval=poll_interval,
+        )
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+    def test_execute_with_workgroup_name(self, mock_exec_query):
+        cluster_identifier = None
+        workgroup_name = "workgroup_name"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        parameters = [{"name": "id", "value": "1"}]
+        poll_interval = 5
+        wait_for_completion = True
+
+        operator = RedshiftDataOperator(
+            aws_conn_id=CONN_ID,
+            task_id=TASK_ID,
+            sql=SQL,
+            database=DATABASE,
+            workgroup_name=workgroup_name,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+            wait_for_completion=True,
+            poll_interval=poll_interval,
+        )
+        operator.execute(None)
+        mock_exec_query.assert_called_once_with(
+            sql=SQL,
+            database=DATABASE,
+            cluster_identifier=cluster_identifier,
+            workgroup_name=workgroup_name,
             db_user=db_user,
             secret_arn=secret_arn,
             statement_name=statement_name,
@@ -85,6 +126,7 @@ class TestRedshiftDataOperator:
         operator = RedshiftDataOperator(
             aws_conn_id=CONN_ID,
             task_id=TASK_ID,
+            cluster_identifier="cluster_identifier",
             sql=SQL,
             database=DATABASE,
             wait_for_completion=False,

Reply via email to