This is an automated email from the ASF dual-hosted git repository.

panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 26212fcf745 Add encryptor check for insert select statement without 
columns (#28430)
26212fcf745 is described below

commit 26212fcf745f16e9c0ad5aba913fd11919156ff8
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Thu Sep 14 08:16:48 2023 +0800

    Add encryptor check for insert select statement without columns (#28430)
---
 .../EncryptInsertDefaultColumnsTokenGenerator.java | 13 ++++++++
 .../fixture/EncryptGeneratorFixtureBuilder.java    | 16 +++++++---
 .../EncryptInsertCipherNameTokenGeneratorTest.java |  2 +-
 ...ryptInsertDefaultColumnsTokenGeneratorTest.java |  8 +++++
 .../statement/dml/InsertStatementBinder.java       | 22 ++++++++++++-
 .../statement/InsertStatementBinderTest.java       | 36 +++++++++++++++++-----
 .../sql/common/statement/dml/InsertStatement.java  |  2 ++
 7 files changed, 85 insertions(+), 14 deletions(-)

diff --git 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
index f9e6b71f0c2..eccaed00007 100644
--- 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
+++ 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
@@ -20,17 +20,23 @@ package 
org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
 import com.google.common.base.Preconditions;
 import lombok.Setter;
 import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptRuleAware;
+import 
org.apache.shardingsphere.encrypt.rewrite.token.util.EncryptTokenGeneratorUtils;
 import org.apache.shardingsphere.encrypt.rule.EncryptRule;
 import org.apache.shardingsphere.encrypt.rule.EncryptTable;
 import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
+import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
 import 
org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
 import 
org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
+import 
org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
+import 
org.apache.shardingsphere.infra.exception.core.external.sql.type.generic.UnsupportedSQLOperationException;
 import 
org.apache.shardingsphere.infra.rewrite.sql.token.generator.OptionalSQLTokenGenerator;
 import 
org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.PreviousSQLTokensAware;
 import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
 import 
org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.UseDefaultInsertColumnsToken;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.InsertColumnsSegment;
 
+import java.util.Collection;
 import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
@@ -80,6 +86,13 @@ public final class EncryptInsertDefaultColumnsTokenGenerator 
implements Optional
     private UseDefaultInsertColumnsToken generateNewSQLToken(final 
InsertStatementContext insertStatementContext, final String tableName) {
         Optional<InsertColumnsSegment> insertColumnsSegment = 
insertStatementContext.getSqlStatement().getInsertColumns();
         Preconditions.checkState(insertColumnsSegment.isPresent());
+        if (null != insertStatementContext.getInsertSelectContext()) {
+            Collection<ColumnSegment> derivedInsertColumns = 
insertStatementContext.getSqlStatement().getDerivedInsertColumns();
+            Collection<Projection> projections = 
insertStatementContext.getInsertSelectContext().getSelectStatementContext().getProjectionsContext().getExpandProjections();
+            ShardingSpherePreconditions.checkState(derivedInsertColumns.size() 
== projections.size(), () -> new UnsupportedSQLOperationException("Column count 
doesn't match value count."));
+            
ShardingSpherePreconditions.checkState(EncryptTokenGeneratorUtils.isAllInsertSelectColumnsUseSameEncryptor(derivedInsertColumns,
 projections, encryptRule),
+                    () -> new UnsupportedSQLOperationException("Can not use 
different encryptor in insert select columns"));
+        }
         return new UseDefaultInsertColumnsToken(
                 insertColumnsSegment.get().getStopIndex(), 
getColumnNames(insertStatementContext, encryptRule.getEncryptTable(tableName), 
insertStatementContext.getColumnNames()));
     }
diff --git 
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
 
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
index 9293e78ff0c..d6053fac6d5 100644
--- 
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
+++ 
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
@@ -129,7 +129,7 @@ public final class EncryptGeneratorFixtureBuilder {
         return result;
     }
     
-    private static InsertStatement createInsertSelectStatement() {
+    private static InsertStatement createInsertSelectStatement(final boolean 
containsInsertColumns) {
         InsertStatement result = new MySQLInsertStatement();
         result.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new 
IdentifierValue("t_user"))));
         ColumnSegment userIdColumn = new ColumnSegment(0, 0, new 
IdentifierValue("user_id"));
@@ -138,8 +138,13 @@ public final class EncryptGeneratorFixtureBuilder {
         ColumnSegment userNameColumn = new ColumnSegment(0, 0, new 
IdentifierValue("user_name"));
         userNameColumn.setColumnBoundedInfo(new ColumnSegmentBoundedInfo(new 
IdentifierValue(DefaultDatabase.LOGIC_NAME), new 
IdentifierValue(DefaultDatabase.LOGIC_NAME),
                 new IdentifierValue("t_user"), new 
IdentifierValue("user_name")));
-        InsertColumnsSegment insertColumnsSegment = new 
InsertColumnsSegment(0, 0, Arrays.asList(userIdColumn, userNameColumn));
-        result.setInsertColumns(insertColumnsSegment);
+        List<ColumnSegment> insertColumns = Arrays.asList(userIdColumn, 
userNameColumn);
+        if (containsInsertColumns) {
+            result.setInsertColumns(new InsertColumnsSegment(0, 0, 
insertColumns));
+        } else {
+            result.setInsertColumns(new InsertColumnsSegment(0, 0, 
Collections.emptyList()));
+            result.getDerivedInsertColumns().addAll(insertColumns);
+        }
         MySQLSelectStatement selectStatement = new MySQLSelectStatement();
         selectStatement.setFrom(new SimpleTableSegment(new TableNameSegment(0, 
0, new IdentifierValue("t_user"))));
         ProjectionsSegment projections = new ProjectionsSegment(0, 0);
@@ -220,10 +225,11 @@ public final class EncryptGeneratorFixtureBuilder {
      * Create insert select statement context.
      *
      * @param params parameters
+     * @param containsInsertColumns contains insert columns
      * @return created insert select statement context
      */
-    public static InsertStatementContext 
createInsertSelectStatementContext(final List<Object> params) {
-        InsertStatement insertStatement = createInsertSelectStatement();
+    public static InsertStatementContext 
createInsertSelectStatementContext(final List<Object> params, final boolean 
containsInsertColumns) {
+        InsertStatement insertStatement = 
createInsertSelectStatement(containsInsertColumns);
         ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, 
RETURNS_DEEP_STUBS);
         ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
         
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(schema);
diff --git 
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
 
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
index b2e6abc8def..f3242b9b9ff 100644
--- 
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
+++ 
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
@@ -58,6 +58,6 @@ class EncryptInsertCipherNameTokenGeneratorTest {
     
     @Test
     void 
assertGenerateSQLTokensWhenInsertColumnsUseDifferentEncryptorWithSelectProjection()
 {
-        assertThrows(UnsupportedSQLOperationException.class, () -> 
generator.generateSQLTokens(EncryptGeneratorFixtureBuilder.createInsertSelectStatementContext(Collections.emptyList())).size());
+        assertThrows(UnsupportedSQLOperationException.class, () -> 
generator.generateSQLTokens(EncryptGeneratorFixtureBuilder.createInsertSelectStatementContext(Collections.emptyList(),
 true)));
     }
 }
diff --git 
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGeneratorTest.java
 
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGeneratorTest.java
index 617597f6ff3..320fdeaf8bc 100644
--- 
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGeneratorTest.java
+++ 
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGeneratorTest.java
@@ -18,6 +18,7 @@
 package org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
 
 import 
org.apache.shardingsphere.encrypt.rewrite.token.generator.fixture.EncryptGeneratorFixtureBuilder;
+import 
org.apache.shardingsphere.infra.exception.core.external.sql.type.generic.UnsupportedSQLOperationException;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
@@ -26,6 +27,7 @@ import java.util.Collections;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 
 class EncryptInsertDefaultColumnsTokenGeneratorTest {
     
@@ -54,4 +56,10 @@ class EncryptInsertDefaultColumnsTokenGeneratorTest {
         
assertThat(generator.generateSQLToken(EncryptGeneratorFixtureBuilder.createInsertStatementContext(Collections.emptyList())).toString(),
                 is("(id, name, status, pwd_cipher, pwd_assist, pwd_like)"));
     }
+    
+    @Test
+    void 
assertGenerateSQLTokensWhenInsertColumnsUseDifferentEncryptorWithSelectProjection()
 {
+        
generator.setPreviousSQLTokens(EncryptGeneratorFixtureBuilder.getPreviousSQLTokens());
+        assertThrows(UnsupportedSQLOperationException.class, () -> 
generator.generateSQLToken(EncryptGeneratorFixtureBuilder.createInsertSelectStatementContext(Collections.emptyList(),
 false)));
+    }
 }
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
index aeb08a73358..08719ac64bf 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
@@ -25,11 +25,16 @@ import 
org.apache.shardingsphere.infra.binder.segment.from.impl.SimpleTableSegme
 import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinder;
 import 
org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.LinkedHashMap;
+import java.util.LinkedList;
 import java.util.Map;
 
 /**
@@ -50,7 +55,12 @@ public final class InsertStatementBinder implements 
SQLStatementBinder<InsertSta
         
statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
         Map<String, TableSegmentBinderContext> tableBinderContexts = new 
LinkedHashMap<>();
         result.setTable(SimpleTableSegmentBinder.bind(sqlStatement.getTable(), 
statementBinderContext, tableBinderContexts));
-        sqlStatement.getInsertColumns().ifPresent(optional -> 
result.setInsertColumns(InsertColumnsSegmentBinder.bind(optional, 
statementBinderContext, tableBinderContexts)));
+        if (sqlStatement.getInsertColumns().isPresent() && 
!sqlStatement.getInsertColumns().get().getColumns().isEmpty()) {
+            
result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(),
 statementBinderContext, tableBinderContexts));
+        } else {
+            
sqlStatement.getInsertColumns().ifPresent(result::setInsertColumns);
+            tableBinderContexts.values().forEach(each -> 
result.getDerivedInsertColumns().addAll(getVisibleColumns(each.getProjectionSegments())));
+        }
         sqlStatement.getInsertSelect().ifPresent(optional -> 
result.setInsertSelect(SubquerySegmentBinder.bind(optional, 
statementBinderContext, tableBinderContexts)));
         result.getValues().addAll(sqlStatement.getValues());
         
InsertStatementHandler.getOnDuplicateKeyColumnsSegment(sqlStatement).ifPresent(optional
 -> InsertStatementHandler.setOnDuplicateKeyColumnsSegment(result, optional));
@@ -63,4 +73,14 @@ public final class InsertStatementBinder implements 
SQLStatementBinder<InsertSta
         result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
         return result;
     }
+    
+    private Collection<ColumnSegment> getVisibleColumns(final 
Collection<ProjectionSegment> projectionSegments) {
+        Collection<ColumnSegment> result = new LinkedList<>();
+        for (ProjectionSegment each : projectionSegments) {
+            if (each instanceof ColumnProjectionSegment && each.isVisible()) {
+                result.add(((ColumnProjectionSegment) each).getColumn());
+            }
+        }
+        return result;
+    }
 }
diff --git 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
index 540ed393c0b..035cb63b2b3 100644
--- 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
+++ 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
@@ -65,13 +65,13 @@ class InsertStatementBinderTest {
         InsertStatement actual = new 
InsertStatementBinder().bind(insertStatement, createMetaData(), 
DefaultDatabase.LOGIC_NAME);
         assertThat(actual, not(insertStatement));
         assertThat(actual.getTable().getTableName(), 
not(insertStatement.getTable().getTableName()));
-        assertInsertColumns(actual);
+        assertTrue(actual.getInsertColumns().isPresent());
+        assertInsertColumns(actual.getInsertColumns().get().getColumns());
     }
     
-    private static void assertInsertColumns(final InsertStatement actual) {
-        assertTrue(actual.getInsertColumns().isPresent());
-        assertThat(actual.getInsertColumns().get().getColumns().size(), is(3));
-        Iterator<ColumnSegment> iterator = 
actual.getInsertColumns().get().getColumns().iterator();
+    private static void assertInsertColumns(final Collection<ColumnSegment> 
insertColumns) {
+        assertThat(insertColumns.size(), is(3));
+        Iterator<ColumnSegment> iterator = insertColumns.iterator();
         ColumnSegment orderIdColumnSegment = iterator.next();
         
assertThat(orderIdColumnSegment.getColumnBoundedInfo().getOriginalDatabase().getValue(),
 is(DefaultDatabase.LOGIC_NAME));
         
assertThat(orderIdColumnSegment.getColumnBoundedInfo().getOriginalSchema().getValue(),
 is(DefaultDatabase.LOGIC_NAME));
@@ -90,7 +90,7 @@ class InsertStatementBinderTest {
     }
     
     @Test
-    void assertBindInsertSelect() {
+    void assertBindInsertSelectWithColumns() {
         InsertStatement insertStatement = new MySQLInsertStatement();
         insertStatement.setTable(new SimpleTableSegment(new 
TableNameSegment(0, 0, new IdentifierValue("t_order"))));
         insertStatement.setInsertColumns(new InsertColumnsSegment(0, 0, 
Arrays.asList(new ColumnSegment(0, 0, new IdentifierValue("order_id")),
@@ -108,7 +108,29 @@ class InsertStatementBinderTest {
         InsertStatement actual = new 
InsertStatementBinder().bind(insertStatement, createMetaData(), 
DefaultDatabase.LOGIC_NAME);
         assertThat(actual, not(insertStatement));
         assertThat(actual.getTable().getTableName(), 
not(insertStatement.getTable().getTableName()));
-        assertInsertColumns(actual);
+        assertTrue(actual.getInsertColumns().isPresent());
+        assertInsertColumns(actual.getInsertColumns().get().getColumns());
+        assertInsertSelect(actual);
+    }
+    
+    @Test
+    void assertBindInsertSelectWithoutColumns() {
+        InsertStatement insertStatement = new MySQLInsertStatement();
+        insertStatement.setTable(new SimpleTableSegment(new 
TableNameSegment(0, 0, new IdentifierValue("t_order"))));
+        MySQLSelectStatement subSelectStatement = new MySQLSelectStatement();
+        subSelectStatement.setFrom(new SimpleTableSegment(new 
TableNameSegment(0, 0, new IdentifierValue("t_order"))));
+        ProjectionsSegment projections = new ProjectionsSegment(0, 0);
+        projections.getProjections().add(new ColumnProjectionSegment(new 
ColumnSegment(0, 0, new IdentifierValue("order_id"))));
+        projections.getProjections().add(new ColumnProjectionSegment(new 
ColumnSegment(0, 0, new IdentifierValue("user_id"))));
+        projections.getProjections().add(new ColumnProjectionSegment(new 
ColumnSegment(0, 0, new IdentifierValue("status"))));
+        subSelectStatement.setProjections(projections);
+        insertStatement.setInsertSelect(new SubquerySegment(0, 0, 
subSelectStatement));
+        insertStatement.getValues().add(new InsertValuesSegment(0, 0, 
Arrays.asList(new LiteralExpressionSegment(0, 0, 1),
+                new LiteralExpressionSegment(0, 0, 1), new 
LiteralExpressionSegment(0, 0, "OK"))));
+        InsertStatement actual = new 
InsertStatementBinder().bind(insertStatement, createMetaData(), 
DefaultDatabase.LOGIC_NAME);
+        assertThat(actual, not(insertStatement));
+        assertThat(actual.getTable().getTableName(), 
not(insertStatement.getTable().getTableName()));
+        assertInsertColumns(actual.getDerivedInsertColumns());
         assertInsertSelect(actual);
     }
     
diff --git 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/InsertStatement.java
 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/InsertStatement.java
index 4db2c6e26cb..98b3fa54ed8 100644
--- 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/InsertStatement.java
+++ 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/InsertStatement.java
@@ -46,6 +46,8 @@ public abstract class InsertStatement extends 
AbstractSQLStatement implements DM
     
     private final Collection<InsertValuesSegment> values = new LinkedList<>();
     
+    private final Collection<ColumnSegment> derivedInsertColumns = new 
LinkedList<>();
+    
     /**
      * Get insert columns segment.
      * 

Reply via email to