This is an automated email from the ASF dual-hosted git repository.

feluelle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new db121f7  Add truncate table (before copy) option to 
S3ToRedshiftOperator (#9246)
db121f7 is described below

commit db121f726b3c7a37aca1ea05eb4714f884456005
Author: JavierLopezT <[email protected]>
AuthorDate: Wed Oct 28 09:00:28 2020 +0100

    Add truncate table (before copy) option to S3ToRedshiftOperator (#9246)
    
    - add table arg to jinja template fields
    - change ui_color
    
    Co-authored-by: javier.lopez <[email protected]>
---
 .../amazon/aws/transfers/s3_to_redshift.py         | 42 +++++++++++--------
 .../amazon/aws/transfers/test_s3_to_redshift.py    | 48 +++++++++++++++++++---
 2 files changed, 69 insertions(+), 21 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py 
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 9abbe0a..35faad4 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -56,11 +56,16 @@ class S3ToRedshiftOperator(BaseOperator):
     :type verify: bool or str
     :param copy_options: reference to a list of COPY options
     :type copy_options: list
+    :param truncate_table: whether or not to truncate the destination table 
before the copy
+    :type truncate_table: bool
     """
 
-    template_fields = ('s3_key',)
+    template_fields = (
+        's3_key',
+        'table',
+    )
     template_ext = ()
-    ui_color = '#ededed'
+    ui_color = '#99e699'
 
     @apply_defaults
     def __init__(
@@ -75,6 +80,7 @@ class S3ToRedshiftOperator(BaseOperator):
         verify: Optional[Union[bool, str]] = None,
         copy_options: Optional[List] = None,
         autocommit: bool = False,
+        truncate_table: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -87,6 +93,7 @@ class S3ToRedshiftOperator(BaseOperator):
         self.verify = verify
         self.copy_options = copy_options or []
         self.autocommit = autocommit
+        self.truncate_table = truncate_table
 
     def execute(self, context) -> None:
         postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
@@ -94,22 +101,25 @@ class S3ToRedshiftOperator(BaseOperator):
         credentials = s3_hook.get_credentials()
         copy_options = '\n\t\t\t'.join(self.copy_options)
 
-        copy_query = """
-            COPY {schema}.{table}
-            FROM 's3://{s3_bucket}/{s3_key}'
+        copy_statement = f"""
+            COPY {self.schema}.{self.table}
+            FROM 's3://{self.s3_bucket}/{self.s3_key}'
             with credentials
-            'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
+            
'aws_access_key_id={credentials.access_key};aws_secret_access_key={credentials.secret_key}'
             {copy_options};
-        """.format(
-            schema=self.schema,
-            table=self.table,
-            s3_bucket=self.s3_bucket,
-            s3_key=self.s3_key,
-            access_key=credentials.access_key,
-            secret_key=credentials.secret_key,
-            copy_options=copy_options,
-        )
+        """
+
+        if self.truncate_table:
+            truncate_statement = f'TRUNCATE TABLE {self.schema}.{self.table};'
+            sql = f"""
+            BEGIN;
+            {truncate_statement}
+            {copy_statement}
+            COMMIT
+            """
+        else:
+            sql = copy_statement
 
         self.log.info('Executing COPY command...')
-        postgres_hook.run(copy_query, self.autocommit)
+        postgres_hook.run(sql, self.autocommit)
         self.log.info("COPY command complete...")
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py 
b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index 1f2d1b2..74aaafb 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -53,21 +53,59 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         )
         op.execute(None)
 
-        copy_query = """
+        copy_query = f"""
             COPY {schema}.{table}
             FROM 's3://{s3_bucket}/{s3_key}'
             with credentials
             'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
             {copy_options};
-        """.format(
+        """
+
+        assert mock_run.call_count == 1
+        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], 
copy_query)
+
+    @mock.patch("boto3.session.Session")
+    @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+    def test_truncate(self, mock_run, mock_session):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        op = S3ToRedshiftOperator(
             schema=schema,
             table=table,
             s3_bucket=s3_bucket,
             s3_key=s3_key,
-            access_key=access_key,
-            secret_key=secret_key,
             copy_options=copy_options,
+            truncate_table=True,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
         )
+        op.execute(None)
+
+        copy_statement = f"""
+                    COPY {schema}.{table}
+                    FROM 's3://{s3_bucket}/{s3_key}'
+                    with credentials
+                    
'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
+                    {copy_options};
+                """
+
+        truncate_statement = f'TRUNCATE TABLE {schema}.{table};'
+        transaction = f"""
+                    BEGIN;
+                    {truncate_statement}
+                    {copy_statement}
+                    COMMIT
+                    """
+        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], 
transaction)
 
         assert mock_run.call_count == 1
-        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], 
copy_query)

Reply via email to