This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8580e6d046 Support session reuse in `RedshiftDataOperator` (#42218)
8580e6d046 is described below

commit 8580e6d046b11d159e3260ec4015981387e94a57
Author: Boris Morel <[email protected]>
AuthorDate: Wed Sep 25 02:22:54 2024 +0800

    Support session reuse in `RedshiftDataOperator` (#42218)
---
 airflow/providers/amazon/CHANGELOG.rst             |  27 +++++
 .../providers/amazon/aws/hooks/redshift_data.py    |  66 ++++++++---
 .../amazon/aws/operators/redshift_data.py          |  21 +++-
 .../amazon/aws/transfers/redshift_to_s3.py         |  25 ++--
 .../amazon/aws/transfers/s3_to_redshift.py         |  12 +-
 airflow/providers/amazon/aws/utils/openlineage.py  |   4 +-
 .../operators/redshift/redshift_data.rst           |  12 ++
 .../amazon/aws/hooks/test_redshift_data.py         | 127 ++++++++++++++++++++-
 .../amazon/aws/operators/test_redshift_data.py     |  86 ++++++++++++--
 .../providers/amazon/aws/utils/test_openlineage.py |   4 +-
 .../providers/amazon/aws/example_redshift.py       |  41 ++++++-
 .../amazon/aws/example_redshift_s3_transfers.py    | 109 +++++++++++++++---
 12 files changed, 468 insertions(+), 66 deletions(-)

diff --git a/airflow/providers/amazon/CHANGELOG.rst 
b/airflow/providers/amazon/CHANGELOG.rst
index 126da03ad6..7596ad3886 100644
--- a/airflow/providers/amazon/CHANGELOG.rst
+++ b/airflow/providers/amazon/CHANGELOG.rst
@@ -26,6 +26,33 @@
 Changelog
 ---------
 
+Main
+......
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+.. warning::
+  In order to support session reuse in RedshiftData operators, the following 
breaking changes were introduced:
+
+  The ``database`` argument is now optional and as a result was moved after 
the ``sql`` argument which is a positional
+  one. Update your DAGs accordingly if they rely on argument order. Applies to:
+  * ``RedshiftDataHook``'s ``execute_query`` method
+  * ``RedshiftDataOperator``
+
+  ``RedshiftDataHook``'s ``execute_query`` method now returns a 
``QueryExecutionOutput`` object instead of just the
+  statement ID as a string.
+
+  ``RedshiftDataHook``'s ``parse_statement_resposne`` method was renamed to 
``parse_statement_response``.
+
+  ``S3ToRedshiftOperator``'s ``schema`` argument is now optional and was moved 
after the ``s3_key`` positional argument.
+  Update your DAGs accordingly if they rely on argument order.
+
+Features
+~~~~~~~~
+
+* ``Support session reuse in RedshiftDataOperator, RedshiftToS3Operator and 
S3ToRedshiftOperator (#42218)``
+
 8.29.0
 ......
 
diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py 
b/airflow/providers/amazon/aws/hooks/redshift_data.py
index 3c1f84b1f6..b2f46c0ef6 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -18,8 +18,12 @@
 from __future__ import annotations
 
 import time
+from dataclasses import dataclass
 from pprint import pformat
 from typing import TYPE_CHECKING, Any, Iterable
+from uuid import UUID
+
+from pendulum import duration
 
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 from airflow.providers.amazon.aws.utils import trim_none_values
@@ -35,6 +39,14 @@ FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
 RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}
 
 
+@dataclass
+class QueryExecutionOutput:
+    """Describes the output of a query execution."""
+
+    statement_id: str
+    session_id: str | None
+
+
 class RedshiftDataQueryFailedError(ValueError):
     """Raise an error that redshift data query failed."""
 
@@ -65,8 +77,8 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
 
     def execute_query(
         self,
-        database: str,
         sql: str | list[str],
+        database: str | None = None,
         cluster_identifier: str | None = None,
         db_user: str | None = None,
         parameters: Iterable | None = None,
@@ -76,23 +88,28 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
         wait_for_completion: bool = True,
         poll_interval: int = 10,
         workgroup_name: str | None = None,
-    ) -> str:
+        session_id: str | None = None,
+        session_keep_alive_seconds: int | None = None,
+    ) -> QueryExecutionOutput:
         """
         Execute a statement against Amazon Redshift.
 
-        :param database: the name of the database
         :param sql: the SQL statement or list of  SQL statement to run
+        :param database: the name of the database
         :param cluster_identifier: unique identifier of a cluster
         :param db_user: the database username
         :param parameters: the parameters for the SQL statement
         :param secret_arn: the name or ARN of the secret that enables db access
         :param statement_name: the name of the SQL statement
-        :param with_event: indicates whether to send an event to EventBridge
-        :param wait_for_completion: indicates whether to wait for a result, if 
True wait, if False don't wait
+        :param with_event: whether to send an event to EventBridge
+        :param wait_for_completion: whether to wait for a result
         :param poll_interval: how often in seconds to check the query status
         :param workgroup_name: name of the Redshift Serverless workgroup. 
Mutually exclusive with
             `cluster_identifier`. Specify this parameter to query Redshift 
Serverless. More info
             
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
+        :param session_id: the session identifier of the query
+        :param session_keep_alive_seconds: duration in seconds to keep the 
session alive after the query
+            finishes. The maximum time a session can keep alive is 24 hours
 
         :returns statement_id: str, the UUID of the statement
         """
@@ -105,7 +122,28 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
             "SecretArn": secret_arn,
             "StatementName": statement_name,
             "WorkgroupName": workgroup_name,
+            "SessionId": session_id,
+            "SessionKeepAliveSeconds": session_keep_alive_seconds,
         }
+
+        if sum(x is not None for x in (cluster_identifier, workgroup_name, 
session_id)) != 1:
+            raise ValueError(
+                "Exactly one of cluster_identifier, workgroup_name, or 
session_id must be provided"
+            )
+
+        if session_id is not None:
+            msg = "session_id must be a valid UUID4"
+            try:
+                if UUID(session_id).version != 4:
+                    raise ValueError(msg)
+            except ValueError:
+                raise ValueError(msg)
+
+        if session_keep_alive_seconds is not None and (
+            session_keep_alive_seconds < 0 or 
duration(seconds=session_keep_alive_seconds).hours > 24
+        ):
+            raise ValueError("Session keep alive duration must be between 0 
and 86400 seconds.")
+
         if isinstance(sql, list):
             kwargs["Sqls"] = sql
             resp = 
self.conn.batch_execute_statement(**trim_none_values(kwargs))
@@ -115,13 +153,10 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
 
         statement_id = resp["Id"]
 
-        if bool(cluster_identifier) is bool(workgroup_name):
-            raise ValueError("Either 'cluster_identifier' or 'workgroup_name' 
must be specified.")
-
         if wait_for_completion:
             self.wait_for_results(statement_id, poll_interval=poll_interval)
 
-        return statement_id
+        return QueryExecutionOutput(statement_id=statement_id, 
session_id=resp.get("SessionId"))
 
     def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
         while True:
@@ -135,9 +170,9 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
     def check_query_is_finished(self, statement_id: str) -> bool:
         """Check whether query finished, raise exception is failed."""
         resp = self.conn.describe_statement(Id=statement_id)
-        return self.parse_statement_resposne(resp)
+        return self.parse_statement_response(resp)
 
-    def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) 
-> bool:
+    def parse_statement_response(self, resp: DescribeStatementResponseTypeDef) 
-> bool:
         """Parse the response of describe_statement."""
         status = resp["Status"]
         if status == FINISHED_STATE:
@@ -179,8 +214,10 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
         :param table: Name of the target table
         :param database: the name of the database
         :param schema: Name of the target schema, public by default
-        :param sql: the SQL statement or list of  SQL statement to run
         :param cluster_identifier: unique identifier of a cluster
+        :param workgroup_name: name of the Redshift Serverless workgroup. 
Mutually exclusive with
+            `cluster_identifier`. Specify this parameter to query Redshift 
Serverless. More info
+            
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
         :param db_user: the database username
         :param secret_arn: the name or ARN of the secret that enables db access
         :param statement_name: the name of the SQL statement
@@ -212,7 +249,8 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
             with_event=with_event,
             wait_for_completion=wait_for_completion,
             poll_interval=poll_interval,
-        )
+        ).statement_id
+
         pk_columns = []
         token = ""
         while True:
@@ -251,4 +289,4 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
         """
         async with self.async_conn as client:
             resp = await client.describe_statement(Id=statement_id)
-            return self.parse_statement_resposne(resp)
+            return self.parse_statement_response(resp)
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py 
b/airflow/providers/amazon/aws/operators/redshift_data.py
index 45fee2a919..3d00c6d22e 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -56,13 +56,16 @@ class 
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
     :param workgroup_name: name of the Redshift Serverless workgroup. Mutually 
exclusive with
         `cluster_identifier`. Specify this parameter to query Redshift 
Serverless. More info
         
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
+    :param session_id: the session identifier of the query
+    :param session_keep_alive_seconds: duration in seconds to keep the session 
alive after the query
+        finishes. The maximum time a session can keep alive is 24 hours
     :param aws_conn_id: The Airflow connection used for AWS credentials.
         If this is ``None`` or empty then the default boto3 behaviour is used. 
If
         running Airflow in a distributed manner and aws_conn_id is None or
         empty, then default boto3 configuration would be used (and must be
         maintained on each worker node).
     :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
-    :param verify: Whether or not to verify SSL certificates. See:
+    :param verify: Whether to verify SSL certificates. See:
         
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
     :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
         
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
@@ -77,6 +80,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
         "parameters",
         "statement_name",
         "workgroup_name",
+        "session_id",
     )
     template_ext = (".sql",)
     template_fields_renderers = {"sql": "sql"}
@@ -84,8 +88,8 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
 
     def __init__(
         self,
-        database: str,
         sql: str | list,
+        database: str | None = None,
         cluster_identifier: str | None = None,
         db_user: str | None = None,
         parameters: list | None = None,
@@ -97,6 +101,8 @@ class 
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
         return_sql_result: bool = False,
         workgroup_name: str | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        session_id: str | None = None,
+        session_keep_alive_seconds: int | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -120,6 +126,8 @@ class 
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
         self.return_sql_result = return_sql_result
         self.statement_id: str | None = None
         self.deferrable = deferrable
+        self.session_id = session_id
+        self.session_keep_alive_seconds = session_keep_alive_seconds
 
     def execute(self, context: Context) -> GetStatementResultResponseTypeDef | 
str:
         """Execute a statement against Amazon Redshift."""
@@ -130,7 +138,7 @@ class 
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
         if self.deferrable:
             wait_for_completion = False
 
-        self.statement_id = self.hook.execute_query(
+        query_execution_output = self.hook.execute_query(
             database=self.database,
             sql=self.sql,
             cluster_identifier=self.cluster_identifier,
@@ -142,8 +150,15 @@ class 
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
             with_event=self.with_event,
             wait_for_completion=wait_for_completion,
             poll_interval=self.poll_interval,
+            session_id=self.session_id,
+            session_keep_alive_seconds=self.session_keep_alive_seconds,
         )
 
+        self.statement_id = query_execution_output.statement_id
+
+        if query_execution_output.session_id:
+            self.xcom_push(context, key="session_id", 
value=query_execution_output.session_id)
+
         if self.deferrable and self.wait_for_completion:
             is_finished = self.hook.check_query_is_finished(self.statement_id)
             if not is_finished:
diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py 
b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index ef3cebdae9..8538b1dfc3 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -45,7 +45,8 @@ class RedshiftToS3Operator(BaseOperator):
     :param s3_key: reference to a specific S3 key. If ``table_as_file_name`` 
is set
         to False, this param must include the desired file name
     :param schema: reference to a specific schema in redshift database,
-        used when ``table`` param provided and ``select_query`` param not 
provided
+        used when ``table`` param provided and ``select_query`` param not 
provided.
+        Do not provide when unloading a temporary table
     :param table: reference to a specific table in redshift database,
         used when ``schema`` param provided and ``select_query`` param not 
provided
     :param select_query: custom select query to fetch data from redshift 
database,
@@ -55,8 +56,8 @@ class RedshiftToS3Operator(BaseOperator):
         If the AWS connection contains 'aws_iam_role' in ``extras``
         the operator will use AWS STS credentials with a token
         
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials
-    :param verify: Whether or not to verify SSL certificates for S3 connection.
-        By default SSL certificates are verified.
+    :param verify: Whether to verify SSL certificates for S3 connection.
+        By default, SSL certificates are verified.
         You can provide the following values:
 
         - ``False``: do not validate SSL certificates. SSL will still be used
@@ -67,7 +68,7 @@ class RedshiftToS3Operator(BaseOperator):
                  CA cert bundle than the one used by botocore.
     :param unload_options: reference to a list of UNLOAD options
     :param autocommit: If set to True it will automatically commit the UNLOAD 
statement.
-        Otherwise it will be committed right before the redshift connection 
gets closed.
+        Otherwise, it will be committed right before the redshift connection 
gets closed.
     :param include_header: If set to True the s3 file contains the header 
columns.
     :param parameters: (optional) the parameters to render the SQL query with.
     :param table_as_file_name: If set to True, the s3 file will be named as 
the table.
@@ -141,9 +142,15 @@ class RedshiftToS3Operator(BaseOperator):
 
     @property
     def default_select_query(self) -> str | None:
-        if self.schema and self.table:
-            return f"SELECT * FROM {self.schema}.{self.table}"
-        return None
+        if not self.table:
+            return None
+
+        if self.schema:
+            table = f"{self.schema}.{self.table}"
+        else:
+            # Relevant when unloading a temporary table
+            table = self.table
+        return f"SELECT * FROM {table}"
 
     def execute(self, context: Context) -> None:
         if self.table and self.table_as_file_name:
@@ -152,9 +159,7 @@ class RedshiftToS3Operator(BaseOperator):
         self.select_query = self.select_query or self.default_select_query
 
         if self.select_query is None:
-            raise ValueError(
-                "Please provide both `schema` and `table` params or 
`select_query` to fetch the data."
-            )
+            raise ValueError("Please specify either a table or `select_query` 
to fetch the data.")
 
         if self.include_header and "HEADER" not in [uo.upper().strip() for uo 
in self.unload_options]:
             self.unload_options = [*self.unload_options, "HEADER"]
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py 
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 127ee07a60..792119bfeb 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -28,7 +28,6 @@ from airflow.providers.amazon.aws.utils.redshift import 
build_credentials_block
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
-
 AVAILABLE_METHODS = ["APPEND", "REPLACE", "UPSERT"]
 
 
@@ -40,17 +39,18 @@ class S3ToRedshiftOperator(BaseOperator):
         For more information on how to use this operator, take a look at the 
guide:
         :ref:`howto/operator:S3ToRedshiftOperator`
 
-    :param schema: reference to a specific schema in redshift database
     :param table: reference to a specific table in redshift database
     :param s3_bucket: reference to a specific S3 bucket
     :param s3_key: key prefix that selects single or multiple objects from S3
+    :param schema: reference to a specific schema in redshift database.
+        Do not provide when copying into a temporary table
     :param redshift_conn_id: reference to a specific redshift database OR a 
redshift data-api connection
     :param aws_conn_id: reference to a specific S3 connection
         If the AWS connection contains 'aws_iam_role' in ``extras``
         the operator will use AWS STS credentials with a token
         
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials
-    :param verify: Whether or not to verify SSL certificates for S3 connection.
-        By default SSL certificates are verified.
+    :param verify: Whether to verify SSL certificates for S3 connection.
+        By default, SSL certificates are verified.
         You can provide the following values:
 
         - ``False``: do not validate SSL certificates. SSL will still be used
@@ -87,10 +87,10 @@ class S3ToRedshiftOperator(BaseOperator):
     def __init__(
         self,
         *,
-        schema: str,
         table: str,
         s3_bucket: str,
         s3_key: str,
+        schema: str | None = None,
         redshift_conn_id: str = "redshift_default",
         aws_conn_id: str | None = "aws_default",
         verify: bool | str | None = None,
@@ -160,7 +160,7 @@ class S3ToRedshiftOperator(BaseOperator):
             credentials_block = build_credentials_block(credentials)
 
         copy_options = "\n\t\t\t".join(self.copy_options)
-        destination = f"{self.schema}.{self.table}"
+        destination = f"{self.schema}.{self.table}" if self.schema else 
self.table
         copy_destination = f"#{self.table}" if self.method == "UPSERT" else 
destination
 
         copy_statement = self._build_copy_query(
diff --git a/airflow/providers/amazon/aws/utils/openlineage.py 
b/airflow/providers/amazon/aws/utils/openlineage.py
index db472a3e46..be5703e2f6 100644
--- a/airflow/providers/amazon/aws/utils/openlineage.py
+++ b/airflow/providers/amazon/aws/utils/openlineage.py
@@ -86,7 +86,9 @@ def get_facets_from_redshift_table(
             ]
         )
     else:
-        statement_id = redshift_hook.execute_query(sql=sql, poll_interval=1, 
**redshift_data_api_kwargs)
+        statement_id = redshift_hook.execute_query(
+            sql=sql, poll_interval=1, **redshift_data_api_kwargs
+        ).statement_id
         response = redshift_hook.conn.get_statement_result(Id=statement_id)
 
         table_schema = SchemaDatasetFacet(
diff --git 
a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst 
b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst
index 0b314d34f3..2638e1732c 100644
--- a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst
+++ b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst
@@ -54,6 +54,18 @@ the necessity of a Postgres connection.
     :start-after: [START howto_operator_redshift_data]
     :end-before: [END howto_operator_redshift_data]
 
+Reuse a session when executing multiple statements
+==================================================
+
+Specify the ``session_keep_alive_seconds`` parameter on an upstream task. In a 
downstream task, get the session ID from
+the XCom and pass it to the ``session_id`` parameter. This is useful when you 
work with temporary tables.
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_redshift.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_redshift_data_session_reuse]
+    :end-before: [END howto_operator_redshift_data_session_reuse]
+
 Reference
 ---------
 
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py 
b/tests/providers/amazon/aws/hooks/test_redshift_data.py
index a0952e5ba7..d548086449 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_data.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import logging
 from unittest import mock
+from uuid import uuid4
 
 import pytest
 
@@ -63,15 +64,18 @@ class TestRedshiftDataHook:
         mock_conn.describe_statement.assert_not_called()
 
     @pytest.mark.parametrize(
-        "cluster_identifier, workgroup_name",
+        "cluster_identifier, workgroup_name, session_id",
         [
-            (None, None),
-            ("some_cluster", "some_workgroup"),
+            (None, None, None),
+            ("some_cluster", "some_workgroup", None),
+            (None, "some_workgroup", None),
+            ("some_cluster", None, None),
+            (None, None, "some_session_id"),
         ],
     )
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
-    def test_execute_requires_either_cluster_identifier_or_workgroup_name(
-        self, mock_conn, cluster_identifier, workgroup_name
+    def 
test_execute_requires_one_of_cluster_identifier_or_workgroup_name_or_session_id(
+        self, mock_conn, cluster_identifier, workgroup_name, session_id
     ):
         mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
         cluster_identifier = "cluster_identifier"
@@ -84,6 +88,51 @@ class TestRedshiftDataHook:
                 workgroup_name=workgroup_name,
                 sql=SQL,
                 wait_for_completion=False,
+                session_id=session_id,
+            )
+
+    @pytest.mark.parametrize(
+        "cluster_identifier, workgroup_name, session_id",
+        [
+            (None, None, None),
+            ("some_cluster", "some_workgroup", None),
+            (None, "some_workgroup", None),
+            ("some_cluster", None, None),
+            (None, None, "some_session_id"),
+        ],
+    )
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_execute_session_keep_alive_seconds_valid(
+        self, mock_conn, cluster_identifier, workgroup_name, session_id
+    ):
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        cluster_identifier = "cluster_identifier"
+        workgroup_name = "workgroup_name"
+        hook = RedshiftDataHook()
+        with pytest.raises(ValueError):
+            hook.execute_query(
+                database=DATABASE,
+                cluster_identifier=cluster_identifier,
+                workgroup_name=workgroup_name,
+                sql=SQL,
+                wait_for_completion=False,
+                session_id=session_id,
+            )
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_execute_session_id_valid(self, mock_conn):
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        cluster_identifier = "cluster_identifier"
+        workgroup_name = "workgroup_name"
+        hook = RedshiftDataHook()
+        with pytest.raises(ValueError):
+            hook.execute_query(
+                database=DATABASE,
+                cluster_identifier=cluster_identifier,
+                workgroup_name=workgroup_name,
+                sql=SQL,
+                wait_for_completion=False,
+                session_id="not_a_uuid",
             )
 
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@@ -156,6 +205,74 @@ class TestRedshiftDataHook:
             Id=STATEMENT_ID,
         )
 
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_execute_with_new_session(self, mock_conn):
+        cluster_identifier = "cluster_identifier"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        parameters = [{"name": "id", "value": "1"}]
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, 
"SessionId": "session_id"}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+
+        hook = RedshiftDataHook()
+        output = hook.execute_query(
+            sql=SQL,
+            database=DATABASE,
+            cluster_identifier=cluster_identifier,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+            session_keep_alive_seconds=123,
+        )
+        assert output.statement_id == STATEMENT_ID
+        assert output.session_id == "session_id"
+
+        mock_conn.execute_statement.assert_called_once_with(
+            Database=DATABASE,
+            Sql=SQL,
+            ClusterIdentifier=cluster_identifier,
+            DbUser=db_user,
+            SecretArn=secret_arn,
+            StatementName=statement_name,
+            Parameters=parameters,
+            WithEvent=False,
+            SessionKeepAliveSeconds=123,
+        )
+        mock_conn.describe_statement.assert_called_once_with(
+            Id=STATEMENT_ID,
+        )
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    def test_execute_reuse_session(self, mock_conn):
+        statement_name = "statement_name"
+        parameters = [{"name": "id", "value": "1"}]
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, 
"SessionId": "session_id"}
+        mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
+        hook = RedshiftDataHook()
+        session_id = str(uuid4())
+        output = hook.execute_query(
+            database=None,
+            sql=SQL,
+            statement_name=statement_name,
+            parameters=parameters,
+            session_id=session_id,
+        )
+        assert output.statement_id == STATEMENT_ID
+        assert output.session_id == "session_id"
+
+        mock_conn.execute_statement.assert_called_once_with(
+            Sql=SQL,
+            StatementName=statement_name,
+            Parameters=parameters,
+            WithEvent=False,
+            SessionId=session_id,
+        )
+        mock_conn.describe_statement.assert_called_once_with(
+            Id=STATEMENT_ID,
+        )
+
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
     def test_batch_execute(self, mock_conn):
         mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py 
b/tests/providers/amazon/aws/operators/test_redshift_data.py
index abfa2b038b..c22d776a94 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -22,6 +22,7 @@ from unittest import mock
 import pytest
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, TaskDeferred
+from airflow.providers.amazon.aws.hooks.redshift_data import 
QueryExecutionOutput
 from airflow.providers.amazon.aws.operators.redshift_data import 
RedshiftDataOperator
 from airflow.providers.amazon.aws.triggers.redshift_data import 
RedshiftDataTrigger
 from tests.providers.amazon.aws.utils.test_template_fields import 
validate_template_fields
@@ -31,6 +32,7 @@ TASK_ID = "task_id"
 SQL = "sql"
 DATABASE = "database"
 STATEMENT_ID = "statement_id"
+SESSION_ID = "session_id"
 
 
 @pytest.fixture
@@ -98,6 +100,8 @@ class TestRedshiftDataOperator:
         poll_interval = 5
         wait_for_completion = True
 
+        mock_exec_query.return_value = 
QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=None)
+
         operator = RedshiftDataOperator(
             aws_conn_id=CONN_ID,
             task_id=TASK_ID,
@@ -111,7 +115,8 @@ class TestRedshiftDataOperator:
             wait_for_completion=True,
             poll_interval=poll_interval,
         )
-        operator.execute(None)
+        mock_ti = mock.MagicMock(name="MockedTaskInstance")
+        operator.execute({"ti": mock_ti})
         mock_exec_query.assert_called_once_with(
             sql=SQL,
             database=DATABASE,
@@ -124,8 +129,12 @@ class TestRedshiftDataOperator:
             with_event=False,
             wait_for_completion=wait_for_completion,
             poll_interval=poll_interval,
+            session_id=None,
+            session_keep_alive_seconds=None,
         )
 
+        mock_ti.xcom_push.assert_not_called()
+
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
     def test_execute_with_workgroup_name(self, mock_exec_query):
         cluster_identifier = None
@@ -150,7 +159,54 @@ class TestRedshiftDataOperator:
             wait_for_completion=True,
             poll_interval=poll_interval,
         )
-        operator.execute(None)
+        mock_ti = mock.MagicMock(name="MockedTaskInstance")
+        operator.execute({"ti": mock_ti})
+        mock_exec_query.assert_called_once_with(
+            sql=SQL,
+            database=DATABASE,
+            cluster_identifier=cluster_identifier,
+            workgroup_name=workgroup_name,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+            with_event=False,
+            wait_for_completion=wait_for_completion,
+            poll_interval=poll_interval,
+            session_id=None,
+            session_keep_alive_seconds=None,
+        )
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+    def test_execute_new_session(self, mock_exec_query):
+        cluster_identifier = "cluster_identifier"
+        workgroup_name = None
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        parameters = [{"name": "id", "value": "1"}]
+        poll_interval = 5
+        wait_for_completion = True
+
+        mock_exec_query.return_value = 
QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=SESSION_ID)
+
+        operator = RedshiftDataOperator(
+            aws_conn_id=CONN_ID,
+            task_id=TASK_ID,
+            sql=SQL,
+            database=DATABASE,
+            cluster_identifier=cluster_identifier,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+            wait_for_completion=True,
+            poll_interval=poll_interval,
+            session_keep_alive_seconds=123,
+        )
+
+        mock_ti = mock.MagicMock(name="MockedTaskInstance")
+        operator.execute({"ti": mock_ti})
         mock_exec_query.assert_called_once_with(
             sql=SQL,
             database=DATABASE,
@@ -163,7 +219,11 @@ class TestRedshiftDataOperator:
             with_event=False,
             wait_for_completion=wait_for_completion,
             poll_interval=poll_interval,
+            session_id=None,
+            session_keep_alive_seconds=123,
         )
+        assert mock_ti.xcom_push.call_args.kwargs["key"] == "session_id"
+        assert mock_ti.xcom_push.call_args.kwargs["value"] == SESSION_ID
 
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
     def test_on_kill_without_query(self, mock_conn):
@@ -180,7 +240,7 @@ class TestRedshiftDataOperator:
 
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
     def test_on_kill_with_query(self, mock_conn):
-        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, 
"SessionId": SESSION_ID}
         operator = RedshiftDataOperator(
             aws_conn_id=CONN_ID,
             task_id=TASK_ID,
@@ -189,7 +249,8 @@ class TestRedshiftDataOperator:
             database=DATABASE,
             wait_for_completion=False,
         )
-        operator.execute(None)
+        mock_ti = mock.MagicMock(name="MockedTaskInstance")
+        operator.execute({"ti": mock_ti})
         operator.on_kill()
         mock_conn.cancel_statement.assert_called_once_with(
             Id=STATEMENT_ID,
@@ -198,7 +259,7 @@ class TestRedshiftDataOperator:
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
     def test_return_sql_result(self, mock_conn):
         expected_result = {"Result": True}
-        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
+        mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, 
"SessionId": SESSION_ID}
         mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
         mock_conn.get_statement_result.return_value = expected_result
         cluster_identifier = "cluster_identifier"
@@ -216,7 +277,8 @@ class TestRedshiftDataOperator:
             aws_conn_id=CONN_ID,
             return_sql_result=True,
         )
-        actual_result = operator.execute(None)
+        mock_ti = mock.MagicMock(name="MockedTaskInstance")
+        actual_result = operator.execute({"ti": mock_ti})
         assert actual_result == expected_result
         mock_conn.execute_statement.assert_called_once_with(
             Database=DATABASE,
@@ -260,7 +322,9 @@ class TestRedshiftDataOperator:
             poll_interval=poll_interval,
             deferrable=True,
         )
-        operator.execute(None)
+
+        mock_ti = mock.MagicMock(name="MockedTaskInstance")
+        operator.execute({"ti": mock_ti})
 
         assert not mock_defer.called
         mock_exec_query.assert_called_once_with(
@@ -275,6 +339,8 @@ class TestRedshiftDataOperator:
             with_event=False,
             wait_for_completion=False,
             poll_interval=poll_interval,
+            session_id=None,
+            session_keep_alive_seconds=None,
         )
 
     @mock.patch(
@@ -283,8 +349,9 @@ class TestRedshiftDataOperator:
     )
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
     def test_execute_defer(self, mock_exec_query, check_query_is_finished, 
deferrable_operator):
+        mock_ti = mock.MagicMock(name="MockedTaskInstance")
         with pytest.raises(TaskDeferred) as exc:
-            deferrable_operator.execute(None)
+            deferrable_operator.execute({"ti": mock_ti})
 
         assert isinstance(exc.value.trigger, RedshiftDataTrigger)
 
@@ -346,7 +413,8 @@ class TestRedshiftDataOperator:
                 poll_interval=poll_interval,
                 deferrable=deferrable,
             )
-            operator.execute(None)
+            mock_ti = mock.MagicMock(name="MockedTaskInstance")
+            operator.execute({"ti": mock_ti})
 
             assert not mock_check_query_is_finished.called
             assert not mock_defer.called
diff --git a/tests/providers/amazon/aws/utils/test_openlineage.py 
b/tests/providers/amazon/aws/utils/test_openlineage.py
index b3e820b581..195db068d3 100644
--- a/tests/providers/amazon/aws/utils/test_openlineage.py
+++ b/tests/providers/amazon/aws/utils/test_openlineage.py
@@ -21,7 +21,7 @@ from unittest import mock
 
 import pytest
 
-from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
+from airflow.providers.amazon.aws.hooks.redshift_data import 
QueryExecutionOutput, RedshiftDataHook
 from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
 from airflow.providers.amazon.aws.utils.openlineage import (
     get_facets_from_redshift_table,
@@ -58,7 +58,7 @@ def 
test_get_facets_from_redshift_table_sql_hook(mock_get_records):
 
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
 
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
 def test_get_facets_from_redshift_table_data_hook(mock_connection, 
mock_execute_query):
-    mock_execute_query.return_value = "statement_id"
+    mock_execute_query.return_value = 
QueryExecutionOutput(statement_id="statement_id", session_id=None)
     mock_connection.get_statement_result.return_value = {
         "Records": [
             [
diff --git a/tests/system/providers/amazon/aws/example_redshift.py 
b/tests/system/providers/amazon/aws/example_redshift.py
index cc92076dcb..67b822d41e 100644
--- a/tests/system/providers/amazon/aws/example_redshift.py
+++ b/tests/system/providers/amazon/aws/example_redshift.py
@@ -50,7 +50,6 @@ DB_PASS = "MyAmazonPassword1"
 DB_NAME = "dev"
 POLL_INTERVAL = 10
 
-
 with DAG(
     dag_id=DAG_ID,
     start_date=datetime(2021, 1, 1),
@@ -175,6 +174,37 @@ with DAG(
         wait_for_completion=True,
     )
 
+    # [START howto_operator_redshift_data_session_reuse]
+    create_tmp_table_data_api = RedshiftDataOperator(
+        task_id="create_tmp_table_data_api",
+        cluster_identifier=redshift_cluster_identifier,
+        database=DB_NAME,
+        db_user=DB_LOGIN,
+        sql="""
+            CREATE TEMPORARY TABLE tmp_people (
+            id INTEGER,
+            first_name VARCHAR(100),
+            age INTEGER
+            );
+        """,
+        poll_interval=POLL_INTERVAL,
+        wait_for_completion=True,
+        session_keep_alive_seconds=600,
+    )
+
+    insert_data_reuse_session = RedshiftDataOperator(
+        task_id="insert_data_reuse_session",
+        sql="""
+            INSERT INTO tmp_people VALUES ( 1, 'Bob', 30);
+            INSERT INTO tmp_people VALUES ( 2, 'Alice', 35);
+            INSERT INTO tmp_people VALUES ( 3, 'Charlie', 40);
+        """,
+        poll_interval=POLL_INTERVAL,
+        wait_for_completion=True,
+        session_id="{{ 
task_instance.xcom_pull(task_ids='create_tmp_table_data_api', key='session_id') 
}}",
+    )
+    # [END howto_operator_redshift_data_session_reuse]
+
     # [START howto_operator_redshift_delete_cluster]
     delete_cluster = RedshiftDeleteClusterOperator(
         task_id="delete_cluster",
@@ -209,13 +239,20 @@ with DAG(
         delete_cluster,
     )
 
+    # Test session reuse in parallel
+    chain(
+        wait_cluster_available_after_resume,
+        create_tmp_table_data_api,
+        insert_data_reuse_session,
+        delete_cluster_snapshot,
+    )
+
     from tests.system.utils.watcher import watcher
 
     # This test needs watcher in order to properly mark success/failure
     # when "tearDown" task with trigger rule is part of the DAG
     list(dag.tasks) >> watcher()
 
-
 from tests.system.utils import get_test_run  # noqa: E402
 
 # Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
diff --git a/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py 
b/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py
index 0691046190..9fb989ec53 100644
--- a/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py
+++ b/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py
@@ -53,22 +53,34 @@ DB_NAME = "dev"
 
 S3_KEY = "s3_output_"
 S3_KEY_2 = "s3_key_2"
+S3_KEY_3 = "s3_output_tmp_table_"
 S3_KEY_PREFIX = "s3_k"
 REDSHIFT_TABLE = "test_table"
+REDSHIFT_TMP_TABLE = "tmp_table"
 
-SQL_CREATE_TABLE = f"""
-    CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE} (
-    fruit_id INTEGER,
-    name VARCHAR NOT NULL,
-    color VARCHAR NOT NULL
-    );
-"""
+DATA = "0, 'Airflow', 'testing'"
 
-SQL_INSERT_DATA = f"INSERT INTO {REDSHIFT_TABLE} VALUES ( 1, 'Banana', 
'Yellow');"
 
-SQL_DROP_TABLE = f"DROP TABLE IF EXISTS {REDSHIFT_TABLE};"
+def _drop_table(table_name: str) -> str:
+    return f"DROP TABLE IF EXISTS {table_name};"
 
-DATA = "0, 'Airflow', 'testing'"
+
+def _create_table(table_name: str, is_temp: bool = False) -> str:
+    temp_keyword = "TEMPORARY" if is_temp else ""
+    return (
+        _drop_table(table_name)
+        + f"""
+        CREATE {temp_keyword} TABLE {table_name} (
+            fruit_id INTEGER,
+            name VARCHAR NOT NULL,
+            color VARCHAR NOT NULL
+        );
+    """
+    )
+
+
+def _insert_data(table_name: str) -> str:
+    return f"INSERT INTO {table_name} VALUES ( 1, 'Banana', 'Yellow');"
 
 
 with DAG(
@@ -124,7 +136,7 @@ with DAG(
         cluster_identifier=redshift_cluster_identifier,
         database=DB_NAME,
         db_user=DB_LOGIN,
-        sql=SQL_CREATE_TABLE,
+        sql=_create_table(REDSHIFT_TABLE),
         wait_for_completion=True,
     )
 
@@ -133,7 +145,7 @@ with DAG(
         cluster_identifier=redshift_cluster_identifier,
         database=DB_NAME,
         db_user=DB_LOGIN,
-        sql=SQL_INSERT_DATA,
+        sql=_insert_data(REDSHIFT_TABLE),
         wait_for_completion=True,
     )
 
@@ -159,6 +171,33 @@ with DAG(
         bucket_key=f"{S3_KEY}/{REDSHIFT_TABLE}_0000_part_00",
     )
 
+    create_tmp_table = RedshiftDataOperator(
+        task_id="create_tmp_table",
+        cluster_identifier=redshift_cluster_identifier,
+        database=DB_NAME,
+        db_user=DB_LOGIN,
+        sql=_create_table(REDSHIFT_TMP_TABLE, is_temp=True) + 
_insert_data(REDSHIFT_TMP_TABLE),
+        wait_for_completion=True,
+        session_keep_alive_seconds=600,
+    )
+
+    transfer_redshift_to_s3_reuse_session = RedshiftToS3Operator(
+        task_id="transfer_redshift_to_s3_reuse_session",
+        redshift_data_api_kwargs={
+            "wait_for_completion": True,
+            "session_id": "{{ 
task_instance.xcom_pull(task_ids='create_tmp_table', key='session_id') }}",
+        },
+        s3_bucket=bucket_name,
+        s3_key=S3_KEY_3,
+        table=REDSHIFT_TMP_TABLE,
+    )
+
+    check_if_tmp_table_key_exists = S3KeySensor(
+        task_id="check_if_tmp_table_key_exists",
+        bucket_name=bucket_name,
+        bucket_key=f"{S3_KEY_3}/{REDSHIFT_TMP_TABLE}_0000_part_00",
+    )
+
     # [START howto_transfer_s3_to_redshift]
     transfer_s3_to_redshift = S3ToRedshiftOperator(
         task_id="transfer_s3_to_redshift",
@@ -176,6 +215,28 @@ with DAG(
     )
     # [END howto_transfer_s3_to_redshift]
 
+    create_dest_tmp_table = RedshiftDataOperator(
+        task_id="create_dest_tmp_table",
+        cluster_identifier=redshift_cluster_identifier,
+        database=DB_NAME,
+        db_user=DB_LOGIN,
+        sql=_create_table(REDSHIFT_TMP_TABLE, is_temp=True),
+        wait_for_completion=True,
+        session_keep_alive_seconds=600,
+    )
+
+    transfer_s3_to_redshift_tmp_table = S3ToRedshiftOperator(
+        task_id="transfer_s3_to_redshift_tmp_table",
+        redshift_data_api_kwargs={
+            "session_id": "{{ 
task_instance.xcom_pull(task_ids='create_dest_tmp_table', key='session_id') }}",
+            "wait_for_completion": True,
+        },
+        s3_bucket=bucket_name,
+        s3_key=S3_KEY_2,
+        table=REDSHIFT_TMP_TABLE,
+        copy_options=["csv"],
+    )
+
     # [START howto_transfer_s3_to_redshift_multiple_keys]
     transfer_s3_to_redshift_multiple = S3ToRedshiftOperator(
         task_id="transfer_s3_to_redshift_multiple",
@@ -198,7 +259,7 @@ with DAG(
         cluster_identifier=redshift_cluster_identifier,
         database=DB_NAME,
         db_user=DB_LOGIN,
-        sql=SQL_DROP_TABLE,
+        sql=_drop_table(REDSHIFT_TABLE),
         wait_for_completion=True,
         trigger_rule=TriggerRule.ALL_DONE,
     )
@@ -235,13 +296,33 @@ with DAG(
         delete_bucket,
     )
 
+    chain(
+        # TEST SETUP
+        wait_cluster_available,
+        create_tmp_table,
+        # TEST BODY
+        transfer_redshift_to_s3_reuse_session,
+        check_if_tmp_table_key_exists,
+        # TEST TEARDOWN
+        delete_cluster,
+    )
+
+    chain(
+        # TEST SETUP
+        wait_cluster_available,
+        create_dest_tmp_table,
+        # TEST BODY
+        transfer_s3_to_redshift_tmp_table,
+        # TEST TEARDOWN
+        delete_cluster,
+    )
+
     from tests.system.utils.watcher import watcher
 
     # This test needs watcher in order to properly mark success/failure
     # when "tearDown" task with trigger rule is part of the DAG
     list(dag.tasks) >> watcher()
 
-
 from tests.system.utils import get_test_run  # noqa: E402
 
 # Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)


Reply via email to