This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 81ebd78 Added upsert method on S3ToRedshift operator (#18027)
81ebd78 is described below
commit 81ebd78db48a4876377dc20d361a7938be11373a
Author: Mario Taddeucci <[email protected]>
AuthorDate: Sun Sep 12 17:39:20 2021 -0300
Added upsert method on S3ToRedshift operator (#18027)
---
.../amazon/aws/transfers/s3_to_redshift.py | 82 +++++++++++++--
airflow/providers/postgres/hooks/postgres.py | 27 ++++-
.../amazon/aws/transfers/test_s3_to_redshift.py | 116 ++++++++++++++++++++-
3 files changed, 213 insertions(+), 12 deletions(-)
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 2fadd3f..48d80ea 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -15,13 +15,17 @@
# specific language governing permissions and limitations
# under the License.
+import warnings
from typing import List, Optional, Union
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
from airflow.providers.postgres.hooks.postgres import PostgresHook
+AVAILABLE_METHODS = ['APPEND', 'REPLACE', 'UPSERT']
+
class S3ToRedshiftOperator(BaseOperator):
"""
@@ -61,8 +65,10 @@ class S3ToRedshiftOperator(BaseOperator):
:type column_list: List[str]
:param copy_options: reference to a list of COPY options
:type copy_options: list
- :param truncate_table: whether or not to truncate the destination table
before the copy
- :type truncate_table: bool
+ :param method: Action to be performed on execution. Available ``APPEND``,
``UPSERT`` and ``REPLACE``.
+ :type method: str
+ :param upsert_keys: List of fields to use as key on upsert action
+ :type upsert_keys: List[str]
"""
template_fields = ('s3_bucket', 's3_key', 'schema', 'table',
'column_list', 'copy_options')
@@ -82,9 +88,21 @@ class S3ToRedshiftOperator(BaseOperator):
column_list: Optional[List[str]] = None,
copy_options: Optional[List] = None,
autocommit: bool = False,
- truncate_table: bool = False,
+ method: str = 'APPEND',
+ upsert_keys: Optional[List[str]] = None,
**kwargs,
) -> None:
+
+ if 'truncate_table' in kwargs:
+ warnings.warn(
+ """`truncate_table` is deprecated. Please use `REPLACE`
method.""",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if kwargs['truncate_table']:
+ method = 'REPLACE'
+ kwargs.pop('truncate_table', None)
+
super().__init__(**kwargs)
self.schema = schema
self.table = table
@@ -96,33 +114,77 @@ class S3ToRedshiftOperator(BaseOperator):
self.column_list = column_list
self.copy_options = copy_options or []
self.autocommit = autocommit
- self.truncate_table = truncate_table
+ self.method = method
+ self.upsert_keys = upsert_keys
+
+ if self.method not in AVAILABLE_METHODS:
+ raise AirflowException(f'Method not found! Available methods:
{AVAILABLE_METHODS}')
- def _build_copy_query(self, credentials_block: str, copy_options: str) ->
str:
+ def _build_copy_query(self, copy_destination: str, credentials_block: str,
copy_options: str) -> str:
column_names = "(" + ", ".join(self.column_list) + ")" if
self.column_list else ''
return f"""
- COPY {self.schema}.{self.table} {column_names}
+ COPY {copy_destination} {column_names}
FROM 's3://{self.s3_bucket}/{self.s3_key}'
with credentials
'{credentials_block}'
{copy_options};
"""
+ def _get_table_primary_key(self, postgres_hook):
+ sql = """
+ select kcu.column_name
+ from information_schema.table_constraints tco
+ join information_schema.key_column_usage kcu
+ on kcu.constraint_name = tco.constraint_name
+ and kcu.constraint_schema = tco.constraint_schema
+ and kcu.constraint_name = tco.constraint_name
+ where tco.constraint_type = 'PRIMARY KEY'
+ and kcu.table_schema = %s
+ and kcu.table_name = %s
+ """
+
+ result = postgres_hook.get_records(sql, (self.schema, self.table))
+
+ if len(result) == 0:
+ raise AirflowException(
+ f"""
+ No primary key on {self.schema}.{self.table}.
+ Please provide keys on 'upsert_keys' parameter.
+ """
+ )
+ return [row[0] for row in result]
+
def execute(self, context) -> None:
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
credentials = s3_hook.get_credentials()
credentials_block = build_credentials_block(credentials)
copy_options = '\n\t\t\t'.join(self.copy_options)
+ destination = f'{self.schema}.{self.table}'
+ copy_destination = f'#{self.table}' if self.method == 'UPSERT' else
destination
- copy_statement = self._build_copy_query(credentials_block,
copy_options)
+ copy_statement = self._build_copy_query(copy_destination,
credentials_block, copy_options)
- if self.truncate_table:
- delete_statement = f'DELETE FROM {self.schema}.{self.table};'
+ if self.method == 'REPLACE':
sql = f"""
BEGIN;
- {delete_statement}
+ DELETE FROM {destination};
+ {copy_statement}
+ COMMIT
+ """
+ elif self.method == 'UPSERT':
+ keys = self.upsert_keys or
postgres_hook.get_table_primary_key(self.table, self.schema)
+ if not keys:
+ raise AirflowException(
+ f"No primary key on {self.schema}.{self.table}. Please
provide keys on 'upsert_keys'"
+ )
+ where_statement = ' AND '.join([f'{self.table}.{k} =
{copy_destination}.{k}' for k in keys])
+ sql = f"""
+ CREATE TABLE {copy_destination} (LIKE {destination});
{copy_statement}
+ BEGIN;
+ DELETE FROM {destination} USING {copy_destination} WHERE
{where_statement};
+ INSERT INTO {destination} SELECT * FROM {copy_destination};
COMMIT
"""
else:
diff --git a/airflow/providers/postgres/hooks/postgres.py
b/airflow/providers/postgres/hooks/postgres.py
index 446b6f3..67cc8b3 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -19,7 +19,7 @@
import os
from contextlib import closing
from copy import deepcopy
-from typing import Iterable, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Tuple, Union
import psycopg2
import psycopg2.extensions
@@ -197,6 +197,31 @@ class PostgresHook(DbApiHook):
token = aws_hook.conn.generate_db_auth_token(conn.host, port,
conn.login)
return login, token, port
+ def get_table_primary_key(self, table: str, schema: Optional[str] =
"public") -> List[str]:
+ """
+ Helper method that returns the table primary key
+
+ :param table: Name of the target table
+ :type table: str
+ :param table: Name of the target schema, public by default
+ :type table: str
+ :return: Primary key columns list
+ :rtype: List[str]
+ """
+ sql = """
+ select kcu.column_name
+ from information_schema.table_constraints tco
+ join information_schema.key_column_usage kcu
+ on kcu.constraint_name = tco.constraint_name
+ and kcu.constraint_schema = tco.constraint_schema
+ and kcu.constraint_name = tco.constraint_name
+ where tco.constraint_type = 'PRIMARY KEY'
+ and kcu.table_schema = %s
+ and kcu.table_name = %s
+ """
+ pk_columns = [row[0] for row in self.get_records(sql, (schema, table))]
+ return pk_columns or None
+
@staticmethod
def _generate_insert_sql(
table: str, values: Tuple[str, ...], target_fields: Iterable[str],
replace: bool, **kwargs
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index cd18165..1ee139e 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -20,8 +20,10 @@
import unittest
from unittest import mock
+import pytest
from boto3.session import Session
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.transfers.s3_to_redshift import
S3ToRedshiftOperator
from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces
@@ -111,7 +113,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
- def test_truncate(self, mock_run, mock_session):
+ def test_deprecated_truncate(self, mock_run, mock_session):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
@@ -158,6 +160,103 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+ def test_replace(self, mock_run, mock_session):
+ access_key = "aws_access_key_id"
+ secret_key = "aws_secret_access_key"
+ mock_session.return_value = Session(access_key, secret_key)
+ mock_session.return_value.access_key = access_key
+ mock_session.return_value.secret_key = secret_key
+ mock_session.return_value.token = None
+
+ schema = "schema"
+ table = "table"
+ s3_bucket = "bucket"
+ s3_key = "key"
+ copy_options = ""
+
+ op = S3ToRedshiftOperator(
+ schema=schema,
+ table=table,
+ s3_bucket=s3_bucket,
+ s3_key=s3_key,
+ copy_options=copy_options,
+ method='REPLACE',
+ redshift_conn_id="redshift_conn_id",
+ aws_conn_id="aws_conn_id",
+ task_id="task_id",
+ dag=None,
+ )
+ op.execute(None)
+ copy_statement = '''
+ COPY schema.table
+ FROM 's3://bucket/key'
+ with credentials
+
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+ ;
+ '''
+ delete_statement = f'DELETE FROM {schema}.{table};'
+ transaction = f"""
+ BEGIN;
+ {delete_statement}
+ {copy_statement}
+ COMMIT
+ """
+ assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0],
transaction)
+
+ assert mock_run.call_count == 1
+
+ @mock.patch("boto3.session.Session")
+ @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+ def test_upsert(self, mock_run, mock_session):
+ access_key = "aws_access_key_id"
+ secret_key = "aws_secret_access_key"
+ mock_session.return_value = Session(access_key, secret_key)
+ mock_session.return_value.access_key = access_key
+ mock_session.return_value.secret_key = secret_key
+ mock_session.return_value.token = None
+
+ schema = "schema"
+ table = "table"
+ s3_bucket = "bucket"
+ s3_key = "key"
+ copy_options = ""
+
+ op = S3ToRedshiftOperator(
+ schema=schema,
+ table=table,
+ s3_bucket=s3_bucket,
+ s3_key=s3_key,
+ copy_options=copy_options,
+ method='UPSERT',
+ upsert_keys=['id'],
+ redshift_conn_id="redshift_conn_id",
+ aws_conn_id="aws_conn_id",
+ task_id="task_id",
+ dag=None,
+ )
+ op.execute(None)
+
+ copy_statement = f'''
+ COPY #{table}
+ FROM 's3://bucket/key'
+ with credentials
+
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+ ;
+ '''
+ transaction = f"""
+ CREATE TABLE #{table} (LIKE {schema}.{table});
+ {copy_statement}
+ BEGIN;
+ DELETE FROM {schema}.{table} USING #{table} WHERE
{table}.id = #{table}.id;
+ INSERT INTO {schema}.{table} SELECT * FROM #{table};
+ COMMIT
+ """
+ assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0],
transaction)
+
+ assert mock_run.call_count == 1
+
+ @mock.patch("boto3.session.Session")
+ @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
def test_execute_sts_token(self, mock_run, mock_session):
access_key = "ASIA_aws_access_key_id"
secret_key = "aws_secret_access_key"
@@ -207,3 +306,18 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
'column_list',
'copy_options',
)
+
+ def test_execute_unavailable_method(self):
+ """
+ Test execute unavailable method
+ """
+ with pytest.raises(AirflowException):
+ S3ToRedshiftOperator(
+ schema="schema",
+ table="table",
+ s3_bucket="bucket",
+ s3_key="key",
+ method="unavailable_method",
+ task_id="task_id",
+ dag=None,
+ )