This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8580e6d046 Support session reuse in `RedshiftDataOperator` (#42218)
8580e6d046 is described below
commit 8580e6d046b11d159e3260ec4015981387e94a57
Author: Boris Morel <[email protected]>
AuthorDate: Wed Sep 25 02:22:54 2024 +0800
Support session reuse in `RedshiftDataOperator` (#42218)
---
airflow/providers/amazon/CHANGELOG.rst | 27 +++++
.../providers/amazon/aws/hooks/redshift_data.py | 66 ++++++++---
.../amazon/aws/operators/redshift_data.py | 21 +++-
.../amazon/aws/transfers/redshift_to_s3.py | 25 ++--
.../amazon/aws/transfers/s3_to_redshift.py | 12 +-
airflow/providers/amazon/aws/utils/openlineage.py | 4 +-
.../operators/redshift/redshift_data.rst | 12 ++
.../amazon/aws/hooks/test_redshift_data.py | 127 ++++++++++++++++++++-
.../amazon/aws/operators/test_redshift_data.py | 86 ++++++++++++--
.../providers/amazon/aws/utils/test_openlineage.py | 4 +-
.../providers/amazon/aws/example_redshift.py | 41 ++++++-
.../amazon/aws/example_redshift_s3_transfers.py | 109 +++++++++++++++---
12 files changed, 468 insertions(+), 66 deletions(-)
diff --git a/airflow/providers/amazon/CHANGELOG.rst
b/airflow/providers/amazon/CHANGELOG.rst
index 126da03ad6..7596ad3886 100644
--- a/airflow/providers/amazon/CHANGELOG.rst
+++ b/airflow/providers/amazon/CHANGELOG.rst
@@ -26,6 +26,33 @@
Changelog
---------
+Main
+......
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+.. warning::
+ In order to support session reuse in RedshiftData operators, the following
breaking changes were introduced:
+
+ The ``database`` argument is now optional and as a result was moved after
the ``sql`` argument which is a positional
+ one. Update your DAGs accordingly if they rely on argument order. Applies to:
+ * ``RedshiftDataHook``'s ``execute_query`` method
+ * ``RedshiftDataOperator``
+
+ ``RedshiftDataHook``'s ``execute_query`` method now returns a
``QueryExecutionOutput`` object instead of just the
+ statement ID as a string.
+
+ ``RedshiftDataHook``'s ``parse_statement_resposne`` method was renamed to
``parse_statement_response``.
+
+ ``S3ToRedshiftOperator``'s ``schema`` argument is now optional and was moved
after the ``s3_key`` positional argument.
+ Update your DAGs accordingly if they rely on argument order.
+
+Features
+~~~~~~~~
+
+* ``Support session reuse in RedshiftDataOperator, RedshiftToS3Operator and
S3ToRedshiftOperator (#42218)``
+
8.29.0
......
diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py
b/airflow/providers/amazon/aws/hooks/redshift_data.py
index 3c1f84b1f6..b2f46c0ef6 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -18,8 +18,12 @@
from __future__ import annotations
import time
+from dataclasses import dataclass
from pprint import pformat
from typing import TYPE_CHECKING, Any, Iterable
+from uuid import UUID
+
+from pendulum import duration
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import trim_none_values
@@ -35,6 +39,14 @@ FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}
+@dataclass
+class QueryExecutionOutput:
+ """Describes the output of a query execution."""
+
+ statement_id: str
+ session_id: str | None
+
+
class RedshiftDataQueryFailedError(ValueError):
"""Raise an error that redshift data query failed."""
@@ -65,8 +77,8 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
def execute_query(
self,
- database: str,
sql: str | list[str],
+ database: str | None = None,
cluster_identifier: str | None = None,
db_user: str | None = None,
parameters: Iterable | None = None,
@@ -76,23 +88,28 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
wait_for_completion: bool = True,
poll_interval: int = 10,
workgroup_name: str | None = None,
- ) -> str:
+ session_id: str | None = None,
+ session_keep_alive_seconds: int | None = None,
+ ) -> QueryExecutionOutput:
"""
Execute a statement against Amazon Redshift.
- :param database: the name of the database
:param sql: the SQL statement or list of SQL statement to run
+ :param database: the name of the database
:param cluster_identifier: unique identifier of a cluster
:param db_user: the database username
:param parameters: the parameters for the SQL statement
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
- :param with_event: indicates whether to send an event to EventBridge
- :param wait_for_completion: indicates whether to wait for a result, if
True wait, if False don't wait
+ :param with_event: whether to send an event to EventBridge
+ :param wait_for_completion: whether to wait for a result
:param poll_interval: how often in seconds to check the query status
:param workgroup_name: name of the Redshift Serverless workgroup.
Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift
Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
+ :param session_id: the session identifier of the query
+ :param session_keep_alive_seconds: duration in seconds to keep the
session alive after the query
+ finishes. The maximum time a session can keep alive is 24 hours
:returns statement_id: str, the UUID of the statement
"""
@@ -105,7 +122,28 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
"SecretArn": secret_arn,
"StatementName": statement_name,
"WorkgroupName": workgroup_name,
+ "SessionId": session_id,
+ "SessionKeepAliveSeconds": session_keep_alive_seconds,
}
+
+ if sum(x is not None for x in (cluster_identifier, workgroup_name,
session_id)) != 1:
+ raise ValueError(
+ "Exactly one of cluster_identifier, workgroup_name, or
session_id must be provided"
+ )
+
+ if session_id is not None:
+ msg = "session_id must be a valid UUID4"
+ try:
+ if UUID(session_id).version != 4:
+ raise ValueError(msg)
+ except ValueError:
+ raise ValueError(msg)
+
+ if session_keep_alive_seconds is not None and (
+ session_keep_alive_seconds < 0 or
duration(seconds=session_keep_alive_seconds).hours > 24
+ ):
+ raise ValueError("Session keep alive duration must be between 0
and 86400 seconds.")
+
if isinstance(sql, list):
kwargs["Sqls"] = sql
resp =
self.conn.batch_execute_statement(**trim_none_values(kwargs))
@@ -115,13 +153,10 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
statement_id = resp["Id"]
- if bool(cluster_identifier) is bool(workgroup_name):
- raise ValueError("Either 'cluster_identifier' or 'workgroup_name'
must be specified.")
-
if wait_for_completion:
self.wait_for_results(statement_id, poll_interval=poll_interval)
- return statement_id
+ return QueryExecutionOutput(statement_id=statement_id,
session_id=resp.get("SessionId"))
def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
while True:
@@ -135,9 +170,9 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
def check_query_is_finished(self, statement_id: str) -> bool:
"""Check whether query finished, raise exception is failed."""
resp = self.conn.describe_statement(Id=statement_id)
- return self.parse_statement_resposne(resp)
+ return self.parse_statement_response(resp)
- def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef)
-> bool:
+ def parse_statement_response(self, resp: DescribeStatementResponseTypeDef)
-> bool:
"""Parse the response of describe_statement."""
status = resp["Status"]
if status == FINISHED_STATE:
@@ -179,8 +214,10 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
:param table: Name of the target table
:param database: the name of the database
:param schema: Name of the target schema, public by default
- :param sql: the SQL statement or list of SQL statement to run
:param cluster_identifier: unique identifier of a cluster
+ :param workgroup_name: name of the Redshift Serverless workgroup.
Mutually exclusive with
+ `cluster_identifier`. Specify this parameter to query Redshift
Serverless. More info
+
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:param db_user: the database username
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
@@ -212,7 +249,8 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
with_event=with_event,
wait_for_completion=wait_for_completion,
poll_interval=poll_interval,
- )
+ ).statement_id
+
pk_columns = []
token = ""
while True:
@@ -251,4 +289,4 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
"""
async with self.async_conn as client:
resp = await client.describe_statement(Id=statement_id)
- return self.parse_statement_resposne(resp)
+ return self.parse_statement_response(resp)
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py
b/airflow/providers/amazon/aws/operators/redshift_data.py
index 45fee2a919..3d00c6d22e 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -56,13 +56,16 @@ class
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually
exclusive with
`cluster_identifier`. Specify this parameter to query Redshift
Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
+ :param session_id: the session identifier of the query
+ :param session_keep_alive_seconds: duration in seconds to keep the session
alive after the query
+ finishes. The maximum time a session can keep alive is 24 hours
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used.
If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
- :param verify: Whether or not to verify SSL certificates. See:
+ :param verify: Whether to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
@@ -77,6 +80,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
"parameters",
"statement_name",
"workgroup_name",
+ "session_id",
)
template_ext = (".sql",)
template_fields_renderers = {"sql": "sql"}
@@ -84,8 +88,8 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
def __init__(
self,
- database: str,
sql: str | list,
+ database: str | None = None,
cluster_identifier: str | None = None,
db_user: str | None = None,
parameters: list | None = None,
@@ -97,6 +101,8 @@ class
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
return_sql_result: bool = False,
workgroup_name: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ session_id: str | None = None,
+ session_keep_alive_seconds: int | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -120,6 +126,8 @@ class
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
self.return_sql_result = return_sql_result
self.statement_id: str | None = None
self.deferrable = deferrable
+ self.session_id = session_id
+ self.session_keep_alive_seconds = session_keep_alive_seconds
def execute(self, context: Context) -> GetStatementResultResponseTypeDef |
str:
"""Execute a statement against Amazon Redshift."""
@@ -130,7 +138,7 @@ class
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
if self.deferrable:
wait_for_completion = False
- self.statement_id = self.hook.execute_query(
+ query_execution_output = self.hook.execute_query(
database=self.database,
sql=self.sql,
cluster_identifier=self.cluster_identifier,
@@ -142,8 +150,15 @@ class
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
with_event=self.with_event,
wait_for_completion=wait_for_completion,
poll_interval=self.poll_interval,
+ session_id=self.session_id,
+ session_keep_alive_seconds=self.session_keep_alive_seconds,
)
+ self.statement_id = query_execution_output.statement_id
+
+ if query_execution_output.session_id:
+ self.xcom_push(context, key="session_id",
value=query_execution_output.session_id)
+
if self.deferrable and self.wait_for_completion:
is_finished = self.hook.check_query_is_finished(self.statement_id)
if not is_finished:
diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index ef3cebdae9..8538b1dfc3 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -45,7 +45,8 @@ class RedshiftToS3Operator(BaseOperator):
:param s3_key: reference to a specific S3 key. If ``table_as_file_name``
is set
to False, this param must include the desired file name
:param schema: reference to a specific schema in redshift database,
- used when ``table`` param provided and ``select_query`` param not
provided
+ used when ``table`` param provided and ``select_query`` param not
provided.
+ Do not provide when unloading a temporary table
:param table: reference to a specific table in redshift database,
used when ``schema`` param provided and ``select_query`` param not
provided
:param select_query: custom select query to fetch data from redshift
database,
@@ -55,8 +56,8 @@ class RedshiftToS3Operator(BaseOperator):
If the AWS connection contains 'aws_iam_role' in ``extras``
the operator will use AWS STS credentials with a token
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials
- :param verify: Whether or not to verify SSL certificates for S3 connection.
- By default SSL certificates are verified.
+ :param verify: Whether to verify SSL certificates for S3 connection.
+ By default, SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
@@ -67,7 +68,7 @@ class RedshiftToS3Operator(BaseOperator):
CA cert bundle than the one used by botocore.
:param unload_options: reference to a list of UNLOAD options
:param autocommit: If set to True it will automatically commit the UNLOAD
statement.
- Otherwise it will be committed right before the redshift connection
gets closed.
+ Otherwise, it will be committed right before the redshift connection
gets closed.
:param include_header: If set to True the s3 file contains the header
columns.
:param parameters: (optional) the parameters to render the SQL query with.
:param table_as_file_name: If set to True, the s3 file will be named as
the table.
@@ -141,9 +142,15 @@ class RedshiftToS3Operator(BaseOperator):
@property
def default_select_query(self) -> str | None:
- if self.schema and self.table:
- return f"SELECT * FROM {self.schema}.{self.table}"
- return None
+ if not self.table:
+ return None
+
+ if self.schema:
+ table = f"{self.schema}.{self.table}"
+ else:
+ # Relevant when unloading a temporary table
+ table = self.table
+ return f"SELECT * FROM {table}"
def execute(self, context: Context) -> None:
if self.table and self.table_as_file_name:
@@ -152,9 +159,7 @@ class RedshiftToS3Operator(BaseOperator):
self.select_query = self.select_query or self.default_select_query
if self.select_query is None:
- raise ValueError(
- "Please provide both `schema` and `table` params or
`select_query` to fetch the data."
- )
+ raise ValueError("Please specify either a table or `select_query`
to fetch the data.")
if self.include_header and "HEADER" not in [uo.upper().strip() for uo
in self.unload_options]:
self.unload_options = [*self.unload_options, "HEADER"]
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 127ee07a60..792119bfeb 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -28,7 +28,6 @@ from airflow.providers.amazon.aws.utils.redshift import
build_credentials_block
if TYPE_CHECKING:
from airflow.utils.context import Context
-
AVAILABLE_METHODS = ["APPEND", "REPLACE", "UPSERT"]
@@ -40,17 +39,18 @@ class S3ToRedshiftOperator(BaseOperator):
For more information on how to use this operator, take a look at the
guide:
:ref:`howto/operator:S3ToRedshiftOperator`
- :param schema: reference to a specific schema in redshift database
:param table: reference to a specific table in redshift database
:param s3_bucket: reference to a specific S3 bucket
:param s3_key: key prefix that selects single or multiple objects from S3
+ :param schema: reference to a specific schema in redshift database.
+ Do not provide when copying into a temporary table
:param redshift_conn_id: reference to a specific redshift database OR a
redshift data-api connection
:param aws_conn_id: reference to a specific S3 connection
If the AWS connection contains 'aws_iam_role' in ``extras``
the operator will use AWS STS credentials with a token
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials
- :param verify: Whether or not to verify SSL certificates for S3 connection.
- By default SSL certificates are verified.
+ :param verify: Whether to verify SSL certificates for S3 connection.
+ By default, SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
@@ -87,10 +87,10 @@ class S3ToRedshiftOperator(BaseOperator):
def __init__(
self,
*,
- schema: str,
table: str,
s3_bucket: str,
s3_key: str,
+ schema: str | None = None,
redshift_conn_id: str = "redshift_default",
aws_conn_id: str | None = "aws_default",
verify: bool | str | None = None,
@@ -160,7 +160,7 @@ class S3ToRedshiftOperator(BaseOperator):
credentials_block = build_credentials_block(credentials)
copy_options = "\n\t\t\t".join(self.copy_options)
- destination = f"{self.schema}.{self.table}"
+ destination = f"{self.schema}.{self.table}" if self.schema else
self.table
copy_destination = f"#{self.table}" if self.method == "UPSERT" else
destination
copy_statement = self._build_copy_query(
diff --git a/airflow/providers/amazon/aws/utils/openlineage.py
b/airflow/providers/amazon/aws/utils/openlineage.py
index db472a3e46..be5703e2f6 100644
--- a/airflow/providers/amazon/aws/utils/openlineage.py
+++ b/airflow/providers/amazon/aws/utils/openlineage.py
@@ -86,7 +86,9 @@ def get_facets_from_redshift_table(
]
)
else:
- statement_id = redshift_hook.execute_query(sql=sql, poll_interval=1,
**redshift_data_api_kwargs)
+ statement_id = redshift_hook.execute_query(
+ sql=sql, poll_interval=1, **redshift_data_api_kwargs
+ ).statement_id
response = redshift_hook.conn.get_statement_result(Id=statement_id)
table_schema = SchemaDatasetFacet(
diff --git
a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst
b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst
index 0b314d34f3..2638e1732c 100644
--- a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst
+++ b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst
@@ -54,6 +54,18 @@ the necessity of a Postgres connection.
:start-after: [START howto_operator_redshift_data]
:end-before: [END howto_operator_redshift_data]
+Reuse a session when executing multiple statements
+==================================================
+
+Specify the ``session_keep_alive_seconds`` parameter on an upstream task. In a
downstream task, get the session ID from
+the XCom and pass it to the ``session_id`` parameter. This is useful when you
work with temporary tables.
+
+.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_redshift.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_redshift_data_session_reuse]
+ :end-before: [END howto_operator_redshift_data_session_reuse]
+
Reference
---------
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py
b/tests/providers/amazon/aws/hooks/test_redshift_data.py
index a0952e5ba7..d548086449 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_data.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import logging
from unittest import mock
+from uuid import uuid4
import pytest
@@ -63,15 +64,18 @@ class TestRedshiftDataHook:
mock_conn.describe_statement.assert_not_called()
@pytest.mark.parametrize(
- "cluster_identifier, workgroup_name",
+ "cluster_identifier, workgroup_name, session_id",
[
- (None, None),
- ("some_cluster", "some_workgroup"),
+ (None, None, None),
+ ("some_cluster", "some_workgroup", None),
+ (None, "some_workgroup", None),
+ ("some_cluster", None, None),
+ (None, None, "some_session_id"),
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
- def test_execute_requires_either_cluster_identifier_or_workgroup_name(
- self, mock_conn, cluster_identifier, workgroup_name
+ def
test_execute_requires_one_of_cluster_identifier_or_workgroup_name_or_session_id(
+ self, mock_conn, cluster_identifier, workgroup_name, session_id
):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
cluster_identifier = "cluster_identifier"
@@ -84,6 +88,51 @@ class TestRedshiftDataHook:
workgroup_name=workgroup_name,
sql=SQL,
wait_for_completion=False,
+ session_id=session_id,
+ )
+
+ @pytest.mark.parametrize(
+ "cluster_identifier, workgroup_name, session_id",
+ [
+ (None, None, None),
+ ("some_cluster", "some_workgroup", None),
+ (None, "some_workgroup", None),
+ ("some_cluster", None, None),
+ (None, None, "some_session_id"),
+ ],
+ )
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_execute_session_keep_alive_seconds_valid(
+ self, mock_conn, cluster_identifier, workgroup_name, session_id
+ ):
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ cluster_identifier = "cluster_identifier"
+ workgroup_name = "workgroup_name"
+ hook = RedshiftDataHook()
+ with pytest.raises(ValueError):
+ hook.execute_query(
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ workgroup_name=workgroup_name,
+ sql=SQL,
+ wait_for_completion=False,
+ session_id=session_id,
+ )
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_execute_session_id_valid(self, mock_conn):
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ cluster_identifier = "cluster_identifier"
+ workgroup_name = "workgroup_name"
+ hook = RedshiftDataHook()
+ with pytest.raises(ValueError):
+ hook.execute_query(
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ workgroup_name=workgroup_name,
+ sql=SQL,
+ wait_for_completion=False,
+ session_id="not_a_uuid",
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@@ -156,6 +205,74 @@ class TestRedshiftDataHook:
Id=STATEMENT_ID,
)
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_execute_with_new_session(self, mock_conn):
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ parameters = [{"name": "id", "value": "1"}]
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID,
"SessionId": "session_id"}
+ mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+
+ hook = RedshiftDataHook()
+ output = hook.execute_query(
+ sql=SQL,
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ session_keep_alive_seconds=123,
+ )
+ assert output.statement_id == STATEMENT_ID
+ assert output.session_id == "session_id"
+
+ mock_conn.execute_statement.assert_called_once_with(
+ Database=DATABASE,
+ Sql=SQL,
+ ClusterIdentifier=cluster_identifier,
+ DbUser=db_user,
+ SecretArn=secret_arn,
+ StatementName=statement_name,
+ Parameters=parameters,
+ WithEvent=False,
+ SessionKeepAliveSeconds=123,
+ )
+ mock_conn.describe_statement.assert_called_once_with(
+ Id=STATEMENT_ID,
+ )
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+ def test_execute_reuse_session(self, mock_conn):
+ statement_name = "statement_name"
+ parameters = [{"name": "id", "value": "1"}]
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID,
"SessionId": "session_id"}
+ mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+ hook = RedshiftDataHook()
+ session_id = str(uuid4())
+ output = hook.execute_query(
+ database=None,
+ sql=SQL,
+ statement_name=statement_name,
+ parameters=parameters,
+ session_id=session_id,
+ )
+ assert output.statement_id == STATEMENT_ID
+ assert output.session_id == "session_id"
+
+ mock_conn.execute_statement.assert_called_once_with(
+ Sql=SQL,
+ StatementName=statement_name,
+ Parameters=parameters,
+ WithEvent=False,
+ SessionId=session_id,
+ )
+ mock_conn.describe_statement.assert_called_once_with(
+ Id=STATEMENT_ID,
+ )
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_batch_execute(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py
b/tests/providers/amazon/aws/operators/test_redshift_data.py
index abfa2b038b..c22d776a94 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -22,6 +22,7 @@ from unittest import mock
import pytest
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, TaskDeferred
+from airflow.providers.amazon.aws.hooks.redshift_data import
QueryExecutionOutput
from airflow.providers.amazon.aws.operators.redshift_data import
RedshiftDataOperator
from airflow.providers.amazon.aws.triggers.redshift_data import
RedshiftDataTrigger
from tests.providers.amazon.aws.utils.test_template_fields import
validate_template_fields
@@ -31,6 +32,7 @@ TASK_ID = "task_id"
SQL = "sql"
DATABASE = "database"
STATEMENT_ID = "statement_id"
+SESSION_ID = "session_id"
@pytest.fixture
@@ -98,6 +100,8 @@ class TestRedshiftDataOperator:
poll_interval = 5
wait_for_completion = True
+ mock_exec_query.return_value =
QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=None)
+
operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
task_id=TASK_ID,
@@ -111,7 +115,8 @@ class TestRedshiftDataOperator:
wait_for_completion=True,
poll_interval=poll_interval,
)
- operator.execute(None)
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
+ operator.execute({"ti": mock_ti})
mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
@@ -124,8 +129,12 @@ class TestRedshiftDataOperator:
with_event=False,
wait_for_completion=wait_for_completion,
poll_interval=poll_interval,
+ session_id=None,
+ session_keep_alive_seconds=None,
)
+ mock_ti.xcom_push.assert_not_called()
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_with_workgroup_name(self, mock_exec_query):
cluster_identifier = None
@@ -150,7 +159,54 @@ class TestRedshiftDataOperator:
wait_for_completion=True,
poll_interval=poll_interval,
)
- operator.execute(None)
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
+ operator.execute({"ti": mock_ti})
+ mock_exec_query.assert_called_once_with(
+ sql=SQL,
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ workgroup_name=workgroup_name,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ with_event=False,
+ wait_for_completion=wait_for_completion,
+ poll_interval=poll_interval,
+ session_id=None,
+ session_keep_alive_seconds=None,
+ )
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+ def test_execute_new_session(self, mock_exec_query):
+ cluster_identifier = "cluster_identifier"
+ workgroup_name = None
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ parameters = [{"name": "id", "value": "1"}]
+ poll_interval = 5
+ wait_for_completion = True
+
+ mock_exec_query.return_value =
QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=SESSION_ID)
+
+ operator = RedshiftDataOperator(
+ aws_conn_id=CONN_ID,
+ task_id=TASK_ID,
+ sql=SQL,
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ wait_for_completion=True,
+ poll_interval=poll_interval,
+ session_keep_alive_seconds=123,
+ )
+
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
+ operator.execute({"ti": mock_ti})
mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
@@ -163,7 +219,11 @@ class TestRedshiftDataOperator:
with_event=False,
wait_for_completion=wait_for_completion,
poll_interval=poll_interval,
+ session_id=None,
+ session_keep_alive_seconds=123,
)
+ assert mock_ti.xcom_push.call_args.kwargs["key"] == "session_id"
+ assert mock_ti.xcom_push.call_args.kwargs["value"] == SESSION_ID
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_on_kill_without_query(self, mock_conn):
@@ -180,7 +240,7 @@ class TestRedshiftDataOperator:
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_on_kill_with_query(self, mock_conn):
- mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID,
"SessionId": SESSION_ID}
operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
task_id=TASK_ID,
@@ -189,7 +249,8 @@ class TestRedshiftDataOperator:
database=DATABASE,
wait_for_completion=False,
)
- operator.execute(None)
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
+ operator.execute({"ti": mock_ti})
operator.on_kill()
mock_conn.cancel_statement.assert_called_once_with(
Id=STATEMENT_ID,
@@ -198,7 +259,7 @@ class TestRedshiftDataOperator:
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_return_sql_result(self, mock_conn):
expected_result = {"Result": True}
- mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+ mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID,
"SessionId": SESSION_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_conn.get_statement_result.return_value = expected_result
cluster_identifier = "cluster_identifier"
@@ -216,7 +277,8 @@ class TestRedshiftDataOperator:
aws_conn_id=CONN_ID,
return_sql_result=True,
)
- actual_result = operator.execute(None)
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
+ actual_result = operator.execute({"ti": mock_ti})
assert actual_result == expected_result
mock_conn.execute_statement.assert_called_once_with(
Database=DATABASE,
@@ -260,7 +322,9 @@ class TestRedshiftDataOperator:
poll_interval=poll_interval,
deferrable=True,
)
- operator.execute(None)
+
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
+ operator.execute({"ti": mock_ti})
assert not mock_defer.called
mock_exec_query.assert_called_once_with(
@@ -275,6 +339,8 @@ class TestRedshiftDataOperator:
with_event=False,
wait_for_completion=False,
poll_interval=poll_interval,
+ session_id=None,
+ session_keep_alive_seconds=None,
)
@mock.patch(
@@ -283,8 +349,9 @@ class TestRedshiftDataOperator:
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_defer(self, mock_exec_query, check_query_is_finished,
deferrable_operator):
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
with pytest.raises(TaskDeferred) as exc:
- deferrable_operator.execute(None)
+ deferrable_operator.execute({"ti": mock_ti})
assert isinstance(exc.value.trigger, RedshiftDataTrigger)
@@ -346,7 +413,8 @@ class TestRedshiftDataOperator:
poll_interval=poll_interval,
deferrable=deferrable,
)
- operator.execute(None)
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
+ operator.execute({"ti": mock_ti})
assert not mock_check_query_is_finished.called
assert not mock_defer.called
diff --git a/tests/providers/amazon/aws/utils/test_openlineage.py
b/tests/providers/amazon/aws/utils/test_openlineage.py
index b3e820b581..195db068d3 100644
--- a/tests/providers/amazon/aws/utils/test_openlineage.py
+++ b/tests/providers/amazon/aws/utils/test_openlineage.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
-from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
+from airflow.providers.amazon.aws.hooks.redshift_data import
QueryExecutionOutput, RedshiftDataHook
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.amazon.aws.utils.openlineage import (
get_facets_from_redshift_table,
@@ -58,7 +58,7 @@ def
test_get_facets_from_redshift_table_sql_hook(mock_get_records):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_get_facets_from_redshift_table_data_hook(mock_connection,
mock_execute_query):
- mock_execute_query.return_value = "statement_id"
+ mock_execute_query.return_value =
QueryExecutionOutput(statement_id="statement_id", session_id=None)
mock_connection.get_statement_result.return_value = {
"Records": [
[
diff --git a/tests/system/providers/amazon/aws/example_redshift.py
b/tests/system/providers/amazon/aws/example_redshift.py
index cc92076dcb..67b822d41e 100644
--- a/tests/system/providers/amazon/aws/example_redshift.py
+++ b/tests/system/providers/amazon/aws/example_redshift.py
@@ -50,7 +50,6 @@ DB_PASS = "MyAmazonPassword1"
DB_NAME = "dev"
POLL_INTERVAL = 10
-
with DAG(
dag_id=DAG_ID,
start_date=datetime(2021, 1, 1),
@@ -175,6 +174,37 @@ with DAG(
wait_for_completion=True,
)
+ # [START howto_operator_redshift_data_session_reuse]
+ create_tmp_table_data_api = RedshiftDataOperator(
+ task_id="create_tmp_table_data_api",
+ cluster_identifier=redshift_cluster_identifier,
+ database=DB_NAME,
+ db_user=DB_LOGIN,
+ sql="""
+ CREATE TEMPORARY TABLE tmp_people (
+ id INTEGER,
+ first_name VARCHAR(100),
+ age INTEGER
+ );
+ """,
+ poll_interval=POLL_INTERVAL,
+ wait_for_completion=True,
+ session_keep_alive_seconds=600,
+ )
+
+ insert_data_reuse_session = RedshiftDataOperator(
+ task_id="insert_data_reuse_session",
+ sql="""
+ INSERT INTO tmp_people VALUES ( 1, 'Bob', 30);
+ INSERT INTO tmp_people VALUES ( 2, 'Alice', 35);
+ INSERT INTO tmp_people VALUES ( 3, 'Charlie', 40);
+ """,
+ poll_interval=POLL_INTERVAL,
+ wait_for_completion=True,
+ session_id="{{
task_instance.xcom_pull(task_ids='create_tmp_table_data_api', key='session_id')
}}",
+ )
+ # [END howto_operator_redshift_data_session_reuse]
+
# [START howto_operator_redshift_delete_cluster]
delete_cluster = RedshiftDeleteClusterOperator(
task_id="delete_cluster",
@@ -209,13 +239,20 @@ with DAG(
delete_cluster,
)
+ # Test session reuse in parallel
+ chain(
+ wait_cluster_available_after_resume,
+ create_tmp_table_data_api,
+ insert_data_reuse_session,
+ delete_cluster_snapshot,
+ )
+
from tests.system.utils.watcher import watcher
# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()
-
from tests.system.utils import get_test_run # noqa: E402
# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
diff --git a/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py
b/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py
index 0691046190..9fb989ec53 100644
--- a/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py
+++ b/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py
@@ -53,22 +53,34 @@ DB_NAME = "dev"
S3_KEY = "s3_output_"
S3_KEY_2 = "s3_key_2"
+S3_KEY_3 = "s3_output_tmp_table_"
S3_KEY_PREFIX = "s3_k"
REDSHIFT_TABLE = "test_table"
+REDSHIFT_TMP_TABLE = "tmp_table"
-SQL_CREATE_TABLE = f"""
- CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE} (
- fruit_id INTEGER,
- name VARCHAR NOT NULL,
- color VARCHAR NOT NULL
- );
-"""
+DATA = "0, 'Airflow', 'testing'"
-SQL_INSERT_DATA = f"INSERT INTO {REDSHIFT_TABLE} VALUES ( 1, 'Banana',
'Yellow');"
-SQL_DROP_TABLE = f"DROP TABLE IF EXISTS {REDSHIFT_TABLE};"
+def _drop_table(table_name: str) -> str:
+ return f"DROP TABLE IF EXISTS {table_name};"
-DATA = "0, 'Airflow', 'testing'"
+
+def _create_table(table_name: str, is_temp: bool = False) -> str:
+ temp_keyword = "TEMPORARY" if is_temp else ""
+ return (
+ _drop_table(table_name)
+ + f"""
+ CREATE {temp_keyword} TABLE {table_name} (
+ fruit_id INTEGER,
+ name VARCHAR NOT NULL,
+ color VARCHAR NOT NULL
+ );
+ """
+ )
+
+
+def _insert_data(table_name: str) -> str:
+ return f"INSERT INTO {table_name} VALUES ( 1, 'Banana', 'Yellow');"
with DAG(
@@ -124,7 +136,7 @@ with DAG(
cluster_identifier=redshift_cluster_identifier,
database=DB_NAME,
db_user=DB_LOGIN,
- sql=SQL_CREATE_TABLE,
+ sql=_create_table(REDSHIFT_TABLE),
wait_for_completion=True,
)
@@ -133,7 +145,7 @@ with DAG(
cluster_identifier=redshift_cluster_identifier,
database=DB_NAME,
db_user=DB_LOGIN,
- sql=SQL_INSERT_DATA,
+ sql=_insert_data(REDSHIFT_TABLE),
wait_for_completion=True,
)
@@ -159,6 +171,33 @@ with DAG(
bucket_key=f"{S3_KEY}/{REDSHIFT_TABLE}_0000_part_00",
)
+ create_tmp_table = RedshiftDataOperator(
+ task_id="create_tmp_table",
+ cluster_identifier=redshift_cluster_identifier,
+ database=DB_NAME,
+ db_user=DB_LOGIN,
+ sql=_create_table(REDSHIFT_TMP_TABLE, is_temp=True) +
_insert_data(REDSHIFT_TMP_TABLE),
+ wait_for_completion=True,
+ session_keep_alive_seconds=600,
+ )
+
+ transfer_redshift_to_s3_reuse_session = RedshiftToS3Operator(
+ task_id="transfer_redshift_to_s3_reuse_session",
+ redshift_data_api_kwargs={
+ "wait_for_completion": True,
+ "session_id": "{{
task_instance.xcom_pull(task_ids='create_tmp_table', key='session_id') }}",
+ },
+ s3_bucket=bucket_name,
+ s3_key=S3_KEY_3,
+ table=REDSHIFT_TMP_TABLE,
+ )
+
+ check_if_tmp_table_key_exists = S3KeySensor(
+ task_id="check_if_tmp_table_key_exists",
+ bucket_name=bucket_name,
+ bucket_key=f"{S3_KEY_3}/{REDSHIFT_TMP_TABLE}_0000_part_00",
+ )
+
# [START howto_transfer_s3_to_redshift]
transfer_s3_to_redshift = S3ToRedshiftOperator(
task_id="transfer_s3_to_redshift",
@@ -176,6 +215,28 @@ with DAG(
)
# [END howto_transfer_s3_to_redshift]
+ create_dest_tmp_table = RedshiftDataOperator(
+ task_id="create_dest_tmp_table",
+ cluster_identifier=redshift_cluster_identifier,
+ database=DB_NAME,
+ db_user=DB_LOGIN,
+ sql=_create_table(REDSHIFT_TMP_TABLE, is_temp=True),
+ wait_for_completion=True,
+ session_keep_alive_seconds=600,
+ )
+
+ transfer_s3_to_redshift_tmp_table = S3ToRedshiftOperator(
+ task_id="transfer_s3_to_redshift_tmp_table",
+ redshift_data_api_kwargs={
+ "session_id": "{{
task_instance.xcom_pull(task_ids='create_dest_tmp_table', key='session_id') }}",
+ "wait_for_completion": True,
+ },
+ s3_bucket=bucket_name,
+ s3_key=S3_KEY_2,
+ table=REDSHIFT_TMP_TABLE,
+ copy_options=["csv"],
+ )
+
# [START howto_transfer_s3_to_redshift_multiple_keys]
transfer_s3_to_redshift_multiple = S3ToRedshiftOperator(
task_id="transfer_s3_to_redshift_multiple",
@@ -198,7 +259,7 @@ with DAG(
cluster_identifier=redshift_cluster_identifier,
database=DB_NAME,
db_user=DB_LOGIN,
- sql=SQL_DROP_TABLE,
+ sql=_drop_table(REDSHIFT_TABLE),
wait_for_completion=True,
trigger_rule=TriggerRule.ALL_DONE,
)
@@ -235,13 +296,33 @@ with DAG(
delete_bucket,
)
+ chain(
+ # TEST SETUP
+ wait_cluster_available,
+ create_tmp_table,
+ # TEST BODY
+ transfer_redshift_to_s3_reuse_session,
+ check_if_tmp_table_key_exists,
+ # TEST TEARDOWN
+ delete_cluster,
+ )
+
+ chain(
+ # TEST SETUP
+ wait_cluster_available,
+ create_dest_tmp_table,
+ # TEST BODY
+ transfer_s3_to_redshift_tmp_table,
+ # TEST TEARDOWN
+ delete_cluster,
+ )
+
from tests.system.utils.watcher import watcher
# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()
-
from tests.system.utils import get_test_run # noqa: E402
# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)