This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9cd7930  Update copy command for s3 to redshift (#16241)
9cd7930 is described below

commit 9cd7930c34c2842bef0cdc1748d42e7caa722301
Author: sunki-hong <[email protected]>
AuthorDate: Sun Jun 13 18:09:21 2021 +0900

    Update copy command for s3 to redshift (#16241)
---
 .../amazon/aws/transfers/s3_to_redshift.py         |  9 ++-
 .../amazon/aws/transfers/test_s3_to_redshift.py    | 73 ++++++++++++++++++----
 2 files changed, 69 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py 
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 14be612..9efd8a7 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -57,13 +57,15 @@ class S3ToRedshiftOperator(BaseOperator):
                  You can specify this argument if you want to use a different
                  CA cert bundle than the one used by botocore.
     :type verify: bool or str
+    :param column_list: list of column names to load
+    :type column_list: List[str]
     :param copy_options: reference to a list of COPY options
     :type copy_options: list
     :param truncate_table: whether or not to truncate the destination table 
before the copy
     :type truncate_table: bool
     """
 
-    template_fields = ('s3_bucket', 's3_key', 'schema', 'table', 
'copy_options')
+    template_fields = ('s3_bucket', 's3_key', 'schema', 'table', 
'column_list', 'copy_options')
     template_ext = ()
     ui_color = '#99e699'
 
@@ -77,6 +79,7 @@ class S3ToRedshiftOperator(BaseOperator):
         redshift_conn_id: str = 'redshift_default',
         aws_conn_id: str = 'aws_default',
         verify: Optional[Union[bool, str]] = None,
+        column_list: Optional[List[str]] = None,
         copy_options: Optional[List] = None,
         autocommit: bool = False,
         truncate_table: bool = False,
@@ -90,13 +93,15 @@ class S3ToRedshiftOperator(BaseOperator):
         self.redshift_conn_id = redshift_conn_id
         self.aws_conn_id = aws_conn_id
         self.verify = verify
+        self.column_list = column_list
         self.copy_options = copy_options or []
         self.autocommit = autocommit
         self.truncate_table = truncate_table
 
     def _build_copy_query(self, credentials_block: str, copy_options: str) -> 
str:
+        column_names = "(" + ", ".join(self.column_list) + ")" if 
self.column_list else ''
         return f"""
-                    COPY {self.schema}.{self.table}
+                    COPY {self.schema}.{self.table} {column_names}
                     FROM 's3://{self.s3_bucket}/{self.s3_key}'
                     with credentials
                     '{credentials_block}'
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py 
b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index 51f80ad..7702b0e 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -23,7 +23,6 @@ from unittest import mock
 from boto3.session import Session
 
 from airflow.providers.amazon.aws.transfers.s3_to_redshift import 
S3ToRedshiftOperator
-from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
 from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces
 
 
@@ -56,10 +55,55 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
             dag=None,
         )
         op.execute(None)
+        copy_query = '''
+                        COPY schema.table
+                        FROM 's3://bucket/key'
+                        with credentials
+                        
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+                        ;
+                     '''
+        assert mock_run.call_count == 1
+        assert access_key in copy_query
+        assert secret_key in copy_query
+        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], 
copy_query)
 
-        credentials_block = build_credentials_block(mock_session.return_value)
-        copy_query = op._build_copy_query(credentials_block, copy_options)
+    @mock.patch("boto3.session.Session")
+    @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+    def test_execute_with_column_list(self, mock_run, mock_session):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        column_list = ["column_1", "column_2"]
+        copy_options = ""
 
+        op = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            column_list=column_list,
+            copy_options=copy_options,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+        op.execute(None)
+        copy_query = '''
+                        COPY schema.table (column_1, column_2)
+                        FROM 's3://bucket/key'
+                        with credentials
+                        
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+                        ;
+                     '''
         assert mock_run.call_count == 1
         assert access_key in copy_query
         assert secret_key in copy_query
@@ -94,10 +138,13 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
             dag=None,
         )
         op.execute(None)
-
-        credentials_block = build_credentials_block(mock_session.return_value)
-        copy_statement = op._build_copy_query(credentials_block, copy_options)
-
+        copy_statement = '''
+                        COPY schema.table
+                        FROM 's3://bucket/key'
+                        with credentials
+                        
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+                        ;
+                     '''
         truncate_statement = f'TRUNCATE TABLE {schema}.{table};'
         transaction = f"""
                     BEGIN;
@@ -137,11 +184,14 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
             task_id="task_id",
             dag=None,
         )
-
-        credentials_block = build_credentials_block(mock_session.return_value)
-        copy_statement = op._build_copy_query(credentials_block, copy_options)
         op.execute(None)
-
+        copy_statement = '''
+                            COPY schema.table
+                            FROM 's3://bucket/key'
+                            with credentials
+                            
'aws_access_key_id=ASIA_aws_access_key_id;aws_secret_access_key=aws_secret_access_key;token=aws_secret_token'
+                            ;
+                         '''
         assert access_key in copy_statement
         assert secret_key in copy_statement
         assert token in copy_statement
@@ -154,5 +204,6 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
             's3_key',
             'schema',
             'table',
+            'column_list',
             'copy_options',
         )

Reply via email to