This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 26212fcf745 Add encryptor check for insert select statement without
columns (#28430)
26212fcf745 is described below
commit 26212fcf745f16e9c0ad5aba913fd11919156ff8
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Thu Sep 14 08:16:48 2023 +0800
Add encryptor check for insert select statement without columns (#28430)
---
.../EncryptInsertDefaultColumnsTokenGenerator.java | 13 ++++++++
.../fixture/EncryptGeneratorFixtureBuilder.java | 16 +++++++---
.../EncryptInsertCipherNameTokenGeneratorTest.java | 2 +-
...ryptInsertDefaultColumnsTokenGeneratorTest.java | 8 +++++
.../statement/dml/InsertStatementBinder.java | 22 ++++++++++++-
.../statement/InsertStatementBinderTest.java | 36 +++++++++++++++++-----
.../sql/common/statement/dml/InsertStatement.java | 2 ++
7 files changed, 85 insertions(+), 14 deletions(-)
diff --git
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
index f9e6b71f0c2..eccaed00007 100644
---
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
+++
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
@@ -20,17 +20,23 @@ package
org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
import com.google.common.base.Preconditions;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptRuleAware;
+import
org.apache.shardingsphere.encrypt.rewrite.token.util.EncryptTokenGeneratorUtils;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.EncryptTable;
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
+import
org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
import
org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import
org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
+import
org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
+import
org.apache.shardingsphere.infra.exception.core.external.sql.type.generic.UnsupportedSQLOperationException;
import
org.apache.shardingsphere.infra.rewrite.sql.token.generator.OptionalSQLTokenGenerator;
import
org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.PreviousSQLTokensAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
import
org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.UseDefaultInsertColumnsToken;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.InsertColumnsSegment;
+import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
@@ -80,6 +86,13 @@ public final class EncryptInsertDefaultColumnsTokenGenerator
implements Optional
private UseDefaultInsertColumnsToken generateNewSQLToken(final
InsertStatementContext insertStatementContext, final String tableName) {
Optional<InsertColumnsSegment> insertColumnsSegment =
insertStatementContext.getSqlStatement().getInsertColumns();
Preconditions.checkState(insertColumnsSegment.isPresent());
+ if (null != insertStatementContext.getInsertSelectContext()) {
+ Collection<ColumnSegment> derivedInsertColumns =
insertStatementContext.getSqlStatement().getDerivedInsertColumns();
+ Collection<Projection> projections =
insertStatementContext.getInsertSelectContext().getSelectStatementContext().getProjectionsContext().getExpandProjections();
+ ShardingSpherePreconditions.checkState(derivedInsertColumns.size()
== projections.size(), () -> new UnsupportedSQLOperationException("Column count
doesn't match value count."));
+
ShardingSpherePreconditions.checkState(EncryptTokenGeneratorUtils.isAllInsertSelectColumnsUseSameEncryptor(derivedInsertColumns,
projections, encryptRule),
+ () -> new UnsupportedSQLOperationException("Can not use
different encryptor in insert select columns"));
+ }
return new UseDefaultInsertColumnsToken(
insertColumnsSegment.get().getStopIndex(),
getColumnNames(insertStatementContext, encryptRule.getEncryptTable(tableName),
insertStatementContext.getColumnNames()));
}
diff --git
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
index 9293e78ff0c..d6053fac6d5 100644
---
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
+++
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
@@ -129,7 +129,7 @@ public final class EncryptGeneratorFixtureBuilder {
return result;
}
- private static InsertStatement createInsertSelectStatement() {
+ private static InsertStatement createInsertSelectStatement(final boolean
containsInsertColumns) {
InsertStatement result = new MySQLInsertStatement();
result.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new
IdentifierValue("t_user"))));
ColumnSegment userIdColumn = new ColumnSegment(0, 0, new
IdentifierValue("user_id"));
@@ -138,8 +138,13 @@ public final class EncryptGeneratorFixtureBuilder {
ColumnSegment userNameColumn = new ColumnSegment(0, 0, new
IdentifierValue("user_name"));
userNameColumn.setColumnBoundedInfo(new ColumnSegmentBoundedInfo(new
IdentifierValue(DefaultDatabase.LOGIC_NAME), new
IdentifierValue(DefaultDatabase.LOGIC_NAME),
new IdentifierValue("t_user"), new
IdentifierValue("user_name")));
- InsertColumnsSegment insertColumnsSegment = new
InsertColumnsSegment(0, 0, Arrays.asList(userIdColumn, userNameColumn));
- result.setInsertColumns(insertColumnsSegment);
+ List<ColumnSegment> insertColumns = Arrays.asList(userIdColumn,
userNameColumn);
+ if (containsInsertColumns) {
+ result.setInsertColumns(new InsertColumnsSegment(0, 0,
insertColumns));
+ } else {
+ result.setInsertColumns(new InsertColumnsSegment(0, 0,
Collections.emptyList()));
+ result.getDerivedInsertColumns().addAll(insertColumns);
+ }
MySQLSelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.setFrom(new SimpleTableSegment(new TableNameSegment(0,
0, new IdentifierValue("t_user"))));
ProjectionsSegment projections = new ProjectionsSegment(0, 0);
@@ -220,10 +225,11 @@ public final class EncryptGeneratorFixtureBuilder {
* Create insert select statement context.
*
* @param params parameters
+ * @param containsInsertColumns contains insert columns
* @return created insert select statement context
*/
- public static InsertStatementContext
createInsertSelectStatementContext(final List<Object> params) {
- InsertStatement insertStatement = createInsertSelectStatement();
+ public static InsertStatementContext
createInsertSelectStatementContext(final List<Object> params, final boolean
containsInsertColumns) {
+ InsertStatement insertStatement =
createInsertSelectStatement(containsInsertColumns);
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class,
RETURNS_DEEP_STUBS);
ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(schema);
diff --git
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
index b2e6abc8def..f3242b9b9ff 100644
---
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
+++
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
@@ -58,6 +58,6 @@ class EncryptInsertCipherNameTokenGeneratorTest {
@Test
void
assertGenerateSQLTokensWhenInsertColumnsUseDifferentEncryptorWithSelectProjection()
{
- assertThrows(UnsupportedSQLOperationException.class, () ->
generator.generateSQLTokens(EncryptGeneratorFixtureBuilder.createInsertSelectStatementContext(Collections.emptyList())).size());
+ assertThrows(UnsupportedSQLOperationException.class, () ->
generator.generateSQLTokens(EncryptGeneratorFixtureBuilder.createInsertSelectStatementContext(Collections.emptyList(),
true)));
}
}
diff --git
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGeneratorTest.java
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGeneratorTest.java
index 617597f6ff3..320fdeaf8bc 100644
---
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGeneratorTest.java
+++
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGeneratorTest.java
@@ -18,6 +18,7 @@
package org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
import
org.apache.shardingsphere.encrypt.rewrite.token.generator.fixture.EncryptGeneratorFixtureBuilder;
+import
org.apache.shardingsphere.infra.exception.core.external.sql.type.generic.UnsupportedSQLOperationException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -26,6 +27,7 @@ import java.util.Collections;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
class EncryptInsertDefaultColumnsTokenGeneratorTest {
@@ -54,4 +56,10 @@ class EncryptInsertDefaultColumnsTokenGeneratorTest {
assertThat(generator.generateSQLToken(EncryptGeneratorFixtureBuilder.createInsertStatementContext(Collections.emptyList())).toString(),
is("(id, name, status, pwd_cipher, pwd_assist, pwd_like)"));
}
+
+ @Test
+ void
assertGenerateSQLTokensWhenInsertColumnsUseDifferentEncryptorWithSelectProjection()
{
+
generator.setPreviousSQLTokens(EncryptGeneratorFixtureBuilder.getPreviousSQLTokens());
+ assertThrows(UnsupportedSQLOperationException.class, () ->
generator.generateSQLToken(EncryptGeneratorFixtureBuilder.createInsertSelectStatementContext(Collections.emptyList(),
false)));
+ }
}
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
index aeb08a73358..08719ac64bf 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
@@ -25,11 +25,16 @@ import
org.apache.shardingsphere.infra.binder.segment.from.impl.SimpleTableSegme
import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinder;
import
org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
import
org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
+import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
+import java.util.LinkedList;
import java.util.Map;
/**
@@ -50,7 +55,12 @@ public final class InsertStatementBinder implements
SQLStatementBinder<InsertSta
statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
Map<String, TableSegmentBinderContext> tableBinderContexts = new
LinkedHashMap<>();
result.setTable(SimpleTableSegmentBinder.bind(sqlStatement.getTable(),
statementBinderContext, tableBinderContexts));
- sqlStatement.getInsertColumns().ifPresent(optional ->
result.setInsertColumns(InsertColumnsSegmentBinder.bind(optional,
statementBinderContext, tableBinderContexts)));
+ if (sqlStatement.getInsertColumns().isPresent() &&
!sqlStatement.getInsertColumns().get().getColumns().isEmpty()) {
+
result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(),
statementBinderContext, tableBinderContexts));
+ } else {
+
sqlStatement.getInsertColumns().ifPresent(result::setInsertColumns);
+ tableBinderContexts.values().forEach(each ->
result.getDerivedInsertColumns().addAll(getVisibleColumns(each.getProjectionSegments())));
+ }
sqlStatement.getInsertSelect().ifPresent(optional ->
result.setInsertSelect(SubquerySegmentBinder.bind(optional,
statementBinderContext, tableBinderContexts)));
result.getValues().addAll(sqlStatement.getValues());
InsertStatementHandler.getOnDuplicateKeyColumnsSegment(sqlStatement).ifPresent(optional
-> InsertStatementHandler.setOnDuplicateKeyColumnsSegment(result, optional));
@@ -63,4 +73,14 @@ public final class InsertStatementBinder implements
SQLStatementBinder<InsertSta
result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
return result;
}
+
+ private Collection<ColumnSegment> getVisibleColumns(final
Collection<ProjectionSegment> projectionSegments) {
+ Collection<ColumnSegment> result = new LinkedList<>();
+ for (ProjectionSegment each : projectionSegments) {
+ if (each instanceof ColumnProjectionSegment && each.isVisible()) {
+ result.add(((ColumnProjectionSegment) each).getColumn());
+ }
+ }
+ return result;
+ }
}
diff --git
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
index 540ed393c0b..035cb63b2b3 100644
---
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
+++
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
@@ -65,13 +65,13 @@ class InsertStatementBinderTest {
InsertStatement actual = new
InsertStatementBinder().bind(insertStatement, createMetaData(),
DefaultDatabase.LOGIC_NAME);
assertThat(actual, not(insertStatement));
assertThat(actual.getTable().getTableName(),
not(insertStatement.getTable().getTableName()));
- assertInsertColumns(actual);
+ assertTrue(actual.getInsertColumns().isPresent());
+ assertInsertColumns(actual.getInsertColumns().get().getColumns());
}
- private static void assertInsertColumns(final InsertStatement actual) {
- assertTrue(actual.getInsertColumns().isPresent());
- assertThat(actual.getInsertColumns().get().getColumns().size(), is(3));
- Iterator<ColumnSegment> iterator =
actual.getInsertColumns().get().getColumns().iterator();
+ private static void assertInsertColumns(final Collection<ColumnSegment>
insertColumns) {
+ assertThat(insertColumns.size(), is(3));
+ Iterator<ColumnSegment> iterator = insertColumns.iterator();
ColumnSegment orderIdColumnSegment = iterator.next();
assertThat(orderIdColumnSegment.getColumnBoundedInfo().getOriginalDatabase().getValue(),
is(DefaultDatabase.LOGIC_NAME));
assertThat(orderIdColumnSegment.getColumnBoundedInfo().getOriginalSchema().getValue(),
is(DefaultDatabase.LOGIC_NAME));
@@ -90,7 +90,7 @@ class InsertStatementBinderTest {
}
@Test
- void assertBindInsertSelect() {
+ void assertBindInsertSelectWithColumns() {
InsertStatement insertStatement = new MySQLInsertStatement();
insertStatement.setTable(new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("t_order"))));
insertStatement.setInsertColumns(new InsertColumnsSegment(0, 0,
Arrays.asList(new ColumnSegment(0, 0, new IdentifierValue("order_id")),
@@ -108,7 +108,29 @@ class InsertStatementBinderTest {
InsertStatement actual = new
InsertStatementBinder().bind(insertStatement, createMetaData(),
DefaultDatabase.LOGIC_NAME);
assertThat(actual, not(insertStatement));
assertThat(actual.getTable().getTableName(),
not(insertStatement.getTable().getTableName()));
- assertInsertColumns(actual);
+ assertTrue(actual.getInsertColumns().isPresent());
+ assertInsertColumns(actual.getInsertColumns().get().getColumns());
+ assertInsertSelect(actual);
+ }
+
+ @Test
+ void assertBindInsertSelectWithoutColumns() {
+ InsertStatement insertStatement = new MySQLInsertStatement();
+ insertStatement.setTable(new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("t_order"))));
+ MySQLSelectStatement subSelectStatement = new MySQLSelectStatement();
+ subSelectStatement.setFrom(new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("t_order"))));
+ ProjectionsSegment projections = new ProjectionsSegment(0, 0);
+ projections.getProjections().add(new ColumnProjectionSegment(new
ColumnSegment(0, 0, new IdentifierValue("order_id"))));
+ projections.getProjections().add(new ColumnProjectionSegment(new
ColumnSegment(0, 0, new IdentifierValue("user_id"))));
+ projections.getProjections().add(new ColumnProjectionSegment(new
ColumnSegment(0, 0, new IdentifierValue("status"))));
+ subSelectStatement.setProjections(projections);
+ insertStatement.setInsertSelect(new SubquerySegment(0, 0,
subSelectStatement));
+ insertStatement.getValues().add(new InsertValuesSegment(0, 0,
Arrays.asList(new LiteralExpressionSegment(0, 0, 1),
+ new LiteralExpressionSegment(0, 0, 1), new
LiteralExpressionSegment(0, 0, "OK"))));
+ InsertStatement actual = new
InsertStatementBinder().bind(insertStatement, createMetaData(),
DefaultDatabase.LOGIC_NAME);
+ assertThat(actual, not(insertStatement));
+ assertThat(actual.getTable().getTableName(),
not(insertStatement.getTable().getTableName()));
+ assertInsertColumns(actual.getDerivedInsertColumns());
assertInsertSelect(actual);
}
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/InsertStatement.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/InsertStatement.java
index 4db2c6e26cb..98b3fa54ed8 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/InsertStatement.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/InsertStatement.java
@@ -46,6 +46,8 @@ public abstract class InsertStatement extends
AbstractSQLStatement implements DM
private final Collection<InsertValuesSegment> values = new LinkedList<>();
+ private final Collection<ColumnSegment> derivedInsertColumns = new
LinkedList<>();
+
/**
* Get insert columns segment.
*