This is an automated email from the ASF dual-hosted git repository.

panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 2cd8123  fix binding table route logic with different sharding column 
(#13000)
2cd8123 is described below

commit 2cd812306f257e51dcd4a737fa9f432cbf8ac295
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Tue Oct 12 14:02:22 2021 +0800

    fix binding table route logic with different sharding column (#13000)
---
 .../route/engine/condition/ShardingConditions.java | 34 ++++++++++++----------
 .../infra/binder/segment/table/TablesContext.java  |  4 +--
 .../sql/common/util/SubqueryExtractUtil.java       | 22 ++++++++++++--
 .../sql/common/util/SubqueryExtractUtilTest.java   | 19 ++++++++++++
 .../ShardingSQLRewriterParameterizedTest.java      |  2 ++
 .../resources/scenario/sharding/case/select.xml    |  7 ++++-
 6 files changed, 67 insertions(+), 21 deletions(-)

diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/ShardingConditions.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/ShardingConditions.java
index 6e6dab5..b1a413e 100644
--- 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/ShardingConditions.java
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/ShardingConditions.java
@@ -31,6 +31,7 @@ import 
org.apache.shardingsphere.sharding.rule.BindingTableRule;
 import org.apache.shardingsphere.sharding.rule.ShardingRule;
 import org.apache.shardingsphere.sharding.rule.TableRule;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.util.SafeNumberOperationUtil;
 
@@ -121,6 +122,9 @@ public final class ShardingConditions {
         if (selectStatements.size() > 1) {
             Map<Integer, List<ShardingCondition>> startIndexShardingConditions 
= 
conditions.stream().collect(Collectors.groupingBy(ShardingCondition::getStartIndex));
             for (SelectStatement each : selectStatements) {
+                if (each.getFrom() instanceof SubqueryTableSegment) {
+                    continue;
+                }
                 if (!each.getWhere().isPresent() || 
!startIndexShardingConditions.containsKey(each.getWhere().get().getExpr().getStartIndex()))
 {
                     return false;
                 }
@@ -164,36 +168,36 @@ public final class ShardingConditions {
             return false;
         }
         for (int i = 0; i < shardingCondition1.getValues().size(); i++) {
-            ShardingConditionValue shardingConditionValue1 = 
shardingCondition1.getValues().get(i);
-            ShardingConditionValue shardingConditionValue2 = 
shardingCondition2.getValues().get(i);
-            if (!isSameShardingConditionValue(shardingRule, 
shardingConditionValue1, shardingConditionValue2)) {
+            ShardingConditionValue shardingValue1 = 
shardingCondition1.getValues().get(i);
+            ShardingConditionValue shardingValue2 = 
shardingCondition2.getValues().get(i);
+            if (!isSameShardingConditionValue(shardingRule, shardingValue1, 
shardingValue2)) {
                 return false;
             }
         }
         return true;
     }
     
-    private boolean isRoutingByHint(final ShardingRule shardingRule, final 
TableRule tableRule) {
-        return 
shardingRule.getDatabaseShardingStrategyConfiguration(tableRule) instanceof 
HintShardingStrategyConfiguration
-                && 
shardingRule.getTableShardingStrategyConfiguration(tableRule) instanceof 
HintShardingStrategyConfiguration;
+    private boolean isSameShardingCondition(final ShardingRule shardingRule, 
final ShardingConditionValue shardingValue1, final ShardingConditionValue 
shardingValue2) {
+        return 
shardingValue1.getTableName().equals(shardingValue2.getTableName()) 
+                && 
shardingValue1.getColumnName().equals(shardingValue2.getColumnName()) || 
isBindingTable(shardingRule, shardingValue1, shardingValue2);
     }
     
-    private boolean isSameShardingConditionValue(final ShardingRule 
shardingRule, final ShardingConditionValue shardingConditionValue1, final 
ShardingConditionValue shardingConditionValue2) {
-        return isSameLogicTable(shardingRule, shardingConditionValue1, 
shardingConditionValue2) && 
shardingConditionValue1.getColumnName().equals(shardingConditionValue2.getColumnName())
-                && isSameValue(shardingConditionValue1, 
shardingConditionValue2);
+    private boolean isBindingTable(final ShardingRule shardingRule, final 
ShardingConditionValue shardingValue1, final ShardingConditionValue 
shardingValue2) {
+        Optional<BindingTableRule> bindingRule = 
shardingRule.findBindingTableRule(shardingValue1.getTableName());
+        return bindingRule.isPresent() && 
bindingRule.get().hasLogicTable(shardingValue2.getTableName());
     }
     
-    private boolean isSameLogicTable(final ShardingRule shardingRule, final 
ShardingConditionValue shardingValue1, final ShardingConditionValue 
shardingValue2) {
-        return 
shardingValue1.getTableName().equals(shardingValue2.getTableName()) || 
isBindingTable(shardingRule, shardingValue1, shardingValue2);
+    private boolean isRoutingByHint(final ShardingRule shardingRule, final 
TableRule tableRule) {
+        return 
shardingRule.getDatabaseShardingStrategyConfiguration(tableRule) instanceof 
HintShardingStrategyConfiguration
+                && 
shardingRule.getTableShardingStrategyConfiguration(tableRule) instanceof 
HintShardingStrategyConfiguration;
     }
     
-    private boolean isBindingTable(final ShardingRule shardingRule, final 
ShardingConditionValue shardingValue1, final ShardingConditionValue 
shardingValue2) {
-        Optional<BindingTableRule> bindingRule = 
shardingRule.findBindingTableRule(shardingValue1.getTableName());
-        return bindingRule.isPresent() && 
bindingRule.get().hasLogicTable(shardingValue2.getTableName());
+    private boolean isSameShardingConditionValue(final ShardingRule 
shardingRule, final ShardingConditionValue shardingValue1, final 
ShardingConditionValue shardingValue2) {
+        return isSameShardingCondition(shardingRule, shardingValue1, 
shardingValue2) && isSameShardingValue(shardingValue1, shardingValue2);
     }
     
     @SuppressWarnings({"rawtypes", "unchecked"})
-    private boolean isSameValue(final ShardingConditionValue 
shardingConditionValue1, final ShardingConditionValue shardingConditionValue2) {
+    private boolean isSameShardingValue(final ShardingConditionValue 
shardingConditionValue1, final ShardingConditionValue shardingConditionValue2) {
         if (shardingConditionValue1 instanceof ListShardingConditionValue && 
shardingConditionValue2 instanceof ListShardingConditionValue) {
             return SafeNumberOperationUtil.safeCollectionEquals(
                     ((ListShardingConditionValue) 
shardingConditionValue1).getValues(), ((ListShardingConditionValue) 
shardingConditionValue2).getValues());
diff --git 
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
 
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
index d1e525f..50237de 100644
--- 
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
+++ 
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/table/TablesContext.java
@@ -152,8 +152,8 @@ public final class TablesContext {
             if (tableColumnNames.isEmpty()) {
                 continue;
             }
-            tableColumnNames.retainAll(columnNames);
-            for (String columnName : tableColumnNames) {
+            Collection<String> intersectionColumnNames = 
tableColumnNames.stream().filter(columnNames::contains).collect(Collectors.toList());
+            for (String columnName : intersectionColumnNames) {
                 result.put(columnName, each);
             }
         }
diff --git 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtil.java
 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtil.java
index 7000cea..0603707 100644
--- 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtil.java
+++ 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtil.java
@@ -19,8 +19,10 @@ package org.apache.shardingsphere.sql.parser.sql.common.util;
 
 import lombok.AccessLevel;
 import lombok.NoArgsConstructor;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BetweenExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubqueryExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
@@ -67,7 +69,9 @@ public final class SubqueryExtractUtil {
             if (!(each instanceof SubqueryProjectionSegment)) {
                 continue;
             }
-            result.add(((SubqueryProjectionSegment) each).getSubquery());
+            SubquerySegment subquery = ((SubqueryProjectionSegment) 
each).getSubquery();
+            result.add(subquery);
+            result.addAll(getSubquerySegments(subquery.getSelect()));
         }
         return result;
     }
@@ -78,7 +82,9 @@ public final class SubqueryExtractUtil {
         }
         Collection<SubquerySegment> result = new LinkedList<>();
         if (tableSegment instanceof SubqueryTableSegment) {
-            result.add(((SubqueryTableSegment) tableSegment).getSubquery());
+            SubquerySegment subquery = ((SubqueryTableSegment) 
tableSegment).getSubquery();
+            result.add(subquery);
+            result.addAll(getSubquerySegments(subquery.getSelect()));
         }
         if (tableSegment instanceof JoinTableSegment) {
             
result.addAll(getSubquerySegmentsFromTableSegment(((JoinTableSegment) 
tableSegment).getLeft()));
@@ -90,7 +96,9 @@ public final class SubqueryExtractUtil {
     private static Collection<SubquerySegment> 
getSubquerySegmentsFromExpression(final ExpressionSegment expressionSegment) {
         Collection<SubquerySegment> result = new LinkedList<>();
         if (expressionSegment instanceof SubqueryExpressionSegment) {
-            result.add(((SubqueryExpressionSegment) 
expressionSegment).getSubquery());
+            SubquerySegment subquery = ((SubqueryExpressionSegment) 
expressionSegment).getSubquery();
+            result.add(subquery);
+            result.addAll(getSubquerySegments(subquery.getSelect()));
         }
         if (expressionSegment instanceof ListExpression) {
             for (ExpressionSegment each : ((ListExpression) 
expressionSegment).getItems()) {
@@ -101,6 +109,14 @@ public final class SubqueryExtractUtil {
             
result.addAll(getSubquerySegmentsFromExpression(((BinaryOperationExpression) 
expressionSegment).getLeft()));
             
result.addAll(getSubquerySegmentsFromExpression(((BinaryOperationExpression) 
expressionSegment).getRight()));
         }
+        if (expressionSegment instanceof InExpression) {
+            result.addAll(getSubquerySegmentsFromExpression(((InExpression) 
expressionSegment).getLeft()));
+            result.addAll(getSubquerySegmentsFromExpression(((InExpression) 
expressionSegment).getRight()));
+        }
+        if (expressionSegment instanceof BetweenExpression) {
+            
result.addAll(getSubquerySegmentsFromExpression(((BetweenExpression) 
expressionSegment).getBetweenExpr()));
+            
result.addAll(getSubquerySegmentsFromExpression(((BetweenExpression) 
expressionSegment).getAndExpr()));
+        }
         return result;
     }
 }
diff --git 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilTest.java
 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilTest.java
index 2f5d68c..318cf35 100644
--- 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilTest.java
+++ 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilTest.java
@@ -20,6 +20,8 @@ package org.apache.shardingsphere.sql.parser.sql.common.util;
 import 
org.apache.shardingsphere.sql.parser.sql.common.constant.AggregationType;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubqueryExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
@@ -32,6 +34,7 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.Joi
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
 import org.junit.Test;
@@ -145,4 +148,20 @@ public final class SubqueryExtractUtilTest {
         assertThat(iterator.next(), is(leftSubquerySegment.getSubquery()));
         assertThat(iterator.next(), is(rightSubquerySegment.getSubquery()));
     }
+    
+    @Test
+    public void assertGetSubquerySegmentsWithMultiNestedSubquery() {
+        SelectStatement selectStatement = new MySQLSelectStatement();
+        selectStatement.setFrom(new 
SubqueryTableSegment(createSubquerySegmentForFrom()));
+        Collection<SubquerySegment> result = 
SubqueryExtractUtil.getSubquerySegments(selectStatement);
+        assertThat(result.size(), is(2));
+    }
+    
+    private SubquerySegment createSubquerySegmentForFrom() {
+        SelectStatement selectStatement = new MySQLSelectStatement();
+        ExpressionSegment left = new ColumnSegment(0, 0, new 
IdentifierValue("order_id"));
+        selectStatement.setWhere(new WhereSegment(0, 0, new InExpression(0, 0, 
+                left, new SubqueryExpressionSegment(new SubquerySegment(0, 0, 
new MySQLSelectStatement())), false)));
+        return new SubquerySegment(0, 0, selectStatement);
+    }
 }
diff --git 
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/java/org/apache/shardingsphere/sharding/rewrite/parameterized/scenario/ShardingSQLRewriterParameterizedTest.java
 
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/java/org/apache/shardingsphere/sharding/rewrite/parameterized/scenario/ShardingSQLRewriterParameterizedTest.java
index 00ebb0c..ba96087 100644
--- 
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/java/org/apache/shardingsphere/sharding/rewrite/parameterized/scenario/ShardingSQLRewriterParameterizedTest.java
+++ 
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/java/org/apache/shardingsphere/sharding/rewrite/parameterized/scenario/ShardingSQLRewriterParameterizedTest.java
@@ -93,6 +93,8 @@ public final class ShardingSQLRewriterParameterizedTest 
extends AbstractSQLRewri
         when(result.get("t_account")).thenReturn(accountTableMetaData);
         
when(result.get("t_account_detail")).thenReturn(mock(TableMetaData.class));
         when(result.getAllColumnNames("t_account")).thenReturn(new 
ArrayList<>(Arrays.asList("account_id", "amount", "status")));
+        when(result.getAllColumnNames("t_user")).thenReturn(new 
ArrayList<>(Arrays.asList("id", "content")));
+        when(result.getAllColumnNames("t_user_extend")).thenReturn(new 
ArrayList<>(Arrays.asList("user_id", "content")));
         when(result.containsColumn("t_account", 
"account_id")).thenReturn(true);
         return result;
     }
diff --git 
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
 
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
index 6ebcccf..b735cec 100644
--- 
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
+++ 
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
@@ -85,7 +85,7 @@
 
     <rewrite-assertion id="select_with_subquery_with_subquery" db-type="MySQL">
         <input sql="SELECT * FROM (select b.account_id from (select 
t_account.account_id from t_account) b where b.account_id=?) a WHERE account_id 
= 100" parameters="100" />
-        <output sql="SELECT * FROM (select b.account_id from (select 
t_account_0.account_id from t_account_0) b where b.account_id=?) a WHERE 
account_id = 100" parameters="100" />
+        <output sql="SELECT * FROM (select b.account_id from (select 
t_account.account_id from t_account) b where b.account_id=?) a WHERE account_id 
= 100" parameters="100" />
     </rewrite-assertion>
 
     <rewrite-assertion id="select_with_subquery_in_projection_and_where" 
db-type="MySQL">
@@ -506,4 +506,9 @@
         <output sql="SELECT * FROM t_account_0 WHERE amount=? OR amount=? AND 
account_id=?" parameters="1, 2, 3"/>
         <output sql="SELECT * FROM t_account_1 WHERE amount=? OR amount=? AND 
account_id=?" parameters="1, 2, 3"/>
     </rewrite-assertion>
+
+    <rewrite-assertion id="select_multi_nested_subquery_with_binding_tables">
+        <input sql="SELECT * FROM (SELECT id, content FROM t_user WHERE id = ? 
AND content IN (SELECT content FROM t_user_extend WHERE user_id = ?)) AS temp" 
parameters="1, 1"/>
+        <output sql="SELECT * FROM (SELECT id, content FROM t_user_1 WHERE id 
= ? AND content IN (SELECT content FROM t_user_extend_1 WHERE user_id = ?)) AS 
temp" parameters="1, 1"/>
+    </rewrite-assertion>
 </rewrite-assertions>

Reply via email to