This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0604033829 Add Amazon Redshift-data to S3<>RS Transfer Operators 
(#27947)
0604033829 is described below

commit 0604033829787ebed59b9982bf08c1a68d93b120
Author: Josh Dimarsky <[email protected]>
AuthorDate: Mon Feb 20 00:19:33 2023 -0500

    Add Amazon Redshift-data to S3<>RS Transfer Operators (#27947)
    
    * refactored Amazon Redshift-data functionality into the hook
    ---------
    
    Co-authored-by: eladkal <[email protected]>
---
 .../providers/amazon/aws/hooks/redshift_data.py    | 145 ++++++++++++++++++-
 .../amazon/aws/operators/redshift_data.py          | 122 +++++++++-------
 .../amazon/aws/transfers/redshift_to_s3.py         |  25 +++-
 .../amazon/aws/transfers/s3_to_redshift.py         |  31 +++-
 .../amazon/aws/hooks/test_redshift_data.py         | 160 ++++++++++++++++++++-
 .../amazon/aws/operators/test_redshift_data.py     |  80 +++--------
 .../amazon/aws/transfers/test_redshift_to_s3.py    | 111 ++++++++++++++
 .../amazon/aws/transfers/test_s3_to_redshift.py    |  99 +++++++++++++
 8 files changed, 647 insertions(+), 126 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py 
b/airflow/providers/amazon/aws/hooks/redshift_data.py
index 75efc9e822..e73c5a943a 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -17,9 +17,11 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+from time import sleep
+from typing import TYPE_CHECKING, Any, Iterable
 
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.utils import trim_none_values
 
 if TYPE_CHECKING:
     from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient  # noqa
@@ -43,3 +45,144 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
     def __init__(self, *args, **kwargs) -> None:
         kwargs["client_type"] = "redshift-data"
         super().__init__(*args, **kwargs)
+
+    def execute_query(
+        self,
+        database: str,
+        sql: str | list[str],
+        cluster_identifier: str | None = None,
+        db_user: str | None = None,
+        parameters: Iterable | None = None,
+        secret_arn: str | None = None,
+        statement_name: str | None = None,
+        with_event: bool = False,
+        wait_for_completion: bool = True,
+        poll_interval: int = 10,
+    ) -> str:
+        """
+        Execute a statement against Amazon Redshift
+
+        :param database: the name of the database
+        :param sql: the SQL statement or list of  SQL statement to run
+        :param cluster_identifier: unique identifier of a cluster
+        :param db_user: the database username
+        :param parameters: the parameters for the SQL statement
+        :param secret_arn: the name or ARN of the secret that enables db access
+        :param statement_name: the name of the SQL statement
+        :param with_event: indicates whether to send an event to EventBridge
+        :param wait_for_completion: indicates whether to wait for a result, if 
True wait, if False don't wait
+        :param poll_interval: how often in seconds to check the query status
+
+        :returns statement_id: str, the UUID of the statement
+        """
+        kwargs: dict[str, Any] = {
+            "ClusterIdentifier": cluster_identifier,
+            "Database": database,
+            "DbUser": db_user,
+            "Parameters": parameters,
+            "WithEvent": with_event,
+            "SecretArn": secret_arn,
+            "StatementName": statement_name,
+        }
+        if isinstance(sql, list):
+            kwargs["Sqls"] = sql
+            resp = 
self.conn.batch_execute_statement(**trim_none_values(kwargs))
+        else:
+            kwargs["Sql"] = sql
+            resp = self.conn.execute_statement(**trim_none_values(kwargs))
+
+        statement_id = resp["Id"]
+
+        if wait_for_completion:
+            self.wait_for_results(statement_id, poll_interval=poll_interval)
+
+        return statement_id
+
+    def wait_for_results(self, statement_id, poll_interval):
+        while True:
+            self.log.info("Polling statement %s", statement_id)
+            resp = self.conn.describe_statement(
+                Id=statement_id,
+            )
+            status = resp["Status"]
+            if status == "FINISHED":
+                return status
+            elif status == "FAILED" or status == "ABORTED":
+                raise ValueError(
+                    f"Statement {statement_id!r} terminated with status 
{status}, "
+                    f"error msg: {resp.get('Error')}"
+                )
+            else:
+                self.log.info("Query %s", status)
+            sleep(poll_interval)
+
+    def get_table_primary_key(
+        self,
+        table: str,
+        database: str,
+        schema: str | None = "public",
+        cluster_identifier: str | None = None,
+        db_user: str | None = None,
+        secret_arn: str | None = None,
+        statement_name: str | None = None,
+        with_event: bool = False,
+        wait_for_completion: bool = True,
+        poll_interval: int = 10,
+    ) -> list[str] | None:
+        """
+        Helper method that returns the table primary key.
+
+        Copied from ``RedshiftSQLHook.get_table_primary_key()``
+
+        :param table: Name of the target table
+        :param database: the name of the database
+        :param schema: Name of the target schema, public by default
+        :param sql: the SQL statement or list of  SQL statement to run
+        :param cluster_identifier: unique identifier of a cluster
+        :param db_user: the database username
+        :param secret_arn: the name or ARN of the secret that enables db access
+        :param statement_name: the name of the SQL statement
+        :param with_event: indicates whether to send an event to EventBridge
+        :param wait_for_completion: indicates whether to wait for a result, if 
True wait, if False don't wait
+        :param poll_interval: how often in seconds to check the query status
+
+        :return: Primary key columns list
+        """
+        sql = f"""
+            select kcu.column_name
+            from information_schema.table_constraints tco
+                    join information_schema.key_column_usage kcu
+                        on kcu.constraint_name = tco.constraint_name
+                            and kcu.constraint_schema = tco.constraint_schema
+                            and kcu.constraint_name = tco.constraint_name
+            where tco.constraint_type = 'PRIMARY KEY'
+            and kcu.table_schema = {schema}
+            and kcu.table_name = {table}
+        """
+        stmt_id = self.execute_query(
+            sql=sql,
+            database=database,
+            cluster_identifier=cluster_identifier,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            with_event=with_event,
+            wait_for_completion=wait_for_completion,
+            poll_interval=poll_interval,
+        )
+        pk_columns = []
+        token = ""
+        while True:
+            kwargs = dict(Id=stmt_id)
+            if token:
+                kwargs["NextToken"] = token
+            response = self.conn.get_statement_result(**kwargs)
+            # we only select a single column (that is a string),
+            # so safe to assume that there is only a single col in the record
+            pk_columns += [y["stringValue"] for x in response["Records"] for y 
in x]
+            if "NextToken" not in response.keys():
+                break
+            else:
+                token = response["NextToken"]
+
+        return pk_columns or None
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py 
b/airflow/providers/amazon/aws/operators/redshift_data.py
index 416ce8146e..9b5204b767 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -17,13 +17,12 @@
 # under the License.
 from __future__ import annotations
 
-from time import sleep
-from typing import TYPE_CHECKING, Any
+import warnings
+from typing import TYPE_CHECKING
 
 from airflow.compat.functools import cached_property
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
-from airflow.providers.amazon.aws.utils import trim_none_values
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -99,67 +98,80 @@ class RedshiftDataOperator(BaseOperator):
             )
         self.aws_conn_id = aws_conn_id
         self.region = region
-        self.statement_id = None
+        self.statement_id: str | None = None
 
     @cached_property
     def hook(self) -> RedshiftDataHook:
         """Create and return an RedshiftDataHook."""
         return RedshiftDataHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region)
 
-    def execute_query(self):
-        kwargs: dict[str, Any] = {
-            "ClusterIdentifier": self.cluster_identifier,
-            "Database": self.database,
-            "Sql": self.sql,
-            "DbUser": self.db_user,
-            "Parameters": self.parameters,
-            "WithEvent": self.with_event,
-            "SecretArn": self.secret_arn,
-            "StatementName": self.statement_name,
-        }
-
-        resp = self.hook.conn.execute_statement(**trim_none_values(kwargs))
-        return resp["Id"]
-
-    def execute_batch_query(self):
-        kwargs: dict[str, Any] = {
-            "ClusterIdentifier": self.cluster_identifier,
-            "Database": self.database,
-            "Sqls": self.sql,
-            "DbUser": self.db_user,
-            "Parameters": self.parameters,
-            "WithEvent": self.with_event,
-            "SecretArn": self.secret_arn,
-            "StatementName": self.statement_name,
-        }
-        resp = 
self.hook.conn.batch_execute_statement(**trim_none_values(kwargs))
-        return resp["Id"]
-
-    def wait_for_results(self, statement_id):
-        while True:
-            self.log.info("Polling statement %s", statement_id)
-            resp = self.hook.conn.describe_statement(
-                Id=statement_id,
-            )
-            status = resp["Status"]
-            if status == "FINISHED":
-                return status
-            elif status == "FAILED" or status == "ABORTED":
-                raise ValueError(f"Statement {statement_id!r} terminated with 
status {status}.")
-            else:
-                self.log.info("Query %s", status)
-            sleep(self.poll_interval)
-
-    def execute(self, context: Context) -> None:
+    def execute_query(self) -> str:
+        warnings.warn(
+            "This method is deprecated and has been moved to the hook "
+            
"`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        self.statement_id = self.hook.execute_query(
+            database=self.database,
+            sql=self.sql,
+            cluster_identifier=self.cluster_identifier,
+            db_user=self.db_user,
+            parameters=self.parameters,
+            secret_arn=self.secret_arn,
+            statement_name=self.statement_name,
+            with_event=self.with_event,
+            wait_for_completion=self.await_result,
+            poll_interval=self.poll_interval,
+        )
+        return self.statement_id
+
+    def execute_batch_query(self) -> str:
+        warnings.warn(
+            "This method is deprecated and has been moved to the hook "
+            
"`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        self.statement_id = self.hook.execute_query(
+            database=self.database,
+            sql=self.sql,
+            cluster_identifier=self.cluster_identifier,
+            db_user=self.db_user,
+            parameters=self.parameters,
+            secret_arn=self.secret_arn,
+            statement_name=self.statement_name,
+            with_event=self.with_event,
+            wait_for_completion=self.await_result,
+            poll_interval=self.poll_interval,
+        )
+        return self.statement_id
+
+    def wait_for_results(self, statement_id: str):
+        warnings.warn(
+            "This method is deprecated and has been moved to the hook "
+            
"`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        return self.hook.wait_for_results(statement_id=statement_id, 
poll_interval=self.poll_interval)
+
+    def execute(self, context: Context) -> str:
         """Execute a statement against Amazon Redshift"""
         self.log.info("Executing statement: %s", self.sql)
-        if isinstance(self.sql, list):
-            self.statement_id = self.execute_batch_query()
-        else:
-            self.statement_id = self.execute_query()
 
-        if self.await_result:
-            self.wait_for_results(self.statement_id)
+        self.statement_id = self.hook.execute_query(
+            database=self.database,
+            sql=self.sql,
+            cluster_identifier=self.cluster_identifier,
+            db_user=self.db_user,
+            parameters=self.parameters,
+            secret_arn=self.secret_arn,
+            statement_name=self.statement_name,
+            with_event=self.with_event,
+            wait_for_completion=self.await_result,
+            poll_interval=self.poll_interval,
+        )
 
         return self.statement_id
 
diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py 
b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index ed57df0984..45f1e1783c 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -20,7 +20,9 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
 
+from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
 from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
@@ -67,6 +69,9 @@ class RedshiftToS3Operator(BaseOperator):
     :param parameters: (optional) the parameters to render the SQL query with.
     :param table_as_file_name: If set to True, the s3 file will be named as 
the table.
         Applicable when ``table`` param provided.
+    :param redshift_data_api_kwargs: If using the Redshift Data API instead of 
the SQL-based connection,
+        dict of arguments for the hook's ``execute_query`` method.
+        Cannot include any of these kwargs: ``{'sql', 'parameters'}``
     """
 
     template_fields: Sequence[str] = (
@@ -98,6 +103,7 @@ class RedshiftToS3Operator(BaseOperator):
         include_header: bool = False,
         parameters: Iterable | Mapping | None = None,
         table_as_file_name: bool = True,  # Set to True by default for not 
breaking current workflows
+        redshift_data_api_kwargs: dict = {},
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -113,6 +119,7 @@ class RedshiftToS3Operator(BaseOperator):
         self.include_header = include_header
         self.parameters = parameters
         self.table_as_file_name = table_as_file_name
+        self.redshift_data_api_kwargs = redshift_data_api_kwargs
 
         if select_query:
             self.select_query = select_query
@@ -128,6 +135,11 @@ class RedshiftToS3Operator(BaseOperator):
                 "HEADER",
             ]
 
+        if self.redshift_data_api_kwargs:
+            for arg in ["sql", "parameters"]:
+                if arg in self.redshift_data_api_kwargs.keys():
+                    raise AirflowException(f"Cannot include param '{arg}' in 
Redshift Data API kwargs")
+
     def _build_unload_query(
         self, credentials_block: str, select_query: str, s3_key: str, 
unload_options: str
     ) -> str:
@@ -140,7 +152,11 @@ class RedshiftToS3Operator(BaseOperator):
         """
 
     def execute(self, context: Context) -> None:
-        redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
+        redshift_hook: RedshiftDataHook | RedshiftSQLHook
+        if self.redshift_data_api_kwargs:
+            redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
+        else:
+            redshift_hook = 
RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
         conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
         if conn.extra_dejson.get("role_arn", False):
             credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
@@ -156,5 +172,10 @@ class RedshiftToS3Operator(BaseOperator):
         )
 
         self.log.info("Executing UNLOAD command...")
-        redshift_hook.run(unload_query, self.autocommit, 
parameters=self.parameters)
+        if isinstance(redshift_hook, RedshiftDataHook):
+            redshift_hook.execute_query(
+                sql=unload_query, parameters=self.parameters, 
**self.redshift_data_api_kwargs
+            )
+        else:
+            redshift_hook.run(unload_query, self.autocommit, 
parameters=self.parameters)
         self.log.info("UNLOAD command complete...")
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py 
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 1dcb6e7648..dc99939021 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Iterable, Sequence
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
 from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
@@ -43,7 +44,7 @@ class S3ToRedshiftOperator(BaseOperator):
     :param table: reference to a specific table in redshift database
     :param s3_bucket: reference to a specific S3 bucket
     :param s3_key: key prefix that selects single or multiple objects from S3
-    :param redshift_conn_id: reference to a specific redshift database
+    :param redshift_conn_id: reference to a specific redshift database OR a 
redshift data-api connection
     :param aws_conn_id: reference to a specific S3 connection
         If the AWS connection contains 'aws_iam_role' in ``extras``
         the operator will use AWS STS credentials with a token
@@ -62,6 +63,9 @@ class S3ToRedshiftOperator(BaseOperator):
     :param copy_options: reference to a list of COPY options
     :param method: Action to be performed on execution. Available ``APPEND``, 
``UPSERT`` and ``REPLACE``.
     :param upsert_keys: List of fields to use as key on upsert action
+    :param redshift_data_api_kwargs: If using the Redshift Data API instead of 
the SQL-based connection,
+        dict of arguments for the hook's ``execute_query`` method.
+        Cannot include any of these kwargs: ``{'sql', 'parameters'}``
     """
 
     template_fields: Sequence[str] = (
@@ -91,6 +95,7 @@ class S3ToRedshiftOperator(BaseOperator):
         autocommit: bool = False,
         method: str = "APPEND",
         upsert_keys: list[str] | None = None,
+        redshift_data_api_kwargs: dict = {},
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -106,10 +111,16 @@ class S3ToRedshiftOperator(BaseOperator):
         self.autocommit = autocommit
         self.method = method
         self.upsert_keys = upsert_keys
+        self.redshift_data_api_kwargs = redshift_data_api_kwargs
 
         if self.method not in AVAILABLE_METHODS:
             raise AirflowException(f"Method not found! Available methods: 
{AVAILABLE_METHODS}")
 
+        if self.redshift_data_api_kwargs:
+            for arg in ["sql", "parameters"]:
+                if arg in self.redshift_data_api_kwargs.keys():
+                    raise AirflowException(f"Cannot include param '{arg}' in 
Redshift Data API kwargs")
+
     def _build_copy_query(self, copy_destination: str, credentials_block: str, 
copy_options: str) -> str:
         column_names = "(" + ", ".join(self.column_list) + ")" if 
self.column_list else ""
         return f"""
@@ -121,7 +132,11 @@ class S3ToRedshiftOperator(BaseOperator):
         """
 
     def execute(self, context: Context) -> None:
-        redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
+        redshift_hook: RedshiftDataHook | RedshiftSQLHook
+        if self.redshift_data_api_kwargs:
+            redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
+        else:
+            redshift_hook = 
RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
         conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
 
         if conn.extra_dejson.get("role_arn", False):
@@ -142,7 +157,12 @@ class S3ToRedshiftOperator(BaseOperator):
         if self.method == "REPLACE":
             sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, 
"COMMIT"]
         elif self.method == "UPSERT":
-            keys = self.upsert_keys or 
redshift_hook.get_table_primary_key(self.table, self.schema)
+            if isinstance(redshift_hook, RedshiftDataHook):
+                keys = self.upsert_keys or redshift_hook.get_table_primary_key(
+                    table=self.table, schema=self.schema, 
**self.redshift_data_api_kwargs
+                )
+            else:
+                keys = self.upsert_keys or 
redshift_hook.get_table_primary_key(self.table, self.schema)
             if not keys:
                 raise AirflowException(
                     f"No primary key on {self.schema}.{self.table}. Please 
provide keys on 'upsert_keys'"
@@ -162,5 +182,8 @@ class S3ToRedshiftOperator(BaseOperator):
             sql = copy_statement
 
         self.log.info("Executing COPY command...")
-        redshift_hook.run(sql, autocommit=self.autocommit)
+        if isinstance(redshift_hook, RedshiftDataHook):
+            redshift_hook.execute_query(sql=sql, 
**self.redshift_data_api_kwargs)
+        else:
+            redshift_hook.run(sql, autocommit=self.autocommit)
         self.log.info("COPY command complete...")
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py 
b/tests/providers/amazon/aws/hooks/test_redshift_data.py
index f204703170..29816442c4 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_data.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py
@@ -17,14 +17,172 @@
 # under the License.
 from __future__ import annotations
 
+from unittest import mock
+
 from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
 
+CONN_ID = "aws_conn_test"
+SQL = "sql"
+DATABASE = "database"
+STATEMENT_ID = "statement_id"
+
 
 class TestRedshiftDataHook:
     def test_conn_attribute(self):
-        hook = RedshiftDataHook(aws_conn_id="aws_default", 
region_name="us-east-1")
+        hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
         assert hasattr(hook, "conn")
         assert hook.conn.__class__.__name__ == "RedshiftDataAPIService"
         conn = hook.conn
         assert conn is hook.conn  # Cached property
         assert conn is hook.get_conn()  # Same object as returned by `conn` 
property
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_execute_without_waiting(self, mock_conn):
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+
+        hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+        hook.execute_query(
+            database=DATABASE,
+            sql=SQL,
+            wait_for_completion=False,
+        )
+        mock_conn.execute_statement.assert_called_once_with(
+            Database=DATABASE,
+            Sql=SQL,
+            WithEvent=False,
+        )
+        mock_conn.describe_statement.assert_not_called()
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_execute_with_all_parameters(self, mock_conn):
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        parameters = [{"name": "id", "value": "1"}]
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+
+        hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+        hook.execute_query(
+            sql=SQL,
+            database=DATABASE,
+            cluster_identifier=cluster_identifier,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+        )
+
+        mock_conn.execute_statement.assert_called_once_with(
+            Database=DATABASE,
+            Sql=SQL,
+            ClusterIdentifier=cluster_identifier,
+            DbUser=db_user,
+            SecretArn=secret_arn,
+            StatementName=statement_name,
+            Parameters=parameters,
+            WithEvent=False,
+        )
+        mock_conn.describe_statement.assert_called_once_with(
+            Id=STATEMENT_ID,
+        )
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_batch_execute(self, mock_conn):
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+
+        hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+        hook.execute_query(
+            cluster_identifier=cluster_identifier,
+            database=DATABASE,
+            db_user=db_user,
+            sql=[SQL],
+            statement_name=statement_name,
+            secret_arn=secret_arn,
+        )
+
+        mock_conn.batch_execute_statement.assert_called_once_with(
+            Database=DATABASE,
+            Sqls=[SQL],
+            ClusterIdentifier=cluster_identifier,
+            DbUser=db_user,
+            SecretArn=secret_arn,
+            StatementName=statement_name,
+            WithEvent=False,
+        )
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_get_table_primary_key_no_token(self, mock_conn):
+        table = "table"
+        schema = "schema"
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+        mock_conn.get_statement_result.return_value = {
+            "Records": [[{"stringValue": "string"}]],
+        }
+
+        hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+
+        hook.get_table_primary_key(
+            table=table,
+            database=DATABASE,
+            schema=schema,
+            cluster_identifier=cluster_identifier,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+        )
+
+        mock_conn.get_statement_result.assert_called_once_with(Id=STATEMENT_ID)
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_get_table_primary_key_with_token(self, mock_conn):
+        table = "table"
+        schema = "schema"
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+        mock_conn.get_statement_result.side_effect = [
+            {
+                "NextToken": "token1",
+                "Records": [[{"stringValue": "string1"}]],
+            },
+            {
+                "NextToken": "token2",
+                "Records": [[{"stringValue": "string2"}]],
+            },
+            {
+                "Records": [[{"stringValue": "string3"}]],
+            },
+        ]
+
+        hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+
+        hook.get_table_primary_key(
+            table=table,
+            database=DATABASE,
+            schema=schema,
+            cluster_identifier=cluster_identifier,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+        )
+
+        assert mock_conn.get_statement_result.call_args_list == [
+            (dict(Id=STATEMENT_ID),),
+            (dict(Id=STATEMENT_ID, NextToken="token1"),),
+            (dict(Id=STATEMENT_ID, NextToken="token2"),),
+        ]
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py 
b/tests/providers/amazon/aws/operators/test_redshift_data.py
index fa527ffc09..efa88cb026 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -29,33 +29,15 @@ STATEMENT_ID = "statement_id"
 
 
 class TestRedshiftDataOperator:
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
-    def test_execute_without_waiting(self, mock_conn):
-        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
-        operator = RedshiftDataOperator(
-            aws_conn_id=CONN_ID,
-            task_id=TASK_ID,
-            sql=SQL,
-            database=DATABASE,
-            await_result=False,
-        )
-        operator.execute(None)
-        mock_conn.execute_statement.assert_called_once_with(
-            Database=DATABASE,
-            Sql=SQL,
-            WithEvent=False,
-        )
-        mock_conn.describe_statement.assert_not_called()
-
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
-    def test_execute_with_all_parameters(self, mock_conn):
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+    def test_execute(self, mock_exec_query):
         cluster_identifier = "cluster_identifier"
         db_user = "db_user"
         secret_arn = "secret_arn"
         statement_name = "statement_name"
         parameters = [{"name": "id", "value": "1"}]
-        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
-        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+        poll_interval = 5
+        await_result = True
 
         operator = RedshiftDataOperator(
             aws_conn_id=CONN_ID,
@@ -67,20 +49,21 @@ class TestRedshiftDataOperator:
             secret_arn=secret_arn,
             statement_name=statement_name,
             parameters=parameters,
+            await_result=True,
+            poll_interval=poll_interval,
         )
         operator.execute(None)
-        mock_conn.execute_statement.assert_called_once_with(
-            Database=DATABASE,
-            Sql=SQL,
-            ClusterIdentifier=cluster_identifier,
-            DbUser=db_user,
-            SecretArn=secret_arn,
-            StatementName=statement_name,
-            Parameters=parameters,
-            WithEvent=False,
-        )
-        mock_conn.describe_statement.assert_called_once_with(
-            Id=STATEMENT_ID,
+        mock_exec_query.assert_called_once_with(
+            sql=SQL,
+            database=DATABASE,
+            cluster_identifier=cluster_identifier,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+            with_event=False,
+            wait_for_completion=await_result,
+            poll_interval=poll_interval,
         )
 
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@@ -111,32 +94,3 @@ class TestRedshiftDataOperator:
         mock_conn.cancel_statement.assert_called_once_with(
             Id=STATEMENT_ID,
         )
-
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
-    def test_batch_execute(self, mock_conn):
-        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
-        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
-        cluster_identifier = "cluster_identifier"
-        db_user = "db_user"
-        secret_arn = "secret_arn"
-        statement_name = "statement_name"
-        operator = RedshiftDataOperator(
-            task_id=TASK_ID,
-            cluster_identifier=cluster_identifier,
-            database=DATABASE,
-            db_user=db_user,
-            sql=[SQL],
-            statement_name=statement_name,
-            secret_arn=secret_arn,
-            aws_conn_id=CONN_ID,
-        )
-        operator.execute(None)
-        mock_conn.batch_execute_statement.assert_called_once_with(
-            Database=DATABASE,
-            Sqls=[SQL],
-            ClusterIdentifier=cluster_identifier,
-            DbUser=db_user,
-            SecretArn=secret_arn,
-            StatementName=statement_name,
-            WithEvent=False,
-        )
diff --git a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py 
b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
index 91a1ae32da..67982821e8 100644
--- a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
@@ -17,11 +17,13 @@
 # under the License.
 from __future__ import annotations
 
+from copy import deepcopy
 from unittest import mock
 
 import pytest
 from boto3.session import Session
 
+from airflow.exceptions import AirflowException
 from airflow.models.connection import Connection
 from airflow.providers.amazon.aws.transfers.redshift_to_s3 import 
RedshiftToS3Operator
 from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
@@ -284,3 +286,112 @@ class TestRedshiftToS3Transfer:
             "select_query",
             "redshift_conn_id",
         )
+
+    @pytest.mark.parametrize("param", ["sql", "parameters"])
+    def test_invalid_param_in_redshift_data_api_kwargs(self, param):
+        """
+        Test passing invalid param in RS Data API kwargs raises an error
+        """
+        with pytest.raises(AirflowException):
+            RedshiftToS3Operator(
+                s3_bucket="s3_bucket",
+                s3_key="s3_key",
+                select_query="select_query",
+                task_id="task_id",
+                dag=None,
+                redshift_data_api_kwargs={param: "param"},
+            )
+
+    @pytest.mark.parametrize("table_as_file_name, expected_s3_key", [[True, 
"key/table_"], [False, "key"]])
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
+    @mock.patch("boto3.session.Session")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_table_unloading_using_redshift_data_api(
+        self,
+        mock_rs,
+        mock_run,
+        mock_session,
+        mock_connection,
+        mock_hook,
+        table_as_file_name,
+        expected_s3_key,
+    ):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+        mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"}
+        mock_rs.describe_statement.return_value = {"Status": "FINISHED"}
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        unload_options = [
+            "HEADER",
+        ]
+        # RS Data API params
+        database = "database"
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+
+        op = RedshiftToS3Operator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            unload_options=unload_options,
+            include_header=True,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            table_as_file_name=table_as_file_name,
+            dag=None,
+            redshift_data_api_kwargs=dict(
+                database=database,
+                cluster_identifier=cluster_identifier,
+                db_user=db_user,
+                secret_arn=secret_arn,
+                statement_name=statement_name,
+            ),
+        )
+
+        op.execute(None)
+
+        unload_options = "\n\t\t\t".join(unload_options)
+        select_query = f"SELECT * FROM {schema}.{table}"
+        credentials_block = build_credentials_block(mock_session.return_value)
+
+        unload_query = op._build_unload_query(
+            credentials_block, select_query, expected_s3_key, unload_options
+        )
+
+        mock_run.assert_not_called()
+        assert access_key in unload_query
+        assert secret_key in unload_query
+
+        mock_rs.execute_statement.assert_called_once()
+        # test with all args besides sql
+        _call = deepcopy(mock_rs.execute_statement.call_args[1])
+        _call.pop("Sql")
+        assert _call == dict(
+            Database=database,
+            ClusterIdentifier=cluster_identifier,
+            DbUser=db_user,
+            SecretArn=secret_arn,
+            StatementName=statement_name,
+            WithEvent=False,
+        )
+        mock_rs.describe_statement.assert_called_once_with(
+            Id="STATEMENT_ID",
+        )
+        # test sql arg
+        assert_equal_ignore_multiple_spaces(self, 
mock_rs.execute_statement.call_args[1]["Sql"], unload_query)
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py 
b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index e69673b27e..60b8bee249 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from copy import deepcopy
 from unittest import mock
 
 import pytest
@@ -348,3 +349,101 @@ class TestS3ToRedshiftTransfer:
                 task_id="task_id",
                 dag=None,
             )
+
+    @pytest.mark.parametrize("param", ["sql", "parameters"])
+    def test_invalid_param_in_redshift_data_api_kwargs(self, param):
+        """
+        Test passing invalid param in RS Data API kwargs raises an error
+        """
+        with pytest.raises(AirflowException):
+            S3ToRedshiftOperator(
+                schema="schema",
+                table="table",
+                s3_bucket="bucket",
+                s3_key="key",
+                task_id="task_id",
+                dag=None,
+                redshift_data_api_kwargs={param: "param"},
+            )
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
+    @mock.patch("boto3.session.Session")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session, 
mock_connection, mock_hook):
+        """
+        Using the Redshift Data API instead of the SQL-based connection
+        """
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+        mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"}
+        mock_rs.describe_statement.return_value = {"Status": "FINISHED"}
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        # RS Data API params
+        database = "database"
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+
+        op = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+            redshift_data_api_kwargs=dict(
+                database=database,
+                cluster_identifier=cluster_identifier,
+                db_user=db_user,
+                secret_arn=secret_arn,
+                statement_name=statement_name,
+            ),
+        )
+        op.execute(None)
+        copy_query = """
+                        COPY schema.table
+                        FROM 's3://bucket/key'
+                        credentials
+                        
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+                        ;
+                     """
+        mock_run.assert_not_called()
+        assert access_key in copy_query
+        assert secret_key in copy_query
+
+        mock_rs.execute_statement.assert_called_once()
+        # test with all args besides sql
+        _call = deepcopy(mock_rs.execute_statement.call_args[1])
+        _call.pop("Sql")
+        assert _call == dict(
+            Database=database,
+            ClusterIdentifier=cluster_identifier,
+            DbUser=db_user,
+            SecretArn=secret_arn,
+            StatementName=statement_name,
+            WithEvent=False,
+        )
+        mock_rs.describe_statement.assert_called_once_with(
+            Id="STATEMENT_ID",
+        )
+        # test sql arg
+        assert_equal_ignore_multiple_spaces(self, 
mock_rs.execute_statement.call_args[1]["Sql"], copy_query)


Reply via email to