This is an automated email from the ASF dual-hosted git repository.

jianglongtao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new f4742d1  optimize binding table route logic (#13892)
f4742d1 is described below

commit f4742d1b396c64ba09f59b6bf5ad774d7c45b242
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Thu Dec 2 14:03:53 2021 +0800

    optimize binding table route logic (#13892)
    
    * optimize binding table route logic
    
    * fix rewrite test
    
    * update binding table doc
    
    * refactor unit test
---
 .../content/features/sharding/concept/table.cn.md  |   5 +-
 .../content/features/sharding/concept/table.en.md  |   5 +-
 .../engine/type/ShardingRouteEngineFactory.java    |  89 +---------------
 .../shardingsphere/sharding/rule/ShardingRule.java | 106 ++++++++++++++++---
 .../sharding/rule/ShardingRuleTest.java            | 117 ++++++++++++++++++++-
 .../resources/scenario/sharding/case/select.xml    |   1 -
 6 files changed, 217 insertions(+), 106 deletions(-)

diff --git a/docs/document/content/features/sharding/concept/table.cn.md 
b/docs/document/content/features/sharding/concept/table.cn.md
index ae1337a..f0c8639 100644
--- a/docs/document/content/features/sharding/concept/table.cn.md
+++ b/docs/document/content/features/sharding/concept/table.cn.md
@@ -19,7 +19,8 @@ Apache ShardingSphere 通过提供多样化的表类型,适配不同场景下
 ## 绑定表
 
 指分片规则一致的主表和子表。
-例如:`t_order` 表和 `t_order_item` 表,均按照 `order_id` 分片,则此两张表互为绑定表关系。
+使用绑定表进行多表关联查询时,必须使用分片键进行关联,否则会出现笛卡尔积关联或跨库关联,从而影响查询效率。
+例如:`t_order` 表和 `t_order_item` 表,均按照 `order_id` 分片,并且使用 `order_id` 
进行关联,则此两张表互为绑定表关系。
 绑定表之间的多表关联查询不会出现笛卡尔积关联,关联查询效率将大大提升。
 举例说明,如果 SQL 为:
 
@@ -39,7 +40,7 @@ SELECT i.* FROM t_order_1 o JOIN t_order_item_0 i ON 
o.order_id=i.order_id WHERE
 SELECT i.* FROM t_order_1 o JOIN t_order_item_1 i ON o.order_id=i.order_id 
WHERE o.order_id in (10, 11);
 ```
 
-在配置绑定表关系后,路由的 SQL 应该为 2 条:
+在配置绑定表关系,并且使用 `order_id` 进行关联后,路由的 SQL 应该为 2 条:
 
 ```sql
 SELECT i.* FROM t_order_0 o JOIN t_order_item_0 i ON o.order_id=i.order_id 
WHERE o.order_id in (10, 11);
diff --git a/docs/document/content/features/sharding/concept/table.en.md 
b/docs/document/content/features/sharding/concept/table.en.md
index 406cb24..7357d70 100644
--- a/docs/document/content/features/sharding/concept/table.en.md
+++ b/docs/document/content/features/sharding/concept/table.en.md
@@ -18,7 +18,8 @@ The physical table that really exists in the horizontal 
sharding database, i.e.,
 ## Binding Table
 
 It refers to the primary table and the joiner table with the same sharding 
rules.
-for example, `t_order` and `t_order_item` are both sharded by `order_id`, so 
they are binding tables with each other. 
+When using binding tables in multi-table correlating query, you must use the 
sharding key for correlation, otherwise Cartesian product correlation or 
cross-database correlation will appear, which will affect query efficiency.
+For example, `t_order` and `t_order_item` are both sharded by `order_id`, and 
use `order_id` to correlate, so they are binding tables with each other. 
 Cartesian product correlation will not appear in the multi-tables correlating 
query, so the query efficiency will increase greatly.
 Take this one for example, if SQL is:
 
@@ -38,7 +39,7 @@ SELECT i.* FROM t_order_1 o JOIN t_order_item_0 i ON 
o.order_id=i.order_id WHERE
 SELECT i.* FROM t_order_1 o JOIN t_order_item_1 i ON o.order_id=i.order_id 
WHERE o.order_id in (10, 11);
 ```
 
-With binding table configuration, there should be 2 SQLs after routing:
+With binding table configuration and use `order_id` to correlate, there should 
be 2 SQLs after routing:
 
 ```sql
 SELECT i.* FROM t_order_0 o JOIN t_order_item_0 i ON o.order_id=i.order_id 
WHERE o.order_id in (10, 11);
diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/ShardingRouteEngineFactory.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/ShardingRouteEngineFactory.java
index cc9e2a9..1e46cd3 100644
--- 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/ShardingRouteEngineFactory.java
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/ShardingRouteEngineFactory.java
@@ -19,7 +19,6 @@ package org.apache.shardingsphere.sharding.route.engine.type;
 
 import lombok.AccessLevel;
 import lombok.NoArgsConstructor;
-import 
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
 import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import 
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
 import org.apache.shardingsphere.infra.binder.type.TableAvailable;
@@ -27,8 +26,6 @@ import 
org.apache.shardingsphere.infra.config.properties.ConfigurationProperties
 import 
org.apache.shardingsphere.infra.config.properties.ConfigurationPropertyKey;
 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
 import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
-import 
org.apache.shardingsphere.sharding.api.config.strategy.sharding.ShardingStrategyConfiguration;
-import 
org.apache.shardingsphere.sharding.api.config.strategy.sharding.StandardShardingStrategyConfiguration;
 import 
org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
 import 
org.apache.shardingsphere.sharding.route.engine.condition.ShardingConditions;
 import 
org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
@@ -42,12 +39,6 @@ import 
org.apache.shardingsphere.sharding.route.engine.type.ignore.ShardingIgnor
 import 
org.apache.shardingsphere.sharding.route.engine.type.standard.ShardingStandardRoutingEngine;
 import 
org.apache.shardingsphere.sharding.route.engine.type.unicast.ShardingUnicastRoutingEngine;
 import org.apache.shardingsphere.sharding.rule.ShardingRule;
-import org.apache.shardingsphere.sharding.rule.TableRule;
-import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
-import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
-import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
-import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
-import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
 import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dal.AnalyzeTableStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dal.DALStatement;
@@ -68,20 +59,13 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.statement.ddl.DropTablesp
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.tcl.TCLStatement;
-import 
org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtil;
-import org.apache.shardingsphere.sql.parser.sql.common.util.WhereExtractUtil;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLCreateResourceGroupStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLOptimizeTableStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLSetResourceGroupStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLShowDatabasesStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLUseStatement;
 
-import java.util.Arrays;
 import java.util.Collection;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.Map;
-import java.util.Optional;
 import java.util.stream.Collectors;
 
 /**
@@ -90,8 +74,6 @@ import java.util.stream.Collectors;
 @NoArgsConstructor(access = AccessLevel.PRIVATE)
 public final class ShardingRouteEngineFactory {
     
-    private static final String EQUAL = "=";
-    
     /**
      * Create new instance of routing engine.
      *
@@ -212,7 +194,7 @@ public final class ShardingRouteEngineFactory {
     
     private static ShardingRouteEngine getDQLRouteEngineForShardingTable(final 
ShardingRule shardingRule, final ShardingSphereSchema schema, final 
SQLStatementContext<?> sqlStatementContext, 
                                                                          final 
ShardingConditions shardingConditions, final ConfigurationProperties props, 
final Collection<String> tableNames) {
-        boolean allBindingTables = tableNames.size() > 1 && 
isAllBindingTables(shardingRule, schema, sqlStatementContext, tableNames);
+        boolean allBindingTables = tableNames.size() > 1 && 
shardingRule.isAllBindingTables(schema, sqlStatementContext, tableNames);
         if (isShardingFederatedQuery(shardingRule, sqlStatementContext, 
shardingConditions, props, tableNames, allBindingTables)) {
             return new ShardingFederatedRoutingEngine(tableNames);
         }
@@ -254,73 +236,4 @@ public final class ShardingRouteEngineFactory {
         }
         return tableNames.size() > 1 && !allBindingTables;
     }
-    
-    private static boolean isAllBindingTables(final ShardingRule shardingRule, 
final ShardingSphereSchema schema, 
-                                              final SQLStatementContext<?> 
sqlStatementContext, final Collection<String> tableNames) {
-        if (!(sqlStatementContext instanceof SelectStatementContext)) {
-            return shardingRule.isAllBindingTables(tableNames); 
-        }
-        return shardingRule.isAllBindingTables(tableNames) && 
isJoinConditionContainsShardingColumns(shardingRule, schema, 
(SelectStatementContext) sqlStatementContext, tableNames);
-    }
-    
-    private static boolean isJoinConditionContainsShardingColumns(final 
ShardingRule shardingRule, final ShardingSphereSchema schema,
-                                                                  final 
SelectStatementContext select, final Collection<String> tableNames) {
-        Collection<String> databaseJoinConditionTables = new 
HashSet<>(tableNames.size());
-        Collection<String> tableJoinConditionTables = new 
HashSet<>(tableNames.size());
-        for (WhereSegment each : 
WhereExtractUtil.getJoinWhereSegments(select.getSqlStatement())) {
-            Collection<AndPredicate> andPredicates = 
ExpressionExtractUtil.getAndPredicates(each.getExpr());
-            if (andPredicates.size() > 1) {
-                return false;
-            }
-            for (AndPredicate andPredicate : andPredicates) {
-                
databaseJoinConditionTables.addAll(getJoinConditionTables(shardingRule, schema, 
select, andPredicate.getPredicates(), true));
-                
tableJoinConditionTables.addAll(getJoinConditionTables(shardingRule, schema, 
select, andPredicate.getPredicates(), false));
-            }
-        }
-        TableRule tableRule = 
shardingRule.getTableRule(tableNames.iterator().next());
-        boolean containsDatabaseShardingColumns = 
!(tableRule.getDatabaseShardingStrategyConfig() instanceof 
StandardShardingStrategyConfiguration) 
-                || databaseJoinConditionTables.containsAll(tableNames);
-        boolean containsTableShardingColumns = 
!(tableRule.getTableShardingStrategyConfig() instanceof 
StandardShardingStrategyConfiguration) || 
tableJoinConditionTables.containsAll(tableNames);
-        return containsDatabaseShardingColumns && containsTableShardingColumns;
-    }
-    
-    private static Collection<String> getJoinConditionTables(final 
ShardingRule shardingRule, final ShardingSphereSchema schema, final 
SelectStatementContext select, 
-                                                             final 
Collection<ExpressionSegment> predicates, final boolean 
isDatabaseJoinCondition) {
-        Collection<String> result = new LinkedList<>();
-        for (ExpressionSegment expression : predicates) {
-            if (!isJoinTableConditionExpression(expression)) {
-                continue;
-            }
-            ColumnProjection leftColumn = 
buildColumnProjection((ColumnSegment) ((BinaryOperationExpression) 
expression).getLeft());
-            ColumnProjection rightColumn = 
buildColumnProjection((ColumnSegment) ((BinaryOperationExpression) 
expression).getRight());
-            Map<String, String> columnTableNames = 
select.getTablesContext().findTableName(Arrays.asList(leftColumn, rightColumn), 
schema);
-            Optional<TableRule> leftTableRule = 
shardingRule.findTableRule(columnTableNames.get(leftColumn.getExpression()));
-            Optional<TableRule> rightTableRule = 
shardingRule.findTableRule(columnTableNames.get(rightColumn.getExpression()));
-            if (!leftTableRule.isPresent() || !rightTableRule.isPresent()) {
-                continue;
-            }
-            ShardingStrategyConfiguration leftConfiguration = 
isDatabaseJoinCondition
-                    ? 
shardingRule.getDatabaseShardingStrategyConfiguration(leftTableRule.get()) : 
shardingRule.getTableShardingStrategyConfiguration(leftTableRule.get());
-            ShardingStrategyConfiguration rightConfiguration = 
isDatabaseJoinCondition
-                    ? 
shardingRule.getDatabaseShardingStrategyConfiguration(rightTableRule.get()) : 
shardingRule.getTableShardingStrategyConfiguration(rightTableRule.get());
-            if (shardingRule.isShardingColumn(leftConfiguration, 
leftColumn.getName()) && shardingRule.isShardingColumn(rightConfiguration, 
rightColumn.getName())) {
-                result.add(columnTableNames.get(leftColumn.getExpression()));
-                result.add(columnTableNames.get(rightColumn.getExpression()));
-            }
-        }
-        return result;
-    }
-    
-    private static ColumnProjection buildColumnProjection(final ColumnSegment 
segment) {
-        String owner = segment.getOwner().map(optional -> 
optional.getIdentifier().getValue()).orElse(null);
-        return new ColumnProjection(owner, segment.getIdentifier().getValue(), 
null);
-    }
-    
-    private static boolean isJoinTableConditionExpression(final 
ExpressionSegment expression) {
-        if (!(expression instanceof BinaryOperationExpression)) {
-            return false;
-        }
-        BinaryOperationExpression binaryExpression = 
(BinaryOperationExpression) expression;
-        return binaryExpression.getLeft() instanceof ColumnSegment && 
binaryExpression.getRight() instanceof ColumnSegment && 
EQUAL.equals(binaryExpression.getOperator());
-    }
 }
diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
index ff4b9b9..700b065 100644
--- 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
@@ -20,9 +20,13 @@ package org.apache.shardingsphere.sharding.rule;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Splitter;
 import lombok.Getter;
+import 
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import 
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
 import 
org.apache.shardingsphere.infra.config.algorithm.ShardingSphereAlgorithmFactory;
 import 
org.apache.shardingsphere.infra.config.exception.ShardingSphereConfigurationException;
 import org.apache.shardingsphere.infra.datanode.DataNode;
+import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
 import org.apache.shardingsphere.infra.rule.identifier.scope.SchemaRule;
 import 
org.apache.shardingsphere.infra.rule.identifier.type.DataNodeContainedRule;
 import org.apache.shardingsphere.infra.rule.identifier.type.TableContainedRule;
@@ -41,7 +45,15 @@ import 
org.apache.shardingsphere.sharding.spi.ShardingAlgorithm;
 import org.apache.shardingsphere.sharding.support.InlineExpressionParser;
 import org.apache.shardingsphere.spi.ShardingSphereServiceLoader;
 import org.apache.shardingsphere.spi.required.RequiredSPIRegistry;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtil;
+import org.apache.shardingsphere.sql.parser.sql.common.util.WhereExtractUtil;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
@@ -61,6 +73,8 @@ import java.util.stream.Collectors;
 @Getter
 public final class ShardingRule implements SchemaRule, DataNodeContainedRule, 
TableContainedRule {
     
+    private static final String EQUAL = "=";
+    
     static {
         ShardingSphereServiceLoader.register(ShardingAlgorithm.class);
         ShardingSphereServiceLoader.register(KeyGenerateAlgorithm.class);
@@ -244,10 +258,10 @@ public final class ShardingRule implements SchemaRule, 
DataNodeContainedRule, Ta
     }
     
     /**
-     * Judge whether logic table is all binding encryptors or not.
+     * Judge whether logic table is all binding tables or not.
      *
      * @param logicTableNames logic table names
-     * @return whether logic table is all binding encryptors or not
+     * @return whether logic table is all binding tables or not
      */
     public boolean isAllBindingTables(final Collection<String> 
logicTableNames) {
         if (logicTableNames.isEmpty()) {
@@ -262,6 +276,21 @@ public final class ShardingRule implements SchemaRule, 
DataNodeContainedRule, Ta
         return !result.isEmpty() && result.containsAll(logicTableNames);
     }
     
+    /**
+     * Judge whether logic table is all binding tables.
+     *
+     * @param schema schema
+     * @param sqlStatementContext sqlStatementContext
+     * @param logicTableNames logic table names
+     * @return whether logic table is all binding tables.
+     */
+    public boolean isAllBindingTables(final ShardingSphereSchema schema, final 
SQLStatementContext<?> sqlStatementContext, final Collection<String> 
logicTableNames) {
+        if (!(sqlStatementContext instanceof SelectStatementContext && 
((SelectStatementContext) sqlStatementContext).isContainsJoinQuery())) {
+            return isAllBindingTables(logicTableNames);
+        }
+        return isAllBindingTables(logicTableNames) && 
isJoinConditionContainsShardingColumns(schema, (SelectStatementContext) 
sqlStatementContext, logicTableNames);
+    }
+    
     private Optional<BindingTableRule> findBindingTableRule(final 
Collection<String> logicTableNames) {
         return 
logicTableNames.stream().map(this::findBindingTableRule).filter(Optional::isPresent).findFirst().orElse(Optional.empty());
     }
@@ -282,10 +311,10 @@ public final class ShardingRule implements SchemaRule, 
DataNodeContainedRule, Ta
     }
     
     /**
-     * Judge whether logic table is all broadcast encryptors or not.
+     * Judge whether logic table is all broadcast tables or not.
      *
      * @param logicTableNames logic table names
-     * @return whether logic table is all broadcast encryptors or not
+     * @return whether logic table is all broadcast tables or not
      */
     public boolean isAllBroadcastTables(final Collection<String> 
logicTableNames) {
         return !logicTableNames.isEmpty() && 
broadcastTables.containsAll(logicTableNames);
@@ -366,14 +395,7 @@ public final class ShardingRule implements SchemaRule, 
DataNodeContainedRule, Ta
         return 
isShardingColumn(getDatabaseShardingStrategyConfiguration(tableRule), 
columnName) || 
isShardingColumn(getTableShardingStrategyConfiguration(tableRule), columnName);
     }
     
-    /**
-     * Judge whether given logic table column is sharding column or not.
-     * 
-     * @param shardingStrategyConfig sharding strategy config
-     * @param columnName column name
-     * @return whether given logic table column is sharding column or not
-     */
-    public boolean isShardingColumn(final ShardingStrategyConfiguration 
shardingStrategyConfig, final String columnName) {
+    private boolean isShardingColumn(final ShardingStrategyConfiguration 
shardingStrategyConfig, final String columnName) {
         if (shardingStrategyConfig instanceof 
StandardShardingStrategyConfiguration) {
             String shardingColumn = null == 
((StandardShardingStrategyConfiguration) 
shardingStrategyConfig).getShardingColumn()
                     ? defaultShardingColumn : 
((StandardShardingStrategyConfiguration) 
shardingStrategyConfig).getShardingColumn();
@@ -531,4 +553,64 @@ public final class ShardingRule implements SchemaRule, 
DataNodeContainedRule, Ta
     public String getType() {
         return ShardingRule.class.getSimpleName();
     }
+    
+    private boolean isJoinConditionContainsShardingColumns(final 
ShardingSphereSchema schema, final SelectStatementContext select, final 
Collection<String> tableNames) {
+        Collection<String> databaseJoinConditionTables = new 
HashSet<>(tableNames.size());
+        Collection<String> tableJoinConditionTables = new 
HashSet<>(tableNames.size());
+        for (WhereSegment each : 
WhereExtractUtil.getJoinWhereSegments(select.getSqlStatement())) {
+            Collection<AndPredicate> andPredicates = 
ExpressionExtractUtil.getAndPredicates(each.getExpr());
+            if (andPredicates.size() > 1) {
+                return false;
+            }
+            for (AndPredicate andPredicate : andPredicates) {
+                
databaseJoinConditionTables.addAll(getJoinConditionTables(schema, select, 
andPredicate.getPredicates(), true));
+                tableJoinConditionTables.addAll(getJoinConditionTables(schema, 
select, andPredicate.getPredicates(), false));
+            }
+        }
+        TableRule tableRule = getTableRule(tableNames.iterator().next());
+        boolean containsDatabaseShardingColumns = 
!(tableRule.getDatabaseShardingStrategyConfig() instanceof 
StandardShardingStrategyConfiguration)
+                || databaseJoinConditionTables.containsAll(tableNames);
+        boolean containsTableShardingColumns = 
!(tableRule.getTableShardingStrategyConfig() instanceof 
StandardShardingStrategyConfiguration) || 
tableJoinConditionTables.containsAll(tableNames);
+        return containsDatabaseShardingColumns && containsTableShardingColumns;
+    }
+    
+    private Collection<String> getJoinConditionTables(final 
ShardingSphereSchema schema, final SelectStatementContext select,
+                                                      final 
Collection<ExpressionSegment> predicates, final boolean 
isDatabaseJoinCondition) {
+        Collection<String> result = new LinkedList<>();
+        for (ExpressionSegment each : predicates) {
+            if (!isJoinConditionExpression(each)) {
+                continue;
+            }
+            ColumnProjection leftColumn = 
buildColumnProjection((ColumnSegment) ((BinaryOperationExpression) 
each).getLeft());
+            ColumnProjection rightColumn = 
buildColumnProjection((ColumnSegment) ((BinaryOperationExpression) 
each).getRight());
+            Map<String, String> columnTableNames = 
select.getTablesContext().findTableName(Arrays.asList(leftColumn, rightColumn), 
schema);
+            Optional<TableRule> leftTableRule = 
findTableRule(columnTableNames.get(leftColumn.getExpression()));
+            Optional<TableRule> rightTableRule = 
findTableRule(columnTableNames.get(rightColumn.getExpression()));
+            if (!leftTableRule.isPresent() || !rightTableRule.isPresent()) {
+                continue;
+            }
+            ShardingStrategyConfiguration leftConfiguration = 
isDatabaseJoinCondition
+                    ? 
getDatabaseShardingStrategyConfiguration(leftTableRule.get()) : 
getTableShardingStrategyConfiguration(leftTableRule.get());
+            ShardingStrategyConfiguration rightConfiguration = 
isDatabaseJoinCondition
+                    ? 
getDatabaseShardingStrategyConfiguration(rightTableRule.get()) : 
getTableShardingStrategyConfiguration(rightTableRule.get());
+            if (isShardingColumn(leftConfiguration, leftColumn.getName()) && 
isShardingColumn(rightConfiguration, rightColumn.getName())) {
+                result.add(columnTableNames.get(leftColumn.getExpression()));
+                result.add(columnTableNames.get(rightColumn.getExpression()));
+            }
+        }
+        return result;
+    }
+    
+    private ColumnProjection buildColumnProjection(final ColumnSegment 
segment) {
+        String owner = segment.getOwner().map(optional -> 
optional.getIdentifier().getValue()).orElse(null);
+        return new ColumnProjection(owner, segment.getIdentifier().getValue(), 
null);
+    }
+    
+    private boolean isJoinConditionExpression(final ExpressionSegment 
expression) {
+        if (!(expression instanceof BinaryOperationExpression)) {
+            return false;
+        }
+        BinaryOperationExpression binaryExpression = 
(BinaryOperationExpression) expression;
+        return binaryExpression.getLeft() instanceof ColumnSegment && 
binaryExpression.getRight() instanceof ColumnSegment && 
EQUAL.equals(binaryExpression.getOperator());
+    }
 }
diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java
index 1068bb6..7d12bcb 100644
--- 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java
@@ -17,9 +17,14 @@
 
 package org.apache.shardingsphere.sharding.rule;
 
+import 
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import 
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
+import 
org.apache.shardingsphere.infra.binder.statement.dml.UpdateStatementContext;
 import 
org.apache.shardingsphere.infra.config.algorithm.ShardingSphereAlgorithmConfiguration;
 import 
org.apache.shardingsphere.infra.config.exception.ShardingSphereConfigurationException;
 import org.apache.shardingsphere.infra.datanode.DataNode;
+import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
 import 
org.apache.shardingsphere.sharding.algorithm.keygen.SnowflakeKeyGenerateAlgorithm;
 import 
org.apache.shardingsphere.sharding.algorithm.keygen.fixture.IncrementKeyGenerateAlgorithm;
 import 
org.apache.shardingsphere.sharding.algorithm.sharding.inline.InlineShardingAlgorithm;
@@ -30,12 +35,21 @@ import 
org.apache.shardingsphere.sharding.api.config.strategy.keygen.KeyGenerate
 import 
org.apache.shardingsphere.sharding.api.config.strategy.sharding.ComplexShardingStrategyConfiguration;
 import 
org.apache.shardingsphere.sharding.api.config.strategy.sharding.NoneShardingStrategyConfiguration;
 import 
org.apache.shardingsphere.sharding.api.config.strategy.sharding.StandardShardingStrategyConfiguration;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.JoinTableSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
+import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
 import org.junit.Test;
 
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedHashSet;
+import java.util.Map;
 import java.util.Properties;
 import java.util.TreeSet;
 
@@ -45,9 +59,16 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public final class ShardingRuleTest {
     
+    private static final String EQUAL = "=";
+    
+    private static final String AND = "AND";
+    
     @Test
     public void assertNewShardingRuleWithMaximumConfiguration() {
         ShardingRule actual = createMaximumShardingRule();
@@ -368,7 +389,10 @@ public final class ShardingRuleTest {
     }
     
     private ShardingTableRuleConfiguration createTableRuleConfiguration(final 
String logicTableName, final String actualDataNodes) {
-        return new ShardingTableRuleConfiguration(logicTableName, 
actualDataNodes);
+        ShardingTableRuleConfiguration result = new 
ShardingTableRuleConfiguration(logicTableName, actualDataNodes);
+        result.setDatabaseShardingStrategy(new 
StandardShardingStrategyConfiguration("user_id", "database_inline"));
+        result.setTableShardingStrategy(new 
StandardShardingStrategyConfiguration("order_id", "table_inline"));
+        return result;
     }
     
     private Collection<String> createDataSourceNames() {
@@ -407,4 +431,95 @@ public final class ShardingRuleTest {
         result.setTableShardingStrategy(new 
NoneShardingStrategyConfiguration());
         return result;
     }
+    
+    @Test
+    public void assertIsAllBindingTableWithUpdateStatementContext() {
+        SQLStatementContext<?> sqlStatementContext = 
mock(UpdateStatementContext.class);
+        
assertTrue(createMaximumShardingRule().isAllBindingTables(mock(ShardingSphereSchema.class),
 sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+    }
+    
+    @Test
+    public void assertIsAllBindingTableWithoutJoinQuery() {
+        SelectStatementContext sqlStatementContext = 
mock(SelectStatementContext.class);
+        when(sqlStatementContext.isContainsJoinQuery()).thenReturn(false);
+        
assertTrue(createMaximumShardingRule().isAllBindingTables(mock(ShardingSphereSchema.class),
 sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+    }
+    
+    @Test
+    public void assertIsAllBindingTableWithJoinQueryWithoutJoinCondition() {
+        SelectStatementContext sqlStatementContext = 
mock(SelectStatementContext.class);
+        when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true);
+        
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(MySQLSelectStatement.class));
+        
assertFalse(createMaximumShardingRule().isAllBindingTables(mock(ShardingSphereSchema.class),
 sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+    }
+    
+    @Test
+    public void 
assertIsAllBindingTableWithJoinQueryWithDatabaseJoinCondition() {
+        ColumnSegment leftDatabaseJoin = createColumnSegment("user_id", 
"logic_Table");
+        ColumnSegment rightDatabaseJoin = createColumnSegment("user_id", 
"sub_Logic_Table");
+        BinaryOperationExpression condition = 
createBinaryOperationExpression(leftDatabaseJoin, rightDatabaseJoin, EQUAL);
+        JoinTableSegment joinTable = mock(JoinTableSegment.class);
+        when(joinTable.getCondition()).thenReturn(condition);
+        MySQLSelectStatement selectStatement = 
mock(MySQLSelectStatement.class);
+        when(selectStatement.getFrom()).thenReturn(joinTable);
+        SelectStatementContext sqlStatementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
+        when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true);
+        ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
+        
when(sqlStatementContext.getTablesContext().findTableName(Arrays.asList(buildColumnProjection(leftDatabaseJoin),
 
+                buildColumnProjection(rightDatabaseJoin)), 
schema)).thenReturn(createColumnTableNameMap());
+        assertFalse(createMaximumShardingRule().isAllBindingTables(schema, 
sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+    }
+    
+    @Test
+    public void 
assertIsAllBindingTableWithJoinQueryWithDatabaseTableJoinCondition() {
+        ColumnSegment leftDatabaseJoin = createColumnSegment("user_id", 
"logic_Table");
+        ColumnSegment rightDatabaseJoin = createColumnSegment("user_id", 
"sub_Logic_Table");
+        BinaryOperationExpression databaseJoin = 
createBinaryOperationExpression(leftDatabaseJoin, rightDatabaseJoin, EQUAL);
+        ColumnSegment leftTableJoin = createColumnSegment("order_id", 
"logic_Table");
+        ColumnSegment rightTableJoin = createColumnSegment("order_id", 
"sub_Logic_Table");
+        BinaryOperationExpression tableJoin = 
createBinaryOperationExpression(leftTableJoin, rightTableJoin, EQUAL);
+        JoinTableSegment joinTable = mock(JoinTableSegment.class);
+        BinaryOperationExpression condition = 
createBinaryOperationExpression(databaseJoin, tableJoin, AND);
+        when(joinTable.getCondition()).thenReturn(condition);
+        MySQLSelectStatement selectStatement = 
mock(MySQLSelectStatement.class);
+        when(selectStatement.getFrom()).thenReturn(joinTable);
+        SelectStatementContext sqlStatementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
+        when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true);
+        ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
+        
when(sqlStatementContext.getTablesContext().findTableName(Arrays.asList(buildColumnProjection(leftDatabaseJoin),
 
+                buildColumnProjection(rightDatabaseJoin)), 
schema)).thenReturn(createColumnTableNameMap());
+        
when(sqlStatementContext.getTablesContext().findTableName(Arrays.asList(buildColumnProjection(leftTableJoin),
 
+                buildColumnProjection(rightTableJoin)), 
schema)).thenReturn(createColumnTableNameMap());
+        assertTrue(createMaximumShardingRule().isAllBindingTables(schema, 
sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+    }
+    
+    private BinaryOperationExpression createBinaryOperationExpression(final 
ExpressionSegment left, final ExpressionSegment right, final String operator) {
+        BinaryOperationExpression result = 
mock(BinaryOperationExpression.class);
+        when(result.getLeft()).thenReturn(left);
+        when(result.getRight()).thenReturn(right);
+        when(result.getOperator()).thenReturn(operator);
+        return result;
+    }
+    
+    private ColumnSegment createColumnSegment(final String columnName, final 
String owner) {
+        ColumnSegment result = new ColumnSegment(0, 0, new 
IdentifierValue(columnName));
+        result.setOwner(new OwnerSegment(0, 0, new IdentifierValue(owner)));
+        return result;
+    }
+    
+    private Map<String, String> createColumnTableNameMap() {
+        Map<String, String> result = new HashMap<>();
+        result.put("logic_Table.user_id", "logic_Table");
+        result.put("sub_Logic_Table.user_id", "sub_Logic_Table");
+        result.put("logic_Table.order_id", "logic_Table");
+        result.put("sub_Logic_Table.order_id", "sub_Logic_Table");
+        return result;
+    }
+    
+    private ColumnProjection buildColumnProjection(final ColumnSegment 
segment) {
+        String owner = segment.getOwner().map(optional -> 
optional.getIdentifier().getValue()).orElse(null);
+        return new ColumnProjection(owner, segment.getIdentifier().getValue(), 
null);
+    }
 }
diff --git 
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
 
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
index f0b0d22..7995f25 100644
--- 
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
+++ 
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
@@ -504,7 +504,6 @@
 
     <rewrite-assertion id="select_multi_nested_subquery_with_binding_tables" 
db-types="MySQL,PostgreSQL,openGauss,SQLServer,SQL92">
         <input sql="SELECT * FROM (SELECT id, content FROM t_user WHERE id = ? 
AND content IN (SELECT content FROM t_user_extend WHERE user_id = ?)) AS temp" 
parameters="1, 1"/>
-        <output sql="SELECT * FROM (SELECT id, content FROM t_user_0 WHERE id 
= ? AND content IN (SELECT content FROM t_user_extend_0 WHERE user_id = ?)) AS 
temp" parameters="1, 1"/>
         <output sql="SELECT * FROM (SELECT id, content FROM t_user_1 WHERE id 
= ? AND content IN (SELECT content FROM t_user_extend_1 WHERE user_id = ?)) AS 
temp" parameters="1, 1"/>
     </rewrite-assertion>
     

Reply via email to