This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 81ebd78  Added upsert method on S3ToRedshift operator (#18027)
81ebd78 is described below

commit 81ebd78db48a4876377dc20d361a7938be11373a
Author: Mario Taddeucci <[email protected]>
AuthorDate: Sun Sep 12 17:39:20 2021 -0300

    Added upsert method on S3ToRedshift operator (#18027)
---
 .../amazon/aws/transfers/s3_to_redshift.py         |  82 +++++++++++++--
 airflow/providers/postgres/hooks/postgres.py       |  27 ++++-
 .../amazon/aws/transfers/test_s3_to_redshift.py    | 116 ++++++++++++++++++++-
 3 files changed, 213 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py 
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 2fadd3f..48d80ea 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -15,13 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import warnings
 from typing import List, Optional, Union
 
+from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
 from airflow.providers.postgres.hooks.postgres import PostgresHook
 
+AVAILABLE_METHODS = ['APPEND', 'REPLACE', 'UPSERT']
+
 
 class S3ToRedshiftOperator(BaseOperator):
     """
@@ -61,8 +65,10 @@ class S3ToRedshiftOperator(BaseOperator):
     :type column_list: List[str]
     :param copy_options: reference to a list of COPY options
     :type copy_options: list
-    :param truncate_table: whether or not to truncate the destination table 
before the copy
-    :type truncate_table: bool
+    :param method: Action to be performed on execution. Available ``APPEND``, 
``UPSERT`` and ``REPLACE``.
+    :type method: str
+    :param upsert_keys: List of fields to use as key on upsert action
+    :type upsert_keys: List[str]
     """
 
     template_fields = ('s3_bucket', 's3_key', 'schema', 'table', 
'column_list', 'copy_options')
@@ -82,9 +88,21 @@ class S3ToRedshiftOperator(BaseOperator):
         column_list: Optional[List[str]] = None,
         copy_options: Optional[List] = None,
         autocommit: bool = False,
-        truncate_table: bool = False,
+        method: str = 'APPEND',
+        upsert_keys: Optional[List[str]] = None,
         **kwargs,
     ) -> None:
+
+        if 'truncate_table' in kwargs:
+            warnings.warn(
+                """`truncate_table` is deprecated. Please use `REPLACE` 
method.""",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            if kwargs['truncate_table']:
+                method = 'REPLACE'
+            kwargs.pop('truncate_table', None)
+
         super().__init__(**kwargs)
         self.schema = schema
         self.table = table
@@ -96,33 +114,77 @@ class S3ToRedshiftOperator(BaseOperator):
         self.column_list = column_list
         self.copy_options = copy_options or []
         self.autocommit = autocommit
-        self.truncate_table = truncate_table
+        self.method = method
+        self.upsert_keys = upsert_keys
+
+        if self.method not in AVAILABLE_METHODS:
+            raise AirflowException(f'Method not found! Available methods: 
{AVAILABLE_METHODS}')
 
-    def _build_copy_query(self, credentials_block: str, copy_options: str) -> 
str:
+    def _build_copy_query(self, copy_destination: str, credentials_block: str, 
copy_options: str) -> str:
         column_names = "(" + ", ".join(self.column_list) + ")" if 
self.column_list else ''
         return f"""
-                    COPY {self.schema}.{self.table} {column_names}
+                    COPY {copy_destination} {column_names}
                     FROM 's3://{self.s3_bucket}/{self.s3_key}'
                     with credentials
                     '{credentials_block}'
                     {copy_options};
         """
 
+    def _get_table_primary_key(self, postgres_hook):
+        sql = """
+            select kcu.column_name
+            from information_schema.table_constraints tco
+                    join information_schema.key_column_usage kcu
+                        on kcu.constraint_name = tco.constraint_name
+                            and kcu.constraint_schema = tco.constraint_schema
+                            and kcu.constraint_name = tco.constraint_name
+            where tco.constraint_type = 'PRIMARY KEY'
+            and kcu.table_schema = %s
+            and kcu.table_name = %s
+        """
+
+        result = postgres_hook.get_records(sql, (self.schema, self.table))
+
+        if len(result) == 0:
+            raise AirflowException(
+                f"""
+                No primary key on {self.schema}.{self.table}.
+                Please provide keys on 'upsert_keys' parameter.
+                """
+            )
+        return [row[0] for row in result]
+
     def execute(self, context) -> None:
         postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
         s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
         credentials = s3_hook.get_credentials()
         credentials_block = build_credentials_block(credentials)
         copy_options = '\n\t\t\t'.join(self.copy_options)
+        destination = f'{self.schema}.{self.table}'
+        copy_destination = f'#{self.table}' if self.method == 'UPSERT' else 
destination
 
-        copy_statement = self._build_copy_query(credentials_block, 
copy_options)
+        copy_statement = self._build_copy_query(copy_destination, 
credentials_block, copy_options)
 
-        if self.truncate_table:
-            delete_statement = f'DELETE FROM {self.schema}.{self.table};'
+        if self.method == 'REPLACE':
             sql = f"""
             BEGIN;
-            {delete_statement}
+            DELETE FROM {destination};
+            {copy_statement}
+            COMMIT
+            """
+        elif self.method == 'UPSERT':
+            keys = self.upsert_keys or 
postgres_hook.get_table_primary_key(self.table, self.schema)
+            if not keys:
+                raise AirflowException(
+                    f"No primary key on {self.schema}.{self.table}. Please 
provide keys on 'upsert_keys'"
+                )
+            where_statement = ' AND '.join([f'{self.table}.{k} = 
{copy_destination}.{k}' for k in keys])
+            sql = f"""
+            CREATE TABLE {copy_destination} (LIKE {destination});
             {copy_statement}
+            BEGIN;
+            DELETE FROM {destination} USING {copy_destination} WHERE 
{where_statement};
+            INSERT INTO {destination} SELECT * FROM {copy_destination};
             COMMIT
             """
         else:
diff --git a/airflow/providers/postgres/hooks/postgres.py 
b/airflow/providers/postgres/hooks/postgres.py
index 446b6f3..67cc8b3 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -19,7 +19,7 @@
 import os
 from contextlib import closing
 from copy import deepcopy
-from typing import Iterable, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Tuple, Union
 
 import psycopg2
 import psycopg2.extensions
@@ -197,6 +197,31 @@ class PostgresHook(DbApiHook):
             token = aws_hook.conn.generate_db_auth_token(conn.host, port, 
conn.login)
         return login, token, port
 
+    def get_table_primary_key(self, table: str, schema: Optional[str] = 
"public") -> List[str]:
+        """
+        Helper method that returns the table primary key
+
+        :param table: Name of the target table
+        :type table: str
+        :param table: Name of the target schema, public by default
+        :type table: str
+        :return: Primary key columns list
+        :rtype: List[str]
+        """
+        sql = """
+            select kcu.column_name
+            from information_schema.table_constraints tco
+                    join information_schema.key_column_usage kcu
+                        on kcu.constraint_name = tco.constraint_name
+                            and kcu.constraint_schema = tco.constraint_schema
+                            and kcu.constraint_name = tco.constraint_name
+            where tco.constraint_type = 'PRIMARY KEY'
+            and kcu.table_schema = %s
+            and kcu.table_name = %s
+        """
+        pk_columns = [row[0] for row in self.get_records(sql, (schema, table))]
+        return pk_columns or None
+
     @staticmethod
     def _generate_insert_sql(
         table: str, values: Tuple[str, ...], target_fields: Iterable[str], 
replace: bool, **kwargs
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py 
b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index cd18165..1ee139e 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -20,8 +20,10 @@
 import unittest
 from unittest import mock
 
+import pytest
 from boto3.session import Session
 
+from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.transfers.s3_to_redshift import 
S3ToRedshiftOperator
 from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces
 
@@ -111,7 +113,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
 
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
-    def test_truncate(self, mock_run, mock_session):
+    def test_deprecated_truncate(self, mock_run, mock_session):
         access_key = "aws_access_key_id"
         secret_key = "aws_secret_access_key"
         mock_session.return_value = Session(access_key, secret_key)
@@ -158,6 +160,103 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
 
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+    def test_replace(self, mock_run, mock_session):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        op = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            method='REPLACE',
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+        op.execute(None)
+        copy_statement = '''
+                        COPY schema.table
+                        FROM 's3://bucket/key'
+                        with credentials
+                        
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+                        ;
+                     '''
+        delete_statement = f'DELETE FROM {schema}.{table};'
+        transaction = f"""
+                    BEGIN;
+                    {delete_statement}
+                    {copy_statement}
+                    COMMIT
+                    """
+        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], 
transaction)
+
+        assert mock_run.call_count == 1
+
+    @mock.patch("boto3.session.Session")
+    @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+    def test_upsert(self, mock_run, mock_session):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        op = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            method='UPSERT',
+            upsert_keys=['id'],
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+        op.execute(None)
+
+        copy_statement = f'''
+                        COPY #{table}
+                        FROM 's3://bucket/key'
+                        with credentials
+                        
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+                        ;
+                     '''
+        transaction = f"""
+                    CREATE TABLE #{table} (LIKE {schema}.{table});
+                    {copy_statement}
+                    BEGIN;
+                    DELETE FROM {schema}.{table} USING #{table} WHERE 
{table}.id = #{table}.id;
+                    INSERT INTO {schema}.{table} SELECT * FROM #{table};
+                    COMMIT
+                    """
+        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], 
transaction)
+
+        assert mock_run.call_count == 1
+
+    @mock.patch("boto3.session.Session")
+    @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
     def test_execute_sts_token(self, mock_run, mock_session):
         access_key = "ASIA_aws_access_key_id"
         secret_key = "aws_secret_access_key"
@@ -207,3 +306,18 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
             'column_list',
             'copy_options',
         )
+
+    def test_execute_unavailable_method(self):
+        """
+        Test execute unavailable method
+        """
+        with pytest.raises(AirflowException):
+            S3ToRedshiftOperator(
+                schema="schema",
+                table="table",
+                s3_bucket="bucket",
+                s3_key="key",
+                method="unavailable_method",
+                task_id="task_id",
+                dag=None,
+            )

Reply via email to