This is an automated email from the ASF dual-hosted git repository.
feluelle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new db121f7 Add truncate table (before copy) option to
S3ToRedshiftOperator (#9246)
db121f7 is described below
commit db121f726b3c7a37aca1ea05eb4714f884456005
Author: JavierLopezT <[email protected]>
AuthorDate: Wed Oct 28 09:00:28 2020 +0100
Add truncate table (before copy) option to S3ToRedshiftOperator (#9246)
- add table arg to jinja template fields
- change ui_color
Co-authored-by: javier.lopez <[email protected]>
---
.../amazon/aws/transfers/s3_to_redshift.py | 42 +++++++++++--------
.../amazon/aws/transfers/test_s3_to_redshift.py | 48 +++++++++++++++++++---
2 files changed, 69 insertions(+), 21 deletions(-)
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 9abbe0a..35faad4 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -56,11 +56,16 @@ class S3ToRedshiftOperator(BaseOperator):
:type verify: bool or str
:param copy_options: reference to a list of COPY options
:type copy_options: list
+ :param truncate_table: whether or not to truncate the destination table
before the copy
+ :type truncate_table: bool
"""
- template_fields = ('s3_key',)
+ template_fields = (
+ 's3_key',
+ 'table',
+ )
template_ext = ()
- ui_color = '#ededed'
+ ui_color = '#99e699'
@apply_defaults
def __init__(
@@ -75,6 +80,7 @@ class S3ToRedshiftOperator(BaseOperator):
verify: Optional[Union[bool, str]] = None,
copy_options: Optional[List] = None,
autocommit: bool = False,
+ truncate_table: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -87,6 +93,7 @@ class S3ToRedshiftOperator(BaseOperator):
self.verify = verify
self.copy_options = copy_options or []
self.autocommit = autocommit
+ self.truncate_table = truncate_table
def execute(self, context) -> None:
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
@@ -94,22 +101,25 @@ class S3ToRedshiftOperator(BaseOperator):
credentials = s3_hook.get_credentials()
copy_options = '\n\t\t\t'.join(self.copy_options)
- copy_query = """
- COPY {schema}.{table}
- FROM 's3://{s3_bucket}/{s3_key}'
+ copy_statement = f"""
+ COPY {self.schema}.{self.table}
+ FROM 's3://{self.s3_bucket}/{self.s3_key}'
with credentials
- 'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
+
'aws_access_key_id={credentials.access_key};aws_secret_access_key={credentials.secret_key}'
{copy_options};
- """.format(
- schema=self.schema,
- table=self.table,
- s3_bucket=self.s3_bucket,
- s3_key=self.s3_key,
- access_key=credentials.access_key,
- secret_key=credentials.secret_key,
- copy_options=copy_options,
- )
+ """
+
+ if self.truncate_table:
+ truncate_statement = f'TRUNCATE TABLE {self.schema}.{self.table};'
+ sql = f"""
+ BEGIN;
+ {truncate_statement}
+ {copy_statement}
+ COMMIT
+ """
+ else:
+ sql = copy_statement
self.log.info('Executing COPY command...')
- postgres_hook.run(copy_query, self.autocommit)
+ postgres_hook.run(sql, self.autocommit)
self.log.info("COPY command complete...")
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index 1f2d1b2..74aaafb 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -53,21 +53,59 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
)
op.execute(None)
- copy_query = """
+ copy_query = f"""
COPY {schema}.{table}
FROM 's3://{s3_bucket}/{s3_key}'
with credentials
'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
{copy_options};
- """.format(
+ """
+
+ assert mock_run.call_count == 1
+ assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0],
copy_query)
+
+ @mock.patch("boto3.session.Session")
+ @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+ def test_truncate(self, mock_run, mock_session):
+ access_key = "aws_access_key_id"
+ secret_key = "aws_secret_access_key"
+ mock_session.return_value = Session(access_key, secret_key)
+
+ schema = "schema"
+ table = "table"
+ s3_bucket = "bucket"
+ s3_key = "key"
+ copy_options = ""
+
+ op = S3ToRedshiftOperator(
schema=schema,
table=table,
s3_bucket=s3_bucket,
s3_key=s3_key,
- access_key=access_key,
- secret_key=secret_key,
copy_options=copy_options,
+ truncate_table=True,
+ redshift_conn_id="redshift_conn_id",
+ aws_conn_id="aws_conn_id",
+ task_id="task_id",
+ dag=None,
)
+ op.execute(None)
+
+ copy_statement = f"""
+ COPY {schema}.{table}
+ FROM 's3://{s3_bucket}/{s3_key}'
+ with credentials
+
'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
+ {copy_options};
+ """
+
+ truncate_statement = f'TRUNCATE TABLE {schema}.{table};'
+ transaction = f"""
+ BEGIN;
+ {truncate_statement}
+ {copy_statement}
+ COMMIT
+ """
+ assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0],
transaction)
assert mock_run.call_count == 1
- assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0],
copy_query)