This is an automated email from the ASF dual-hosted git repository.

chengzhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new f4cbe5aeadd Support extract join table and get where segments from sub 
table join.
     new d787b262a6f Merge pull request #25459 from tuichenchuxin/dev
f4cbe5aeadd is described below

commit f4cbe5aeaddee89169c4047c6fe361f00267fe33
Author: tuichenchuxin <[email protected]>
AuthorDate: Fri May 5 11:12:13 2023 +0800

    Support extract join table and get where segments from sub table join.
---
 .../parser/sql/common/extractor/TableExtractor.java  |  4 ++++
 .../parser/sql/common/util/WhereExtractUtils.java    |  1 +
 .../sql/common/extractor/TableExtractorTest.java     | 17 +++++++++++++++++
 .../sql/common/util/WhereExtractUtilsTest.java       | 20 ++++++++++++++++++++
 4 files changed, 42 insertions(+)

diff --git 
a/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
 
b/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
index c137c0e6e3e..bf9d68ee10f 100644
--- 
a/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
+++ 
b/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java
@@ -68,6 +68,8 @@ public final class TableExtractor {
     
     private final Collection<TableSegment> tableContext = new LinkedList<>();
     
+    private final Collection<JoinTableSegment> joinTableSegments = new 
LinkedList<>();
+    
     /**
      * Extract table that should be rewritten from select statement.
      *
@@ -108,8 +110,10 @@ public final class TableExtractor {
             TableExtractor tableExtractor = new TableExtractor();
             tableExtractor.extractTablesFromSelect(((SubqueryTableSegment) 
tableSegment).getSubquery().getSelect());
             rewriteTables.addAll(tableExtractor.rewriteTables);
+            joinTableSegments.addAll(tableExtractor.joinTableSegments);
         }
         if (tableSegment instanceof JoinTableSegment) {
+            joinTableSegments.add((JoinTableSegment) tableSegment);
             extractTablesFromJoinTableSegment((JoinTableSegment) tableSegment);
         }
         if (tableSegment instanceof DeleteMultiTableSegment) {
diff --git 
a/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtils.java
 
b/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtils.java
index 0c7388467dd..964dab98f2d 100644
--- 
a/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtils.java
+++ 
b/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtils.java
@@ -73,6 +73,7 @@ public final class WhereExtractUtils {
         Collection<WhereSegment> result = new LinkedList<>();
         for (SubquerySegment each : 
SubqueryExtractUtils.getSubquerySegments(selectStatement)) {
             each.getSelect().getWhere().ifPresent(result::add);
+            result.addAll(getJoinWhereSegments(each.getSelect()));
         }
         return result;
     }
diff --git 
a/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
 
b/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
index 80b36fa1409..9b69c0d8223 100644
--- 
a/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
+++ 
b/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractorTest.java
@@ -26,6 +26,7 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.Co
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDuplicateKeyColumnsSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.AggregationProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
@@ -34,6 +35,7 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.Shorthan
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.LockSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.AliasSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.JoinTableSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.ddl.CreateTableStatement;
@@ -170,4 +172,19 @@ class TableExtractorTest {
         result.setFrom(tableSegment);
         return result;
     }
+    
+    @Test
+    void assertExtractJoinTableSegmentsFromSelect() {
+        JoinTableSegment joinTableSegment = new JoinTableSegment();
+        joinTableSegment.setLeft(new SimpleTableSegment(new 
TableNameSegment(16, 22, new IdentifierValue("t_order"))));
+        joinTableSegment.setRight(new SimpleTableSegment(new 
TableNameSegment(37, 48, new IdentifierValue("t_order_item"))));
+        joinTableSegment.setJoinType("INNER");
+        joinTableSegment.setCondition(new BinaryOperationExpression(56, 79, 
new ColumnSegment(56, 65, new IdentifierValue("order_id")),
+                new ColumnSegment(69, 79, new IdentifierValue("order_id")), 
"=", "oi.order_id = o.order_id"));
+        MySQLSelectStatement selectStatement = new MySQLSelectStatement();
+        selectStatement.setFrom(joinTableSegment);
+        tableExtractor.extractTablesFromSelect(selectStatement);
+        assertThat(tableExtractor.getJoinTableSegments().size(), is(1));
+        assertThat(tableExtractor.getJoinTableSegments().iterator().next(), 
is(joinTableSegment));
+    }
 }
diff --git 
a/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtilsTest.java
 
b/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtilsTest.java
index f5a07c67638..c0dbc3f25fd 100644
--- 
a/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtilsTest.java
+++ 
b/sql-parser/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtilsTest.java
@@ -24,6 +24,9 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.Projecti
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.SubqueryProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.JoinTableSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
@@ -71,4 +74,21 @@ class WhereExtractUtilsTest {
         WhereSegment actual = subqueryWhereSegments.iterator().next();
         assertThat(actual.getExpr(), 
is(subQuerySelectStatement.getWhere().get().getExpr()));
     }
+    
+    @Test
+    void assertGetWhereSegmentsFromSubQueryJoin() {
+        JoinTableSegment joinTableSegment = new JoinTableSegment();
+        joinTableSegment.setLeft(new SimpleTableSegment(new 
TableNameSegment(37, 39, new IdentifierValue("t_order"))));
+        joinTableSegment.setRight(new SimpleTableSegment(new 
TableNameSegment(54, 56, new IdentifierValue("t_order_item"))));
+        joinTableSegment.setJoinType("INNER");
+        joinTableSegment.setCondition(new BinaryOperationExpression(63, 83, 
new ColumnSegment(63, 71, new IdentifierValue("order_id")),
+                new ColumnSegment(75, 83, new IdentifierValue("order_id")), 
"=", "oi.order_id = o.order_id"));
+        MySQLSelectStatement subQuerySelectStatement = new 
MySQLSelectStatement();
+        subQuerySelectStatement.setFrom(joinTableSegment);
+        MySQLSelectStatement mySQLSelectStatement = new MySQLSelectStatement();
+        mySQLSelectStatement.setFrom(new SubqueryTableSegment(new 
SubquerySegment(20, 84, subQuerySelectStatement)));
+        Collection<WhereSegment> subqueryWhereSegments = 
WhereExtractUtils.getSubqueryWhereSegments(mySQLSelectStatement);
+        WhereSegment actual = subqueryWhereSegments.iterator().next();
+        assertThat(actual.getExpr(), is(((JoinTableSegment) 
subQuerySelectStatement.getFrom()).getCondition()));
+    }
 }

Reply via email to