This is an automated email from the ASF dual-hosted git repository.

mattyb149 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git


The following commit(s) were added to refs/heads/main by this push:
     new ad95287  NIFI-6934 In PutDatabaseRecord added DatabaseAdapter-based 
UPSERT support for Postgres (9.5+)
ad95287 is described below

commit ad95287e782d71a778dfa89b3cb2a17637b32bf5
Author: Tamas Palfy <[email protected]>
AuthorDate: Thu Jun 18 19:42:35 2020 +0200

    NIFI-6934 In PutDatabaseRecord added DatabaseAdapter-based UPSERT support 
for Postgres (9.5+)
    
    NIFI-6934 Added more documentation and unit tests.
    
    NIFI-6934 Added missing license for new test class.
    
    Signed-off-by: Matthew Burgess <[email protected]>
    
    This closes #4350
---
 .../processors/standard/PutDatabaseRecord.java     | 214 ++++++++++++++++-----
 .../processors/standard/db/DatabaseAdapter.java    |  27 +++
 .../db/impl/PostgreSQLDatabaseAdapter.java         |  75 ++++++++
 ...che.nifi.processors.standard.db.DatabaseAdapter |   3 +-
 .../db/impl/TestPostgreSQLDatabaseAdapter.java     | 108 +++++++++++
 5 files changed, 377 insertions(+), 50 deletions(-)

diff --git 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java
 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java
index 926c5cd..d6384d2 100644
--- 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java
+++ 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java
@@ -29,6 +29,8 @@ import org.apache.nifi.annotation.documentation.Tags;
 import org.apache.nifi.annotation.lifecycle.OnScheduled;
 import org.apache.nifi.components.AllowableValue;
 import org.apache.nifi.components.PropertyDescriptor;
+import org.apache.nifi.components.ValidationContext;
+import org.apache.nifi.components.ValidationResult;
 import org.apache.nifi.dbcp.DBCPService;
 import org.apache.nifi.expression.AttributeExpression;
 import org.apache.nifi.expression.ExpressionLanguageScope;
@@ -47,6 +49,7 @@ import 
org.apache.nifi.processor.util.pattern.PartialFunctions;
 import org.apache.nifi.processor.util.pattern.Put;
 import org.apache.nifi.processor.util.pattern.RollbackOnFailure;
 import org.apache.nifi.processor.util.pattern.RoutingResult;
+import org.apache.nifi.processors.standard.db.DatabaseAdapter;
 import org.apache.nifi.serialization.MalformedRecordException;
 import org.apache.nifi.serialization.RecordReader;
 import org.apache.nifi.serialization.RecordReaderFactory;
@@ -70,11 +73,13 @@ import java.sql.SQLIntegrityConstraintViolationException;
 import java.sql.SQLNonTransientException;
 import java.sql.Statement;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.ServiceLoader;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -101,6 +106,7 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
     static final String UPDATE_TYPE = "UPDATE";
     static final String INSERT_TYPE = "INSERT";
     static final String DELETE_TYPE = "DELETE";
+    static final String UPSERT_TYPE = "UPSERT";
     static final String SQL_TYPE = "SQL";   // Not an allowable value in the 
Statement Type property, must be set by attribute
     static final String USE_ATTR_TYPE = "Use statement.type Attribute";
 
@@ -152,11 +158,14 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
     static final PropertyDescriptor STATEMENT_TYPE = new 
PropertyDescriptor.Builder()
             .name("put-db-record-statement-type")
             .displayName("Statement Type")
-            .description("Specifies the type of SQL Statement to generate. If 
'Use statement.type Attribute' is chosen, then the value is taken from the 
statement.type attribute in the "
+            .description("Specifies the type of SQL Statement to generate. "
+                    + "Please refer to the database documentation for a 
description of the behavior of each operation. "
+                    + "Please note that some Database Types may not support 
certain Statement Types. "
+                    + "If 'Use statement.type Attribute' is chosen, then the 
value is taken from the statement.type attribute in the "
                     + "FlowFile. The 'Use statement.type Attribute' option is 
the only one that allows the 'SQL' statement type. If 'SQL' is specified, the 
value of the field specified by the "
                     + "'Field Containing SQL' property is expected to be a 
valid SQL statement on the target database, and will be executed as-is.")
             .required(true)
-            .allowableValues(UPDATE_TYPE, INSERT_TYPE, DELETE_TYPE, 
USE_ATTR_TYPE)
+            .allowableValues(UPDATE_TYPE, INSERT_TYPE, UPSERT_TYPE, 
DELETE_TYPE, USE_ATTR_TYPE)
             .build();
 
     static final PropertyDescriptor DBCP_SERVICE = new 
PropertyDescriptor.Builder()
@@ -299,11 +308,34 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
             
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
             .build();
 
+    static final PropertyDescriptor DB_TYPE;
+
+    protected static final Map<String, DatabaseAdapter> dbAdapters;
+
     protected static List<PropertyDescriptor> propDescriptors;
 
     private Cache<SchemaKey, TableSchema> schemaCache;
 
     static {
+        dbAdapters = new HashMap<>();
+        ArrayList<AllowableValue> dbAdapterValues = new ArrayList<>();
+
+        ServiceLoader<DatabaseAdapter> dbAdapterLoader = 
ServiceLoader.load(DatabaseAdapter.class);
+        dbAdapterLoader.forEach(databaseAdapter -> {
+            dbAdapters.put(databaseAdapter.getName(), databaseAdapter);
+            dbAdapterValues.add(new AllowableValue(databaseAdapter.getName(), 
databaseAdapter.getName(), databaseAdapter.getDescription()));
+        });
+
+        DB_TYPE = new PropertyDescriptor.Builder()
+            .name("db-type")
+            .displayName("Database Type")
+            .description("The type/flavor of database, used for generating 
database-specific code. In many cases the Generic type "
+                + "should suffice, but some databases (such as Oracle) require 
custom SQL clauses. ")
+            .allowableValues(dbAdapterValues.toArray(new 
AllowableValue[dbAdapterValues.size()]))
+            .defaultValue("Generic")
+            .required(false)
+            .build();
+
         final Set<Relationship> r = new HashSet<>();
         r.add(REL_SUCCESS);
         r.add(REL_FAILURE);
@@ -312,6 +344,7 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
 
         final List<PropertyDescriptor> pds = new ArrayList<>();
         pds.add(RECORD_READER_FACTORY);
+        pds.add(DB_TYPE);
         pds.add(STATEMENT_TYPE);
         pds.add(DBCP_SERVICE);
         pds.add(CATALOG_NAME);
@@ -335,6 +368,7 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
 
     private Put<FunctionContext, Connection> process;
     private ExceptionHandler<FunctionContext> exceptionHandler;
+    private DatabaseAdapter databaseAdapter;
 
     @Override
     public Set<Relationship> getRelationships() {
@@ -444,9 +478,29 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
         });
     };
 
+    @Override
+    protected Collection<ValidationResult> customValidate(ValidationContext 
validationContext) {
+        Collection<ValidationResult> validationResults = new 
ArrayList<>(super.customValidate(validationContext));
+
+        DatabaseAdapter databaseAdapter = 
dbAdapters.get(validationContext.getProperty(DB_TYPE).getValue());
+        String statementType = 
validationContext.getProperty(STATEMENT_TYPE).getValue();
+
+        if (UPSERT_TYPE.equals(statementType) && 
!databaseAdapter.supportsUpsert()) {
+            validationResults.add(new ValidationResult.Builder()
+                .subject(STATEMENT_TYPE.getDisplayName())
+                .valid(false)
+                .explanation(databaseAdapter.getName() + " does not support " 
+ statementType)
+                .build()
+            );
+        }
+
+        return validationResults;
+    }
 
     @OnScheduled
     public void onScheduled(final ProcessContext context) {
+        databaseAdapter = 
dbAdapters.get(context.getProperty(DB_TYPE).getValue());
+
         final int tableSchemaCacheSize = 
context.getProperty(TABLE_SCHEMA_CACHE_SIZE).asInteger();
         schemaCache = Caffeine.newBuilder()
                 .maximumSize(tableSchemaCacheSize)
@@ -657,6 +711,9 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
         } else if (DELETE_TYPE.equalsIgnoreCase(statementType)) {
             sqlHolder = generateDelete(recordSchema, fqTableName, tableSchema, 
settings);
 
+        } else if (UPSERT_TYPE.equalsIgnoreCase(statementType)) {
+            sqlHolder = generateUpsert(recordSchema, fqTableName, updateKeys, 
tableSchema, settings);
+
         } else {
             throw new IllegalArgumentException(format("Statement Type %s is 
not valid, FlowFile %s", statementType, flowFile));
         }
@@ -790,20 +847,7 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
     SqlAndIncludedColumns generateInsert(final RecordSchema recordSchema, 
final String tableName, final TableSchema tableSchema, final DMLSettings 
settings)
             throws IllegalArgumentException, SQLException {
 
-        final Set<String> normalizedFieldNames = 
getNormalizedColumnNames(recordSchema, settings.translateFieldNames);
-
-        for (final String requiredColName : 
tableSchema.getRequiredColumnNames()) {
-            final String normalizedColName = 
normalizeColumnName(requiredColName, settings.translateFieldNames);
-            if (!normalizedFieldNames.contains(normalizedColName)) {
-                String missingColMessage = "Record does not have a value for 
the Required column '" + requiredColName + "'";
-                if (settings.failUnmappedColumns) {
-                    getLogger().error(missingColMessage);
-                    throw new IllegalArgumentException(missingColMessage);
-                } else if (settings.warningUnmappedColumns) {
-                    getLogger().warn(missingColMessage);
-                }
-            }
-        }
+        checkValuesForRequiredColumns(recordSchema, tableSchema, settings);
 
         final StringBuilder sqlBuilder = new StringBuilder();
         sqlBuilder.append("INSERT INTO ");
@@ -854,47 +898,59 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
         return new SqlAndIncludedColumns(sqlBuilder.toString(), 
includedColumns);
     }
 
-    SqlAndIncludedColumns generateUpdate(final RecordSchema recordSchema, 
final String tableName, final String updateKeys,
+    SqlAndIncludedColumns generateUpsert(final RecordSchema recordSchema, 
final String tableName, final String updateKeys,
                                          final TableSchema tableSchema, final 
DMLSettings settings)
-            throws IllegalArgumentException, MalformedRecordException, 
SQLException {
+        throws IllegalArgumentException, SQLException, 
MalformedRecordException {
 
-        final Set<String> updateKeyNames;
-        if (updateKeys == null) {
-            updateKeyNames = tableSchema.getPrimaryKeyColumnNames();
-        } else {
-            updateKeyNames = new HashSet<>();
-            for (final String updateKey : updateKeys.split(",")) {
-                updateKeyNames.add(updateKey.trim());
-            }
-        }
+        checkValuesForRequiredColumns(recordSchema, tableSchema, settings);
 
-        if (updateKeyNames.isEmpty()) {
-            throw new SQLIntegrityConstraintViolationException("Table '" + 
tableName + "' does not have a Primary Key and no Update Keys were specified");
-        }
+        Set<String> keyColumnNames = getUpdateKeyColumnNames(tableName, 
updateKeys, tableSchema);
+        Set<String> normalizedKeyColumnNames = 
normalizeKeyColumnNamesAndCheckForValues(recordSchema, updateKeys, settings, 
keyColumnNames);
 
-        final StringBuilder sqlBuilder = new StringBuilder();
-        sqlBuilder.append("UPDATE ");
-        sqlBuilder.append(tableName);
+        List<String> usedColumnNames = new ArrayList<>();
+        List<Integer> usedColumnIndices = new ArrayList<>();
 
-        // Create a Set of all normalized Update Key names, and ensure that 
there is a field in the record
-        // for each of the Update Key fields.
-        final Set<String> normalizedFieldNames = 
getNormalizedColumnNames(recordSchema, settings.translateFieldNames);
-        final Set<String> normalizedUpdateNames = new HashSet<>();
-        for (final String uk : updateKeyNames) {
-            final String normalizedUK = normalizeColumnName(uk, 
settings.translateFieldNames);
-            normalizedUpdateNames.add(normalizedUK);
+        List<String> fieldNames = recordSchema.getFieldNames();
+        if (fieldNames != null) {
+            int fieldCount = fieldNames.size();
 
-            if (!normalizedFieldNames.contains(normalizedUK)) {
-                String missingColMessage = "Record does not have a value for 
the " + (updateKeys == null ? "Primary" : "Update") + "Key column '" + uk + "'";
-                if (settings.failUnmappedColumns) {
-                    getLogger().error(missingColMessage);
-                    throw new MalformedRecordException(missingColMessage);
-                } else if (settings.warningUnmappedColumns) {
-                    getLogger().warn(missingColMessage);
+            for (int i = 0; i < fieldCount; i++) {
+                RecordField field = recordSchema.getField(i);
+                String fieldName = field.getFieldName();
+
+                final ColumnDescription desc = 
tableSchema.getColumns().get(normalizeColumnName(fieldName, 
settings.translateFieldNames));
+                if (desc == null && !settings.ignoreUnmappedFields) {
+                    throw new SQLDataException("Cannot map field '" + 
fieldName + "' to any column in the database");
+                }
+
+                if (desc != null) {
+                    if (settings.escapeColumnNames) {
+                        
usedColumnNames.add(tableSchema.getQuotedIdentifierString() + 
desc.getColumnName() + tableSchema.getQuotedIdentifierString());
+                    } else {
+                        usedColumnNames.add(desc.getColumnName());
+                    }
+                    usedColumnIndices.add(i);
                 }
             }
         }
 
+        String sql = databaseAdapter.getUpsertStatement(tableName, 
usedColumnNames, normalizedKeyColumnNames);
+
+        return new SqlAndIncludedColumns(sql, usedColumnIndices);
+    }
+
+    SqlAndIncludedColumns generateUpdate(final RecordSchema recordSchema, 
final String tableName, final String updateKeys,
+                                         final TableSchema tableSchema, final 
DMLSettings settings)
+            throws IllegalArgumentException, MalformedRecordException, 
SQLException {
+
+
+        final Set<String> keyColumnNames = getUpdateKeyColumnNames(tableName, 
updateKeys, tableSchema);
+        final Set<String> normalizedKeyColumnNames = 
normalizeKeyColumnNamesAndCheckForValues(recordSchema, updateKeys, settings, 
keyColumnNames);
+
+        final StringBuilder sqlBuilder = new StringBuilder();
+        sqlBuilder.append("UPDATE ");
+        sqlBuilder.append(tableName);
+
         // iterate over all of the fields in the record, building the SQL 
statement by adding the column names
         List<String> fieldNames = recordSchema.getFieldNames();
         final List<Integer> includedColumns = new ArrayList<>();
@@ -920,7 +976,7 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
 
                 // Check if this column is an Update Key. If so, skip it for 
now. We will come
                 // back to it after we finish the SET clause
-                if (!normalizedUpdateNames.contains(normalizedColName)) {
+                if (!normalizedKeyColumnNames.contains(normalizedColName)) {
                     if (fieldsFound.getAndIncrement() > 0) {
                         sqlBuilder.append(", ");
                     }
@@ -952,7 +1008,7 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
                 if (desc != null) {
 
                     // Check if this column is a Update Key. If so, add it to 
the WHERE clause
-                    if (normalizedUpdateNames.contains(normalizedColName)) {
+                    if (normalizedKeyColumnNames.contains(normalizedColName)) {
 
                         if (whereFieldCount.getAndIncrement() > 0) {
                             sqlBuilder.append(" AND ");
@@ -1045,6 +1101,66 @@ public class PutDatabaseRecord extends 
AbstractSessionFactoryProcessor {
         return new SqlAndIncludedColumns(sqlBuilder.toString(), 
includedColumns);
     }
 
+    private void checkValuesForRequiredColumns(RecordSchema recordSchema, 
TableSchema tableSchema, DMLSettings settings) {
+        final Set<String> normalizedFieldNames = 
getNormalizedColumnNames(recordSchema, settings.translateFieldNames);
+
+        for (final String requiredColName : 
tableSchema.getRequiredColumnNames()) {
+            final String normalizedColName = 
normalizeColumnName(requiredColName, settings.translateFieldNames);
+            if (!normalizedFieldNames.contains(normalizedColName)) {
+                String missingColMessage = "Record does not have a value for 
the Required column '" + requiredColName + "'";
+                if (settings.failUnmappedColumns) {
+                    getLogger().error(missingColMessage);
+                    throw new IllegalArgumentException(missingColMessage);
+                } else if (settings.warningUnmappedColumns) {
+                    getLogger().warn(missingColMessage);
+                }
+            }
+        }
+    }
+
+    private Set<String> getUpdateKeyColumnNames(String tableName, String 
updateKeys, TableSchema tableSchema) throws 
SQLIntegrityConstraintViolationException {
+        final Set<String> updateKeyColumnNames;
+
+        if (updateKeys == null) {
+            updateKeyColumnNames = tableSchema.getPrimaryKeyColumnNames();
+        } else {
+            updateKeyColumnNames = new HashSet<>();
+            for (final String updateKey : updateKeys.split(",")) {
+                updateKeyColumnNames.add(updateKey.trim());
+            }
+        }
+
+        if (updateKeyColumnNames.isEmpty()) {
+            throw new SQLIntegrityConstraintViolationException("Table '" + 
tableName + "' does not have a Primary Key and no Update Keys were specified");
+        }
+
+        return updateKeyColumnNames;
+    }
+
+    private Set<String> normalizeKeyColumnNamesAndCheckForValues(RecordSchema 
recordSchema, String updateKeys, DMLSettings settings, Set<String> 
updateKeyColumnNames) throws MalformedRecordException {
+        // Create a Set of all normalized Update Key names, and ensure that 
there is a field in the record
+        // for each of the Update Key fields.
+        final Set<String> normalizedRecordFieldNames = 
getNormalizedColumnNames(recordSchema, settings.translateFieldNames);
+
+        final Set<String> normalizedKeyColumnNames = new HashSet<>();
+        for (final String updateKeyColumnName : updateKeyColumnNames) {
+            final String normalizedKeyColumnName = 
normalizeColumnName(updateKeyColumnName, settings.translateFieldNames);
+            normalizedKeyColumnNames.add(normalizedKeyColumnName);
+
+            if (!normalizedRecordFieldNames.contains(normalizedKeyColumnName)) 
{
+                String missingColMessage = "Record does not have a value for 
the " + (updateKeys == null ? "Primary" : "Update") + "Key column '" + 
updateKeyColumnName + "'";
+                if (settings.failUnmappedColumns) {
+                    getLogger().error(missingColMessage);
+                    throw new MalformedRecordException(missingColMessage);
+                } else if (settings.warningUnmappedColumns) {
+                    getLogger().warn(missingColMessage);
+                }
+            }
+        }
+
+        return normalizedKeyColumnNames;
+    }
+
     private static String normalizeColumnName(final String colName, final 
boolean translateColumnNames) {
         return colName == null ? null : (translateColumnNames ? 
colName.toUpperCase().replace("_", "") : colName);
     }
diff --git 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/DatabaseAdapter.java
 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/DatabaseAdapter.java
index e1251c4..40de0b8 100644
--- 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/DatabaseAdapter.java
+++ 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/DatabaseAdapter.java
@@ -16,6 +16,9 @@
  */
 package org.apache.nifi.processors.standard.db;
 
+import java.util.Collection;
+import java.util.List;
+
 /**
  * Interface for RDBMS/JDBC-specific code.
  */
@@ -56,6 +59,30 @@ public interface DatabaseAdapter {
     }
 
     /**
+     * Tells whether this adapter supports UPSERT.
+     *
+     * @return true if UPSERT is supported, false otherwise
+     */
+    default boolean supportsUpsert() {
+        return false;
+    }
+
+    /**
+     * Returns an SQL UPSERT statement - i.e. UPDATE record or INSERT if id 
doesn't exist.
+     * <br /><br />
+     * There is no standard way of doing this so not all adapters support it - 
use together with {@link #supportsUpsert()}!
+     *
+     * @param table                     The name of the table in which to 
update/insert a record into.
+     * @param columnNames               The name of the columns in the table 
to add values to.
+     * @param uniqueKeyColumnNames      The name of the columns that form a 
unique key.
+     * @return                          A String containing the parameterized 
jdbc SQL statement.
+     *                                      The order and number of parameters 
are the same as that of the provided column list.
+     */
+    default String getUpsertStatement(String table, List<String> columnNames, 
Collection<String> uniqueKeyColumnNames) {
+        throw new UnsupportedOperationException("UPSERT is not supported for " 
+ getName());
+    }
+
+    /**
      * <p>Returns a bare identifier string by removing wrapping escape 
characters
      * from identifier strings such as table and column names.</p>
      * <p>The default implementation of this method removes double quotes.
diff --git 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/impl/PostgreSQLDatabaseAdapter.java
 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/impl/PostgreSQLDatabaseAdapter.java
new file mode 100644
index 0000000..03def34
--- /dev/null
+++ 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/impl/PostgreSQLDatabaseAdapter.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.processors.standard.db.impl;
+
+import com.google.common.base.Preconditions;
+import org.apache.nifi.util.StringUtils;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class PostgreSQLDatabaseAdapter extends GenericDatabaseAdapter {
+    @Override
+    public String getName() {
+        return "PostgreSQL";
+    }
+
+    @Override
+    public String getDescription() {
+        return "Generates PostgreSQL compatible SQL";
+    }
+
+    @Override
+    public boolean supportsUpsert() {
+        return true;
+    }
+
+    @Override
+    public String getUpsertStatement(String table, List<String> columnNames, 
Collection<String> uniqueKeyColumnNames) {
+        Preconditions.checkArgument(!StringUtils.isEmpty(table), "Table name 
cannot be null or blank");
+        Preconditions.checkArgument(columnNames != null && 
!columnNames.isEmpty(), "Column names cannot be null or empty");
+        Preconditions.checkArgument(uniqueKeyColumnNames != null && 
!uniqueKeyColumnNames.isEmpty(), "Key column names cannot be null or empty");
+
+        String columns = columnNames.stream()
+            .collect(Collectors.joining(", "));
+
+        String parameterizedInsertValues = columnNames.stream()
+            .map(__ -> "?")
+            .collect(Collectors.joining(", "));
+
+        String updateValues = columnNames.stream()
+            .map(columnName -> "EXCLUDED." + columnName)
+            .collect(Collectors.joining(", "));
+
+        String conflictClause = "(" + 
uniqueKeyColumnNames.stream().collect(Collectors.joining(", ")) + ")";
+
+        StringBuilder statementStringBuilder = new StringBuilder("INSERT INTO 
")
+            .append(table)
+            .append("(").append(columns).append(")")
+            .append(" VALUES ")
+            .append("(").append(parameterizedInsertValues).append(")")
+            .append(" ON CONFLICT ")
+            .append(conflictClause)
+            .append(" DO UPDATE SET ")
+            .append("(").append(columns).append(")")
+            .append(" = ")
+            .append("(").append(updateValues).append(")");
+
+        return statementStringBuilder.toString();
+    }
+}
diff --git 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processors.standard.db.DatabaseAdapter
 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processors.standard.db.DatabaseAdapter
index 2f53cf7..f104782 100644
--- 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processors.standard.db.DatabaseAdapter
+++ 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processors.standard.db.DatabaseAdapter
@@ -17,4 +17,5 @@ 
org.apache.nifi.processors.standard.db.impl.OracleDatabaseAdapter
 org.apache.nifi.processors.standard.db.impl.Oracle12DatabaseAdapter
 org.apache.nifi.processors.standard.db.impl.MSSQLDatabaseAdapter
 org.apache.nifi.processors.standard.db.impl.MSSQL2008DatabaseAdapter
-org.apache.nifi.processors.standard.db.impl.MySQLDatabaseAdapter
\ No newline at end of file
+org.apache.nifi.processors.standard.db.impl.MySQLDatabaseAdapter
+org.apache.nifi.processors.standard.db.impl.PostgreSQLDatabaseAdapter
\ No newline at end of file
diff --git 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/db/impl/TestPostgreSQLDatabaseAdapter.java
 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/db/impl/TestPostgreSQLDatabaseAdapter.java
new file mode 100644
index 0000000..15fc95a
--- /dev/null
+++ 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/db/impl/TestPostgreSQLDatabaseAdapter.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.processors.standard.db.impl;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class TestPostgreSQLDatabaseAdapter {
+    private PostgreSQLDatabaseAdapter testSubject;
+
+    @Before
+    public void setUp() throws Exception {
+        testSubject = new PostgreSQLDatabaseAdapter();
+    }
+
+    @Test
+    public void testSupportsUpsert() throws Exception {
+        assertTrue(testSubject.getClass().getSimpleName() + " should support 
upsert", testSubject.supportsUpsert());
+    }
+
+    @Test
+    public void testGetUpsertStatementWithNullTableName() throws Exception {
+        testGetUpsertStatement(null, Arrays.asList("notEmpty"), 
Arrays.asList("notEmpty"), new IllegalArgumentException("Table name cannot be 
null or blank"));
+    }
+
+    @Test
+    public void testGetUpsertStatementWithBlankTableName() throws Exception {
+        testGetUpsertStatement("", Arrays.asList("notEmpty"), 
Arrays.asList("notEmpty"), new IllegalArgumentException("Table name cannot be 
null or blank"));
+    }
+
+    @Test
+    public void testGetUpsertStatementWithNullColumnNames() throws Exception {
+        testGetUpsertStatement("notEmpty", null, Arrays.asList("notEmpty"), 
new IllegalArgumentException("Column names cannot be null or empty"));
+    }
+
+    @Test
+    public void testGetUpsertStatementWithEmptyColumnNames() throws Exception {
+        testGetUpsertStatement("notEmpty", Collections.emptyList(), 
Arrays.asList("notEmpty"), new IllegalArgumentException("Column names cannot be 
null or empty"));
+    }
+
+    @Test
+    public void testGetUpsertStatementWithNullKeyColumnNames() throws 
Exception {
+        testGetUpsertStatement("notEmpty", Arrays.asList("notEmpty"), null, 
new IllegalArgumentException("Key column names cannot be null or empty"));
+    }
+
+    @Test
+    public void testGetUpsertStatementWithEmptyKeyColumnNames() throws 
Exception {
+        testGetUpsertStatement("notEmpty", Arrays.asList("notEmpty"), 
Collections.emptyList(), new IllegalArgumentException("Key column names cannot 
be null or empty"));
+    }
+
+    @Test
+    public void testGetUpsertStatement() throws Exception {
+        // GIVEN
+        String tableName = "table";
+        List<String> columnNames = Arrays.asList("column1","column2", 
"column3", "column4");
+        Collection<String> uniqueKeyColumnNames = 
Arrays.asList("column2","column4");
+
+        String expected = "INSERT INTO" +
+            " table(column1, column2, column3, column4) VALUES (?, ?, ?, ?)" +
+            " ON CONFLICT (column2, column4)" +
+            " DO UPDATE SET" +
+            " (column1, column2, column3, column4) = (EXCLUDED.column1, 
EXCLUDED.column2, EXCLUDED.column3, EXCLUDED.column4)";
+
+        // WHEN
+        // THEN
+        testGetUpsertStatement(tableName, columnNames, uniqueKeyColumnNames, 
expected);
+    }
+
+    private void testGetUpsertStatement(String tableName, List<String> 
columnNames, Collection<String> uniqueKeyColumnNames, IllegalArgumentException 
expected) {
+        try {
+            testGetUpsertStatement(tableName, columnNames, 
uniqueKeyColumnNames, (String)null);
+            fail();
+        } catch (IllegalArgumentException e) {
+            assertEquals(expected.getMessage(), e.getMessage());
+        }
+    }
+
+    private void testGetUpsertStatement(String tableName, List<String> 
columnNames, Collection<String> uniqueKeyColumnNames, String expected) {
+        // WHEN
+        String actual = testSubject.getUpsertStatement(tableName, columnNames, 
uniqueKeyColumnNames);
+
+        // THEN
+        assertEquals(expected, actual);
+    }
+}

Reply via email to