This is an automated email from the ASF dual-hosted git repository.

duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new c3a7984  fix encrypt on duplicate key update (#14309)
c3a7984 is described below

commit c3a7984756d1c9c822e3e2feaf174ee5360235e6
Author: cheese8 <[email protected]>
AuthorDate: Wed Jan 5 08:55:27 2022 +0800

    fix encrypt on duplicate key update (#14309)
    
    * fix Encrypt on duplicate key update
    
    * finish
    
    * fix ci
    
    * fix review
    
    * recover sharding ut
    
    * fix ci
    
    * improve & retrigger ci
    
    * improve & retrigger ci
    
    * improve & retrigger ci
    
    * refactor
    
    * add blank line
    
    * throw clear exceptions on not matched values
    
    * fix review
    
    * add EncryptFunctionAssignmentToken
    
    * remove empty lines
    
    * fixbug
    
    * fixbug
    
    * retrigger ci
---
 ...OnDuplicateKeyUpdateValueParameterRewriter.java |  4 ++
 .../impl/EncryptInsertOnUpdateTokenGenerator.java  | 57 +++++++++++++++-
 .../token/pojo/EncryptFunctionAssignmentToken.java | 76 ++++++++++++++++++++++
 .../insert/values/OnDuplicateUpdateContext.java    | 10 ++-
 .../resources/scenario/encrypt/case/insert.xml     | 12 +++-
 5 files changed, 153 insertions(+), 6 deletions(-)

diff --git 
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/impl/EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter.java
 
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/impl/EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter.java
index 504167e..47373e5 100644
--- 
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/impl/EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter.java
+++ 
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/impl/EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter.java
@@ -27,6 +27,7 @@ import 
org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import 
org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
 import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder;
 import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
 
 import java.util.Collection;
@@ -59,6 +60,9 @@ public final class 
EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter exten
             Optional<EncryptAlgorithm> encryptor = 
getEncryptRule().findEncryptor(schemaName, tableName, encryptLogicColumnName);
             encryptor.ifPresent(optional -> {
                 Object plainColumnValue = 
onDuplicateKeyUpdateValueContext.getValue(columnIndex);
+                if (plainColumnValue instanceof FunctionSegment && 
"VALUES".equalsIgnoreCase(((FunctionSegment) 
plainColumnValue).getFunctionName())) {
+                    return;
+                }
                 Object cipherColumnValue = 
encryptor.get().encrypt(plainColumnValue);
                 
groupedParameterBuilder.getGenericParameterBuilder().addReplacedParameters(columnIndex,
 cipherColumnValue);
                 Collection<Object> addedParameters = new LinkedList<>();
diff --git 
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptInsertOnUpdateTokenGenerator.java
 
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptInsertOnUpdateTokenGenerator.java
index 244b27f..f93ddbd 100644
--- 
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptInsertOnUpdateTokenGenerator.java
+++ 
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptInsertOnUpdateTokenGenerator.java
@@ -20,13 +20,17 @@ package 
org.apache.shardingsphere.encrypt.rewrite.token.generator.impl;
 import com.google.common.base.Preconditions;
 import 
org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
 import 
org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptAssignmentToken;
+import 
org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptFunctionAssignmentToken;
 import 
org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptLiteralAssignmentToken;
 import 
org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptParameterAssignmentToken;
-import 
org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
 import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import 
org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
+import org.apache.shardingsphere.infra.exception.ShardingSphereException;
+import 
org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDuplicateKeyColumnsSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
@@ -61,9 +65,18 @@ public final class EncryptInsertOnUpdateTokenGenerator 
extends BaseEncryptSQLTok
         }
         String schemaName = insertStatementContext.getSchemaName();
         for (AssignmentSegment each : onDuplicateKeyColumnsSegments) {
-            if (getEncryptRule().findEncryptor(schemaName, tableName, 
each.getColumns().get(0).getIdentifier().getValue()).isPresent()) {
-                generateSQLToken(schemaName, tableName, 
each).ifPresent(result::add);
+            boolean leftEncryptorPresent = 
getEncryptRule().findEncryptor(schemaName, tableName, 
each.getColumns().get(0).getIdentifier().getValue()).isPresent();
+            if (each.getValue() instanceof FunctionSegment && 
"VALUES".equalsIgnoreCase(((FunctionSegment) 
each.getValue()).getFunctionName())) {
+                ColumnSegment rightColumn = (ColumnSegment) ((FunctionSegment) 
each.getValue()).getParameters().stream().findFirst().get();
+                boolean rightEncryptorPresent = 
getEncryptRule().findEncryptor(schemaName, tableName, 
rightColumn.getIdentifier().getValue()).isPresent();
+                if (!leftEncryptorPresent && !rightEncryptorPresent) {
+                    continue;
+                }
+            }
+            if (!leftEncryptorPresent) {
+                continue;
             }
+            generateSQLToken(schemaName, tableName, 
each).ifPresent(result::add);
         }
         return result;
     }
@@ -72,6 +85,9 @@ public final class EncryptInsertOnUpdateTokenGenerator 
extends BaseEncryptSQLTok
         if (assignmentSegment.getValue() instanceof 
ParameterMarkerExpressionSegment) {
             return Optional.of(generateParameterSQLToken(tableName, 
assignmentSegment));
         }
+        if (assignmentSegment.getValue() instanceof FunctionSegment && 
"VALUES".equalsIgnoreCase(((FunctionSegment) 
assignmentSegment.getValue()).getFunctionName())) {
+            return Optional.of(generateValuesSQLToken(schemaName, tableName, 
assignmentSegment, (FunctionSegment) assignmentSegment.getValue()));
+        }
         if (assignmentSegment.getValue() instanceof LiteralExpressionSegment) {
             return Optional.of(generateLiteralSQLToken(schemaName, tableName, 
assignmentSegment));
         }
@@ -95,6 +111,41 @@ public final class EncryptInsertOnUpdateTokenGenerator 
extends BaseEncryptSQLTok
         return result;
     }
     
+    private EncryptAssignmentToken generateValuesSQLToken(final String 
schemaName, final String tableName, final AssignmentSegment assignmentSegment, 
final FunctionSegment functionSegment) {
+        ColumnSegment columnSegment = assignmentSegment.getColumns().get(0);
+        String column = columnSegment.getIdentifier().getValue();
+        ColumnSegment valueColumnSegment = (ColumnSegment) 
functionSegment.getParameters().stream().findFirst().get();
+        String valueColumn = valueColumnSegment.getIdentifier().getValue();
+        EncryptFunctionAssignmentToken result = new 
EncryptFunctionAssignmentToken(columnSegment.getStartIndex(), 
assignmentSegment.getStopIndex());
+        boolean cipherColumnPresent = 
getEncryptRule().findEncryptor(schemaName, tableName, column).isPresent();
+        boolean cipherValueColumnPresent = 
getEncryptRule().findEncryptor(schemaName, tableName, valueColumn).isPresent();
+        if (cipherColumnPresent && cipherValueColumnPresent) {
+            String cipherColumn = getEncryptRule().getCipherColumn(tableName, 
column);
+            String cipherValueColumn = 
getEncryptRule().getCipherColumn(tableName, valueColumn);
+            result.addAssignment(cipherColumn, String.format("VALUES(%s)", 
cipherValueColumn));
+        } else if (cipherColumnPresent != cipherValueColumnPresent) {
+            throw new ShardingSphereException("The SQL clause `%s` is 
unsupported in encrypt rule.", String.format("%s=VALUES(%s)", column, 
valueColumn));
+        }
+        Optional<String> assistedQueryColumn = 
getEncryptRule().findAssistedQueryColumn(tableName, column);
+        Optional<String> valueAssistedQueryColumn = 
getEncryptRule().findAssistedQueryColumn(tableName, valueColumn);
+        if (assistedQueryColumn.isPresent() && 
valueAssistedQueryColumn.isPresent()) {
+            result.addAssignment(assistedQueryColumn.get(), 
String.format("VALUES(%s)", valueAssistedQueryColumn.get()));
+        } else if (assistedQueryColumn.isPresent() != 
valueAssistedQueryColumn.isPresent()) {
+            throw new ShardingSphereException("The SQL clause `%s` is 
unsupported in encrypt rule.", String.format("%s=VALUES(%s)", column, 
valueColumn));
+        }
+        Optional<String> plainColumn = 
getEncryptRule().findPlainColumn(tableName, column);
+        Optional<String> valuePlainColumn = 
getEncryptRule().findPlainColumn(tableName, valueColumn);
+        if (plainColumn.isPresent() && valuePlainColumn.isPresent()) {
+            result.addAssignment(plainColumn.get(), 
String.format("VALUES(%s)", valuePlainColumn.get()));
+        } else if (plainColumn.isPresent() != valuePlainColumn.isPresent()) {
+            throw new ShardingSphereException("The SQL clause `%s` is 
unsupported in encrypt rule.", String.format("%s=VALUES(%s)", column, 
valueColumn));
+        }
+        if (result.getAssignment().isEmpty()) {
+            throw new ShardingSphereException("The SQL clause `%s` is 
unsupported in encrypt rule.", String.format("%s=VALUES(%s)", column, 
valueColumn));
+        }
+        return result;
+    }
+    
     private void addCipherColumn(final String tableName, final String 
columnName, final EncryptParameterAssignmentToken token) {
         token.addColumnName(getEncryptRule().getCipherColumn(tableName, 
columnName));
     }
diff --git 
a/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptFunctionAssignmentToken.java
 
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptFunctionAssignmentToken.java
new file mode 100644
index 0000000..5268da0
--- /dev/null
+++ 
b/shardingsphere-features/shardingsphere-encrypt/shardingsphere-encrypt-core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptFunctionAssignmentToken.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.encrypt.rewrite.token.pojo;
+
+import com.google.common.base.Joiner;
+import lombok.RequiredArgsConstructor;
+
+import java.util.Collection;
+import java.util.LinkedList;
+
+/**
+ * Function assignment token for encrypt.
+ */
+public final class EncryptFunctionAssignmentToken extends 
EncryptAssignmentToken {
+    
+    private final Collection<FunctionAssignment> assignments = new 
LinkedList<>();
+    
+    public EncryptFunctionAssignmentToken(final int startIndex, final int 
stopIndex) {
+        super(startIndex, stopIndex);
+    }
+    
+    /**
+     * Add assignment.
+     *
+     * @param columnName column name
+     * @param value assignment value
+     */
+    public void addAssignment(final String columnName, final Object value) {
+        assignments.add(new FunctionAssignment(columnName, value));
+    }
+    
+    /**
+     * Get assignments.
+     * @return FunctionAssignment collection
+     */
+    public Collection<FunctionAssignment> getAssignment() {
+        return assignments;
+    }
+    
+    @Override
+    public String toString() {
+        return Joiner.on(", ").join(assignments);
+    }
+    
+    @RequiredArgsConstructor
+    private static final class FunctionAssignment {
+        
+        private final String columnName;
+        
+        private final Object value;
+        
+        @Override
+        public String toString() {
+            return String.format("%s = %s", columnName, toString(value));
+        }
+    
+        private String toString(final Object value) {
+            return String.class == value.getClass() ? String.format("%s", 
value) : value.toString();
+        }
+    }
+}
diff --git 
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContext.java
 
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContext.java
index c0e50ec..1db5ada 100644
--- 
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContext.java
+++ 
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContext.java
@@ -23,6 +23,7 @@ import lombok.ToString;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtil;
@@ -79,8 +80,13 @@ public final class OnDuplicateUpdateContext {
      */
     public Object getValue(final int index) {
         ExpressionSegment valueExpression = valueExpressions.get(index);
-        return valueExpression instanceof ParameterMarkerExpressionSegment 
-                ? 
parameters.get(getParameterIndex((ParameterMarkerExpressionSegment) 
valueExpression)) : ((LiteralExpressionSegment) valueExpression).getLiterals();
+        if (valueExpression instanceof ParameterMarkerExpressionSegment) {
+            return 
parameters.get(getParameterIndex((ParameterMarkerExpressionSegment) 
valueExpression));
+        }
+        if (valueExpression instanceof FunctionSegment) {
+            return valueExpression;
+        }
+        return ((LiteralExpressionSegment) valueExpression).getLiterals();
     }
     
     private int getParameterIndex(final ParameterMarkerExpressionSegment 
parameterMarkerExpression) {
diff --git 
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/insert.xml
 
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/insert.xml
index a1f21d4..bb79f1f 100644
--- 
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/insert.xml
+++ 
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/encrypt/case/insert.xml
@@ -21,7 +21,17 @@
         <input sql="INSERT INTO t_account(account_id, certificate_number, 
password, amount, status) VALUES (?, ?, ?, ?, ?), (2, '222X', 'bbb', 2000, 
'OK'), (?, ?, ?, ?, ?), (4, '444X', 'ddd', 4000, 'OK')" parameters="1, 111X, 
aaa, 1000, OK, 3, 333X, ccc, 3000, OK" />
         <output sql="INSERT INTO t_account(account_id, 
cipher_certificate_number, assisted_query_certificate_number, cipher_password, 
assisted_query_password, cipher_amount, status) VALUES (?, ?, ?, ?, ?, ?, ?), 
(2, 'encrypt_222X', 'assisted_query_222X', 'encrypt_bbb', 'assisted_query_bbb', 
'encrypt_2000', 'OK'), (?, ?, ?, ?, ?, ?, ?), (4, 'encrypt_444X', 
'assisted_query_444X', 'encrypt_ddd', 'assisted_query_ddd', 'encrypt_4000', 
'OK')" parameters="1, encrypt_111X, assisted_query_111X, e [...]
     </rewrite-assertion>
-    
+
+    <rewrite-assertion id="insert_values_on_duplicated_update_values" 
db-types="MySQL">
+        <input sql="INSERT INTO t_account(account_id, certificate_number, 
password, amount, status) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE 
certificate_number = VALUES(certificate_number)" parameters="1, 111X, aaa, 
1000, OK" />
+        <output sql="INSERT INTO t_account(account_id, 
cipher_certificate_number, assisted_query_certificate_number, cipher_password, 
assisted_query_password, cipher_amount, status) VALUES (?, ?, ?, ?, ?, ?, ?) ON 
DUPLICATE KEY UPDATE cipher_certificate_number = 
VALUES(cipher_certificate_number), assisted_query_certificate_number = 
VALUES(assisted_query_certificate_number)" parameters="1, encrypt_111X, 
assisted_query_111X, encrypt_aaa, assisted_query_aaa, encrypt_1000, OK" />
+    </rewrite-assertion>
+
+    <rewrite-assertion 
id="insert_values_on_duplicated_update_values_wrong_match" db-types="MySQL">
+        <input sql="INSERT INTO t_account(account_id, certificate_number, 
password, amount, status) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE status 
= VALUES(status)" parameters="1, 111X, aaa, 1000, OK" />
+        <output sql="INSERT INTO t_account(account_id, 
cipher_certificate_number, assisted_query_certificate_number, cipher_password, 
assisted_query_password, cipher_amount, status) VALUES (?, ?, ?, ?, ?, ?, ?) ON 
DUPLICATE KEY UPDATE status = VALUES(status)" parameters="1, encrypt_111X, 
assisted_query_111X, encrypt_aaa, assisted_query_aaa, encrypt_1000, OK" />
+    </rewrite-assertion>
+
     <rewrite-assertion id="insert_values_with_columns_for_literals" 
db-types="MySQL">
         <input sql="INSERT INTO t_account(account_id, certificate_number, 
password, amount, status) VALUES (1, '111X', 'aaa', 1000, 'OK'), (2, '222X', 
'bbb', 2000, 'OK'), (3, '333X', 'ccc', 3000, 'OK'), (4, '444X', 'ddd', 4000, 
'OK')" />
         <output sql="INSERT INTO t_account(account_id, 
cipher_certificate_number, assisted_query_certificate_number, cipher_password, 
assisted_query_password, cipher_amount, status) VALUES (1, 'encrypt_111X', 
'assisted_query_111X', 'encrypt_aaa', 'assisted_query_aaa', 'encrypt_1000', 
'OK'), (2, 'encrypt_222X', 'assisted_query_222X', 'encrypt_bbb', 
'assisted_query_bbb', 'encrypt_2000', 'OK'), (3, 'encrypt_333X', 
'assisted_query_333X', 'encrypt_ccc', 'assisted_query_ccc', 'encrypt_3000', ' 
[...]

Reply via email to