This is an automated email from the ASF dual-hosted git repository.
zhaojinchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 3cbd8410aaa Refactor BaseDMLE2EIT and insert select statement parse
logic (#28457)
3cbd8410aaa is described below
commit 3cbd8410aaa362a790a7dfb1bb479dd02e66ed4b
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Mon Sep 18 18:07:38 2023 +0800
Refactor BaseDMLE2EIT and insert select statement parse logic (#28457)
---
.../insert/EncryptInsertDefaultColumnsTokenGenerator.java | 2 +-
.../insert/keygen/engine/GeneratedKeyContextEngine.java | 2 +-
.../context/statement/dml/InsertStatementContext.java | 5 +++--
.../infra/binder/statement/dml/InsertStatementBinder.java | 3 ++-
.../context/statement/dml/InsertStatementContextTest.java | 6 ++++--
.../infra/rewrite/context/SQLRewriteContext.java | 12 +++++++++---
.../infra/rewrite/engine/RouteSQLRewriteEngineTest.java | 6 ++++++
.../mysql/visitor/statement/MySQLStatementVisitor.java | 5 +++--
.../visitor/statement/OpenGaussStatementVisitor.java | 14 +++++++-------
.../visitor/statement/type/OracleDMLStatementVisitor.java | 4 +---
.../visitor/statement/PostgreSQLStatementVisitor.java | 14 +++++++-------
.../test/e2e/engine/type/dml/BaseDMLE2EIT.java | 8 ++++++--
.../asserts/statement/dml/impl/InsertStatementAssert.java | 2 ++
test/it/parser/src/main/resources/case/dml/insert.xml | 10 +++++-----
test/it/parser/src/main/resources/case/dml/replace.xml | 8 ++++----
15 files changed, 61 insertions(+), 40 deletions(-)
diff --git
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
index eccaed00007..c8e8e8e4ea4 100644
---
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
+++
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
@@ -59,7 +59,7 @@ public final class EncryptInsertDefaultColumnsTokenGenerator
implements Optional
@Override
public UseDefaultInsertColumnsToken generateSQLToken(final
InsertStatementContext insertStatementContext) {
- String tableName =
insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
+ String tableName =
Optional.ofNullable(insertStatementContext.getSqlStatement().getTable()).map(optional
-> optional.getTableName().getIdentifier().getValue()).orElse("");
Optional<UseDefaultInsertColumnsToken> previousSQLToken =
findInsertColumnsToken();
if (previousSQLToken.isPresent()) {
processPreviousSQLToken(previousSQLToken.get(),
insertStatementContext, tableName);
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java
index edcb1764881..fa7921f4cdc 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java
@@ -52,7 +52,7 @@ public final class GeneratedKeyContextEngine {
* @return generate key context
*/
public Optional<GeneratedKeyContext> createGenerateKeyContext(final
List<String> insertColumnNames, final List<List<ExpressionSegment>>
valueExpressions, final List<Object> params) {
- String tableName =
insertStatement.getTable().getTableName().getIdentifier().getValue();
+ String tableName =
Optional.ofNullable(insertStatement.getTable()).map(optional ->
optional.getTableName().getIdentifier().getValue()).orElse("");
return findGenerateKeyColumn(tableName).map(optional ->
containsGenerateKey(insertColumnNames, optional)
? findGeneratedKey(insertColumnNames, valueExpressions,
params, optional)
: new GeneratedKeyContext(optional, true));
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java
index fb5a29bb906..2099a80a8f2 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java
@@ -94,7 +94,8 @@ public final class InsertStatementContext extends
CommonSQLStatementContext impl
onDuplicateKeyUpdateValueContext =
getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
tablesContext = new TablesContext(getAllSimpleTableSegments(),
getDatabaseType());
ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
- columnNames = containsInsertColumns() ? insertColumnNames :
schema.getVisibleColumnNames(sqlStatement.getTable().getTableName().getIdentifier().getValue().toLowerCase());
+ columnNames = containsInsertColumns() ? insertColumnNames
+ : Optional.ofNullable(sqlStatement.getTable()).map(optional ->
schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList);
generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement,
schema).createGenerateKeyContext(insertColumnNames,
getAllValueExpressions(sqlStatement), params).orElse(null);
}
@@ -166,7 +167,7 @@ public final class InsertStatementContext extends
CommonSQLStatementContext impl
for (InsertValueContext each : insertValueContexts) {
result.add(each.getParameters());
}
- if (null != insertSelectContext) {
+ if (null != insertSelectContext &&
!insertSelectContext.getParameters().isEmpty()) {
result.add(insertSelectContext.getParameters());
}
return result;
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
index 8d7f8eedf0a..efdce0773da 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
@@ -36,6 +36,7 @@ import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
+import java.util.Optional;
/**
* Select statement binder.
@@ -54,7 +55,7 @@ public final class InsertStatementBinder implements
SQLStatementBinder<InsertSta
SQLStatementBinderContext statementBinderContext = new
SQLStatementBinderContext(metaData, defaultDatabaseName,
sqlStatement.getDatabaseType(), sqlStatement.getVariableNames());
statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
Map<String, TableSegmentBinderContext> tableBinderContexts = new
LinkedHashMap<>();
- result.setTable(SimpleTableSegmentBinder.bind(sqlStatement.getTable(),
statementBinderContext, tableBinderContexts));
+ Optional.ofNullable(sqlStatement.getTable()).ifPresent(optional ->
result.setTable(SimpleTableSegmentBinder.bind(optional, statementBinderContext,
tableBinderContexts)));
if (sqlStatement.getInsertColumns().isPresent() &&
!sqlStatement.getInsertColumns().get().getColumns().isEmpty()) {
result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(),
statementBinderContext, tableBinderContexts));
} else {
diff --git
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java
index 71787a68a5b..0d34002d2ec 100644
---
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java
+++
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java
@@ -24,6 +24,7 @@ import
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import
org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
+import
org.apache.shardingsphere.sql.parser.sql.common.enums.ParameterMarkerType;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
@@ -161,15 +162,16 @@ class InsertStatementContextTest {
void assertInsertSelect() {
InsertStatement insertStatement = new MySQLInsertStatement();
SelectStatement selectStatement = new MySQLSelectStatement();
+ selectStatement.addParameterMarkerSegments(Collections.singleton(new
ParameterMarkerExpressionSegment(0, 0, 0, ParameterMarkerType.QUESTION)));
selectStatement.setProjections(new ProjectionsSegment(0, 0));
SubquerySegment insertSelect = new SubquerySegment(0, 0,
selectStatement);
insertStatement.setInsertSelect(insertSelect);
insertStatement.setTable(new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("tbl"))));
InsertStatementContext actual =
createInsertStatementContext(Collections.singletonList("param"),
insertStatement);
actual.setUpParameters(Collections.singletonList("param"));
- assertThat(actual.getInsertSelectContext().getParameterCount(), is(0));
+ assertThat(actual.getInsertSelectContext().getParameterCount(), is(1));
assertThat(actual.getGroupedParameters().size(), is(1));
- assertThat(actual.getGroupedParameters().iterator().next(),
is(Collections.emptyList()));
+ assertThat(actual.getGroupedParameters().iterator().next(),
is(Collections.singletonList("param")));
}
private void setUpInsertValues(final InsertStatement insertStatement) {
diff --git
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
index 3ea33d8edde..7625e227500 100644
---
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
+++
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
@@ -69,12 +69,18 @@ public final class SQLRewriteContext {
if (!hintValueContext.isSkipSQLRewrite()) {
addSQLTokenGenerators(new
DefaultTokenGeneratorBuilder(sqlStatementContext).getSQLTokenGenerators());
}
- parameterBuilder = sqlStatementContext instanceof
InsertStatementContext && null == ((InsertStatementContext)
sqlStatementContext).getInsertSelectContext()
- ? new GroupedParameterBuilder(
- ((InsertStatementContext)
sqlStatementContext).getGroupedParameters(), ((InsertStatementContext)
sqlStatementContext).getOnDuplicateKeyUpdateParameters())
+ parameterBuilder = containsInsertValues(sqlStatementContext)
+ ? new GroupedParameterBuilder(((InsertStatementContext)
sqlStatementContext).getGroupedParameters(), ((InsertStatementContext)
sqlStatementContext).getOnDuplicateKeyUpdateParameters())
: new StandardParameterBuilder(params);
}
+ private boolean containsInsertValues(final SQLStatementContext
sqlStatementContext) {
+ if (!(sqlStatementContext instanceof InsertStatementContext)) {
+ return false;
+ }
+ return null == ((InsertStatementContext)
sqlStatementContext).getInsertSelectContext();
+ }
+
/**
* Add SQL token generators.
*
diff --git
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
index d1b6b6be822..9cb4985b2be 100644
---
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
+++
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
@@ -89,7 +89,9 @@ class RouteSQLRewriteEngineTest {
void assertRewriteWithGroupedParameterBuilderForBroadcast() {
InsertStatementContext statementContext =
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable)
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
+ when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
+
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(mockDatabase(), statementContext,
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1),
mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"),
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
@@ -107,7 +109,9 @@ class RouteSQLRewriteEngineTest {
void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() {
InsertStatementContext statementContext =
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable)
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
+ when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
+
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(mockDatabase(), statementContext,
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1),
mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"),
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
@@ -127,7 +131,9 @@ class RouteSQLRewriteEngineTest {
void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() {
InsertStatementContext statementContext =
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable)
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
+ when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
+
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(mockDatabase(), statementContext,
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1),
mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"),
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
diff --git
a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
index fccea89e6fe..ae5ce5b7a28 100644
---
a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
+++
b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
@@ -177,9 +177,9 @@ import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpres
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.MatchAgainstExpression;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.NotExpression;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.RowExpression;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.UnaryOperationExpression;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
@@ -1361,6 +1361,7 @@ public abstract class MySQLStatementVisitor extends
MySQLStatementBaseVisitor<AS
@Override
public ASTNode visitInsertSelectClause(final InsertSelectClauseContext
ctx) {
MySQLInsertStatement result = new MySQLInsertStatement();
+ result.setInsertSelect(createInsertSelectSegment(ctx));
if (null != ctx.LP_()) {
if (null != ctx.fields()) {
result.setInsertColumns(new
InsertColumnsSegment(ctx.LP_().getSymbol().getStartIndex(),
ctx.RP_().getSymbol().getStopIndex(), createInsertColumns(ctx.fields())));
@@ -1370,12 +1371,12 @@ public abstract class MySQLStatementVisitor extends
MySQLStatementBaseVisitor<AS
} else {
result.setInsertColumns(new
InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() -
1, Collections.emptyList()));
}
- result.setInsertSelect(createInsertSelectSegment(ctx));
return result;
}
private SubquerySegment createInsertSelectSegment(final
InsertSelectClauseContext ctx) {
MySQLSelectStatement selectStatement = (MySQLSelectStatement)
visit(ctx.select());
+
selectStatement.getParameterMarkerSegments().addAll(getParameterMarkerSegments());
return new SubquerySegment(ctx.select().start.getStartIndex(),
ctx.select().stop.getStopIndex(), selectStatement);
}
diff --git
a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
index a06f67b24ec..b185b4c3a98 100644
---
a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
+++
b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
@@ -751,6 +751,13 @@ public abstract class OpenGaussStatementVisitor extends
OpenGaussStatementBaseVi
@Override
public ASTNode visitInsertRest(final InsertRestContext ctx) {
OpenGaussInsertStatement result = new OpenGaussInsertStatement();
+ ValuesClauseContext valuesClause =
ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
+ if (null == valuesClause) {
+ OpenGaussSelectStatement selectStatement =
(OpenGaussSelectStatement) visit(ctx.select());
+ result.setInsertSelect(new
SubquerySegment(ctx.select().start.getStartIndex(),
ctx.select().stop.getStopIndex(), selectStatement));
+ } else {
+
result.getValues().addAll(createInsertValuesSegments(valuesClause));
+ }
if (null == ctx.insertColumnList()) {
result.setInsertColumns(new
InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() -
1, Collections.emptyList()));
} else {
@@ -759,13 +766,6 @@ public abstract class OpenGaussStatementVisitor extends
OpenGaussStatementBaseVi
InsertColumnsSegment insertColumnsSegment = new
InsertColumnsSegment(insertColumns.start.getStartIndex() - 1,
insertColumns.stop.getStopIndex() + 1, columns.getValue());
result.setInsertColumns(insertColumnsSegment);
}
- ValuesClauseContext valuesClause =
ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
- if (null == valuesClause) {
- OpenGaussSelectStatement selectStatement =
(OpenGaussSelectStatement) visit(ctx.select());
- result.setInsertSelect(new
SubquerySegment(ctx.select().start.getStartIndex(),
ctx.select().stop.getStopIndex(), selectStatement));
- } else {
-
result.getValues().addAll(createInsertValuesSegments(valuesClause));
- }
return result;
}
diff --git
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
index b35aa4e345c..39c6223ae87 100644
---
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
+++
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
@@ -326,6 +326,7 @@ public final class OracleDMLStatementVisitor extends
OracleStatementVisitor impl
@Override
public ASTNode visitInsertMultiTable(final InsertMultiTableContext ctx) {
OracleInsertStatement result = new OracleInsertStatement();
+ result.setInsertSelect(new
SubquerySegment(ctx.selectSubquery().start.getStartIndex(),
ctx.selectSubquery().stop.getStopIndex(), (OracleSelectStatement)
visit(ctx.selectSubquery())));
result.setMultiTableInsertType(null != ctx.conditionalInsertClause()
&& null != ctx.conditionalInsertClause().FIRST() ? MultiTableInsertType.FIRST :
MultiTableInsertType.ALL);
List<MultiTableElementContext> multiTableElementContexts =
ctx.multiTableElement();
if (null != multiTableElementContexts &&
!multiTableElementContexts.isEmpty()) {
@@ -336,9 +337,6 @@ public final class OracleDMLStatementVisitor extends
OracleStatementVisitor impl
} else {
result.setMultiTableConditionalIntoSegment((MultiTableConditionalIntoSegment)
visit(ctx.conditionalInsertClause()));
}
- OracleSelectStatement subquery = (OracleSelectStatement)
visit(ctx.selectSubquery());
- SubquerySegment subquerySegment = new
SubquerySegment(ctx.selectSubquery().start.getStartIndex(),
ctx.selectSubquery().stop.getStopIndex(), subquery);
- result.setInsertSelect(subquerySegment);
result.addParameterMarkerSegments(getParameterMarkerSegments());
return result;
}
diff --git
a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
index 4cdde01be62..57e5005b2f5 100644
---
a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
+++
b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
@@ -756,6 +756,13 @@ public abstract class PostgreSQLStatementVisitor extends
PostgreSQLStatementPars
@Override
public ASTNode visitInsertRest(final InsertRestContext ctx) {
PostgreSQLInsertStatement result = new PostgreSQLInsertStatement();
+ ValuesClauseContext valuesClause =
ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
+ if (null == valuesClause) {
+ PostgreSQLSelectStatement selectStatement =
(PostgreSQLSelectStatement) visit(ctx.select());
+ result.setInsertSelect(new
SubquerySegment(ctx.select().start.getStartIndex(),
ctx.select().stop.getStopIndex(), selectStatement));
+ } else {
+
result.getValues().addAll(createInsertValuesSegments(valuesClause));
+ }
if (null == ctx.insertColumnList()) {
result.setInsertColumns(new
InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() -
1, Collections.emptyList()));
} else {
@@ -764,13 +771,6 @@ public abstract class PostgreSQLStatementVisitor extends
PostgreSQLStatementPars
InsertColumnsSegment insertColumnsSegment = new
InsertColumnsSegment(insertColumns.start.getStartIndex() - 1,
insertColumns.stop.getStopIndex() + 1, columns.getValue());
result.setInsertColumns(insertColumnsSegment);
}
- ValuesClauseContext valuesClause =
ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
- if (null == valuesClause) {
- PostgreSQLSelectStatement selectStatement =
(PostgreSQLSelectStatement) visit(ctx.select());
- result.setInsertSelect(new
SubquerySegment(ctx.select().start.getStartIndex(),
ctx.select().stop.getStopIndex(), selectStatement));
- } else {
-
result.getValues().addAll(createInsertValuesSegments(valuesClause));
- }
return result;
}
diff --git
a/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java
b/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java
index c1dc45a7633..e1a71a5cf6d 100644
---
a/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java
+++
b/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java
@@ -86,9 +86,13 @@ public abstract class BaseDMLE2EIT {
}
protected final void assertDataSet(final AssertionTestParameter testParam,
final SingleE2EContainerComposer containerComposer, final int
actualUpdateCount) throws SQLException {
- assertThat("Only support single table for DML.",
containerComposer.getDataSet().getMetaDataList().size(), is(1));
assertThat(actualUpdateCount,
is(containerComposer.getDataSet().getUpdateCount()));
- DataSetMetaData expectedDataSetMetaData =
containerComposer.getDataSet().getMetaDataList().get(0);
+ for (DataSetMetaData each :
containerComposer.getDataSet().getMetaDataList()) {
+ assertDataSet(testParam, containerComposer, each);
+ }
+ }
+
+ private void assertDataSet(final AssertionTestParameter testParam, final
SingleE2EContainerComposer containerComposer, final DataSetMetaData
expectedDataSetMetaData) throws SQLException {
for (String each :
InlineExpressionParserFactory.newInstance().splitAndEvaluate(expectedDataSetMetaData.getDataNodes()))
{
DataNode dataNode = new DataNode(each);
DataSource dataSource =
containerComposer.getActualDataSourceMap().get(dataNode.getDataSourceName());
diff --git
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java
index 2cdb42d37a1..1df0c82214a 100644
---
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java
+++
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java
@@ -36,6 +36,7 @@ import
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.ins
import
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.insert.MultiTableInsertIntoClauseAssert;
import
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.insert.OnDuplicateKeyColumnsAssert;
import
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.output.OutputClauseAssert;
+import
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.parameter.ParameterMarkerAssert;
import
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.returning.ReturningClauseAssert;
import
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.set.SetClauseAssert;
import
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.table.TableAssert;
@@ -119,6 +120,7 @@ public final class InsertStatementAssert {
assertFalse(actual.getInsertSelect().isPresent(),
assertContext.getText("Actual insert select segment should not exist."));
} else {
assertTrue(actual.getInsertSelect().isPresent(),
assertContext.getText("Actual insert select segment should exist."));
+ ParameterMarkerAssert.assertCount(assertContext,
actual.getInsertSelect().get().getSelect().getParameterCount(),
expected.getSelectTestCase().getParameters().size());
SelectStatementAssert.assertIs(assertContext,
actual.getInsertSelect().get().getSelect(), expected.getSelectTestCase());
}
}
diff --git a/test/it/parser/src/main/resources/case/dml/insert.xml
b/test/it/parser/src/main/resources/case/dml/insert.xml
index a7b8ca3c729..34c2b6b6107 100644
--- a/test/it/parser/src/main/resources/case/dml/insert.xml
+++ b/test/it/parser/src/main/resources/case/dml/insert.xml
@@ -1493,7 +1493,7 @@
<column name="user_id" start-index="31" stop-index="37" />
<column name="status" start-index="40" stop-index="45" />
</columns>
- <select>
+ <select parameters="100">
<from>
<simple-table name="t_order" start-index="86" stop-index="92"
/>
</from>
@@ -1522,7 +1522,7 @@
<insert sql-case-id="insert_select_without_columns" parameters="100">
<table name="t_order" start-index="12" stop-index="18" />
<columns start-index="19" stop-index="19" />
- <select>
+ <select parameters="100">
<from>
<simple-table name="t_order" start-index="58" stop-index="64"
/>
</from>
@@ -1557,7 +1557,7 @@
<column name="status" start-index="53" stop-index="58" />
<column name="creation_date" start-index="61" stop-index="73" />
</columns>
- <select>
+ <select parameters="100">
<from>
<simple-table name="t_order_item" start-index="139"
stop-index="150" />
</from>
@@ -1593,7 +1593,7 @@
<column name="status" start-index="44" stop-index="49" />
<column name="creation_date" start-index="52" stop-index="64" />
</columns>
- <select>
+ <select parameters="100">
<from>
<simple-table name="t_order_item" start-index="121"
stop-index="132" />
</from>
@@ -1627,7 +1627,7 @@
<column name="user_id" start-index="30" stop-index="36" />
<column name="status" start-index="39" stop-index="44" />
</columns>
- <select>
+ <select parameters="100">
<from>
<simple-table name="t_order" start-index="85" stop-index="91"
/>
</from>
diff --git a/test/it/parser/src/main/resources/case/dml/replace.xml
b/test/it/parser/src/main/resources/case/dml/replace.xml
index e3dccb901e7..415a09c21c3 100644
--- a/test/it/parser/src/main/resources/case/dml/replace.xml
+++ b/test/it/parser/src/main/resources/case/dml/replace.xml
@@ -867,7 +867,7 @@
<column name="user_id" start-index="32" stop-index="38" />
<column name="status" start-index="41" stop-index="46" />
</columns>
- <select>
+ <select parameters="100">
<from>
<simple-table name="t_order" start-index="87" stop-index="93"
/>
</from>
@@ -896,7 +896,7 @@
<insert sql-case-id="replace_select_without_columns" parameters="100">
<table name="t_order" start-index="13" stop-index="19" />
<columns start-index="20" stop-index="20" />
- <select>
+ <select parameters="100">
<from>
<simple-table name="t_order" start-index="59" stop-index="65"
/>
</from>
@@ -931,7 +931,7 @@
<column name="status" start-index="54" stop-index="59" />
<column name="creation_date" start-index="62" stop-index="74" />
</columns>
- <select>
+ <select parameters="100">
<from>
<simple-table name="t_order_item" start-index="140"
stop-index="151" />
</from>
@@ -967,7 +967,7 @@
<column name="status" start-index="45" stop-index="50" />
<column name="creation_date" start-index="53" stop-index="65" />
</columns>
- <select>
+ <select parameters="100">
<from>
<simple-table name="t_order_item" start-index="122"
stop-index="133" />
</from>