This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0604033829 Add Amazon Redshift-data to S3<>RS Transfer Operators
(#27947)
0604033829 is described below
commit 0604033829787ebed59b9982bf08c1a68d93b120
Author: Josh Dimarsky <[email protected]>
AuthorDate: Mon Feb 20 00:19:33 2023 -0500
Add Amazon Redshift-data to S3<>RS Transfer Operators (#27947)
* refactored Amazon Redshift-data functionality into the hook
---------
Co-authored-by: eladkal <[email protected]>
---
.../providers/amazon/aws/hooks/redshift_data.py | 145 ++++++++++++++++++-
.../amazon/aws/operators/redshift_data.py | 122 +++++++++-------
.../amazon/aws/transfers/redshift_to_s3.py | 25 +++-
.../amazon/aws/transfers/s3_to_redshift.py | 31 +++-
.../amazon/aws/hooks/test_redshift_data.py | 160 ++++++++++++++++++++-
.../amazon/aws/operators/test_redshift_data.py | 80 +++--------
.../amazon/aws/transfers/test_redshift_to_s3.py | 111 ++++++++++++++
.../amazon/aws/transfers/test_s3_to_redshift.py | 99 +++++++++++++
8 files changed, 647 insertions(+), 126 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py
b/airflow/providers/amazon/aws/hooks/redshift_data.py
index 75efc9e822..e73c5a943a 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -17,9 +17,11 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING
+from time import sleep
+from typing import TYPE_CHECKING, Any, Iterable
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.utils import trim_none_values
if TYPE_CHECKING:
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa
@@ -43,3 +45,144 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "redshift-data"
super().__init__(*args, **kwargs)
+
+ def execute_query(
+ self,
+ database: str,
+ sql: str | list[str],
+ cluster_identifier: str | None = None,
+ db_user: str | None = None,
+ parameters: Iterable | None = None,
+ secret_arn: str | None = None,
+ statement_name: str | None = None,
+ with_event: bool = False,
+ wait_for_completion: bool = True,
+ poll_interval: int = 10,
+ ) -> str:
+ """
+ Execute a statement against Amazon Redshift
+
+ :param database: the name of the database
+ :param sql: the SQL statement or list of SQL statement to run
+ :param cluster_identifier: unique identifier of a cluster
+ :param db_user: the database username
+ :param parameters: the parameters for the SQL statement
+ :param secret_arn: the name or ARN of the secret that enables db access
+ :param statement_name: the name of the SQL statement
+ :param with_event: indicates whether to send an event to EventBridge
+ :param wait_for_completion: indicates whether to wait for a result, if
True wait, if False don't wait
+ :param poll_interval: how often in seconds to check the query status
+
+ :returns statement_id: str, the UUID of the statement
+ """
+ kwargs: dict[str, Any] = {
+ "ClusterIdentifier": cluster_identifier,
+ "Database": database,
+ "DbUser": db_user,
+ "Parameters": parameters,
+ "WithEvent": with_event,
+ "SecretArn": secret_arn,
+ "StatementName": statement_name,
+ }
+ if isinstance(sql, list):
+ kwargs["Sqls"] = sql
+ resp =
self.conn.batch_execute_statement(**trim_none_values(kwargs))
+ else:
+ kwargs["Sql"] = sql
+ resp = self.conn.execute_statement(**trim_none_values(kwargs))
+
+ statement_id = resp["Id"]
+
+ if wait_for_completion:
+ self.wait_for_results(statement_id, poll_interval=poll_interval)
+
+ return statement_id
+
+ def wait_for_results(self, statement_id, poll_interval):
+ while True:
+ self.log.info("Polling statement %s", statement_id)
+ resp = self.conn.describe_statement(
+ Id=statement_id,
+ )
+ status = resp["Status"]
+ if status == "FINISHED":
+ return status
+ elif status == "FAILED" or status == "ABORTED":
+ raise ValueError(
+ f"Statement {statement_id!r} terminated with status
{status}, "
+ f"error msg: {resp.get('Error')}"
+ )
+ else:
+ self.log.info("Query %s", status)
+ sleep(poll_interval)
+
+ def get_table_primary_key(
+ self,
+ table: str,
+ database: str,
+ schema: str | None = "public",
+ cluster_identifier: str | None = None,
+ db_user: str | None = None,
+ secret_arn: str | None = None,
+ statement_name: str | None = None,
+ with_event: bool = False,
+ wait_for_completion: bool = True,
+ poll_interval: int = 10,
+ ) -> list[str] | None:
+ """
+ Helper method that returns the table primary key.
+
+ Copied from ``RedshiftSQLHook.get_table_primary_key()``
+
+ :param table: Name of the target table
+ :param database: the name of the database
+ :param schema: Name of the target schema, public by default
+ :param sql: the SQL statement or list of SQL statement to run
+ :param cluster_identifier: unique identifier of a cluster
+ :param db_user: the database username
+ :param secret_arn: the name or ARN of the secret that enables db access
+ :param statement_name: the name of the SQL statement
+ :param with_event: indicates whether to send an event to EventBridge
+ :param wait_for_completion: indicates whether to wait for a result, if
True wait, if False don't wait
+ :param poll_interval: how often in seconds to check the query status
+
+ :return: Primary key columns list
+ """
+ sql = f"""
+ select kcu.column_name
+ from information_schema.table_constraints tco
+ join information_schema.key_column_usage kcu
+ on kcu.constraint_name = tco.constraint_name
+ and kcu.constraint_schema = tco.constraint_schema
+ and kcu.constraint_name = tco.constraint_name
+ where tco.constraint_type = 'PRIMARY KEY'
+ and kcu.table_schema = {schema}
+ and kcu.table_name = {table}
+ """
+ stmt_id = self.execute_query(
+ sql=sql,
+ database=database,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ with_event=with_event,
+ wait_for_completion=wait_for_completion,
+ poll_interval=poll_interval,
+ )
+ pk_columns = []
+ token = ""
+ while True:
+ kwargs = dict(Id=stmt_id)
+ if token:
+ kwargs["NextToken"] = token
+ response = self.conn.get_statement_result(**kwargs)
+ # we only select a single column (that is a string),
+ # so safe to assume that there is only a single col in the record
+ pk_columns += [y["stringValue"] for x in response["Records"] for y
in x]
+ if "NextToken" not in response.keys():
+ break
+ else:
+ token = response["NextToken"]
+
+ return pk_columns or None
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py
b/airflow/providers/amazon/aws/operators/redshift_data.py
index 416ce8146e..9b5204b767 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -17,13 +17,12 @@
# under the License.
from __future__ import annotations
-from time import sleep
-from typing import TYPE_CHECKING, Any
+import warnings
+from typing import TYPE_CHECKING
from airflow.compat.functools import cached_property
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
-from airflow.providers.amazon.aws.utils import trim_none_values
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -99,67 +98,80 @@ class RedshiftDataOperator(BaseOperator):
)
self.aws_conn_id = aws_conn_id
self.region = region
- self.statement_id = None
+ self.statement_id: str | None = None
@cached_property
def hook(self) -> RedshiftDataHook:
"""Create and return an RedshiftDataHook."""
return RedshiftDataHook(aws_conn_id=self.aws_conn_id,
region_name=self.region)
- def execute_query(self):
- kwargs: dict[str, Any] = {
- "ClusterIdentifier": self.cluster_identifier,
- "Database": self.database,
- "Sql": self.sql,
- "DbUser": self.db_user,
- "Parameters": self.parameters,
- "WithEvent": self.with_event,
- "SecretArn": self.secret_arn,
- "StatementName": self.statement_name,
- }
-
- resp = self.hook.conn.execute_statement(**trim_none_values(kwargs))
- return resp["Id"]
-
- def execute_batch_query(self):
- kwargs: dict[str, Any] = {
- "ClusterIdentifier": self.cluster_identifier,
- "Database": self.database,
- "Sqls": self.sql,
- "DbUser": self.db_user,
- "Parameters": self.parameters,
- "WithEvent": self.with_event,
- "SecretArn": self.secret_arn,
- "StatementName": self.statement_name,
- }
- resp =
self.hook.conn.batch_execute_statement(**trim_none_values(kwargs))
- return resp["Id"]
-
- def wait_for_results(self, statement_id):
- while True:
- self.log.info("Polling statement %s", statement_id)
- resp = self.hook.conn.describe_statement(
- Id=statement_id,
- )
- status = resp["Status"]
- if status == "FINISHED":
- return status
- elif status == "FAILED" or status == "ABORTED":
- raise ValueError(f"Statement {statement_id!r} terminated with
status {status}.")
- else:
- self.log.info("Query %s", status)
- sleep(self.poll_interval)
-
- def execute(self, context: Context) -> None:
+ def execute_query(self) -> str:
+ warnings.warn(
+ "This method is deprecated and has been moved to the hook "
+
"`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self.statement_id = self.hook.execute_query(
+ database=self.database,
+ sql=self.sql,
+ cluster_identifier=self.cluster_identifier,
+ db_user=self.db_user,
+ parameters=self.parameters,
+ secret_arn=self.secret_arn,
+ statement_name=self.statement_name,
+ with_event=self.with_event,
+ wait_for_completion=self.await_result,
+ poll_interval=self.poll_interval,
+ )
+ return self.statement_id
+
+ def execute_batch_query(self) -> str:
+ warnings.warn(
+ "This method is deprecated and has been moved to the hook "
+
"`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self.statement_id = self.hook.execute_query(
+ database=self.database,
+ sql=self.sql,
+ cluster_identifier=self.cluster_identifier,
+ db_user=self.db_user,
+ parameters=self.parameters,
+ secret_arn=self.secret_arn,
+ statement_name=self.statement_name,
+ with_event=self.with_event,
+ wait_for_completion=self.await_result,
+ poll_interval=self.poll_interval,
+ )
+ return self.statement_id
+
+ def wait_for_results(self, statement_id: str):
+ warnings.warn(
+ "This method is deprecated and has been moved to the hook "
+
"`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.hook.wait_for_results(statement_id=statement_id,
poll_interval=self.poll_interval)
+
+ def execute(self, context: Context) -> str:
"""Execute a statement against Amazon Redshift"""
self.log.info("Executing statement: %s", self.sql)
- if isinstance(self.sql, list):
- self.statement_id = self.execute_batch_query()
- else:
- self.statement_id = self.execute_query()
- if self.await_result:
- self.wait_for_results(self.statement_id)
+ self.statement_id = self.hook.execute_query(
+ database=self.database,
+ sql=self.sql,
+ cluster_identifier=self.cluster_identifier,
+ db_user=self.db_user,
+ parameters=self.parameters,
+ secret_arn=self.secret_arn,
+ statement_name=self.statement_name,
+ with_event=self.with_event,
+ wait_for_completion=self.await_result,
+ poll_interval=self.poll_interval,
+ )
return self.statement_id
diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index ed57df0984..45f1e1783c 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -20,7 +20,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
@@ -67,6 +69,9 @@ class RedshiftToS3Operator(BaseOperator):
:param parameters: (optional) the parameters to render the SQL query with.
:param table_as_file_name: If set to True, the s3 file will be named as
the table.
Applicable when ``table`` param provided.
+ :param redshift_data_api_kwargs: If using the Redshift Data API instead of
the SQL-based connection,
+ dict of arguments for the hook's ``execute_query`` method.
+ Cannot include any of these kwargs: ``{'sql', 'parameters'}``
"""
template_fields: Sequence[str] = (
@@ -98,6 +103,7 @@ class RedshiftToS3Operator(BaseOperator):
include_header: bool = False,
parameters: Iterable | Mapping | None = None,
table_as_file_name: bool = True, # Set to True by default for not
breaking current workflows
+ redshift_data_api_kwargs: dict = {},
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -113,6 +119,7 @@ class RedshiftToS3Operator(BaseOperator):
self.include_header = include_header
self.parameters = parameters
self.table_as_file_name = table_as_file_name
+ self.redshift_data_api_kwargs = redshift_data_api_kwargs
if select_query:
self.select_query = select_query
@@ -128,6 +135,11 @@ class RedshiftToS3Operator(BaseOperator):
"HEADER",
]
+ if self.redshift_data_api_kwargs:
+ for arg in ["sql", "parameters"]:
+ if arg in self.redshift_data_api_kwargs.keys():
+ raise AirflowException(f"Cannot include param '{arg}' in
Redshift Data API kwargs")
+
def _build_unload_query(
self, credentials_block: str, select_query: str, s3_key: str,
unload_options: str
) -> str:
@@ -140,7 +152,11 @@ class RedshiftToS3Operator(BaseOperator):
"""
def execute(self, context: Context) -> None:
- redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
+ redshift_hook: RedshiftDataHook | RedshiftSQLHook
+ if self.redshift_data_api_kwargs:
+ redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
+ else:
+ redshift_hook =
RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
if conn.extra_dejson.get("role_arn", False):
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
@@ -156,5 +172,10 @@ class RedshiftToS3Operator(BaseOperator):
)
self.log.info("Executing UNLOAD command...")
- redshift_hook.run(unload_query, self.autocommit,
parameters=self.parameters)
+ if isinstance(redshift_hook, RedshiftDataHook):
+ redshift_hook.execute_query(
+ sql=unload_query, parameters=self.parameters,
**self.redshift_data_api_kwargs
+ )
+ else:
+ redshift_hook.run(unload_query, self.autocommit,
parameters=self.parameters)
self.log.info("UNLOAD command complete...")
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 1dcb6e7648..dc99939021 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Iterable, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
@@ -43,7 +44,7 @@ class S3ToRedshiftOperator(BaseOperator):
:param table: reference to a specific table in redshift database
:param s3_bucket: reference to a specific S3 bucket
:param s3_key: key prefix that selects single or multiple objects from S3
- :param redshift_conn_id: reference to a specific redshift database
+ :param redshift_conn_id: reference to a specific redshift database OR a
redshift data-api connection
:param aws_conn_id: reference to a specific S3 connection
If the AWS connection contains 'aws_iam_role' in ``extras``
the operator will use AWS STS credentials with a token
@@ -62,6 +63,9 @@ class S3ToRedshiftOperator(BaseOperator):
:param copy_options: reference to a list of COPY options
:param method: Action to be performed on execution. Available ``APPEND``,
``UPSERT`` and ``REPLACE``.
:param upsert_keys: List of fields to use as key on upsert action
+ :param redshift_data_api_kwargs: If using the Redshift Data API instead of
the SQL-based connection,
+ dict of arguments for the hook's ``execute_query`` method.
+ Cannot include any of these kwargs: ``{'sql', 'parameters'}``
"""
template_fields: Sequence[str] = (
@@ -91,6 +95,7 @@ class S3ToRedshiftOperator(BaseOperator):
autocommit: bool = False,
method: str = "APPEND",
upsert_keys: list[str] | None = None,
+ redshift_data_api_kwargs: dict = {},
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -106,10 +111,16 @@ class S3ToRedshiftOperator(BaseOperator):
self.autocommit = autocommit
self.method = method
self.upsert_keys = upsert_keys
+ self.redshift_data_api_kwargs = redshift_data_api_kwargs
if self.method not in AVAILABLE_METHODS:
raise AirflowException(f"Method not found! Available methods:
{AVAILABLE_METHODS}")
+ if self.redshift_data_api_kwargs:
+ for arg in ["sql", "parameters"]:
+ if arg in self.redshift_data_api_kwargs.keys():
+ raise AirflowException(f"Cannot include param '{arg}' in
Redshift Data API kwargs")
+
def _build_copy_query(self, copy_destination: str, credentials_block: str,
copy_options: str) -> str:
column_names = "(" + ", ".join(self.column_list) + ")" if
self.column_list else ""
return f"""
@@ -121,7 +132,11 @@ class S3ToRedshiftOperator(BaseOperator):
"""
def execute(self, context: Context) -> None:
- redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
+ redshift_hook: RedshiftDataHook | RedshiftSQLHook
+ if self.redshift_data_api_kwargs:
+ redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
+ else:
+ redshift_hook =
RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
if conn.extra_dejson.get("role_arn", False):
@@ -142,7 +157,12 @@ class S3ToRedshiftOperator(BaseOperator):
if self.method == "REPLACE":
sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement,
"COMMIT"]
elif self.method == "UPSERT":
- keys = self.upsert_keys or
redshift_hook.get_table_primary_key(self.table, self.schema)
+ if isinstance(redshift_hook, RedshiftDataHook):
+ keys = self.upsert_keys or redshift_hook.get_table_primary_key(
+ table=self.table, schema=self.schema,
**self.redshift_data_api_kwargs
+ )
+ else:
+ keys = self.upsert_keys or
redshift_hook.get_table_primary_key(self.table, self.schema)
if not keys:
raise AirflowException(
f"No primary key on {self.schema}.{self.table}. Please
provide keys on 'upsert_keys'"
@@ -162,5 +182,8 @@ class S3ToRedshiftOperator(BaseOperator):
sql = copy_statement
self.log.info("Executing COPY command...")
- redshift_hook.run(sql, autocommit=self.autocommit)
+ if isinstance(redshift_hook, RedshiftDataHook):
+ redshift_hook.execute_query(sql=sql,
**self.redshift_data_api_kwargs)
+ else:
+ redshift_hook.run(sql, autocommit=self.autocommit)
self.log.info("COPY command complete...")
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py
b/tests/providers/amazon/aws/hooks/test_redshift_data.py
index f204703170..29816442c4 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_data.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py
@@ -17,14 +17,172 @@
# under the License.
from __future__ import annotations
+from unittest import mock
+
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
+CONN_ID = "aws_conn_test"
+SQL = "sql"
+DATABASE = "database"
+STATEMENT_ID = "statement_id"
+
class TestRedshiftDataHook:
def test_conn_attribute(self):
- hook = RedshiftDataHook(aws_conn_id="aws_default",
region_name="us-east-1")
+ hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
assert hasattr(hook, "conn")
assert hook.conn.__class__.__name__ == "RedshiftDataAPIService"
conn = hook.conn
assert conn is hook.conn # Cached property
assert conn is hook.get_conn() # Same object as returned by `conn`
property
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_execute_without_waiting(self, mock_conn):
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+
+ hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+ hook.execute_query(
+ database=DATABASE,
+ sql=SQL,
+ wait_for_completion=False,
+ )
+ mock_conn.execute_statement.assert_called_once_with(
+ Database=DATABASE,
+ Sql=SQL,
+ WithEvent=False,
+ )
+ mock_conn.describe_statement.assert_not_called()
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_execute_with_all_parameters(self, mock_conn):
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ parameters = [{"name": "id", "value": "1"}]
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+
+ hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+ hook.execute_query(
+ sql=SQL,
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ )
+
+ mock_conn.execute_statement.assert_called_once_with(
+ Database=DATABASE,
+ Sql=SQL,
+ ClusterIdentifier=cluster_identifier,
+ DbUser=db_user,
+ SecretArn=secret_arn,
+ StatementName=statement_name,
+ Parameters=parameters,
+ WithEvent=False,
+ )
+ mock_conn.describe_statement.assert_called_once_with(
+ Id=STATEMENT_ID,
+ )
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_batch_execute(self, mock_conn):
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+
+ hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+ hook.execute_query(
+ cluster_identifier=cluster_identifier,
+ database=DATABASE,
+ db_user=db_user,
+ sql=[SQL],
+ statement_name=statement_name,
+ secret_arn=secret_arn,
+ )
+
+ mock_conn.batch_execute_statement.assert_called_once_with(
+ Database=DATABASE,
+ Sqls=[SQL],
+ ClusterIdentifier=cluster_identifier,
+ DbUser=db_user,
+ SecretArn=secret_arn,
+ StatementName=statement_name,
+ WithEvent=False,
+ )
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_get_table_primary_key_no_token(self, mock_conn):
+ table = "table"
+ schema = "schema"
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+ mock_conn.get_statement_result.return_value = {
+ "Records": [[{"stringValue": "string"}]],
+ }
+
+ hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+
+ hook.get_table_primary_key(
+ table=table,
+ database=DATABASE,
+ schema=schema,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ )
+
+ mock_conn.get_statement_result.assert_called_once_with(Id=STATEMENT_ID)
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_get_table_primary_key_with_token(self, mock_conn):
+ table = "table"
+ schema = "schema"
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+ mock_conn.get_statement_result.side_effect = [
+ {
+ "NextToken": "token1",
+ "Records": [[{"stringValue": "string1"}]],
+ },
+ {
+ "NextToken": "token2",
+ "Records": [[{"stringValue": "string2"}]],
+ },
+ {
+ "Records": [[{"stringValue": "string3"}]],
+ },
+ ]
+
+ hook = RedshiftDataHook(aws_conn_id=CONN_ID, region_name="us-east-1")
+
+ hook.get_table_primary_key(
+ table=table,
+ database=DATABASE,
+ schema=schema,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ )
+
+ assert mock_conn.get_statement_result.call_args_list == [
+ (dict(Id=STATEMENT_ID),),
+ (dict(Id=STATEMENT_ID, NextToken="token1"),),
+ (dict(Id=STATEMENT_ID, NextToken="token2"),),
+ ]
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py
b/tests/providers/amazon/aws/operators/test_redshift_data.py
index fa527ffc09..efa88cb026 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -29,33 +29,15 @@ STATEMENT_ID = "statement_id"
class TestRedshiftDataOperator:
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
- def test_execute_without_waiting(self, mock_conn):
- mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
- operator = RedshiftDataOperator(
- aws_conn_id=CONN_ID,
- task_id=TASK_ID,
- sql=SQL,
- database=DATABASE,
- await_result=False,
- )
- operator.execute(None)
- mock_conn.execute_statement.assert_called_once_with(
- Database=DATABASE,
- Sql=SQL,
- WithEvent=False,
- )
- mock_conn.describe_statement.assert_not_called()
-
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
- def test_execute_with_all_parameters(self, mock_conn):
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+ def test_execute(self, mock_exec_query):
cluster_identifier = "cluster_identifier"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
- mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
- mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+ poll_interval = 5
+ await_result = True
operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
@@ -67,20 +49,21 @@ class TestRedshiftDataOperator:
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
+ await_result=True,
+ poll_interval=poll_interval,
)
operator.execute(None)
- mock_conn.execute_statement.assert_called_once_with(
- Database=DATABASE,
- Sql=SQL,
- ClusterIdentifier=cluster_identifier,
- DbUser=db_user,
- SecretArn=secret_arn,
- StatementName=statement_name,
- Parameters=parameters,
- WithEvent=False,
- )
- mock_conn.describe_statement.assert_called_once_with(
- Id=STATEMENT_ID,
+ mock_exec_query.assert_called_once_with(
+ sql=SQL,
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ with_event=False,
+ wait_for_completion=await_result,
+ poll_interval=poll_interval,
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@@ -111,32 +94,3 @@ class TestRedshiftDataOperator:
mock_conn.cancel_statement.assert_called_once_with(
Id=STATEMENT_ID,
)
-
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
- def test_batch_execute(self, mock_conn):
- mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
- mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
- cluster_identifier = "cluster_identifier"
- db_user = "db_user"
- secret_arn = "secret_arn"
- statement_name = "statement_name"
- operator = RedshiftDataOperator(
- task_id=TASK_ID,
- cluster_identifier=cluster_identifier,
- database=DATABASE,
- db_user=db_user,
- sql=[SQL],
- statement_name=statement_name,
- secret_arn=secret_arn,
- aws_conn_id=CONN_ID,
- )
- operator.execute(None)
- mock_conn.batch_execute_statement.assert_called_once_with(
- Database=DATABASE,
- Sqls=[SQL],
- ClusterIdentifier=cluster_identifier,
- DbUser=db_user,
- SecretArn=secret_arn,
- StatementName=statement_name,
- WithEvent=False,
- )
diff --git a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
index 91a1ae32da..67982821e8 100644
--- a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
@@ -17,11 +17,13 @@
# under the License.
from __future__ import annotations
+from copy import deepcopy
from unittest import mock
import pytest
from boto3.session import Session
+from airflow.exceptions import AirflowException
from airflow.models.connection import Connection
from airflow.providers.amazon.aws.transfers.redshift_to_s3 import
RedshiftToS3Operator
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
@@ -284,3 +286,112 @@ class TestRedshiftToS3Transfer:
"select_query",
"redshift_conn_id",
)
+
+ @pytest.mark.parametrize("param", ["sql", "parameters"])
+ def test_invalid_param_in_redshift_data_api_kwargs(self, param):
+ """
+ Test passing invalid param in RS Data API kwargs raises an error
+ """
+ with pytest.raises(AirflowException):
+ RedshiftToS3Operator(
+ s3_bucket="s3_bucket",
+ s3_key="s3_key",
+ select_query="select_query",
+ task_id="task_id",
+ dag=None,
+ redshift_data_api_kwargs={param: "param"},
+ )
+
+ @pytest.mark.parametrize("table_as_file_name, expected_s3_key", [[True,
"key/table_"], [False, "key"]])
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
+ @mock.patch("boto3.session.Session")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_table_unloading_using_redshift_data_api(
+ self,
+ mock_rs,
+ mock_run,
+ mock_session,
+ mock_connection,
+ mock_hook,
+ table_as_file_name,
+ expected_s3_key,
+ ):
+ access_key = "aws_access_key_id"
+ secret_key = "aws_secret_access_key"
+ mock_session.return_value = Session(access_key, secret_key)
+ mock_session.return_value.access_key = access_key
+ mock_session.return_value.secret_key = secret_key
+ mock_session.return_value.token = None
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
+ mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"}
+ mock_rs.describe_statement.return_value = {"Status": "FINISHED"}
+
+ schema = "schema"
+ table = "table"
+ s3_bucket = "bucket"
+ s3_key = "key"
+ unload_options = [
+ "HEADER",
+ ]
+ # RS Data API params
+ database = "database"
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+
+ op = RedshiftToS3Operator(
+ schema=schema,
+ table=table,
+ s3_bucket=s3_bucket,
+ s3_key=s3_key,
+ unload_options=unload_options,
+ include_header=True,
+ redshift_conn_id="redshift_conn_id",
+ aws_conn_id="aws_conn_id",
+ task_id="task_id",
+ table_as_file_name=table_as_file_name,
+ dag=None,
+ redshift_data_api_kwargs=dict(
+ database=database,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ ),
+ )
+
+ op.execute(None)
+
+ unload_options = "\n\t\t\t".join(unload_options)
+ select_query = f"SELECT * FROM {schema}.{table}"
+ credentials_block = build_credentials_block(mock_session.return_value)
+
+ unload_query = op._build_unload_query(
+ credentials_block, select_query, expected_s3_key, unload_options
+ )
+
+ mock_run.assert_not_called()
+ assert access_key in unload_query
+ assert secret_key in unload_query
+
+ mock_rs.execute_statement.assert_called_once()
+ # test with all args besides sql
+ _call = deepcopy(mock_rs.execute_statement.call_args[1])
+ _call.pop("Sql")
+ assert _call == dict(
+ Database=database,
+ ClusterIdentifier=cluster_identifier,
+ DbUser=db_user,
+ SecretArn=secret_arn,
+ StatementName=statement_name,
+ WithEvent=False,
+ )
+ mock_rs.describe_statement.assert_called_once_with(
+ Id="STATEMENT_ID",
+ )
+ # test sql arg
+ assert_equal_ignore_multiple_spaces(self,
mock_rs.execute_statement.call_args[1]["Sql"], unload_query)
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index e69673b27e..60b8bee249 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from copy import deepcopy
from unittest import mock
import pytest
@@ -348,3 +349,101 @@ class TestS3ToRedshiftTransfer:
task_id="task_id",
dag=None,
)
+
+ @pytest.mark.parametrize("param", ["sql", "parameters"])
+ def test_invalid_param_in_redshift_data_api_kwargs(self, param):
+ """
+ Test passing invalid param in RS Data API kwargs raises an error
+ """
+ with pytest.raises(AirflowException):
+ S3ToRedshiftOperator(
+ schema="schema",
+ table="table",
+ s3_bucket="bucket",
+ s3_key="key",
+ task_id="task_id",
+ dag=None,
+ redshift_data_api_kwargs={param: "param"},
+ )
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
+ @mock.patch("boto3.session.Session")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session,
mock_connection, mock_hook):
+ """
+ Using the Redshift Data API instead of the SQL-based connection
+ """
+ access_key = "aws_access_key_id"
+ secret_key = "aws_secret_access_key"
+ mock_session.return_value = Session(access_key, secret_key)
+ mock_session.return_value.access_key = access_key
+ mock_session.return_value.secret_key = secret_key
+ mock_session.return_value.token = None
+
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
+ mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"}
+ mock_rs.describe_statement.return_value = {"Status": "FINISHED"}
+
+ schema = "schema"
+ table = "table"
+ s3_bucket = "bucket"
+ s3_key = "key"
+ copy_options = ""
+
+ # RS Data API params
+ database = "database"
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+
+ op = S3ToRedshiftOperator(
+ schema=schema,
+ table=table,
+ s3_bucket=s3_bucket,
+ s3_key=s3_key,
+ copy_options=copy_options,
+ redshift_conn_id="redshift_conn_id",
+ aws_conn_id="aws_conn_id",
+ task_id="task_id",
+ dag=None,
+ redshift_data_api_kwargs=dict(
+ database=database,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ ),
+ )
+ op.execute(None)
+ copy_query = """
+ COPY schema.table
+ FROM 's3://bucket/key'
+ credentials
+
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+ ;
+ """
+ mock_run.assert_not_called()
+ assert access_key in copy_query
+ assert secret_key in copy_query
+
+ mock_rs.execute_statement.assert_called_once()
+ # test with all args besides sql
+ _call = deepcopy(mock_rs.execute_statement.call_args[1])
+ _call.pop("Sql")
+ assert _call == dict(
+ Database=database,
+ ClusterIdentifier=cluster_identifier,
+ DbUser=db_user,
+ SecretArn=secret_arn,
+ StatementName=statement_name,
+ WithEvent=False,
+ )
+ mock_rs.describe_statement.assert_called_once_with(
+ Id="STATEMENT_ID",
+ )
+ # test sql arg
+ assert_equal_ignore_multiple_spaces(self,
mock_rs.execute_statement.call_args[1]["Sql"], copy_query)