This is an automated email from the ASF dual-hosted git repository.

chengzhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 4e850741132 Refactor select combine statement parse result to 
SubquerySegment (#30693)
4e850741132 is described below

commit 4e8507411320b097b33067051d655afe827985e4
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Fri Mar 29 16:52:16 2024 +0800

    Refactor select combine statement parse result to SubquerySegment (#30693)
    
    * Refactor select combine statement parse result to SubquerySegment
    
    * setSubqueryType when CombineSegment bind
---
 .../segment/combine/CombineSegmentBinder.java      | 19 +++++++++++++-----
 .../statement/select/SelectStatementConverter.java |  2 +-
 .../visitor/statement/MySQLStatementVisitor.java   | 23 ++++++++++++----------
 .../statement/OpenGaussStatementVisitor.java       |  8 ++++++--
 .../statement/type/OracleDMLStatementVisitor.java  |  7 ++++++-
 .../statement/PostgreSQLStatementVisitor.java      |  8 ++++++--
 .../sql/common/extractor/TableExtractor.java       |  4 ++--
 .../common/segment/dml/combine/CombineSegment.java |  6 +++---
 .../parser/sql/common/util/ColumnExtractor.java    |  2 +-
 .../sql/common/util/SubqueryExtractUtils.java      |  4 ++--
 .../sql/common/extractor/TableExtractorTest.java   | 12 ++++++++---
 .../sql/common/util/SubqueryExtractUtilsTest.java  | 12 +++++------
 .../statement/dml/impl/SelectStatementAssert.java  |  4 ++--
 13 files changed, 71 insertions(+), 40 deletions(-)

diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/combine/CombineSegmentBinder.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/combine/CombineSegmentBinder.java
index b305daf2e3e..e738508e1bd 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/combine/CombineSegmentBinder.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/combine/CombineSegmentBinder.java
@@ -19,12 +19,16 @@ package 
org.apache.shardingsphere.infra.binder.segment.combine;
 
 import lombok.AccessLevel;
 import lombok.NoArgsConstructor;
+import 
org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
 import 
org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
 import 
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementBinder;
 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
 
+import java.util.Map;
+
 /**
  * Combine segment binder.
  */
@@ -41,10 +45,15 @@ public final class CombineSegmentBinder {
     public static CombineSegment bind(final CombineSegment segment, final 
SQLStatementBinderContext statementBinderContext) {
         ShardingSphereMetaData metaData = statementBinderContext.getMetaData();
         String defaultDatabaseName = 
statementBinderContext.getDefaultDatabaseName();
-        SelectStatement boundedLeftSelect = new 
SelectStatementBinder().bindWithExternalTableContexts(segment.getLeft(), 
metaData, defaultDatabaseName,
-                statementBinderContext.getExternalTableBinderContexts());
-        SelectStatement boundedRightSelect = new 
SelectStatementBinder().bindWithExternalTableContexts(segment.getRight(), 
metaData, defaultDatabaseName,
-                statementBinderContext.getExternalTableBinderContexts());
-        return new CombineSegment(segment.getStartIndex(), 
segment.getStopIndex(), boundedLeftSelect, segment.getCombineType(), 
boundedRightSelect);
+        Map<String, TableSegmentBinderContext> externalTableBinderContexts = 
statementBinderContext.getExternalTableBinderContexts();
+        SelectStatement boundedLeftSelect = new 
SelectStatementBinder().bindWithExternalTableContexts(segment.getLeft().getSelect(),
 metaData, defaultDatabaseName, externalTableBinderContexts);
+        SelectStatement boundedRightSelect = new 
SelectStatementBinder().bindWithExternalTableContexts(segment.getRight().getSelect(),
 metaData, defaultDatabaseName, externalTableBinderContexts);
+        SubquerySegment boundedLeft = new 
SubquerySegment(segment.getLeft().getStartIndex(), 
segment.getLeft().getStopIndex(), segment.getLeft().getText());
+        boundedLeft.setSelect(boundedLeftSelect);
+        boundedLeft.setSubqueryType(segment.getLeft().getSubqueryType());
+        SubquerySegment boundedRight = new 
SubquerySegment(segment.getRight().getStartIndex(), 
segment.getRight().getStopIndex(), segment.getRight().getText());
+        boundedRight.setSelect(boundedRightSelect);
+        boundedRight.setSubqueryType(segment.getRight().getSubqueryType());
+        return new CombineSegment(segment.getStartIndex(), 
segment.getStopIndex(), boundedLeft, segment.getCombineType(), boundedRight);
     }
 }
diff --git 
a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java
 
b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java
index 9fd87e15c02..8577dfab1b2 100644
--- 
a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java
+++ 
b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java
@@ -82,7 +82,7 @@ public final class SelectStatementConverter implements 
SQLStatementConverter<Sel
         if (selectStatement.getCombine().isPresent()) {
             CombineSegment combineSegment = selectStatement.getCombine().get();
             return new 
SqlBasicCall(CombineOperatorConverter.convert(combineSegment.getCombineType()),
-                    Arrays.asList(convert(combineSegment.getLeft()), 
convert(combineSegment.getRight())), SqlParserPos.ZERO);
+                    
Arrays.asList(convert(combineSegment.getLeft().getSelect()), 
convert(combineSegment.getRight().getSelect())), SqlParserPos.ZERO);
         }
         return sqlNode;
     }
diff --git 
a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
 
b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
index 632f4e58a2d..218102adcc1 100644
--- 
a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
+++ 
b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
@@ -766,26 +766,28 @@ public abstract class MySQLStatementVisitor extends 
MySQLStatementBaseVisitor<AS
         }
         if (null != ctx.queryExpressionBody()) {
             MySQLSelectStatement result = new MySQLSelectStatement();
-            MySQLSelectStatement left = (MySQLSelectStatement) 
visit(ctx.queryExpressionBody());
-            result.setProjections(left.getProjections());
-            left.getFrom().ifPresent(result::setFrom);
-            left.getTable().ifPresent(result::setTable);
+            SubquerySegment left = new 
SubquerySegment(ctx.queryExpressionBody().start.getStartIndex(), 
ctx.queryExpressionBody().stop.getStopIndex(),
+                    (MySQLSelectStatement) visit(ctx.queryExpressionBody()), 
getOriginalText(ctx.queryExpressionBody()));
+            result.setProjections(left.getSelect().getProjections());
+            left.getSelect().getFrom().ifPresent(result::setFrom);
+            ((MySQLSelectStatement) 
left.getSelect()).getTable().ifPresent(result::setTable);
             result.setCombine(createCombineSegment(ctx.combineClause(), left));
             return result;
         }
         if (null != ctx.queryExpressionParens()) {
             MySQLSelectStatement result = new MySQLSelectStatement();
-            MySQLSelectStatement left = (MySQLSelectStatement) 
visit(ctx.queryExpressionParens());
-            result.setProjections(left.getProjections());
-            left.getFrom().ifPresent(result::setFrom);
-            left.getTable().ifPresent(result::setTable);
+            SubquerySegment left = new 
SubquerySegment(ctx.queryExpressionParens().start.getStartIndex(), 
ctx.queryExpressionParens().stop.getStopIndex(),
+                    (MySQLSelectStatement) visit(ctx.queryExpressionParens()), 
getOriginalText(ctx.queryExpressionParens()));
+            result.setProjections(left.getSelect().getProjections());
+            left.getSelect().getFrom().ifPresent(result::setFrom);
+            ((MySQLSelectStatement) 
left.getSelect()).getTable().ifPresent(result::setTable);
             result.setCombine(createCombineSegment(ctx.combineClause(), left));
             return result;
         }
         return visit(ctx.queryExpressionParens());
     }
     
-    private CombineSegment createCombineSegment(final CombineClauseContext 
ctx, final MySQLSelectStatement left) {
+    private CombineSegment createCombineSegment(final CombineClauseContext 
ctx, final SubquerySegment left) {
         CombineType combineType;
         if (null != ctx.EXCEPT()) {
             combineType = CombineType.EXCEPT;
@@ -794,7 +796,8 @@ public abstract class MySQLStatementVisitor extends 
MySQLStatementBaseVisitor<AS
         } else {
             combineType = null == ctx.combineOption() || null == 
ctx.combineOption().ALL() ? CombineType.UNION : CombineType.UNION_ALL;
         }
-        MySQLSelectStatement right = null == ctx.queryPrimary() ? 
(MySQLSelectStatement) visit(ctx.queryExpressionParens()) : 
(MySQLSelectStatement) visit(ctx.queryPrimary());
+        ParserRuleContext ruleContext = null == ctx.queryPrimary() ? 
ctx.queryExpressionParens() : ctx.queryPrimary();
+        SubquerySegment right = new 
SubquerySegment(ruleContext.start.getStartIndex(), 
ruleContext.stop.getStopIndex(), (MySQLSelectStatement) visit(ruleContext), 
getOriginalText(ruleContext));
         return new CombineSegment(ctx.getStart().getStartIndex(), 
ctx.getStop().getStopIndex(), left, combineType, right);
     }
     
diff --git 
a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
 
b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
index 82118214a14..9aabbd67c21 100644
--- 
a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
+++ 
b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
@@ -969,14 +969,18 @@ public abstract class OpenGaussStatementVisitor extends 
OpenGaussStatementBaseVi
             OpenGaussSelectStatement left = (OpenGaussSelectStatement) 
visit(ctx.selectClauseN(0));
             result.setProjections(left.getProjections());
             left.getFrom().ifPresent(result::setFrom);
-            CombineSegment combineSegment = new CombineSegment(((TerminalNode) 
ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(), 
left, getCombineType(ctx),
-                    (OpenGaussSelectStatement) visit(ctx.selectClauseN(1)));
+            CombineSegment combineSegment = new CombineSegment(((TerminalNode) 
ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(),
+                    createSubquerySegment(ctx.selectClauseN(0), left), 
getCombineType(ctx), createSubquerySegment(ctx.selectClauseN(1), 
(OpenGaussSelectStatement) visit(ctx.selectClauseN(1))));
             result.setCombine(combineSegment);
             return result;
         }
         return visit(ctx.selectWithParens());
     }
     
+    private SubquerySegment createSubquerySegment(final SelectClauseNContext 
ctx, final OpenGaussSelectStatement selectStatement) {
+        return new SubquerySegment(ctx.start.getStartIndex(), 
ctx.stop.getStopIndex(), selectStatement, getOriginalText(ctx));
+    }
+    
     private CombineType getCombineType(final SelectClauseNContext ctx) {
         boolean isDistinct = null == ctx.allOrDistinct() || null != 
ctx.allOrDistinct().DISTINCT();
         if (null != ctx.UNION()) {
diff --git 
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
 
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
index b86a9bbf0f6..ba12a7ed06d 100644
--- 
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
+++ 
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
@@ -615,7 +615,12 @@ public final class OracleDMLStatementVisitor extends 
OracleStatementVisitor impl
         } else {
             combineType = CombineType.MINUS;
         }
-        result.setCombine(new CombineSegment(ctx.getStart().getStartIndex(), 
ctx.getStop().getStopIndex(), left, combineType, (OracleSelectStatement) 
visit(ctx.selectSubquery(1))));
+        result.setCombine(new CombineSegment(ctx.getStart().getStartIndex(), 
ctx.getStop().getStopIndex(), createSubquerySegment(ctx.selectSubquery(0), 
left), combineType,
+                createSubquerySegment(ctx.selectSubquery(1), 
(OracleSelectStatement) visit(ctx.selectSubquery(1)))));
+    }
+    
+    private SubquerySegment createSubquerySegment(final SelectSubqueryContext 
ctx, final OracleSelectStatement selectStatement) {
+        return new SubquerySegment(ctx.start.getStartIndex(), 
ctx.stop.getStopIndex(), selectStatement, getOriginalText(ctx));
     }
     
     @Override
diff --git 
a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
 
b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
index a12503d0da5..9871be259bd 100644
--- 
a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
+++ 
b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
@@ -939,14 +939,18 @@ public abstract class PostgreSQLStatementVisitor extends 
PostgreSQLStatementPars
             PostgreSQLSelectStatement left = (PostgreSQLSelectStatement) 
visit(ctx.selectClauseN(0));
             result.setProjections(left.getProjections());
             left.getFrom().ifPresent(result::setFrom);
-            CombineSegment combineSegment = new CombineSegment(((TerminalNode) 
ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(), 
left, getCombineType(ctx),
-                    (PostgreSQLSelectStatement) visit(ctx.selectClauseN(1)));
+            CombineSegment combineSegment = new CombineSegment(((TerminalNode) 
ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(),
+                    createSubquerySegment(ctx.selectClauseN(0), left), 
getCombineType(ctx), createSubquerySegment(ctx.selectClauseN(1), 
(PostgreSQLSelectStatement) visit(ctx.selectClauseN(1))));
             result.setCombine(combineSegment);
             return result;
         }
         return visit(ctx.selectWithParens());
     }
     
+    private SubquerySegment createSubquerySegment(final SelectClauseNContext 
ctx, final PostgreSQLSelectStatement selectStatement) {
+        return new SubquerySegment(ctx.start.getStartIndex(), 
ctx.stop.getStopIndex(), selectStatement, getOriginalText(ctx));
+    }
+    
     private CombineType getCombineType(final SelectClauseNContext ctx) {
         boolean isDistinct = null == ctx.allOrDistinct() || null != 
ctx.allOrDistinct().DISTINCT();
         if (null != ctx.UNION()) {
diff --git 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
index 89e88b8f458..49871eb2690 100644
--- 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
+++ 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
@@ -80,8 +80,8 @@ public final class TableExtractor {
     public void extractTablesFromSelect(final SelectStatement selectStatement) 
{
         if (selectStatement.getCombine().isPresent()) {
             CombineSegment combineSegment = selectStatement.getCombine().get();
-            extractTablesFromSelect(combineSegment.getLeft());
-            extractTablesFromSelect(combineSegment.getRight());
+            extractTablesFromSelect(combineSegment.getLeft().getSelect());
+            extractTablesFromSelect(combineSegment.getRight().getSelect());
         }
         if (selectStatement.getFrom().isPresent() && 
!selectStatement.getCombine().isPresent()) {
             extractTablesFromTableSegment(selectStatement.getFrom().get());
diff --git 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/combine/CombineSegment.java
 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/combine/CombineSegment.java
index 79371d4f8e5..9d51eaa0396 100644
--- 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/combine/CombineSegment.java
+++ 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/combine/CombineSegment.java
@@ -21,7 +21,7 @@ import lombok.Getter;
 import lombok.RequiredArgsConstructor;
 import org.apache.shardingsphere.sql.parser.sql.common.enums.CombineType;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
-import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
 
 /**
  * Combine segment.
@@ -34,9 +34,9 @@ public final class CombineSegment implements SQLSegment {
     
     private final int stopIndex;
     
-    private final SelectStatement left;
+    private final SubquerySegment left;
     
     private final CombineType combineType;
     
-    private final SelectStatement right;
+    private final SubquerySegment right;
 }
diff --git 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java
 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java
index 19af0d652e2..db9175d5d06 100644
--- 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java
+++ 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java
@@ -131,7 +131,7 @@ public final class ColumnExtractor {
         statement.getGroupBy().ifPresent(optional -> 
extractFromGroupBy(columnSegments, optional, containsSubQuery));
         statement.getHaving().ifPresent(optional -> 
extractFromHaving(columnSegments, optional, containsSubQuery));
         statement.getOrderBy().ifPresent(optional -> 
extractFromOrderBy(columnSegments, optional, containsSubQuery));
-        statement.getCombine().ifPresent(optional -> 
extractFromSelectStatement(columnSegments, optional.getRight(), 
containsSubQuery));
+        statement.getCombine().ifPresent(optional -> 
extractFromSelectStatement(columnSegments, optional.getRight().getSelect(), 
containsSubQuery));
     }
     
     /**
diff --git 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java
 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java
index c3be2352d33..0df855ff77d 100644
--- 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java
+++ 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java
@@ -161,7 +161,7 @@ public final class SubqueryExtractUtils {
     }
     
     private static void extractSubquerySegmentsFromCombine(final 
List<SubquerySegment> result, final CombineSegment combineSegment) {
-        extractSubquerySegments(result, combineSegment.getLeft());
-        extractSubquerySegments(result, combineSegment.getRight());
+        extractSubquerySegments(result, combineSegment.getLeft().getSelect());
+        extractSubquerySegments(result, combineSegment.getRight().getSelect());
     }
 }
diff --git 
a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
 
b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
index 7ddae54892a..fe607fd1e4c 100644
--- 
a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
+++ 
b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
@@ -151,7 +151,9 @@ class TableExtractorTest {
     @Test
     void assertExtractTablesFromCombineSegment() {
         MySQLSelectStatement selectStatement = 
createSelectStatement("t_order");
-        selectStatement.setCombine(new CombineSegment(0, 0, 
createSelectStatement("t_order"), CombineType.UNION, 
createSelectStatement("t_order_item")));
+        SubquerySegment left = new SubquerySegment(0, 0, 
createSelectStatement("t_order"), "");
+        SubquerySegment right = new SubquerySegment(0, 0, 
createSelectStatement("t_order_item"), "");
+        selectStatement.setCombine(new CombineSegment(0, 0, left, 
CombineType.UNION, right));
         tableExtractor.extractTablesFromSelect(selectStatement);
         Collection<SimpleTableSegment> actual = 
tableExtractor.getRewriteTables();
         assertThat(actual.size(), is(2));
@@ -172,7 +174,9 @@ class TableExtractorTest {
     @Test
     void assertExtractTablesFromCombineSegmentWithColumnProjection() {
         MySQLSelectStatement selectStatement = 
createSelectStatementWithColumnProjection("t_order");
-        selectStatement.setCombine(new CombineSegment(0, 0, 
createSelectStatementWithColumnProjection("t_order"), CombineType.UNION, 
createSelectStatementWithColumnProjection("t_order_item")));
+        SubquerySegment left = new SubquerySegment(0, 0, 
createSelectStatementWithColumnProjection("t_order"), "");
+        SubquerySegment right = new SubquerySegment(0, 0, 
createSelectStatementWithColumnProjection("t_order_item"), "");
+        selectStatement.setCombine(new CombineSegment(0, 0, left, 
CombineType.UNION, right));
         tableExtractor.extractTablesFromSelect(selectStatement);
         Collection<SimpleTableSegment> actual = 
tableExtractor.getRewriteTables();
         assertThat(actual.size(), is(2));
@@ -197,7 +201,9 @@ class TableExtractorTest {
     @Test
     void assertExtractTablesFromCombineWithSubQueryProjection() {
         MySQLSelectStatement selectStatement = 
createSelectStatementWithSubQueryProjection("t_order");
-        selectStatement.setCombine(new CombineSegment(0, 0, 
createSelectStatementWithSubQueryProjection("t_order"), CombineType.UNION, 
createSelectStatementWithSubQueryProjection("t_order_item")));
+        SubquerySegment left = new SubquerySegment(0, 0, 
createSelectStatementWithSubQueryProjection("t_order"), "");
+        SubquerySegment right = new SubquerySegment(0, 0, 
createSelectStatementWithSubQueryProjection("t_order_item"), "");
+        selectStatement.setCombine(new CombineSegment(0, 0, left, 
CombineType.UNION, right));
         tableExtractor.extractTablesFromSelect(selectStatement);
         Collection<SimpleTableSegment> actual = 
tableExtractor.getRewriteTables();
         assertThat(actual.size(), is(2));
diff --git 
a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java
 
b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java
index 425f52e370a..7b9fadcd595 100644
--- 
a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java
+++ 
b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java
@@ -171,17 +171,17 @@ class SubqueryExtractUtilsTest {
     @Test
     void assertGetSubquerySegmentsWithCombineSegment() {
         SelectStatement selectStatement = new MySQLSelectStatement();
-        selectStatement.setCombine(new CombineSegment(0, 0, new 
MySQLSelectStatement(), CombineType.UNION, 
createSelectStatementForCombineSegment()));
+        SubquerySegment left = new SubquerySegment(0, 0, new 
MySQLSelectStatement(), "");
+        selectStatement.setCombine(new CombineSegment(0, 0, left, 
CombineType.UNION, createSelectStatementForCombineSegment()));
         Collection<SubquerySegment> actual = 
SubqueryExtractUtils.getSubquerySegments(selectStatement);
         assertThat(actual.size(), is(1));
     }
     
-    private SelectStatement createSelectStatementForCombineSegment() {
-        SelectStatement result = new MySQLSelectStatement();
+    private SubquerySegment createSelectStatementForCombineSegment() {
+        SelectStatement selectStatement = new MySQLSelectStatement();
         ExpressionSegment left = new ColumnSegment(0, 0, new 
IdentifierValue("order_id"));
-        result.setWhere(new WhereSegment(0, 0, new InExpression(0, 0,
-                left, new SubqueryExpressionSegment(new SubquerySegment(0, 0, 
new MySQLSelectStatement(), "")), false)));
-        return result;
+        selectStatement.setWhere(new WhereSegment(0, 0, new InExpression(0, 0, 
left, new SubqueryExpressionSegment(new SubquerySegment(0, 0, new 
MySQLSelectStatement(), "")), false)));
+        return new SubquerySegment(0, 0, selectStatement, "");
     }
     
     @Test
diff --git 
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java
 
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java
index 9d15e2c2843..e99bfe26ec7 100644
--- 
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java
+++ 
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java
@@ -192,8 +192,8 @@ public final class SelectStatementAssert {
             assertTrue(combineSegment.isPresent(), 
assertContext.getText("Actual combine segment should exist."));
             assertThat(assertContext.getText("Combine type assertion error: 
"), combineSegment.get().getCombineType().name(), 
is(expected.getCombineClause().getCombineType()));
             SQLSegmentAssert.assertIs(assertContext, combineSegment.get(), 
expected.getCombineClause());
-            assertIs(assertContext, combineSegment.get().getLeft(), 
expected.getCombineClause().getLeft());
-            assertIs(assertContext, combineSegment.get().getRight(), 
expected.getCombineClause().getRight());
+            assertIs(assertContext, 
combineSegment.get().getLeft().getSelect(), 
expected.getCombineClause().getLeft());
+            assertIs(assertContext, 
combineSegment.get().getRight().getSelect(), 
expected.getCombineClause().getRight());
         }
     }
     

Reply via email to