This is an automated email from the ASF dual-hosted git repository.

zhaojinchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 3cbd8410aaa Refactor BaseDMLE2EIT and insert select statement parse 
logic (#28457)
3cbd8410aaa is described below

commit 3cbd8410aaa362a790a7dfb1bb479dd02e66ed4b
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Mon Sep 18 18:07:38 2023 +0800

    Refactor BaseDMLE2EIT and insert select statement parse logic (#28457)
---
 .../insert/EncryptInsertDefaultColumnsTokenGenerator.java  |  2 +-
 .../insert/keygen/engine/GeneratedKeyContextEngine.java    |  2 +-
 .../context/statement/dml/InsertStatementContext.java      |  5 +++--
 .../infra/binder/statement/dml/InsertStatementBinder.java  |  3 ++-
 .../context/statement/dml/InsertStatementContextTest.java  |  6 ++++--
 .../infra/rewrite/context/SQLRewriteContext.java           | 12 +++++++++---
 .../infra/rewrite/engine/RouteSQLRewriteEngineTest.java    |  6 ++++++
 .../mysql/visitor/statement/MySQLStatementVisitor.java     |  5 +++--
 .../visitor/statement/OpenGaussStatementVisitor.java       | 14 +++++++-------
 .../visitor/statement/type/OracleDMLStatementVisitor.java  |  4 +---
 .../visitor/statement/PostgreSQLStatementVisitor.java      | 14 +++++++-------
 .../test/e2e/engine/type/dml/BaseDMLE2EIT.java             |  8 ++++++--
 .../asserts/statement/dml/impl/InsertStatementAssert.java  |  2 ++
 test/it/parser/src/main/resources/case/dml/insert.xml      | 10 +++++-----
 test/it/parser/src/main/resources/case/dml/replace.xml     |  8 ++++----
 15 files changed, 61 insertions(+), 40 deletions(-)

diff --git 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
index eccaed00007..c8e8e8e4ea4 100644
--- 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
+++ 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java
@@ -59,7 +59,7 @@ public final class EncryptInsertDefaultColumnsTokenGenerator 
implements Optional
     
     @Override
     public UseDefaultInsertColumnsToken generateSQLToken(final 
InsertStatementContext insertStatementContext) {
-        String tableName = 
insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
+        String tableName = 
Optional.ofNullable(insertStatementContext.getSqlStatement().getTable()).map(optional
 -> optional.getTableName().getIdentifier().getValue()).orElse("");
         Optional<UseDefaultInsertColumnsToken> previousSQLToken = 
findInsertColumnsToken();
         if (previousSQLToken.isPresent()) {
             processPreviousSQLToken(previousSQLToken.get(), 
insertStatementContext, tableName);
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java
index edcb1764881..fa7921f4cdc 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java
@@ -52,7 +52,7 @@ public final class GeneratedKeyContextEngine {
      * @return generate key context
      */
     public Optional<GeneratedKeyContext> createGenerateKeyContext(final 
List<String> insertColumnNames, final List<List<ExpressionSegment>> 
valueExpressions, final List<Object> params) {
-        String tableName = 
insertStatement.getTable().getTableName().getIdentifier().getValue();
+        String tableName = 
Optional.ofNullable(insertStatement.getTable()).map(optional -> 
optional.getTableName().getIdentifier().getValue()).orElse("");
         return findGenerateKeyColumn(tableName).map(optional -> 
containsGenerateKey(insertColumnNames, optional)
                 ? findGeneratedKey(insertColumnNames, valueExpressions, 
params, optional)
                 : new GeneratedKeyContext(optional, true));
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java
index fb5a29bb906..2099a80a8f2 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java
@@ -94,7 +94,8 @@ public final class InsertStatementContext extends 
CommonSQLStatementContext impl
         onDuplicateKeyUpdateValueContext = 
getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
         tablesContext = new TablesContext(getAllSimpleTableSegments(), 
getDatabaseType());
         ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
-        columnNames = containsInsertColumns() ? insertColumnNames : 
schema.getVisibleColumnNames(sqlStatement.getTable().getTableName().getIdentifier().getValue().toLowerCase());
+        columnNames = containsInsertColumns() ? insertColumnNames
+                : Optional.ofNullable(sqlStatement.getTable()).map(optional -> 
schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList);
         generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, 
schema).createGenerateKeyContext(insertColumnNames, 
getAllValueExpressions(sqlStatement), params).orElse(null);
     }
     
@@ -166,7 +167,7 @@ public final class InsertStatementContext extends 
CommonSQLStatementContext impl
         for (InsertValueContext each : insertValueContexts) {
             result.add(each.getParameters());
         }
-        if (null != insertSelectContext) {
+        if (null != insertSelectContext && 
!insertSelectContext.getParameters().isEmpty()) {
             result.add(insertSelectContext.getParameters());
         }
         return result;
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
index 8d7f8eedf0a..efdce0773da 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java
@@ -36,6 +36,7 @@ import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.LinkedList;
 import java.util.Map;
+import java.util.Optional;
 
 /**
  * Select statement binder.
@@ -54,7 +55,7 @@ public final class InsertStatementBinder implements 
SQLStatementBinder<InsertSta
         SQLStatementBinderContext statementBinderContext = new 
SQLStatementBinderContext(metaData, defaultDatabaseName, 
sqlStatement.getDatabaseType(), sqlStatement.getVariableNames());
         
statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
         Map<String, TableSegmentBinderContext> tableBinderContexts = new 
LinkedHashMap<>();
-        result.setTable(SimpleTableSegmentBinder.bind(sqlStatement.getTable(), 
statementBinderContext, tableBinderContexts));
+        Optional.ofNullable(sqlStatement.getTable()).ifPresent(optional -> 
result.setTable(SimpleTableSegmentBinder.bind(optional, statementBinderContext, 
tableBinderContexts)));
         if (sqlStatement.getInsertColumns().isPresent() && 
!sqlStatement.getInsertColumns().get().getColumns().isEmpty()) {
             
result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(),
 statementBinderContext, tableBinderContexts));
         } else {
diff --git 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java
 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java
index 71787a68a5b..0d34002d2ec 100644
--- 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java
+++ 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java
@@ -24,6 +24,7 @@ import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import 
org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
 import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
 import 
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
+import 
org.apache.shardingsphere.sql.parser.sql.common.enums.ParameterMarkerType;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
@@ -161,15 +162,16 @@ class InsertStatementContextTest {
     void assertInsertSelect() {
         InsertStatement insertStatement = new MySQLInsertStatement();
         SelectStatement selectStatement = new MySQLSelectStatement();
+        selectStatement.addParameterMarkerSegments(Collections.singleton(new 
ParameterMarkerExpressionSegment(0, 0, 0, ParameterMarkerType.QUESTION)));
         selectStatement.setProjections(new ProjectionsSegment(0, 0));
         SubquerySegment insertSelect = new SubquerySegment(0, 0, 
selectStatement);
         insertStatement.setInsertSelect(insertSelect);
         insertStatement.setTable(new SimpleTableSegment(new 
TableNameSegment(0, 0, new IdentifierValue("tbl"))));
         InsertStatementContext actual = 
createInsertStatementContext(Collections.singletonList("param"), 
insertStatement);
         actual.setUpParameters(Collections.singletonList("param"));
-        assertThat(actual.getInsertSelectContext().getParameterCount(), is(0));
+        assertThat(actual.getInsertSelectContext().getParameterCount(), is(1));
         assertThat(actual.getGroupedParameters().size(), is(1));
-        assertThat(actual.getGroupedParameters().iterator().next(), 
is(Collections.emptyList()));
+        assertThat(actual.getGroupedParameters().iterator().next(), 
is(Collections.singletonList("param")));
     }
     
     private void setUpInsertValues(final InsertStatement insertStatement) {
diff --git 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
index 3ea33d8edde..7625e227500 100644
--- 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
+++ 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
@@ -69,12 +69,18 @@ public final class SQLRewriteContext {
         if (!hintValueContext.isSkipSQLRewrite()) {
             addSQLTokenGenerators(new 
DefaultTokenGeneratorBuilder(sqlStatementContext).getSQLTokenGenerators());
         }
-        parameterBuilder = sqlStatementContext instanceof 
InsertStatementContext && null == ((InsertStatementContext) 
sqlStatementContext).getInsertSelectContext()
-                ? new GroupedParameterBuilder(
-                        ((InsertStatementContext) 
sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) 
sqlStatementContext).getOnDuplicateKeyUpdateParameters())
+        parameterBuilder = containsInsertValues(sqlStatementContext)
+                ? new GroupedParameterBuilder(((InsertStatementContext) 
sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) 
sqlStatementContext).getOnDuplicateKeyUpdateParameters())
                 : new StandardParameterBuilder(params);
     }
     
+    private boolean containsInsertValues(final SQLStatementContext 
sqlStatementContext) {
+        if (!(sqlStatementContext instanceof InsertStatementContext)) {
+            return false;
+        }
+        return null == ((InsertStatementContext) 
sqlStatementContext).getInsertSelectContext();
+    }
+    
     /**
      * Add SQL token generators.
      *
diff --git 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
index d1b6b6be822..9cb4985b2be 100644
--- 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
+++ 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
@@ -89,7 +89,9 @@ class RouteSQLRewriteEngineTest {
     void assertRewriteWithGroupedParameterBuilderForBroadcast() {
         InsertStatementContext statementContext = 
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
+        when(statementContext.getInsertSelectContext()).thenReturn(null);
         
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
+        
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
         SQLRewriteContext sqlRewriteContext =
                 new SQLRewriteContext(mockDatabase(), statementContext, 
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
@@ -107,7 +109,9 @@ class RouteSQLRewriteEngineTest {
     void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() {
         InsertStatementContext statementContext = 
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
+        when(statementContext.getInsertSelectContext()).thenReturn(null);
         
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
+        
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
         SQLRewriteContext sqlRewriteContext =
                 new SQLRewriteContext(mockDatabase(), statementContext, 
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
@@ -127,7 +131,9 @@ class RouteSQLRewriteEngineTest {
     void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() {
         InsertStatementContext statementContext = 
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
+        when(statementContext.getInsertSelectContext()).thenReturn(null);
         
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
+        
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
         SQLRewriteContext sqlRewriteContext =
                 new SQLRewriteContext(mockDatabase(), statementContext, 
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
diff --git 
a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
 
b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
index fccea89e6fe..ae5ce5b7a28 100644
--- 
a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
+++ 
b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
@@ -177,9 +177,9 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpres
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.MatchAgainstExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.NotExpression;
-import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.RowExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.UnaryOperationExpression;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
@@ -1361,6 +1361,7 @@ public abstract class MySQLStatementVisitor extends 
MySQLStatementBaseVisitor<AS
     @Override
     public ASTNode visitInsertSelectClause(final InsertSelectClauseContext 
ctx) {
         MySQLInsertStatement result = new MySQLInsertStatement();
+        result.setInsertSelect(createInsertSelectSegment(ctx));
         if (null != ctx.LP_()) {
             if (null != ctx.fields()) {
                 result.setInsertColumns(new 
InsertColumnsSegment(ctx.LP_().getSymbol().getStartIndex(), 
ctx.RP_().getSymbol().getStopIndex(), createInsertColumns(ctx.fields())));
@@ -1370,12 +1371,12 @@ public abstract class MySQLStatementVisitor extends 
MySQLStatementBaseVisitor<AS
         } else {
             result.setInsertColumns(new 
InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 
1, Collections.emptyList()));
         }
-        result.setInsertSelect(createInsertSelectSegment(ctx));
         return result;
     }
     
     private SubquerySegment createInsertSelectSegment(final 
InsertSelectClauseContext ctx) {
         MySQLSelectStatement selectStatement = (MySQLSelectStatement) 
visit(ctx.select());
+        
selectStatement.getParameterMarkerSegments().addAll(getParameterMarkerSegments());
         return new SubquerySegment(ctx.select().start.getStartIndex(), 
ctx.select().stop.getStopIndex(), selectStatement);
     }
     
diff --git 
a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
 
b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
index a06f67b24ec..b185b4c3a98 100644
--- 
a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
+++ 
b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
@@ -751,6 +751,13 @@ public abstract class OpenGaussStatementVisitor extends 
OpenGaussStatementBaseVi
     @Override
     public ASTNode visitInsertRest(final InsertRestContext ctx) {
         OpenGaussInsertStatement result = new OpenGaussInsertStatement();
+        ValuesClauseContext valuesClause = 
ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
+        if (null == valuesClause) {
+            OpenGaussSelectStatement selectStatement = 
(OpenGaussSelectStatement) visit(ctx.select());
+            result.setInsertSelect(new 
SubquerySegment(ctx.select().start.getStartIndex(), 
ctx.select().stop.getStopIndex(), selectStatement));
+        } else {
+            
result.getValues().addAll(createInsertValuesSegments(valuesClause));
+        }
         if (null == ctx.insertColumnList()) {
             result.setInsertColumns(new 
InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 
1, Collections.emptyList()));
         } else {
@@ -759,13 +766,6 @@ public abstract class OpenGaussStatementVisitor extends 
OpenGaussStatementBaseVi
             InsertColumnsSegment insertColumnsSegment = new 
InsertColumnsSegment(insertColumns.start.getStartIndex() - 1, 
insertColumns.stop.getStopIndex() + 1, columns.getValue());
             result.setInsertColumns(insertColumnsSegment);
         }
-        ValuesClauseContext valuesClause = 
ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
-        if (null == valuesClause) {
-            OpenGaussSelectStatement selectStatement = 
(OpenGaussSelectStatement) visit(ctx.select());
-            result.setInsertSelect(new 
SubquerySegment(ctx.select().start.getStartIndex(), 
ctx.select().stop.getStopIndex(), selectStatement));
-        } else {
-            
result.getValues().addAll(createInsertValuesSegments(valuesClause));
-        }
         return result;
     }
     
diff --git 
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
 
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
index b35aa4e345c..39c6223ae87 100644
--- 
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
+++ 
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
@@ -326,6 +326,7 @@ public final class OracleDMLStatementVisitor extends 
OracleStatementVisitor impl
     @Override
     public ASTNode visitInsertMultiTable(final InsertMultiTableContext ctx) {
         OracleInsertStatement result = new OracleInsertStatement();
+        result.setInsertSelect(new 
SubquerySegment(ctx.selectSubquery().start.getStartIndex(), 
ctx.selectSubquery().stop.getStopIndex(), (OracleSelectStatement) 
visit(ctx.selectSubquery())));
         result.setMultiTableInsertType(null != ctx.conditionalInsertClause() 
&& null != ctx.conditionalInsertClause().FIRST() ? MultiTableInsertType.FIRST : 
MultiTableInsertType.ALL);
         List<MultiTableElementContext> multiTableElementContexts = 
ctx.multiTableElement();
         if (null != multiTableElementContexts && 
!multiTableElementContexts.isEmpty()) {
@@ -336,9 +337,6 @@ public final class OracleDMLStatementVisitor extends 
OracleStatementVisitor impl
         } else {
             
result.setMultiTableConditionalIntoSegment((MultiTableConditionalIntoSegment) 
visit(ctx.conditionalInsertClause()));
         }
-        OracleSelectStatement subquery = (OracleSelectStatement) 
visit(ctx.selectSubquery());
-        SubquerySegment subquerySegment = new 
SubquerySegment(ctx.selectSubquery().start.getStartIndex(), 
ctx.selectSubquery().stop.getStopIndex(), subquery);
-        result.setInsertSelect(subquerySegment);
         result.addParameterMarkerSegments(getParameterMarkerSegments());
         return result;
     }
diff --git 
a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
 
b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
index 4cdde01be62..57e5005b2f5 100644
--- 
a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
+++ 
b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
@@ -756,6 +756,13 @@ public abstract class PostgreSQLStatementVisitor extends 
PostgreSQLStatementPars
     @Override
     public ASTNode visitInsertRest(final InsertRestContext ctx) {
         PostgreSQLInsertStatement result = new PostgreSQLInsertStatement();
+        ValuesClauseContext valuesClause = 
ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
+        if (null == valuesClause) {
+            PostgreSQLSelectStatement selectStatement = 
(PostgreSQLSelectStatement) visit(ctx.select());
+            result.setInsertSelect(new 
SubquerySegment(ctx.select().start.getStartIndex(), 
ctx.select().stop.getStopIndex(), selectStatement));
+        } else {
+            
result.getValues().addAll(createInsertValuesSegments(valuesClause));
+        }
         if (null == ctx.insertColumnList()) {
             result.setInsertColumns(new 
InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 
1, Collections.emptyList()));
         } else {
@@ -764,13 +771,6 @@ public abstract class PostgreSQLStatementVisitor extends 
PostgreSQLStatementPars
             InsertColumnsSegment insertColumnsSegment = new 
InsertColumnsSegment(insertColumns.start.getStartIndex() - 1, 
insertColumns.stop.getStopIndex() + 1, columns.getValue());
             result.setInsertColumns(insertColumnsSegment);
         }
-        ValuesClauseContext valuesClause = 
ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
-        if (null == valuesClause) {
-            PostgreSQLSelectStatement selectStatement = 
(PostgreSQLSelectStatement) visit(ctx.select());
-            result.setInsertSelect(new 
SubquerySegment(ctx.select().start.getStartIndex(), 
ctx.select().stop.getStopIndex(), selectStatement));
-        } else {
-            
result.getValues().addAll(createInsertValuesSegments(valuesClause));
-        }
         return result;
     }
     
diff --git 
a/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java
 
b/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java
index c1dc45a7633..e1a71a5cf6d 100644
--- 
a/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java
+++ 
b/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java
@@ -86,9 +86,13 @@ public abstract class BaseDMLE2EIT {
     }
     
     protected final void assertDataSet(final AssertionTestParameter testParam, 
final SingleE2EContainerComposer containerComposer, final int 
actualUpdateCount) throws SQLException {
-        assertThat("Only support single table for DML.", 
containerComposer.getDataSet().getMetaDataList().size(), is(1));
         assertThat(actualUpdateCount, 
is(containerComposer.getDataSet().getUpdateCount()));
-        DataSetMetaData expectedDataSetMetaData = 
containerComposer.getDataSet().getMetaDataList().get(0);
+        for (DataSetMetaData each : 
containerComposer.getDataSet().getMetaDataList()) {
+            assertDataSet(testParam, containerComposer, each);
+        }
+    }
+    
+    private void assertDataSet(final AssertionTestParameter testParam, final 
SingleE2EContainerComposer containerComposer, final DataSetMetaData 
expectedDataSetMetaData) throws SQLException {
         for (String each : 
InlineExpressionParserFactory.newInstance().splitAndEvaluate(expectedDataSetMetaData.getDataNodes()))
 {
             DataNode dataNode = new DataNode(each);
             DataSource dataSource = 
containerComposer.getActualDataSourceMap().get(dataNode.getDataSourceName());
diff --git 
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java
 
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java
index 2cdb42d37a1..1df0c82214a 100644
--- 
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java
+++ 
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java
@@ -36,6 +36,7 @@ import 
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.ins
 import 
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.insert.MultiTableInsertIntoClauseAssert;
 import 
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.insert.OnDuplicateKeyColumnsAssert;
 import 
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.output.OutputClauseAssert;
+import 
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.parameter.ParameterMarkerAssert;
 import 
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.returning.ReturningClauseAssert;
 import 
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.set.SetClauseAssert;
 import 
org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.table.TableAssert;
@@ -119,6 +120,7 @@ public final class InsertStatementAssert {
             assertFalse(actual.getInsertSelect().isPresent(), 
assertContext.getText("Actual insert select segment should not exist."));
         } else {
             assertTrue(actual.getInsertSelect().isPresent(), 
assertContext.getText("Actual insert select segment should exist."));
+            ParameterMarkerAssert.assertCount(assertContext, 
actual.getInsertSelect().get().getSelect().getParameterCount(), 
expected.getSelectTestCase().getParameters().size());
             SelectStatementAssert.assertIs(assertContext, 
actual.getInsertSelect().get().getSelect(), expected.getSelectTestCase());
         }
     }
diff --git a/test/it/parser/src/main/resources/case/dml/insert.xml 
b/test/it/parser/src/main/resources/case/dml/insert.xml
index a7b8ca3c729..34c2b6b6107 100644
--- a/test/it/parser/src/main/resources/case/dml/insert.xml
+++ b/test/it/parser/src/main/resources/case/dml/insert.xml
@@ -1493,7 +1493,7 @@
             <column name="user_id" start-index="31" stop-index="37" />
             <column name="status" start-index="40" stop-index="45" />
         </columns>
-        <select>
+        <select parameters="100">
             <from>
                 <simple-table name="t_order" start-index="86" stop-index="92" 
/>
             </from>
@@ -1522,7 +1522,7 @@
     <insert sql-case-id="insert_select_without_columns" parameters="100">
         <table name="t_order" start-index="12" stop-index="18" />
         <columns start-index="19" stop-index="19" />
-        <select>
+        <select parameters="100">
             <from>
                 <simple-table name="t_order" start-index="58" stop-index="64" 
/>
             </from>
@@ -1557,7 +1557,7 @@
             <column name="status" start-index="53" stop-index="58" />
             <column name="creation_date" start-index="61" stop-index="73" />
         </columns>
-        <select>
+        <select parameters="100">
             <from>
                 <simple-table name="t_order_item" start-index="139" 
stop-index="150" />
             </from>
@@ -1593,7 +1593,7 @@
             <column name="status" start-index="44" stop-index="49" />
             <column name="creation_date" start-index="52" stop-index="64" />
         </columns>
-        <select>
+        <select parameters="100">
             <from>
                 <simple-table name="t_order_item" start-index="121" 
stop-index="132" />
             </from>
@@ -1627,7 +1627,7 @@
             <column name="user_id" start-index="30" stop-index="36" />
             <column name="status" start-index="39" stop-index="44" />
         </columns>
-        <select>
+        <select parameters="100">
             <from>
                 <simple-table name="t_order" start-index="85" stop-index="91" 
/>
             </from>
diff --git a/test/it/parser/src/main/resources/case/dml/replace.xml 
b/test/it/parser/src/main/resources/case/dml/replace.xml
index e3dccb901e7..415a09c21c3 100644
--- a/test/it/parser/src/main/resources/case/dml/replace.xml
+++ b/test/it/parser/src/main/resources/case/dml/replace.xml
@@ -867,7 +867,7 @@
             <column name="user_id" start-index="32" stop-index="38" />
             <column name="status" start-index="41" stop-index="46" />
         </columns>
-        <select>
+        <select parameters="100">
             <from>
                 <simple-table name="t_order" start-index="87" stop-index="93" 
/>
             </from>
@@ -896,7 +896,7 @@
     <insert sql-case-id="replace_select_without_columns" parameters="100">
         <table name="t_order" start-index="13" stop-index="19" />
         <columns start-index="20" stop-index="20" />
-        <select>
+        <select parameters="100">
             <from>
                 <simple-table name="t_order" start-index="59" stop-index="65" 
/>
             </from>
@@ -931,7 +931,7 @@
             <column name="status" start-index="54" stop-index="59" />
             <column name="creation_date" start-index="62" stop-index="74" />
         </columns>
-        <select>
+        <select parameters="100">
             <from>
                 <simple-table name="t_order_item" start-index="140" 
stop-index="151" />
             </from>
@@ -967,7 +967,7 @@
             <column name="status" start-index="45" stop-index="50" />
             <column name="creation_date" start-index="53" stop-index="65" />
         </columns>
-        <select>
+        <select parameters="100">
             <from>
                 <simple-table name="t_order_item" start-index="122" 
stop-index="133" />
             </from>

Reply via email to