This is an automated email from the ASF dual-hosted git repository.
chengzhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 4e850741132 Refactor select combine statement parse result to
SubquerySegment (#30693)
4e850741132 is described below
commit 4e8507411320b097b33067051d655afe827985e4
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Fri Mar 29 16:52:16 2024 +0800
Refactor select combine statement parse result to SubquerySegment (#30693)
* Refactor select combine statement parse result to SubquerySegment
* setSubqueryType when CombineSegment bind
---
.../segment/combine/CombineSegmentBinder.java | 19 +++++++++++++-----
.../statement/select/SelectStatementConverter.java | 2 +-
.../visitor/statement/MySQLStatementVisitor.java | 23 ++++++++++++----------
.../statement/OpenGaussStatementVisitor.java | 8 ++++++--
.../statement/type/OracleDMLStatementVisitor.java | 7 ++++++-
.../statement/PostgreSQLStatementVisitor.java | 8 ++++++--
.../sql/common/extractor/TableExtractor.java | 4 ++--
.../common/segment/dml/combine/CombineSegment.java | 6 +++---
.../parser/sql/common/util/ColumnExtractor.java | 2 +-
.../sql/common/util/SubqueryExtractUtils.java | 4 ++--
.../sql/common/extractor/TableExtractorTest.java | 12 ++++++++---
.../sql/common/util/SubqueryExtractUtilsTest.java | 12 +++++------
.../statement/dml/impl/SelectStatementAssert.java | 4 ++--
13 files changed, 71 insertions(+), 40 deletions(-)
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/combine/CombineSegmentBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/combine/CombineSegmentBinder.java
index b305daf2e3e..e738508e1bd 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/combine/CombineSegmentBinder.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/combine/CombineSegmentBinder.java
@@ -19,12 +19,16 @@ package
org.apache.shardingsphere.infra.binder.segment.combine;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
+import
org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
import
org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementBinder;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
+import java.util.Map;
+
/**
* Combine segment binder.
*/
@@ -41,10 +45,15 @@ public final class CombineSegmentBinder {
public static CombineSegment bind(final CombineSegment segment, final
SQLStatementBinderContext statementBinderContext) {
ShardingSphereMetaData metaData = statementBinderContext.getMetaData();
String defaultDatabaseName =
statementBinderContext.getDefaultDatabaseName();
- SelectStatement boundedLeftSelect = new
SelectStatementBinder().bindWithExternalTableContexts(segment.getLeft(),
metaData, defaultDatabaseName,
- statementBinderContext.getExternalTableBinderContexts());
- SelectStatement boundedRightSelect = new
SelectStatementBinder().bindWithExternalTableContexts(segment.getRight(),
metaData, defaultDatabaseName,
- statementBinderContext.getExternalTableBinderContexts());
- return new CombineSegment(segment.getStartIndex(),
segment.getStopIndex(), boundedLeftSelect, segment.getCombineType(),
boundedRightSelect);
+ Map<String, TableSegmentBinderContext> externalTableBinderContexts =
statementBinderContext.getExternalTableBinderContexts();
+ SelectStatement boundedLeftSelect = new
SelectStatementBinder().bindWithExternalTableContexts(segment.getLeft().getSelect(),
metaData, defaultDatabaseName, externalTableBinderContexts);
+ SelectStatement boundedRightSelect = new
SelectStatementBinder().bindWithExternalTableContexts(segment.getRight().getSelect(),
metaData, defaultDatabaseName, externalTableBinderContexts);
+ SubquerySegment boundedLeft = new
SubquerySegment(segment.getLeft().getStartIndex(),
segment.getLeft().getStopIndex(), segment.getLeft().getText());
+ boundedLeft.setSelect(boundedLeftSelect);
+ boundedLeft.setSubqueryType(segment.getLeft().getSubqueryType());
+ SubquerySegment boundedRight = new
SubquerySegment(segment.getRight().getStartIndex(),
segment.getRight().getStopIndex(), segment.getRight().getText());
+ boundedRight.setSelect(boundedRightSelect);
+ boundedRight.setSubqueryType(segment.getRight().getSubqueryType());
+ return new CombineSegment(segment.getStartIndex(),
segment.getStopIndex(), boundedLeft, segment.getCombineType(), boundedRight);
}
}
diff --git
a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java
b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java
index 9fd87e15c02..8577dfab1b2 100644
---
a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java
+++
b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java
@@ -82,7 +82,7 @@ public final class SelectStatementConverter implements
SQLStatementConverter<Sel
if (selectStatement.getCombine().isPresent()) {
CombineSegment combineSegment = selectStatement.getCombine().get();
return new
SqlBasicCall(CombineOperatorConverter.convert(combineSegment.getCombineType()),
- Arrays.asList(convert(combineSegment.getLeft()),
convert(combineSegment.getRight())), SqlParserPos.ZERO);
+
Arrays.asList(convert(combineSegment.getLeft().getSelect()),
convert(combineSegment.getRight().getSelect())), SqlParserPos.ZERO);
}
return sqlNode;
}
diff --git
a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
index 632f4e58a2d..218102adcc1 100644
---
a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
+++
b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java
@@ -766,26 +766,28 @@ public abstract class MySQLStatementVisitor extends
MySQLStatementBaseVisitor<AS
}
if (null != ctx.queryExpressionBody()) {
MySQLSelectStatement result = new MySQLSelectStatement();
- MySQLSelectStatement left = (MySQLSelectStatement)
visit(ctx.queryExpressionBody());
- result.setProjections(left.getProjections());
- left.getFrom().ifPresent(result::setFrom);
- left.getTable().ifPresent(result::setTable);
+ SubquerySegment left = new
SubquerySegment(ctx.queryExpressionBody().start.getStartIndex(),
ctx.queryExpressionBody().stop.getStopIndex(),
+ (MySQLSelectStatement) visit(ctx.queryExpressionBody()),
getOriginalText(ctx.queryExpressionBody()));
+ result.setProjections(left.getSelect().getProjections());
+ left.getSelect().getFrom().ifPresent(result::setFrom);
+ ((MySQLSelectStatement)
left.getSelect()).getTable().ifPresent(result::setTable);
result.setCombine(createCombineSegment(ctx.combineClause(), left));
return result;
}
if (null != ctx.queryExpressionParens()) {
MySQLSelectStatement result = new MySQLSelectStatement();
- MySQLSelectStatement left = (MySQLSelectStatement)
visit(ctx.queryExpressionParens());
- result.setProjections(left.getProjections());
- left.getFrom().ifPresent(result::setFrom);
- left.getTable().ifPresent(result::setTable);
+ SubquerySegment left = new
SubquerySegment(ctx.queryExpressionParens().start.getStartIndex(),
ctx.queryExpressionParens().stop.getStopIndex(),
+ (MySQLSelectStatement) visit(ctx.queryExpressionParens()),
getOriginalText(ctx.queryExpressionParens()));
+ result.setProjections(left.getSelect().getProjections());
+ left.getSelect().getFrom().ifPresent(result::setFrom);
+ ((MySQLSelectStatement)
left.getSelect()).getTable().ifPresent(result::setTable);
result.setCombine(createCombineSegment(ctx.combineClause(), left));
return result;
}
return visit(ctx.queryExpressionParens());
}
- private CombineSegment createCombineSegment(final CombineClauseContext
ctx, final MySQLSelectStatement left) {
+ private CombineSegment createCombineSegment(final CombineClauseContext
ctx, final SubquerySegment left) {
CombineType combineType;
if (null != ctx.EXCEPT()) {
combineType = CombineType.EXCEPT;
@@ -794,7 +796,8 @@ public abstract class MySQLStatementVisitor extends
MySQLStatementBaseVisitor<AS
} else {
combineType = null == ctx.combineOption() || null ==
ctx.combineOption().ALL() ? CombineType.UNION : CombineType.UNION_ALL;
}
- MySQLSelectStatement right = null == ctx.queryPrimary() ?
(MySQLSelectStatement) visit(ctx.queryExpressionParens()) :
(MySQLSelectStatement) visit(ctx.queryPrimary());
+ ParserRuleContext ruleContext = null == ctx.queryPrimary() ?
ctx.queryExpressionParens() : ctx.queryPrimary();
+ SubquerySegment right = new
SubquerySegment(ruleContext.start.getStartIndex(),
ruleContext.stop.getStopIndex(), (MySQLSelectStatement) visit(ruleContext),
getOriginalText(ruleContext));
return new CombineSegment(ctx.getStart().getStartIndex(),
ctx.getStop().getStopIndex(), left, combineType, right);
}
diff --git
a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
index 82118214a14..9aabbd67c21 100644
---
a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
+++
b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java
@@ -969,14 +969,18 @@ public abstract class OpenGaussStatementVisitor extends
OpenGaussStatementBaseVi
OpenGaussSelectStatement left = (OpenGaussSelectStatement)
visit(ctx.selectClauseN(0));
result.setProjections(left.getProjections());
left.getFrom().ifPresent(result::setFrom);
- CombineSegment combineSegment = new CombineSegment(((TerminalNode)
ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(),
left, getCombineType(ctx),
- (OpenGaussSelectStatement) visit(ctx.selectClauseN(1)));
+ CombineSegment combineSegment = new CombineSegment(((TerminalNode)
ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(),
+ createSubquerySegment(ctx.selectClauseN(0), left),
getCombineType(ctx), createSubquerySegment(ctx.selectClauseN(1),
(OpenGaussSelectStatement) visit(ctx.selectClauseN(1))));
result.setCombine(combineSegment);
return result;
}
return visit(ctx.selectWithParens());
}
+ private SubquerySegment createSubquerySegment(final SelectClauseNContext
ctx, final OpenGaussSelectStatement selectStatement) {
+ return new SubquerySegment(ctx.start.getStartIndex(),
ctx.stop.getStopIndex(), selectStatement, getOriginalText(ctx));
+ }
+
private CombineType getCombineType(final SelectClauseNContext ctx) {
boolean isDistinct = null == ctx.allOrDistinct() || null !=
ctx.allOrDistinct().DISTINCT();
if (null != ctx.UNION()) {
diff --git
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
index b86a9bbf0f6..ba12a7ed06d 100644
---
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
+++
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
@@ -615,7 +615,12 @@ public final class OracleDMLStatementVisitor extends
OracleStatementVisitor impl
} else {
combineType = CombineType.MINUS;
}
- result.setCombine(new CombineSegment(ctx.getStart().getStartIndex(),
ctx.getStop().getStopIndex(), left, combineType, (OracleSelectStatement)
visit(ctx.selectSubquery(1))));
+ result.setCombine(new CombineSegment(ctx.getStart().getStartIndex(),
ctx.getStop().getStopIndex(), createSubquerySegment(ctx.selectSubquery(0),
left), combineType,
+ createSubquerySegment(ctx.selectSubquery(1),
(OracleSelectStatement) visit(ctx.selectSubquery(1)))));
+ }
+
+ private SubquerySegment createSubquerySegment(final SelectSubqueryContext
ctx, final OracleSelectStatement selectStatement) {
+ return new SubquerySegment(ctx.start.getStartIndex(),
ctx.stop.getStopIndex(), selectStatement, getOriginalText(ctx));
}
@Override
diff --git
a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
index a12503d0da5..9871be259bd 100644
---
a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
+++
b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java
@@ -939,14 +939,18 @@ public abstract class PostgreSQLStatementVisitor extends
PostgreSQLStatementPars
PostgreSQLSelectStatement left = (PostgreSQLSelectStatement)
visit(ctx.selectClauseN(0));
result.setProjections(left.getProjections());
left.getFrom().ifPresent(result::setFrom);
- CombineSegment combineSegment = new CombineSegment(((TerminalNode)
ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(),
left, getCombineType(ctx),
- (PostgreSQLSelectStatement) visit(ctx.selectClauseN(1)));
+ CombineSegment combineSegment = new CombineSegment(((TerminalNode)
ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(),
+ createSubquerySegment(ctx.selectClauseN(0), left),
getCombineType(ctx), createSubquerySegment(ctx.selectClauseN(1),
(PostgreSQLSelectStatement) visit(ctx.selectClauseN(1))));
result.setCombine(combineSegment);
return result;
}
return visit(ctx.selectWithParens());
}
+ private SubquerySegment createSubquerySegment(final SelectClauseNContext
ctx, final PostgreSQLSelectStatement selectStatement) {
+ return new SubquerySegment(ctx.start.getStartIndex(),
ctx.stop.getStopIndex(), selectStatement, getOriginalText(ctx));
+ }
+
private CombineType getCombineType(final SelectClauseNContext ctx) {
boolean isDistinct = null == ctx.allOrDistinct() || null !=
ctx.allOrDistinct().DISTINCT();
if (null != ctx.UNION()) {
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
index 89e88b8f458..49871eb2690 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
@@ -80,8 +80,8 @@ public final class TableExtractor {
public void extractTablesFromSelect(final SelectStatement selectStatement)
{
if (selectStatement.getCombine().isPresent()) {
CombineSegment combineSegment = selectStatement.getCombine().get();
- extractTablesFromSelect(combineSegment.getLeft());
- extractTablesFromSelect(combineSegment.getRight());
+ extractTablesFromSelect(combineSegment.getLeft().getSelect());
+ extractTablesFromSelect(combineSegment.getRight().getSelect());
}
if (selectStatement.getFrom().isPresent() &&
!selectStatement.getCombine().isPresent()) {
extractTablesFromTableSegment(selectStatement.getFrom().get());
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/combine/CombineSegment.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/combine/CombineSegment.java
index 79371d4f8e5..9d51eaa0396 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/combine/CombineSegment.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/combine/CombineSegment.java
@@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.common.enums.CombineType;
import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
-import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
/**
* Combine segment.
@@ -34,9 +34,9 @@ public final class CombineSegment implements SQLSegment {
private final int stopIndex;
- private final SelectStatement left;
+ private final SubquerySegment left;
private final CombineType combineType;
- private final SelectStatement right;
+ private final SubquerySegment right;
}
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java
index 19af0d652e2..db9175d5d06 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java
@@ -131,7 +131,7 @@ public final class ColumnExtractor {
statement.getGroupBy().ifPresent(optional ->
extractFromGroupBy(columnSegments, optional, containsSubQuery));
statement.getHaving().ifPresent(optional ->
extractFromHaving(columnSegments, optional, containsSubQuery));
statement.getOrderBy().ifPresent(optional ->
extractFromOrderBy(columnSegments, optional, containsSubQuery));
- statement.getCombine().ifPresent(optional ->
extractFromSelectStatement(columnSegments, optional.getRight(),
containsSubQuery));
+ statement.getCombine().ifPresent(optional ->
extractFromSelectStatement(columnSegments, optional.getRight().getSelect(),
containsSubQuery));
}
/**
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java
index c3be2352d33..0df855ff77d 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java
@@ -161,7 +161,7 @@ public final class SubqueryExtractUtils {
}
private static void extractSubquerySegmentsFromCombine(final
List<SubquerySegment> result, final CombineSegment combineSegment) {
- extractSubquerySegments(result, combineSegment.getLeft());
- extractSubquerySegments(result, combineSegment.getRight());
+ extractSubquerySegments(result, combineSegment.getLeft().getSelect());
+ extractSubquerySegments(result, combineSegment.getRight().getSelect());
}
}
diff --git
a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
index 7ddae54892a..fe607fd1e4c 100644
---
a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
+++
b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
@@ -151,7 +151,9 @@ class TableExtractorTest {
@Test
void assertExtractTablesFromCombineSegment() {
MySQLSelectStatement selectStatement =
createSelectStatement("t_order");
- selectStatement.setCombine(new CombineSegment(0, 0,
createSelectStatement("t_order"), CombineType.UNION,
createSelectStatement("t_order_item")));
+ SubquerySegment left = new SubquerySegment(0, 0,
createSelectStatement("t_order"), "");
+ SubquerySegment right = new SubquerySegment(0, 0,
createSelectStatement("t_order_item"), "");
+ selectStatement.setCombine(new CombineSegment(0, 0, left,
CombineType.UNION, right));
tableExtractor.extractTablesFromSelect(selectStatement);
Collection<SimpleTableSegment> actual =
tableExtractor.getRewriteTables();
assertThat(actual.size(), is(2));
@@ -172,7 +174,9 @@ class TableExtractorTest {
@Test
void assertExtractTablesFromCombineSegmentWithColumnProjection() {
MySQLSelectStatement selectStatement =
createSelectStatementWithColumnProjection("t_order");
- selectStatement.setCombine(new CombineSegment(0, 0,
createSelectStatementWithColumnProjection("t_order"), CombineType.UNION,
createSelectStatementWithColumnProjection("t_order_item")));
+ SubquerySegment left = new SubquerySegment(0, 0,
createSelectStatementWithColumnProjection("t_order"), "");
+ SubquerySegment right = new SubquerySegment(0, 0,
createSelectStatementWithColumnProjection("t_order_item"), "");
+ selectStatement.setCombine(new CombineSegment(0, 0, left,
CombineType.UNION, right));
tableExtractor.extractTablesFromSelect(selectStatement);
Collection<SimpleTableSegment> actual =
tableExtractor.getRewriteTables();
assertThat(actual.size(), is(2));
@@ -197,7 +201,9 @@ class TableExtractorTest {
@Test
void assertExtractTablesFromCombineWithSubQueryProjection() {
MySQLSelectStatement selectStatement =
createSelectStatementWithSubQueryProjection("t_order");
- selectStatement.setCombine(new CombineSegment(0, 0,
createSelectStatementWithSubQueryProjection("t_order"), CombineType.UNION,
createSelectStatementWithSubQueryProjection("t_order_item")));
+ SubquerySegment left = new SubquerySegment(0, 0,
createSelectStatementWithSubQueryProjection("t_order"), "");
+ SubquerySegment right = new SubquerySegment(0, 0,
createSelectStatementWithSubQueryProjection("t_order_item"), "");
+ selectStatement.setCombine(new CombineSegment(0, 0, left,
CombineType.UNION, right));
tableExtractor.extractTablesFromSelect(selectStatement);
Collection<SimpleTableSegment> actual =
tableExtractor.getRewriteTables();
assertThat(actual.size(), is(2));
diff --git
a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java
b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java
index 425f52e370a..7b9fadcd595 100644
---
a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java
+++
b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java
@@ -171,17 +171,17 @@ class SubqueryExtractUtilsTest {
@Test
void assertGetSubquerySegmentsWithCombineSegment() {
SelectStatement selectStatement = new MySQLSelectStatement();
- selectStatement.setCombine(new CombineSegment(0, 0, new
MySQLSelectStatement(), CombineType.UNION,
createSelectStatementForCombineSegment()));
+ SubquerySegment left = new SubquerySegment(0, 0, new
MySQLSelectStatement(), "");
+ selectStatement.setCombine(new CombineSegment(0, 0, left,
CombineType.UNION, createSelectStatementForCombineSegment()));
Collection<SubquerySegment> actual =
SubqueryExtractUtils.getSubquerySegments(selectStatement);
assertThat(actual.size(), is(1));
}
- private SelectStatement createSelectStatementForCombineSegment() {
- SelectStatement result = new MySQLSelectStatement();
+ private SubquerySegment createSelectStatementForCombineSegment() {
+ SelectStatement selectStatement = new MySQLSelectStatement();
ExpressionSegment left = new ColumnSegment(0, 0, new
IdentifierValue("order_id"));
- result.setWhere(new WhereSegment(0, 0, new InExpression(0, 0,
- left, new SubqueryExpressionSegment(new SubquerySegment(0, 0,
new MySQLSelectStatement(), "")), false)));
- return result;
+ selectStatement.setWhere(new WhereSegment(0, 0, new InExpression(0, 0,
left, new SubqueryExpressionSegment(new SubquerySegment(0, 0, new
MySQLSelectStatement(), "")), false)));
+ return new SubquerySegment(0, 0, selectStatement, "");
}
@Test
diff --git
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java
index 9d15e2c2843..e99bfe26ec7 100644
---
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java
+++
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java
@@ -192,8 +192,8 @@ public final class SelectStatementAssert {
assertTrue(combineSegment.isPresent(),
assertContext.getText("Actual combine segment should exist."));
assertThat(assertContext.getText("Combine type assertion error:
"), combineSegment.get().getCombineType().name(),
is(expected.getCombineClause().getCombineType()));
SQLSegmentAssert.assertIs(assertContext, combineSegment.get(),
expected.getCombineClause());
- assertIs(assertContext, combineSegment.get().getLeft(),
expected.getCombineClause().getLeft());
- assertIs(assertContext, combineSegment.get().getRight(),
expected.getCombineClause().getRight());
+ assertIs(assertContext,
combineSegment.get().getLeft().getSelect(),
expected.getCombineClause().getLeft());
+ assertIs(assertContext,
combineSegment.get().getRight().getSelect(),
expected.getCombineClause().getRight());
}
}