This is an automated email from the ASF dual-hosted git repository.
duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 51b3b256d9d Enhance insert statement sql binder and add insert select
encrypt check when contains insert columns (#28427)
51b3b256d9d is described below
commit 51b3b256d9dd6bc3aae3c7974d19b3abcbb2c4c2
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Wed Sep 13 16:49:14 2023 +0800
Enhance insert statement sql binder and add insert select encrypt check
when contains insert columns (#28427)
* Enhance insert statement sql binder and add insert select encrypt check
when contains insert columns
* fix unit test
* fix unit test
* fix unit test
* fix unit test
---
.../EncryptInsertCipherNameTokenGenerator.java | 12 +-
.../token/util/EncryptTokenGeneratorUtils.java | 30 ++++
.../fixture/EncryptGeneratorFixtureBuilder.java | 46 +++++++
.../EncryptInsertCipherNameTokenGeneratorTest.java | 7 +
.../checker/ShardingRouteCacheableCheckerTest.java | 9 ++
.../select/projection/impl/ColumnProjection.java | 5 +-
.../infra/binder/enums/SegmentType.java | 2 +-
.../segment/column/InsertColumnsSegmentBinder.java | 54 ++++++++
.../expression/impl/ColumnSegmentBinder.java | 17 ++-
.../statement/dml/InsertStatementBinder.java | 10 +-
.../statement/SQLStatementContextFactoryTest.java | 7 +-
.../statement/InsertStatementBinderTest.java | 153 +++++++++++++++++++++
.../generic/bounded/ColumnSegmentBoundedInfo.java | 6 +-
.../bind/OpenGaussComBatchBindExecutorTest.java | 6 +
...egatedBatchedStatementsCommandExecutorTest.java | 7 +
.../PostgreSQLBatchedStatementsExecutorTest.java | 13 +-
.../parse/PostgreSQLComParseExecutorTest.java | 9 +-
17 files changed, 372 insertions(+), 21 deletions(-)
diff --git
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGenerator.java
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGenerator.java
index 0ead81a6c3c..621d0b60b72 100644
---
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGenerator.java
+++
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGenerator.java
@@ -20,12 +20,15 @@ package
org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
import com.google.common.base.Preconditions;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptRuleAware;
+import
org.apache.shardingsphere.encrypt.rewrite.token.util.EncryptTokenGeneratorUtils;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.EncryptTable;
import
org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
import
org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
import
org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import
org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
+import
org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
+import
org.apache.shardingsphere.infra.exception.core.external.sql.type.generic.UnsupportedSQLOperationException;
import
org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
import
org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnNameToken;
@@ -58,9 +61,16 @@ public final class EncryptInsertCipherNameTokenGenerator
implements CollectionSQ
public Collection<SQLToken> generateSQLTokens(final InsertStatementContext
insertStatementContext) {
Optional<InsertColumnsSegment> insertColumnsSegment =
insertStatementContext.getSqlStatement().getInsertColumns();
Preconditions.checkState(insertColumnsSegment.isPresent());
+ Collection<ColumnSegment> insertColumns =
insertColumnsSegment.get().getColumns();
+ if (null != insertStatementContext.getInsertSelectContext()) {
+ Collection<Projection> projections =
insertStatementContext.getInsertSelectContext().getSelectStatementContext().getProjectionsContext().getExpandProjections();
+ ShardingSpherePreconditions.checkState(insertColumns.size() ==
projections.size(), () -> new UnsupportedSQLOperationException("Column count
doesn't match value count."));
+
ShardingSpherePreconditions.checkState(EncryptTokenGeneratorUtils.isAllInsertSelectColumnsUseSameEncryptor(insertColumns,
projections, encryptRule),
+ () -> new UnsupportedSQLOperationException("Can not use
different encryptor in insert select columns"));
+ }
EncryptTable encryptTable =
encryptRule.getEncryptTable(insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue());
Collection<SQLToken> result = new LinkedList<>();
- for (ColumnSegment each : insertColumnsSegment.get().getColumns()) {
+ for (ColumnSegment each : insertColumns) {
String columnName = each.getIdentifier().getValue();
if (encryptTable.isEncryptColumn(columnName)) {
Collection<Projection> projections =
diff --git
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/util/EncryptTokenGeneratorUtils.java
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/util/EncryptTokenGeneratorUtils.java
index e191301196c..f3ad417f780 100644
---
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/util/EncryptTokenGeneratorUtils.java
+++
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/util/EncryptTokenGeneratorUtils.java
@@ -23,11 +23,15 @@ import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.EncryptTable;
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.encrypt.spi.EncryptAlgorithm;
+import
org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
+import
org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
+import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
import java.util.Collection;
+import java.util.Iterator;
/**
* Encrypt token generator utils.
@@ -97,4 +101,30 @@ public final class EncryptTokenGeneratorUtils {
}
return encryptColumn.getCipher().getEncryptor();
}
+
+ /**
+ * Judge whether all insert select columns use same encryptor or not.
+ *
+ * @param insertColumns insert columns
+ * @param projections projections
+ * @param encryptRule encrypt rule
+ * @return whether all insert select columns use same encryptor or not
+ */
+ public static boolean isAllInsertSelectColumnsUseSameEncryptor(final
Collection<ColumnSegment> insertColumns, final Collection<Projection>
projections, final EncryptRule encryptRule) {
+ Iterator<ColumnSegment> insertColumnsIterator =
insertColumns.iterator();
+ Iterator<Projection> projectionIterator = projections.iterator();
+ while (insertColumnsIterator.hasNext()) {
+ ColumnSegment columnSegment = insertColumnsIterator.next();
+ EncryptAlgorithm<?, ?> leftColumnEncryptor =
getColumnEncryptor(columnSegment.getColumnBoundedInfo(), encryptRule);
+ Projection projection = projectionIterator.next();
+ ColumnSegmentBoundedInfo columnBoundedInfo = projection instanceof
ColumnProjection
+ ? new ColumnSegmentBoundedInfo(null, null,
((ColumnProjection) projection).getOriginalTable(), ((ColumnProjection)
projection).getOriginalColumn())
+ : new ColumnSegmentBoundedInfo(new
IdentifierValue(projection.getColumnLabel()));
+ EncryptAlgorithm<?, ?> rightColumnEncryptor =
getColumnEncryptor(columnBoundedInfo, encryptRule);
+ if (!isSameEncryptor(leftColumnEncryptor, rightColumnEncryptor)) {
+ return false;
+ }
+ }
+ return true;
+ }
}
diff --git
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
index 22339c8b13b..9293e78ff0c 100644
---
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
+++
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/fixture/EncryptGeneratorFixtureBuilder.java
@@ -48,6 +48,9 @@ import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOp
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
@@ -55,6 +58,7 @@ import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.Tab
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLInsertStatement;
+import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLUpdateStatement;
import java.util.ArrayList;
@@ -125,6 +129,30 @@ public final class EncryptGeneratorFixtureBuilder {
return result;
}
+ private static InsertStatement createInsertSelectStatement() {
+ InsertStatement result = new MySQLInsertStatement();
+ result.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new
IdentifierValue("t_user"))));
+ ColumnSegment userIdColumn = new ColumnSegment(0, 0, new
IdentifierValue("user_id"));
+ userIdColumn.setColumnBoundedInfo(new ColumnSegmentBoundedInfo(new
IdentifierValue(DefaultDatabase.LOGIC_NAME), new
IdentifierValue(DefaultDatabase.LOGIC_NAME), new IdentifierValue("t_user"),
+ new IdentifierValue("user_id")));
+ ColumnSegment userNameColumn = new ColumnSegment(0, 0, new
IdentifierValue("user_name"));
+ userNameColumn.setColumnBoundedInfo(new ColumnSegmentBoundedInfo(new
IdentifierValue(DefaultDatabase.LOGIC_NAME), new
IdentifierValue(DefaultDatabase.LOGIC_NAME),
+ new IdentifierValue("t_user"), new
IdentifierValue("user_name")));
+ InsertColumnsSegment insertColumnsSegment = new
InsertColumnsSegment(0, 0, Arrays.asList(userIdColumn, userNameColumn));
+ result.setInsertColumns(insertColumnsSegment);
+ MySQLSelectStatement selectStatement = new MySQLSelectStatement();
+ selectStatement.setFrom(new SimpleTableSegment(new TableNameSegment(0,
0, new IdentifierValue("t_user"))));
+ ProjectionsSegment projections = new ProjectionsSegment(0, 0);
+ projections.getProjections().add(new
ColumnProjectionSegment(userIdColumn));
+ ColumnSegment statusColumn = new ColumnSegment(0, 0, new
IdentifierValue("status"));
+ statusColumn.setColumnBoundedInfo(new ColumnSegmentBoundedInfo(new
IdentifierValue(DefaultDatabase.LOGIC_NAME), new
IdentifierValue(DefaultDatabase.LOGIC_NAME), new IdentifierValue("t_user"),
+ new IdentifierValue("status")));
+ projections.getProjections().add(new
ColumnProjectionSegment(statusColumn));
+ selectStatement.setProjections(projections);
+ result.setInsertSelect(new SubquerySegment(0, 0, selectStatement));
+ return result;
+ }
+
/**
* Create update statement context.
*
@@ -187,4 +215,22 @@ public final class EncryptGeneratorFixtureBuilder {
when(result.getJoinConditions()).thenReturn(Collections.singleton(new
BinaryOperationExpression(0, 0, leftColumn, rightColumn, "=", "")));
return result;
}
+
+ /**
+ * Create insert select statement context.
+ *
+ * @param params parameters
+ * @return created insert select statement context
+ */
+ public static InsertStatementContext
createInsertSelectStatementContext(final List<Object> params) {
+ InsertStatement insertStatement = createInsertSelectStatement();
+ ShardingSphereDatabase database = mock(ShardingSphereDatabase.class,
RETURNS_DEEP_STUBS);
+ ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
+
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(schema);
+
when(schema.getAllColumnNames("t_user")).thenReturn(Arrays.asList("user_id",
"user_name", "pwd"));
+ ShardingSphereMetaData metaData = new ShardingSphereMetaData(
+ Collections.singletonMap(DefaultDatabase.LOGIC_NAME,
database), mock(ResourceMetaData.class),
+ mock(RuleMetaData.class), mock(ConfigurationProperties.class));
+ return new InsertStatementContext(metaData, params, insertStatement,
DefaultDatabase.LOGIC_NAME);
+ }
}
diff --git
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
index c1a18ef9a84..b2e6abc8def 100644
---
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
+++
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGeneratorTest.java
@@ -19,6 +19,7 @@ package
org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
import
org.apache.shardingsphere.encrypt.rewrite.token.generator.fixture.EncryptGeneratorFixtureBuilder;
import
org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
+import
org.apache.shardingsphere.infra.exception.core.external.sql.type.generic.UnsupportedSQLOperationException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -27,6 +28,7 @@ import java.util.Collections;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
@@ -53,4 +55,9 @@ class EncryptInsertCipherNameTokenGeneratorTest {
void assertGenerateSQLTokensWithInsertStatementContext() {
assertThat(generator.generateSQLTokens(EncryptGeneratorFixtureBuilder.createInsertStatementContext(Collections.emptyList())).size(),
is(1));
}
+
+ @Test
+ void
assertGenerateSQLTokensWhenInsertColumnsUseDifferentEncryptorWithSelectProjection()
{
+ assertThrows(UnsupportedSQLOperationException.class, () ->
generator.generateSQLTokens(EncryptGeneratorFixtureBuilder.createInsertSelectStatementContext(Collections.emptyList())).size());
+ }
}
diff --git
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/cache/checker/ShardingRouteCacheableCheckerTest.java
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/cache/checker/ShardingRouteCacheableCheckerTest.java
index 77e9b4d2047..e3a30f754ca 100644
---
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/cache/checker/ShardingRouteCacheableCheckerTest.java
+++
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/cache/checker/ShardingRouteCacheableCheckerTest.java
@@ -122,6 +122,15 @@ class ShardingRouteCacheableCheckerTest {
new ShardingSphereColumn("warehouse_id", Types.INTEGER, false,
false, false, true, false, false),
new ShardingSphereColumn("order_broadcast_table_id",
Types.INTEGER, true, false, false, true, false, false)),
Collections.emptyList(), Collections.emptyList()));
+ schema.getTables().put("t_non_sharding_table", new
ShardingSphereTable("t_non_sharding_table", Collections.singleton(
+ new ShardingSphereColumn("id", Types.INTEGER, false, false,
false, true, false, false)),
+ Collections.emptyList(), Collections.emptyList()));
+ schema.getTables().put("t_non_cacheable_database_sharding", new
ShardingSphereTable("t_non_cacheable_database_sharding", Collections.singleton(
+ new ShardingSphereColumn("id", Types.INTEGER, false, false,
false, true, false, false)),
+ Collections.emptyList(), Collections.emptyList()));
+ schema.getTables().put("t_non_cacheable_table_sharding", new
ShardingSphereTable("t_non_cacheable_table_sharding", Collections.singleton(
+ new ShardingSphereColumn("id", Types.INTEGER, false, false,
false, true, false, false)),
+ Collections.emptyList(), Collections.emptyList()));
return new ShardingSphereDatabase(DATABASE_NAME,
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"),
new ResourceMetaData(DATABASE_NAME, Collections.emptyMap()),
new RuleMetaData(Arrays.asList(shardingRule, timestampServiceRule)),
Collections.singletonMap(SCHEMA_NAME, schema));
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/ColumnProjection.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/ColumnProjection.java
index e6118e17afa..996a0c0d55e 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/ColumnProjection.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/ColumnProjection.java
@@ -17,6 +17,7 @@
package
org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl;
+import com.google.common.base.Strings;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
@@ -93,7 +94,7 @@ public final class ColumnProjection implements Projection {
* @return original table
*/
public IdentifierValue getOriginalTable() {
- if (null == originalTable) {
+ if (Strings.isNullOrEmpty(originalTable.getValue())) {
return null == owner ? new IdentifierValue("") : owner;
}
return originalTable;
@@ -105,6 +106,6 @@ public final class ColumnProjection implements Projection {
* @return original column
*/
public IdentifierValue getOriginalColumn() {
- return null == originalColumn ? name : originalColumn;
+ return Strings.isNullOrEmpty(originalColumn.getValue()) ? name :
originalColumn;
}
}
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/enums/SegmentType.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/enums/SegmentType.java
index b647a2b7e36..0324ed740d3 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/enums/SegmentType.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/enums/SegmentType.java
@@ -22,5 +22,5 @@ package org.apache.shardingsphere.infra.binder.enums;
*/
public enum SegmentType {
- PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK,
SET_ASSIGNMENT, VALUES
+ PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK,
SET_ASSIGNMENT, VALUES, INSERT_COLUMNS
}
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/column/InsertColumnsSegmentBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/column/InsertColumnsSegmentBinder.java
new file mode 100644
index 00000000000..35254713c4e
--- /dev/null
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/column/InsertColumnsSegmentBinder.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.infra.binder.segment.column;
+
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+import org.apache.shardingsphere.infra.binder.enums.SegmentType;
+import
org.apache.shardingsphere.infra.binder.segment.expression.impl.ColumnSegmentBinder;
+import
org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
+import
org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.InsertColumnsSegment;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.Map;
+
+/**
+ * Insert columns segment binder.
+ */
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
+public final class InsertColumnsSegmentBinder {
+
+ /**
+ * Bind insert columns segment with metadata.
+ *
+ * @param segment insert columns segment
+ * @param statementBinderContext statement binder context
+ * @param tableBinderContexts table binder contexts
+ * @return bounded insert columns segment
+ */
+ public static InsertColumnsSegment bind(final InsertColumnsSegment
segment, final SQLStatementBinderContext statementBinderContext,
+ final Map<String,
TableSegmentBinderContext> tableBinderContexts) {
+ Collection<ColumnSegment> boundedColumns = new LinkedList<>();
+ segment.getColumns().forEach(each ->
boundedColumns.add(ColumnSegmentBinder.bind(each, SegmentType.INSERT_COLUMNS,
statementBinderContext, tableBinderContexts, Collections.emptyMap())));
+ return new InsertColumnsSegment(segment.getStartIndex(),
segment.getStopIndex(), boundedColumns);
+ }
+}
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
index 531320b6ef8..2f44a93b03d 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
@@ -17,6 +17,7 @@
package org.apache.shardingsphere.infra.binder.segment.expression.impl;
+import com.google.common.base.Strings;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.groovy.util.Maps;
@@ -55,7 +56,7 @@ public final class ColumnSegmentBinder {
"LOCALTIMESTAMP", "UID", "USER", "NEXTVAL", "ROWID"));
private static final Map<SegmentType, String> SEGMENT_TYPE_MESSAGES =
Maps.of(SegmentType.PROJECTION, "field list", SegmentType.JOIN_ON, "on clause",
SegmentType.JOIN_USING, "from clause",
- SegmentType.PREDICATE, "where clause", SegmentType.ORDER_BY,
"order clause", SegmentType.GROUP_BY, "group statement");
+ SegmentType.PREDICATE, "where clause", SegmentType.ORDER_BY,
"order clause", SegmentType.GROUP_BY, "group statement",
SegmentType.INSERT_COLUMNS, "field list");
private static final String UNKNOWN_SEGMENT_TYPE_MESSAGE = "unknown
clause";
@@ -204,12 +205,14 @@ public final class ColumnSegmentBinder {
private static ColumnSegmentBoundedInfo
createColumnSegmentBoundedInfo(final ColumnSegment segment, final ColumnSegment
inputColumnSegment) {
IdentifierValue originalDatabase = null == inputColumnSegment ? null :
inputColumnSegment.getColumnBoundedInfo().getOriginalDatabase();
IdentifierValue originalSchema = null == inputColumnSegment ? null :
inputColumnSegment.getColumnBoundedInfo().getOriginalSchema();
- IdentifierValue originalTable =
- null == segment.getColumnBoundedInfo().getOriginalTable() ?
Optional.ofNullable(inputColumnSegment).map(optional ->
optional.getColumnBoundedInfo().getOriginalTable()).orElse(null)
- : segment.getColumnBoundedInfo().getOriginalTable();
- IdentifierValue originalColumn =
- null == segment.getColumnBoundedInfo().getOriginalColumn() ?
Optional.ofNullable(inputColumnSegment).map(optional ->
optional.getColumnBoundedInfo().getOriginalColumn()).orElse(null)
- : segment.getColumnBoundedInfo().getOriginalColumn();
+ IdentifierValue segmentOriginalTable =
segment.getColumnBoundedInfo().getOriginalTable();
+ IdentifierValue originalTable =
Strings.isNullOrEmpty(segmentOriginalTable.getValue())
+ ? Optional.ofNullable(inputColumnSegment).map(optional ->
optional.getColumnBoundedInfo().getOriginalTable()).orElse(segmentOriginalTable)
+ : segmentOriginalTable;
+ IdentifierValue segmentOriginalColumn =
segment.getColumnBoundedInfo().getOriginalColumn();
+ IdentifierValue originalColumn =
Strings.isNullOrEmpty(segmentOriginalColumn.getValue())
+ ? Optional.ofNullable(inputColumnSegment).map(optional ->
optional.getColumnBoundedInfo().getOriginalColumn()).orElse(segmentOriginalColumn)
+ : segmentOriginalColumn;
return new ColumnSegmentBoundedInfo(originalDatabase, originalSchema,
originalTable, originalColumn);
}
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
index a0cf9de542c..aeb08a73358 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
@@ -18,8 +18,10 @@
package org.apache.shardingsphere.infra.binder.statement.dml;
import lombok.SneakyThrows;
+import
org.apache.shardingsphere.infra.binder.segment.column.InsertColumnsSegmentBinder;
import
org.apache.shardingsphere.infra.binder.segment.expression.impl.SubquerySegmentBinder;
import
org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
+import
org.apache.shardingsphere.infra.binder.segment.from.impl.SimpleTableSegmentBinder;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinder;
import
org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
@@ -27,6 +29,7 @@ import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertState
import
org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
import java.util.Collections;
+import java.util.LinkedHashMap;
import java.util.Map;
/**
@@ -43,11 +46,12 @@ public final class InsertStatementBinder implements
SQLStatementBinder<InsertSta
private InsertStatement bind(final InsertStatement sqlStatement, final
ShardingSphereMetaData metaData, final String defaultDatabaseName,
final Map<String, TableSegmentBinderContext>
externalTableBinderContexts) {
InsertStatement result =
sqlStatement.getClass().getDeclaredConstructor().newInstance();
- result.setTable(sqlStatement.getTable());
- sqlStatement.getInsertColumns().ifPresent(result::setInsertColumns);
SQLStatementBinderContext statementBinderContext = new
SQLStatementBinderContext(metaData, defaultDatabaseName,
sqlStatement.getDatabaseType(), sqlStatement.getVariableNames());
statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
- sqlStatement.getInsertSelect().ifPresent(optional ->
result.setInsertSelect(SubquerySegmentBinder.bind(optional,
statementBinderContext, Collections.emptyMap())));
+ Map<String, TableSegmentBinderContext> tableBinderContexts = new
LinkedHashMap<>();
+ result.setTable(SimpleTableSegmentBinder.bind(sqlStatement.getTable(),
statementBinderContext, tableBinderContexts));
+ sqlStatement.getInsertColumns().ifPresent(optional ->
result.setInsertColumns(InsertColumnsSegmentBinder.bind(optional,
statementBinderContext, tableBinderContexts)));
+ sqlStatement.getInsertSelect().ifPresent(optional ->
result.setInsertSelect(SubquerySegmentBinder.bind(optional,
statementBinderContext, tableBinderContexts)));
result.getValues().addAll(sqlStatement.getValues());
InsertStatementHandler.getOnDuplicateKeyColumnsSegment(sqlStatement).ifPresent(optional
-> InsertStatementHandler.setOnDuplicateKeyColumnsSegment(result, optional));
InsertStatementHandler.getSetAssignmentSegment(sqlStatement).ifPresent(optional
-> InsertStatementHandler.setSetAssignmentSegment(result, optional));
diff --git
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/SQLStatementContextFactoryTest.java
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/SQLStatementContextFactoryTest.java
index 30b8a396cca..2ff863375cc 100644
---
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/SQLStatementContextFactoryTest.java
+++
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/SQLStatementContextFactoryTest.java
@@ -150,7 +150,12 @@ class SQLStatementContextFactoryTest {
}
private ShardingSphereMetaData mockMetaData() {
- Map<String, ShardingSphereDatabase> databases =
Collections.singletonMap(DefaultDatabase.LOGIC_NAME,
mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS));
+ ShardingSphereDatabase database = mock(ShardingSphereDatabase.class,
RETURNS_DEEP_STUBS);
+
when(database.containsSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(true);
+ when(database.containsSchema("public")).thenReturn(true);
+
when(database.getSchema(DefaultDatabase.LOGIC_NAME).containsTable("tbl")).thenReturn(true);
+
when(database.getSchema("public").containsTable("tbl")).thenReturn(true);
+ Map<String, ShardingSphereDatabase> databases =
Collections.singletonMap(DefaultDatabase.LOGIC_NAME, database);
return new ShardingSphereMetaData(databases,
mock(ResourceMetaData.class), mock(RuleMetaData.class),
mock(ConfigurationProperties.class));
}
}
diff --git
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
new file mode 100644
index 00000000000..540ed393c0b
--- /dev/null
+++
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/InsertStatementBinderTest.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.infra.binder.statement;
+
+import
org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementBinder;
+import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
+import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
+import
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
+import
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.InsertColumnsSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
+import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
+import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLInsertStatement;
+import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
+import org.junit.jupiter.api.Test;
+
+import java.sql.Types;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+
+import static org.hamcrest.CoreMatchers.instanceOf;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+class InsertStatementBinderTest {
+
+ @Test
+ void assertBindInsertValues() {
+ InsertStatement insertStatement = new MySQLInsertStatement();
+ insertStatement.setTable(new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("t_order"))));
+ insertStatement.setInsertColumns(new InsertColumnsSegment(0, 0,
Arrays.asList(new ColumnSegment(0, 0, new IdentifierValue("order_id")),
+ new ColumnSegment(0, 0, new IdentifierValue("user_id")), new
ColumnSegment(0, 0, new IdentifierValue("status")))));
+ insertStatement.getValues().add(new InsertValuesSegment(0, 0,
Arrays.asList(new LiteralExpressionSegment(0, 0, 1),
+ new LiteralExpressionSegment(0, 0, 1), new
LiteralExpressionSegment(0, 0, "OK"))));
+ InsertStatement actual = new
InsertStatementBinder().bind(insertStatement, createMetaData(),
DefaultDatabase.LOGIC_NAME);
+ assertThat(actual, not(insertStatement));
+ assertThat(actual.getTable().getTableName(),
not(insertStatement.getTable().getTableName()));
+ assertInsertColumns(actual);
+ }
+
+ private static void assertInsertColumns(final InsertStatement actual) {
+ assertTrue(actual.getInsertColumns().isPresent());
+ assertThat(actual.getInsertColumns().get().getColumns().size(), is(3));
+ Iterator<ColumnSegment> iterator =
actual.getInsertColumns().get().getColumns().iterator();
+ ColumnSegment orderIdColumnSegment = iterator.next();
+
assertThat(orderIdColumnSegment.getColumnBoundedInfo().getOriginalDatabase().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+
assertThat(orderIdColumnSegment.getColumnBoundedInfo().getOriginalSchema().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+
assertThat(orderIdColumnSegment.getColumnBoundedInfo().getOriginalTable().getValue(),
is("t_order"));
+
assertThat(orderIdColumnSegment.getColumnBoundedInfo().getOriginalColumn().getValue(),
is("order_id"));
+ ColumnSegment userIdColumnSegment = iterator.next();
+
assertThat(userIdColumnSegment.getColumnBoundedInfo().getOriginalDatabase().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+
assertThat(userIdColumnSegment.getColumnBoundedInfo().getOriginalSchema().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+
assertThat(userIdColumnSegment.getColumnBoundedInfo().getOriginalTable().getValue(),
is("t_order"));
+
assertThat(userIdColumnSegment.getColumnBoundedInfo().getOriginalColumn().getValue(),
is("user_id"));
+ ColumnSegment statusColumnSegment = iterator.next();
+
assertThat(statusColumnSegment.getColumnBoundedInfo().getOriginalDatabase().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+
assertThat(statusColumnSegment.getColumnBoundedInfo().getOriginalSchema().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+
assertThat(statusColumnSegment.getColumnBoundedInfo().getOriginalTable().getValue(),
is("t_order"));
+
assertThat(statusColumnSegment.getColumnBoundedInfo().getOriginalColumn().getValue(),
is("status"));
+ }
+
+ @Test
+ void assertBindInsertSelect() {
+ InsertStatement insertStatement = new MySQLInsertStatement();
+ insertStatement.setTable(new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("t_order"))));
+ insertStatement.setInsertColumns(new InsertColumnsSegment(0, 0,
Arrays.asList(new ColumnSegment(0, 0, new IdentifierValue("order_id")),
+ new ColumnSegment(0, 0, new IdentifierValue("user_id")), new
ColumnSegment(0, 0, new IdentifierValue("status")))));
+ MySQLSelectStatement subSelectStatement = new MySQLSelectStatement();
+ subSelectStatement.setFrom(new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("t_order"))));
+ ProjectionsSegment projections = new ProjectionsSegment(0, 0);
+ projections.getProjections().add(new ColumnProjectionSegment(new
ColumnSegment(0, 0, new IdentifierValue("order_id"))));
+ projections.getProjections().add(new ColumnProjectionSegment(new
ColumnSegment(0, 0, new IdentifierValue("user_id"))));
+ projections.getProjections().add(new ColumnProjectionSegment(new
ColumnSegment(0, 0, new IdentifierValue("status"))));
+ subSelectStatement.setProjections(projections);
+ insertStatement.setInsertSelect(new SubquerySegment(0, 0,
subSelectStatement));
+ insertStatement.getValues().add(new InsertValuesSegment(0, 0,
Arrays.asList(new LiteralExpressionSegment(0, 0, 1),
+ new LiteralExpressionSegment(0, 0, 1), new
LiteralExpressionSegment(0, 0, "OK"))));
+ InsertStatement actual = new
InsertStatementBinder().bind(insertStatement, createMetaData(),
DefaultDatabase.LOGIC_NAME);
+ assertThat(actual, not(insertStatement));
+ assertThat(actual.getTable().getTableName(),
not(insertStatement.getTable().getTableName()));
+ assertInsertColumns(actual);
+ assertInsertSelect(actual);
+ }
+
+ private static void assertInsertSelect(final InsertStatement actual) {
+ assertTrue(actual.getInsertSelect().isPresent());
+ Collection<ProjectionSegment> actualProjections =
actual.getInsertSelect().get().getSelect().getProjections().getProjections();
+ assertThat(actualProjections.size(), is(3));
+ Iterator<ProjectionSegment> projectionIterator =
actualProjections.iterator();
+ ProjectionSegment orderIdProjectionSegment = projectionIterator.next();
+ assertThat(orderIdProjectionSegment,
instanceOf(ColumnProjectionSegment.class));
+ assertThat(((ColumnProjectionSegment)
orderIdProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalDatabase().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+ assertThat(((ColumnProjectionSegment)
orderIdProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalSchema().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+ assertThat(((ColumnProjectionSegment)
orderIdProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalTable().getValue(),
is("t_order"));
+ assertThat(((ColumnProjectionSegment)
orderIdProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalColumn().getValue(),
is("order_id"));
+ ProjectionSegment userIdProjectionSegment = projectionIterator.next();
+ assertThat(userIdProjectionSegment,
instanceOf(ColumnProjectionSegment.class));
+ assertThat(((ColumnProjectionSegment)
userIdProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalDatabase().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+ assertThat(((ColumnProjectionSegment)
userIdProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalSchema().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+ assertThat(((ColumnProjectionSegment)
userIdProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalTable().getValue(),
is("t_order"));
+ assertThat(((ColumnProjectionSegment)
userIdProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalColumn().getValue(),
is("user_id"));
+ ProjectionSegment statusProjectionSegment = projectionIterator.next();
+ assertThat(statusProjectionSegment,
instanceOf(ColumnProjectionSegment.class));
+ assertThat(((ColumnProjectionSegment)
statusProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalDatabase().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+ assertThat(((ColumnProjectionSegment)
statusProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalSchema().getValue(),
is(DefaultDatabase.LOGIC_NAME));
+ assertThat(((ColumnProjectionSegment)
statusProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalTable().getValue(),
is("t_order"));
+ assertThat(((ColumnProjectionSegment)
statusProjectionSegment).getColumn().getColumnBoundedInfo().getOriginalColumn().getValue(),
is("status"));
+ }
+
+ private ShardingSphereMetaData createMetaData() {
+ ShardingSphereSchema schema = mock(ShardingSphereSchema.class,
RETURNS_DEEP_STUBS);
+
when(schema.getTable("t_order").getColumnValues()).thenReturn(Arrays.asList(
+ new ShardingSphereColumn("order_id", Types.INTEGER, true,
false, false, true, false, false),
+ new ShardingSphereColumn("user_id", Types.INTEGER, false,
false, false, true, false, false),
+ new ShardingSphereColumn("status", Types.INTEGER, false,
false, false, true, false, false)));
+ ShardingSphereMetaData result = mock(ShardingSphereMetaData.class,
RETURNS_DEEP_STUBS);
+
when(result.getDatabase(DefaultDatabase.LOGIC_NAME).getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(schema);
+
when(result.containsDatabase(DefaultDatabase.LOGIC_NAME)).thenReturn(true);
+
when(result.getDatabase(DefaultDatabase.LOGIC_NAME).containsSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(true);
+
when(result.getDatabase(DefaultDatabase.LOGIC_NAME).getSchema(DefaultDatabase.LOGIC_NAME).containsTable("t_order")).thenReturn(true);
+ return result;
+ }
+}
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/bounded/ColumnSegmentBoundedInfo.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/bounded/ColumnSegmentBoundedInfo.java
index da5433edb52..b3232f8a2ec 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/bounded/ColumnSegmentBoundedInfo.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/bounded/ColumnSegmentBoundedInfo.java
@@ -37,9 +37,9 @@ public final class ColumnSegmentBoundedInfo {
private final IdentifierValue originalColumn;
public ColumnSegmentBoundedInfo(final IdentifierValue originalColumn) {
- this.originalDatabase = null;
- this.originalSchema = null;
- this.originalTable = null;
+ this.originalDatabase = new IdentifierValue("");
+ this.originalSchema = new IdentifierValue("");
+ this.originalTable = new IdentifierValue("");
this.originalColumn = originalColumn;
}
}
diff --git
a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/query/extended/bind/OpenGaussComBatchBindExecutorTest.java
b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/query/extended/bind/OpenGaussComBatchBindExecutorTest.java
index f753fe7bce1..8ad1f6165ae 100644
---
a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/query/extended/bind/OpenGaussComBatchBindExecutorTest.java
+++
b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/query/extended/bind/OpenGaussComBatchBindExecutorTest.java
@@ -32,6 +32,7 @@ import org.apache.shardingsphere.infra.hint.HintValueContext;
import
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import
org.apache.shardingsphere.infra.metadata.database.resource.storage.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
+import
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
import org.apache.shardingsphere.infra.parser.ShardingSphereSQLParserEngine;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
@@ -56,6 +57,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
+import java.sql.Types;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
@@ -131,6 +133,7 @@ class OpenGaussComBatchBindExecutorTest {
new SQLTranslatorRule(new
DefaultSQLTranslatorRuleConfigurationBuilder().build()), new LoggingRule(new
DefaultLoggingRuleConfigurationBuilder().build()))));
ShardingSphereDatabase database = mockDatabase();
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db")).thenReturn(database);
+
when(result.getMetaDataContexts().getMetaData().containsDatabase("foo_db")).thenReturn(true);
return result;
}
@@ -141,6 +144,9 @@ class OpenGaussComBatchBindExecutorTest {
when(storageUnit.getStorageType()).thenReturn(TypedSPILoader.getService(DatabaseType.class,
"openGauss"));
when(result.getResourceMetaData().getStorageUnitMetaData().getStorageUnits()).thenReturn(Collections.singletonMap("foo_ds",
storageUnit));
when(result.getRuleMetaData()).thenReturn(new
RuleMetaData(Collections.emptyList()));
+ when(result.containsSchema("public")).thenReturn(true);
+
when(result.getSchema("public").containsTable("bmsql")).thenReturn(true);
+
when(result.getSchema("public").getTable("bmsql").getColumnValues()).thenReturn(Collections.singleton(new
ShardingSphereColumn("id", Types.VARCHAR, false, false, false, true, false,
false)));
return result;
}
}
diff --git
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLAggregatedBatchedStatementsCommandExecutorTest.java
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLAggregatedBatchedStatementsCommandExecutorTest.java
index 25515d7db4a..979ab817ba8 100644
---
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLAggregatedBatchedStatementsCommandExecutorTest.java
+++
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLAggregatedBatchedStatementsCommandExecutorTest.java
@@ -37,6 +37,7 @@ import org.apache.shardingsphere.infra.hint.HintValueContext;
import
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import
org.apache.shardingsphere.infra.metadata.database.resource.storage.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
+import
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
import org.apache.shardingsphere.infra.parser.ShardingSphereSQLParserEngine;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
@@ -61,6 +62,7 @@ import org.mockito.quality.Strictness;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
+import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -160,7 +162,12 @@ class
PostgreSQLAggregatedBatchedStatementsCommandExecutorTest {
when(storageUnit.getStorageType()).thenReturn(TypedSPILoader.getService(DatabaseType.class,
"PostgreSQL"));
when(database.getResourceMetaData().getStorageUnitMetaData().getStorageUnits()).thenReturn(Collections.singletonMap("foo_ds",
storageUnit));
when(database.getRuleMetaData()).thenReturn(new
RuleMetaData(Collections.emptyList()));
+ when(database.containsSchema("public")).thenReturn(true);
+
when(database.getSchema("public").containsTable("t_order")).thenReturn(true);
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db")).thenReturn(database);
+
when(result.getMetaDataContexts().getMetaData().containsDatabase("foo_db")).thenReturn(true);
+
when(database.getSchema("public").getTable("t_order").getColumnValues())
+ .thenReturn(Collections.singleton(new
ShardingSphereColumn("id", Types.VARCHAR, false, false, false, true, false,
false)));
return result;
}
}
diff --git
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutorTest.java
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutorTest.java
index 60233dfdc7b..9a61edf1ef4 100644
---
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutorTest.java
+++
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutorTest.java
@@ -30,6 +30,7 @@ import org.apache.shardingsphere.infra.hint.HintValueContext;
import
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import
org.apache.shardingsphere.infra.metadata.database.resource.storage.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
+import
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.logging.rule.LoggingRule;
@@ -39,6 +40,9 @@ import
org.apache.shardingsphere.proxy.backend.connector.ProxyDatabaseConnection
import
org.apache.shardingsphere.proxy.backend.connector.jdbc.statement.JDBCBackendStatement;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLInsertStatement;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
import
org.apache.shardingsphere.sqltranslator.rule.builder.DefaultSQLTranslatorRuleConfigurationBuilder;
@@ -55,6 +59,7 @@ import org.mockito.quality.Strictness;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
+import java.sql.Types;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@@ -113,9 +118,10 @@ class PostgreSQLBatchedStatementsExecutorTest {
private InsertStatementContext mockInsertStatementContext() {
PostgreSQLInsertStatement insertStatement =
mock(PostgreSQLInsertStatement.class, RETURNS_DEEP_STUBS);
-
when(insertStatement.getTable().getTableName().getIdentifier().getValue()).thenReturn("t");
+ when(insertStatement.getTable()).thenReturn(new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("t"))));
when(insertStatement.getValues()).thenReturn(Collections.emptyList());
when(insertStatement.getCommentSegments()).thenReturn(Collections.emptyList());
+
when(insertStatement.getDatabaseType()).thenReturn(TypedSPILoader.getService(DatabaseType.class,
"PostgreSQL"));
InsertStatementContext result = mock(InsertStatementContext.class);
when(result.getSqlStatement()).thenReturn(insertStatement);
return result;
@@ -132,6 +138,11 @@ class PostgreSQLBatchedStatementsExecutorTest {
when(database.getResourceMetaData().getStorageUnitMetaData().getStorageUnits()).thenReturn(Collections.singletonMap("ds_0",
storageUnit));
when(database.getResourceMetaData().getAllInstanceDataSourceNames()).thenReturn(Collections.singletonList("ds_0"));
when(database.getRuleMetaData()).thenReturn(new
RuleMetaData(Collections.emptyList()));
+ when(database.containsSchema("public")).thenReturn(true);
+ when(database.getSchema("public").containsTable("t")).thenReturn(true);
+
when(database.getSchema("public").getTable("t").getColumnValues()).thenReturn(Arrays.asList(new
ShardingSphereColumn("id", Types.VARCHAR, false, false, false, true, false,
false),
+ new ShardingSphereColumn("col", Types.VARCHAR, false, false,
false, true, false, false)));
+
when(result.getMetaDataContexts().getMetaData().containsDatabase("db")).thenReturn(true);
when(result.getMetaDataContexts().getMetaData().getDatabase("db")).thenReturn(database);
RuleMetaData globalRuleMetaData = new RuleMetaData(Arrays.asList(new
SQLTranslatorRule(new DefaultSQLTranslatorRuleConfigurationBuilder().build()),
new LoggingRule(new
DefaultLoggingRuleConfigurationBuilder().build())));
diff --git
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
index 56b62df345c..c93955d7c3f 100644
---
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
+++
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
@@ -182,11 +182,16 @@ class PostgreSQLComParseExecutorTest {
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getProtocolType()).thenReturn(TypedSPILoader.getService(DatabaseType.class,
"PostgreSQL"));
when(result.getMetaDataContexts().getMetaData().getGlobalRuleMetaData())
.thenReturn(new RuleMetaData(Collections.singleton(new
SQLParserRule(new DefaultSQLParserRuleConfigurationBuilder().build()))));
- ShardingSphereTable table = new ShardingSphereTable("t_test",
Arrays.asList(new ShardingSphereColumn("id", Types.BIGINT, true, false, false,
false, true, false),
+ ShardingSphereTable testTable = new ShardingSphereTable("t_test",
Arrays.asList(new ShardingSphereColumn("id", Types.BIGINT, true, false, false,
false, true, false),
new ShardingSphereColumn("name", Types.VARCHAR, false, false,
false, false, false, false),
new ShardingSphereColumn("age", Types.SMALLINT, false, false,
false, false, true, false)), Collections.emptyList(), Collections.emptyList());
+ ShardingSphereTable sbTestTable = new ShardingSphereTable("sbtest1",
Arrays.asList(new ShardingSphereColumn("id", Types.BIGINT, true, false, false,
false, true, false),
+ new ShardingSphereColumn("k", Types.VARCHAR, false, false,
false, false, false, false),
+ new ShardingSphereColumn("c", Types.VARCHAR, false, false,
false, false, true, false),
+ new ShardingSphereColumn("pad", Types.VARCHAR, false, false,
false, false, true, false)), Collections.emptyList(), Collections.emptyList());
ShardingSphereSchema schema = new ShardingSphereSchema();
- schema.getTables().put("t_test", table);
+ schema.getTables().put("t_test", testTable);
+ schema.getTables().put("sbtest1", sbTestTable);
ShardingSphereDatabase database = new ShardingSphereDatabase("foo_db",
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"),
new ResourceMetaData("foo_db", Collections.emptyMap()), new
RuleMetaData(Collections.emptyList()), Collections.singletonMap("public",
schema));
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db")).thenReturn(database);