This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 410b57795b fix: select_query should have precedence over default query 
in RedshiftToS3Operator (#41634)
410b57795b is described below

commit 410b57795b37f3e22e9920499feec22709f49427
Author: Kacper Muda <[email protected]>
AuthorDate: Wed Aug 21 16:18:06 2024 +0200

    fix: select_query should have precedence over default query in 
RedshiftToS3Operator (#41634)
---
 .../amazon/aws/transfers/redshift_to_s3.py         |  20 ++--
 .../amazon/aws/transfers/test_redshift_to_s3.py    | 117 +++++++++++++++++++++
 2 files changed, 130 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py 
b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index 73578ea539..ef3cebdae9 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -44,11 +44,12 @@ class RedshiftToS3Operator(BaseOperator):
     :param s3_bucket: reference to a specific S3 bucket
     :param s3_key: reference to a specific S3 key. If ``table_as_file_name`` 
is set
         to False, this param must include the desired file name
-    :param schema: reference to a specific schema in redshift database
-        Applicable when ``table`` param provided.
-    :param table: reference to a specific table in redshift database
-        Used when ``select_query`` param not provided.
-    :param select_query: custom select query to fetch data from redshift 
database
+    :param schema: reference to a specific schema in redshift database,
+        used when ``table`` param provided and ``select_query`` param not 
provided
+    :param table: reference to a specific table in redshift database,
+        used when ``schema`` param provided and ``select_query`` param not 
provided
+    :param select_query: custom select query to fetch data from redshift 
database,
+        has precedence over default query `SELECT * FROM ``schema``.``table``
     :param redshift_conn_id: reference to a specific redshift database
     :param aws_conn_id: reference to a specific S3 connection
         If the AWS connection contains 'aws_iam_role' in ``extras``
@@ -138,12 +139,17 @@ class RedshiftToS3Operator(BaseOperator):
                     {unload_options};
         """
 
+    @property
+    def default_select_query(self) -> str | None:
+        if self.schema and self.table:
+            return f"SELECT * FROM {self.schema}.{self.table}"
+        return None
+
     def execute(self, context: Context) -> None:
         if self.table and self.table_as_file_name:
             self.s3_key = f"{self.s3_key}/{self.table}_"
 
-        if self.schema and self.table:
-            self.select_query = f"SELECT * FROM {self.schema}.{self.table}"
+        self.select_query = self.select_query or self.default_select_query
 
         if self.select_query is None:
             raise ValueError(
diff --git a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py 
b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
index d2af90a445..2d28acd22e 100644
--- a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
@@ -305,6 +305,123 @@ class TestRedshiftToS3Transfer:
         assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], 
unload_query)
         assert f"UNLOAD ($${expected_query}$$)" in unload_query
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
+    @mock.patch("boto3.session.Session")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+    def test_custom_select_query_has_precedence_over_table_and_schema(
+        self,
+        mock_run,
+        mock_session,
+        mock_connection,
+        mock_hook,
+    ):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+        s3_bucket = "bucket"
+        s3_key = "key"
+        unload_options = [
+            "HEADER",
+        ]
+        select_query = "select column from table"
+
+        op = RedshiftToS3Operator(
+            select_query=select_query,
+            table="table",
+            schema="schema",
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            unload_options=unload_options,
+            include_header=True,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+
+        op.execute(None)
+
+        unload_options = "\n\t\t\t".join(unload_options)
+        credentials_block = build_credentials_block(mock_session.return_value)
+
+        unload_query = op._build_unload_query(credentials_block, select_query, 
"key/table_", unload_options)
+
+        assert mock_run.call_count == 1
+        assert access_key in unload_query
+        assert secret_key in unload_query
+        assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], 
unload_query)
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
+    @mock.patch("boto3.session.Session")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+    def test_default_select_query_used_when_table_and_schema_missing(
+        self,
+        mock_run,
+        mock_session,
+        mock_connection,
+        mock_hook,
+    ):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+        s3_bucket = "bucket"
+        s3_key = "key"
+        unload_options = [
+            "HEADER",
+        ]
+        default_query = "SELECT * FROM schema.table"
+
+        op = RedshiftToS3Operator(
+            table="table",
+            schema="schema",
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            unload_options=unload_options,
+            include_header=True,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+
+        op.execute(None)
+
+        unload_options = "\n\t\t\t".join(unload_options)
+        credentials_block = build_credentials_block(mock_session.return_value)
+
+        unload_query = op._build_unload_query(credentials_block, 
default_query, "key/table_", unload_options)
+
+        assert mock_run.call_count == 1
+        assert access_key in unload_query
+        assert secret_key in unload_query
+        assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], 
unload_query)
+
+    def test_lack_of_select_query_and_schema_and_table_raises_error(self):
+        op = RedshiftToS3Operator(
+            s3_bucket="bucket",
+            s3_key="key",
+            include_header=True,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+
+        with pytest.raises(ValueError):
+            op.execute(None)
+
     @pytest.mark.parametrize("table_as_file_name, expected_s3_key", [[True, 
"key/table_"], [False, "key"]])
     @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
     @mock.patch("airflow.models.connection.Connection")

Reply via email to