This is an automated email from the ASF dual-hosted git repository.
jianglongtao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new f4742d1 optimize binding table route logic (#13892)
f4742d1 is described below
commit f4742d1b396c64ba09f59b6bf5ad774d7c45b242
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Thu Dec 2 14:03:53 2021 +0800
optimize binding table route logic (#13892)
* optimize binding table route logic
* fix rewrite test
* update binding table doc
* refactor unit test
---
.../content/features/sharding/concept/table.cn.md | 5 +-
.../content/features/sharding/concept/table.en.md | 5 +-
.../engine/type/ShardingRouteEngineFactory.java | 89 +---------------
.../shardingsphere/sharding/rule/ShardingRule.java | 106 ++++++++++++++++---
.../sharding/rule/ShardingRuleTest.java | 117 ++++++++++++++++++++-
.../resources/scenario/sharding/case/select.xml | 1 -
6 files changed, 217 insertions(+), 106 deletions(-)
diff --git a/docs/document/content/features/sharding/concept/table.cn.md
b/docs/document/content/features/sharding/concept/table.cn.md
index ae1337a..f0c8639 100644
--- a/docs/document/content/features/sharding/concept/table.cn.md
+++ b/docs/document/content/features/sharding/concept/table.cn.md
@@ -19,7 +19,8 @@ Apache ShardingSphere 通过提供多样化的表类型,适配不同场景下
## 绑定表
指分片规则一致的主表和子表。
-例如:`t_order` 表和 `t_order_item` 表,均按照 `order_id` 分片,则此两张表互为绑定表关系。
+使用绑定表进行多表关联查询时,必须使用分片键进行关联,否则会出现笛卡尔积关联或跨库关联,从而影响查询效率。
+例如:`t_order` 表和 `t_order_item` 表,均按照 `order_id` 分片,并且使用 `order_id`
进行关联,则此两张表互为绑定表关系。
绑定表之间的多表关联查询不会出现笛卡尔积关联,关联查询效率将大大提升。
举例说明,如果 SQL 为:
@@ -39,7 +40,7 @@ SELECT i.* FROM t_order_1 o JOIN t_order_item_0 i ON
o.order_id=i.order_id WHERE
SELECT i.* FROM t_order_1 o JOIN t_order_item_1 i ON o.order_id=i.order_id
WHERE o.order_id in (10, 11);
```
-在配置绑定表关系后,路由的 SQL 应该为 2 条:
+在配置绑定表关系,并且使用 `order_id` 进行关联后,路由的 SQL 应该为 2 条:
```sql
SELECT i.* FROM t_order_0 o JOIN t_order_item_0 i ON o.order_id=i.order_id
WHERE o.order_id in (10, 11);
diff --git a/docs/document/content/features/sharding/concept/table.en.md
b/docs/document/content/features/sharding/concept/table.en.md
index 406cb24..7357d70 100644
--- a/docs/document/content/features/sharding/concept/table.en.md
+++ b/docs/document/content/features/sharding/concept/table.en.md
@@ -18,7 +18,8 @@ The physical table that really exists in the horizontal
sharding database, i.e.,
## Binding Table
It refers to the primary table and the joiner table with the same sharding
rules.
-for example, `t_order` and `t_order_item` are both sharded by `order_id`, so
they are binding tables with each other.
+When using binding tables in multi-table correlating query, you must use the
sharding key for correlation, otherwise Cartesian product correlation or
cross-database correlation will appear, which will affect query efficiency.
+For example, `t_order` and `t_order_item` are both sharded by `order_id`, and
use `order_id` to correlate, so they are binding tables with each other.
Cartesian product correlation will not appear in the multi-tables correlating
query, so the query efficiency will increase greatly.
Take this one for example, if SQL is:
@@ -38,7 +39,7 @@ SELECT i.* FROM t_order_1 o JOIN t_order_item_0 i ON
o.order_id=i.order_id WHERE
SELECT i.* FROM t_order_1 o JOIN t_order_item_1 i ON o.order_id=i.order_id
WHERE o.order_id in (10, 11);
```
-With binding table configuration, there should be 2 SQLs after routing:
+With binding table configuration and use `order_id` to correlate, there should
be 2 SQLs after routing:
```sql
SELECT i.* FROM t_order_0 o JOIN t_order_item_0 i ON o.order_id=i.order_id
WHERE o.order_id in (10, 11);
diff --git
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/ShardingRouteEngineFactory.java
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/ShardingRouteEngineFactory.java
index cc9e2a9..1e46cd3 100644
---
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/ShardingRouteEngineFactory.java
+++
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/route/engine/type/ShardingRouteEngineFactory.java
@@ -19,7 +19,6 @@ package org.apache.shardingsphere.sharding.route.engine.type;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
-import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.type.TableAvailable;
@@ -27,8 +26,6 @@ import
org.apache.shardingsphere.infra.config.properties.ConfigurationProperties
import
org.apache.shardingsphere.infra.config.properties.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
-import
org.apache.shardingsphere.sharding.api.config.strategy.sharding.ShardingStrategyConfiguration;
-import
org.apache.shardingsphere.sharding.api.config.strategy.sharding.StandardShardingStrategyConfiguration;
import
org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
import
org.apache.shardingsphere.sharding.route.engine.condition.ShardingConditions;
import
org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
@@ -42,12 +39,6 @@ import
org.apache.shardingsphere.sharding.route.engine.type.ignore.ShardingIgnor
import
org.apache.shardingsphere.sharding.route.engine.type.standard.ShardingStandardRoutingEngine;
import
org.apache.shardingsphere.sharding.route.engine.type.unicast.ShardingUnicastRoutingEngine;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
-import org.apache.shardingsphere.sharding.rule.TableRule;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dal.AnalyzeTableStatement;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dal.DALStatement;
@@ -68,20 +59,13 @@ import
org.apache.shardingsphere.sql.parser.sql.common.statement.ddl.DropTablesp
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.tcl.TCLStatement;
-import
org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtil;
-import org.apache.shardingsphere.sql.parser.sql.common.util.WhereExtractUtil;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLCreateResourceGroupStatement;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLOptimizeTableStatement;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLSetResourceGroupStatement;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLShowDatabasesStatement;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dal.MySQLUseStatement;
-import java.util.Arrays;
import java.util.Collection;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.Map;
-import java.util.Optional;
import java.util.stream.Collectors;
/**
@@ -90,8 +74,6 @@ import java.util.stream.Collectors;
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class ShardingRouteEngineFactory {
- private static final String EQUAL = "=";
-
/**
* Create new instance of routing engine.
*
@@ -212,7 +194,7 @@ public final class ShardingRouteEngineFactory {
private static ShardingRouteEngine getDQLRouteEngineForShardingTable(final
ShardingRule shardingRule, final ShardingSphereSchema schema, final
SQLStatementContext<?> sqlStatementContext,
final
ShardingConditions shardingConditions, final ConfigurationProperties props,
final Collection<String> tableNames) {
- boolean allBindingTables = tableNames.size() > 1 &&
isAllBindingTables(shardingRule, schema, sqlStatementContext, tableNames);
+ boolean allBindingTables = tableNames.size() > 1 &&
shardingRule.isAllBindingTables(schema, sqlStatementContext, tableNames);
if (isShardingFederatedQuery(shardingRule, sqlStatementContext,
shardingConditions, props, tableNames, allBindingTables)) {
return new ShardingFederatedRoutingEngine(tableNames);
}
@@ -254,73 +236,4 @@ public final class ShardingRouteEngineFactory {
}
return tableNames.size() > 1 && !allBindingTables;
}
-
- private static boolean isAllBindingTables(final ShardingRule shardingRule,
final ShardingSphereSchema schema,
- final SQLStatementContext<?>
sqlStatementContext, final Collection<String> tableNames) {
- if (!(sqlStatementContext instanceof SelectStatementContext)) {
- return shardingRule.isAllBindingTables(tableNames);
- }
- return shardingRule.isAllBindingTables(tableNames) &&
isJoinConditionContainsShardingColumns(shardingRule, schema,
(SelectStatementContext) sqlStatementContext, tableNames);
- }
-
- private static boolean isJoinConditionContainsShardingColumns(final
ShardingRule shardingRule, final ShardingSphereSchema schema,
- final
SelectStatementContext select, final Collection<String> tableNames) {
- Collection<String> databaseJoinConditionTables = new
HashSet<>(tableNames.size());
- Collection<String> tableJoinConditionTables = new
HashSet<>(tableNames.size());
- for (WhereSegment each :
WhereExtractUtil.getJoinWhereSegments(select.getSqlStatement())) {
- Collection<AndPredicate> andPredicates =
ExpressionExtractUtil.getAndPredicates(each.getExpr());
- if (andPredicates.size() > 1) {
- return false;
- }
- for (AndPredicate andPredicate : andPredicates) {
-
databaseJoinConditionTables.addAll(getJoinConditionTables(shardingRule, schema,
select, andPredicate.getPredicates(), true));
-
tableJoinConditionTables.addAll(getJoinConditionTables(shardingRule, schema,
select, andPredicate.getPredicates(), false));
- }
- }
- TableRule tableRule =
shardingRule.getTableRule(tableNames.iterator().next());
- boolean containsDatabaseShardingColumns =
!(tableRule.getDatabaseShardingStrategyConfig() instanceof
StandardShardingStrategyConfiguration)
- || databaseJoinConditionTables.containsAll(tableNames);
- boolean containsTableShardingColumns =
!(tableRule.getTableShardingStrategyConfig() instanceof
StandardShardingStrategyConfiguration) ||
tableJoinConditionTables.containsAll(tableNames);
- return containsDatabaseShardingColumns && containsTableShardingColumns;
- }
-
- private static Collection<String> getJoinConditionTables(final
ShardingRule shardingRule, final ShardingSphereSchema schema, final
SelectStatementContext select,
- final
Collection<ExpressionSegment> predicates, final boolean
isDatabaseJoinCondition) {
- Collection<String> result = new LinkedList<>();
- for (ExpressionSegment expression : predicates) {
- if (!isJoinTableConditionExpression(expression)) {
- continue;
- }
- ColumnProjection leftColumn =
buildColumnProjection((ColumnSegment) ((BinaryOperationExpression)
expression).getLeft());
- ColumnProjection rightColumn =
buildColumnProjection((ColumnSegment) ((BinaryOperationExpression)
expression).getRight());
- Map<String, String> columnTableNames =
select.getTablesContext().findTableName(Arrays.asList(leftColumn, rightColumn),
schema);
- Optional<TableRule> leftTableRule =
shardingRule.findTableRule(columnTableNames.get(leftColumn.getExpression()));
- Optional<TableRule> rightTableRule =
shardingRule.findTableRule(columnTableNames.get(rightColumn.getExpression()));
- if (!leftTableRule.isPresent() || !rightTableRule.isPresent()) {
- continue;
- }
- ShardingStrategyConfiguration leftConfiguration =
isDatabaseJoinCondition
- ?
shardingRule.getDatabaseShardingStrategyConfiguration(leftTableRule.get()) :
shardingRule.getTableShardingStrategyConfiguration(leftTableRule.get());
- ShardingStrategyConfiguration rightConfiguration =
isDatabaseJoinCondition
- ?
shardingRule.getDatabaseShardingStrategyConfiguration(rightTableRule.get()) :
shardingRule.getTableShardingStrategyConfiguration(rightTableRule.get());
- if (shardingRule.isShardingColumn(leftConfiguration,
leftColumn.getName()) && shardingRule.isShardingColumn(rightConfiguration,
rightColumn.getName())) {
- result.add(columnTableNames.get(leftColumn.getExpression()));
- result.add(columnTableNames.get(rightColumn.getExpression()));
- }
- }
- return result;
- }
-
- private static ColumnProjection buildColumnProjection(final ColumnSegment
segment) {
- String owner = segment.getOwner().map(optional ->
optional.getIdentifier().getValue()).orElse(null);
- return new ColumnProjection(owner, segment.getIdentifier().getValue(),
null);
- }
-
- private static boolean isJoinTableConditionExpression(final
ExpressionSegment expression) {
- if (!(expression instanceof BinaryOperationExpression)) {
- return false;
- }
- BinaryOperationExpression binaryExpression =
(BinaryOperationExpression) expression;
- return binaryExpression.getLeft() instanceof ColumnSegment &&
binaryExpression.getRight() instanceof ColumnSegment &&
EQUAL.equals(binaryExpression.getOperator());
- }
}
diff --git
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
index ff4b9b9..700b065 100644
---
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
+++
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/main/java/org/apache/shardingsphere/sharding/rule/ShardingRule.java
@@ -20,9 +20,13 @@ package org.apache.shardingsphere.sharding.rule;
import com.google.common.base.Preconditions;
import com.google.common.base.Splitter;
import lombok.Getter;
+import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import
org.apache.shardingsphere.infra.config.algorithm.ShardingSphereAlgorithmFactory;
import
org.apache.shardingsphere.infra.config.exception.ShardingSphereConfigurationException;
import org.apache.shardingsphere.infra.datanode.DataNode;
+import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rule.identifier.scope.SchemaRule;
import
org.apache.shardingsphere.infra.rule.identifier.type.DataNodeContainedRule;
import org.apache.shardingsphere.infra.rule.identifier.type.TableContainedRule;
@@ -41,7 +45,15 @@ import
org.apache.shardingsphere.sharding.spi.ShardingAlgorithm;
import org.apache.shardingsphere.sharding.support.InlineExpressionParser;
import org.apache.shardingsphere.spi.ShardingSphereServiceLoader;
import org.apache.shardingsphere.spi.required.RequiredSPIRegistry;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtil;
+import org.apache.shardingsphere.sql.parser.sql.common.util.WhereExtractUtil;
+import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
@@ -61,6 +73,8 @@ import java.util.stream.Collectors;
@Getter
public final class ShardingRule implements SchemaRule, DataNodeContainedRule,
TableContainedRule {
+ private static final String EQUAL = "=";
+
static {
ShardingSphereServiceLoader.register(ShardingAlgorithm.class);
ShardingSphereServiceLoader.register(KeyGenerateAlgorithm.class);
@@ -244,10 +258,10 @@ public final class ShardingRule implements SchemaRule,
DataNodeContainedRule, Ta
}
/**
- * Judge whether logic table is all binding encryptors or not.
+ * Judge whether logic table is all binding tables or not.
*
* @param logicTableNames logic table names
- * @return whether logic table is all binding encryptors or not
+ * @return whether logic table is all binding tables or not
*/
public boolean isAllBindingTables(final Collection<String>
logicTableNames) {
if (logicTableNames.isEmpty()) {
@@ -262,6 +276,21 @@ public final class ShardingRule implements SchemaRule,
DataNodeContainedRule, Ta
return !result.isEmpty() && result.containsAll(logicTableNames);
}
+ /**
+ * Judge whether logic table is all binding tables.
+ *
+ * @param schema schema
+ * @param sqlStatementContext sqlStatementContext
+ * @param logicTableNames logic table names
+ * @return whether logic table is all binding tables.
+ */
+ public boolean isAllBindingTables(final ShardingSphereSchema schema, final
SQLStatementContext<?> sqlStatementContext, final Collection<String>
logicTableNames) {
+ if (!(sqlStatementContext instanceof SelectStatementContext &&
((SelectStatementContext) sqlStatementContext).isContainsJoinQuery())) {
+ return isAllBindingTables(logicTableNames);
+ }
+ return isAllBindingTables(logicTableNames) &&
isJoinConditionContainsShardingColumns(schema, (SelectStatementContext)
sqlStatementContext, logicTableNames);
+ }
+
private Optional<BindingTableRule> findBindingTableRule(final
Collection<String> logicTableNames) {
return
logicTableNames.stream().map(this::findBindingTableRule).filter(Optional::isPresent).findFirst().orElse(Optional.empty());
}
@@ -282,10 +311,10 @@ public final class ShardingRule implements SchemaRule,
DataNodeContainedRule, Ta
}
/**
- * Judge whether logic table is all broadcast encryptors or not.
+ * Judge whether logic table is all broadcast tables or not.
*
* @param logicTableNames logic table names
- * @return whether logic table is all broadcast encryptors or not
+ * @return whether logic table is all broadcast tables or not
*/
public boolean isAllBroadcastTables(final Collection<String>
logicTableNames) {
return !logicTableNames.isEmpty() &&
broadcastTables.containsAll(logicTableNames);
@@ -366,14 +395,7 @@ public final class ShardingRule implements SchemaRule,
DataNodeContainedRule, Ta
return
isShardingColumn(getDatabaseShardingStrategyConfiguration(tableRule),
columnName) ||
isShardingColumn(getTableShardingStrategyConfiguration(tableRule), columnName);
}
- /**
- * Judge whether given logic table column is sharding column or not.
- *
- * @param shardingStrategyConfig sharding strategy config
- * @param columnName column name
- * @return whether given logic table column is sharding column or not
- */
- public boolean isShardingColumn(final ShardingStrategyConfiguration
shardingStrategyConfig, final String columnName) {
+ private boolean isShardingColumn(final ShardingStrategyConfiguration
shardingStrategyConfig, final String columnName) {
if (shardingStrategyConfig instanceof
StandardShardingStrategyConfiguration) {
String shardingColumn = null ==
((StandardShardingStrategyConfiguration)
shardingStrategyConfig).getShardingColumn()
? defaultShardingColumn :
((StandardShardingStrategyConfiguration)
shardingStrategyConfig).getShardingColumn();
@@ -531,4 +553,64 @@ public final class ShardingRule implements SchemaRule,
DataNodeContainedRule, Ta
public String getType() {
return ShardingRule.class.getSimpleName();
}
+
+ private boolean isJoinConditionContainsShardingColumns(final
ShardingSphereSchema schema, final SelectStatementContext select, final
Collection<String> tableNames) {
+ Collection<String> databaseJoinConditionTables = new
HashSet<>(tableNames.size());
+ Collection<String> tableJoinConditionTables = new
HashSet<>(tableNames.size());
+ for (WhereSegment each :
WhereExtractUtil.getJoinWhereSegments(select.getSqlStatement())) {
+ Collection<AndPredicate> andPredicates =
ExpressionExtractUtil.getAndPredicates(each.getExpr());
+ if (andPredicates.size() > 1) {
+ return false;
+ }
+ for (AndPredicate andPredicate : andPredicates) {
+
databaseJoinConditionTables.addAll(getJoinConditionTables(schema, select,
andPredicate.getPredicates(), true));
+ tableJoinConditionTables.addAll(getJoinConditionTables(schema,
select, andPredicate.getPredicates(), false));
+ }
+ }
+ TableRule tableRule = getTableRule(tableNames.iterator().next());
+ boolean containsDatabaseShardingColumns =
!(tableRule.getDatabaseShardingStrategyConfig() instanceof
StandardShardingStrategyConfiguration)
+ || databaseJoinConditionTables.containsAll(tableNames);
+ boolean containsTableShardingColumns =
!(tableRule.getTableShardingStrategyConfig() instanceof
StandardShardingStrategyConfiguration) ||
tableJoinConditionTables.containsAll(tableNames);
+ return containsDatabaseShardingColumns && containsTableShardingColumns;
+ }
+
+ private Collection<String> getJoinConditionTables(final
ShardingSphereSchema schema, final SelectStatementContext select,
+ final
Collection<ExpressionSegment> predicates, final boolean
isDatabaseJoinCondition) {
+ Collection<String> result = new LinkedList<>();
+ for (ExpressionSegment each : predicates) {
+ if (!isJoinConditionExpression(each)) {
+ continue;
+ }
+ ColumnProjection leftColumn =
buildColumnProjection((ColumnSegment) ((BinaryOperationExpression)
each).getLeft());
+ ColumnProjection rightColumn =
buildColumnProjection((ColumnSegment) ((BinaryOperationExpression)
each).getRight());
+ Map<String, String> columnTableNames =
select.getTablesContext().findTableName(Arrays.asList(leftColumn, rightColumn),
schema);
+ Optional<TableRule> leftTableRule =
findTableRule(columnTableNames.get(leftColumn.getExpression()));
+ Optional<TableRule> rightTableRule =
findTableRule(columnTableNames.get(rightColumn.getExpression()));
+ if (!leftTableRule.isPresent() || !rightTableRule.isPresent()) {
+ continue;
+ }
+ ShardingStrategyConfiguration leftConfiguration =
isDatabaseJoinCondition
+ ?
getDatabaseShardingStrategyConfiguration(leftTableRule.get()) :
getTableShardingStrategyConfiguration(leftTableRule.get());
+ ShardingStrategyConfiguration rightConfiguration =
isDatabaseJoinCondition
+ ?
getDatabaseShardingStrategyConfiguration(rightTableRule.get()) :
getTableShardingStrategyConfiguration(rightTableRule.get());
+ if (isShardingColumn(leftConfiguration, leftColumn.getName()) &&
isShardingColumn(rightConfiguration, rightColumn.getName())) {
+ result.add(columnTableNames.get(leftColumn.getExpression()));
+ result.add(columnTableNames.get(rightColumn.getExpression()));
+ }
+ }
+ return result;
+ }
+
+ private ColumnProjection buildColumnProjection(final ColumnSegment
segment) {
+ String owner = segment.getOwner().map(optional ->
optional.getIdentifier().getValue()).orElse(null);
+ return new ColumnProjection(owner, segment.getIdentifier().getValue(),
null);
+ }
+
+ private boolean isJoinConditionExpression(final ExpressionSegment
expression) {
+ if (!(expression instanceof BinaryOperationExpression)) {
+ return false;
+ }
+ BinaryOperationExpression binaryExpression =
(BinaryOperationExpression) expression;
+ return binaryExpression.getLeft() instanceof ColumnSegment &&
binaryExpression.getRight() instanceof ColumnSegment &&
EQUAL.equals(binaryExpression.getOperator());
+ }
}
diff --git
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java
index 1068bb6..7d12bcb 100644
---
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java
+++
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java
@@ -17,9 +17,14 @@
package org.apache.shardingsphere.sharding.rule;
+import
org.apache.shardingsphere.infra.binder.segment.select.projection.impl.ColumnProjection;
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
+import
org.apache.shardingsphere.infra.binder.statement.dml.UpdateStatementContext;
import
org.apache.shardingsphere.infra.config.algorithm.ShardingSphereAlgorithmConfiguration;
import
org.apache.shardingsphere.infra.config.exception.ShardingSphereConfigurationException;
import org.apache.shardingsphere.infra.datanode.DataNode;
+import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
import
org.apache.shardingsphere.sharding.algorithm.keygen.SnowflakeKeyGenerateAlgorithm;
import
org.apache.shardingsphere.sharding.algorithm.keygen.fixture.IncrementKeyGenerateAlgorithm;
import
org.apache.shardingsphere.sharding.algorithm.sharding.inline.InlineShardingAlgorithm;
@@ -30,12 +35,21 @@ import
org.apache.shardingsphere.sharding.api.config.strategy.keygen.KeyGenerate
import
org.apache.shardingsphere.sharding.api.config.strategy.sharding.ComplexShardingStrategyConfiguration;
import
org.apache.shardingsphere.sharding.api.config.strategy.sharding.NoneShardingStrategyConfiguration;
import
org.apache.shardingsphere.sharding.api.config.strategy.sharding.StandardShardingStrategyConfiguration;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.JoinTableSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
+import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.junit.Test;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.LinkedHashSet;
+import java.util.Map;
import java.util.Properties;
import java.util.TreeSet;
@@ -45,9 +59,16 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
public final class ShardingRuleTest {
+ private static final String EQUAL = "=";
+
+ private static final String AND = "AND";
+
@Test
public void assertNewShardingRuleWithMaximumConfiguration() {
ShardingRule actual = createMaximumShardingRule();
@@ -368,7 +389,10 @@ public final class ShardingRuleTest {
}
private ShardingTableRuleConfiguration createTableRuleConfiguration(final
String logicTableName, final String actualDataNodes) {
- return new ShardingTableRuleConfiguration(logicTableName,
actualDataNodes);
+ ShardingTableRuleConfiguration result = new
ShardingTableRuleConfiguration(logicTableName, actualDataNodes);
+ result.setDatabaseShardingStrategy(new
StandardShardingStrategyConfiguration("user_id", "database_inline"));
+ result.setTableShardingStrategy(new
StandardShardingStrategyConfiguration("order_id", "table_inline"));
+ return result;
}
private Collection<String> createDataSourceNames() {
@@ -407,4 +431,95 @@ public final class ShardingRuleTest {
result.setTableShardingStrategy(new
NoneShardingStrategyConfiguration());
return result;
}
+
+ @Test
+ public void assertIsAllBindingTableWithUpdateStatementContext() {
+ SQLStatementContext<?> sqlStatementContext =
mock(UpdateStatementContext.class);
+
assertTrue(createMaximumShardingRule().isAllBindingTables(mock(ShardingSphereSchema.class),
sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+ }
+
+ @Test
+ public void assertIsAllBindingTableWithoutJoinQuery() {
+ SelectStatementContext sqlStatementContext =
mock(SelectStatementContext.class);
+ when(sqlStatementContext.isContainsJoinQuery()).thenReturn(false);
+
assertTrue(createMaximumShardingRule().isAllBindingTables(mock(ShardingSphereSchema.class),
sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+ }
+
+ @Test
+ public void assertIsAllBindingTableWithJoinQueryWithoutJoinCondition() {
+ SelectStatementContext sqlStatementContext =
mock(SelectStatementContext.class);
+ when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true);
+
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(MySQLSelectStatement.class));
+
assertFalse(createMaximumShardingRule().isAllBindingTables(mock(ShardingSphereSchema.class),
sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+ }
+
+ @Test
+ public void
assertIsAllBindingTableWithJoinQueryWithDatabaseJoinCondition() {
+ ColumnSegment leftDatabaseJoin = createColumnSegment("user_id",
"logic_Table");
+ ColumnSegment rightDatabaseJoin = createColumnSegment("user_id",
"sub_Logic_Table");
+ BinaryOperationExpression condition =
createBinaryOperationExpression(leftDatabaseJoin, rightDatabaseJoin, EQUAL);
+ JoinTableSegment joinTable = mock(JoinTableSegment.class);
+ when(joinTable.getCondition()).thenReturn(condition);
+ MySQLSelectStatement selectStatement =
mock(MySQLSelectStatement.class);
+ when(selectStatement.getFrom()).thenReturn(joinTable);
+ SelectStatementContext sqlStatementContext =
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
+ when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true);
+ ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
+
when(sqlStatementContext.getTablesContext().findTableName(Arrays.asList(buildColumnProjection(leftDatabaseJoin),
+ buildColumnProjection(rightDatabaseJoin)),
schema)).thenReturn(createColumnTableNameMap());
+ assertFalse(createMaximumShardingRule().isAllBindingTables(schema,
sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+ }
+
+ @Test
+ public void
assertIsAllBindingTableWithJoinQueryWithDatabaseTableJoinCondition() {
+ ColumnSegment leftDatabaseJoin = createColumnSegment("user_id",
"logic_Table");
+ ColumnSegment rightDatabaseJoin = createColumnSegment("user_id",
"sub_Logic_Table");
+ BinaryOperationExpression databaseJoin =
createBinaryOperationExpression(leftDatabaseJoin, rightDatabaseJoin, EQUAL);
+ ColumnSegment leftTableJoin = createColumnSegment("order_id",
"logic_Table");
+ ColumnSegment rightTableJoin = createColumnSegment("order_id",
"sub_Logic_Table");
+ BinaryOperationExpression tableJoin =
createBinaryOperationExpression(leftTableJoin, rightTableJoin, EQUAL);
+ JoinTableSegment joinTable = mock(JoinTableSegment.class);
+ BinaryOperationExpression condition =
createBinaryOperationExpression(databaseJoin, tableJoin, AND);
+ when(joinTable.getCondition()).thenReturn(condition);
+ MySQLSelectStatement selectStatement =
mock(MySQLSelectStatement.class);
+ when(selectStatement.getFrom()).thenReturn(joinTable);
+ SelectStatementContext sqlStatementContext =
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
+ when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true);
+ ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
+
when(sqlStatementContext.getTablesContext().findTableName(Arrays.asList(buildColumnProjection(leftDatabaseJoin),
+ buildColumnProjection(rightDatabaseJoin)),
schema)).thenReturn(createColumnTableNameMap());
+
when(sqlStatementContext.getTablesContext().findTableName(Arrays.asList(buildColumnProjection(leftTableJoin),
+ buildColumnProjection(rightTableJoin)),
schema)).thenReturn(createColumnTableNameMap());
+ assertTrue(createMaximumShardingRule().isAllBindingTables(schema,
sqlStatementContext, Arrays.asList("logic_Table", "sub_Logic_Table")));
+ }
+
+ private BinaryOperationExpression createBinaryOperationExpression(final
ExpressionSegment left, final ExpressionSegment right, final String operator) {
+ BinaryOperationExpression result =
mock(BinaryOperationExpression.class);
+ when(result.getLeft()).thenReturn(left);
+ when(result.getRight()).thenReturn(right);
+ when(result.getOperator()).thenReturn(operator);
+ return result;
+ }
+
+ private ColumnSegment createColumnSegment(final String columnName, final
String owner) {
+ ColumnSegment result = new ColumnSegment(0, 0, new
IdentifierValue(columnName));
+ result.setOwner(new OwnerSegment(0, 0, new IdentifierValue(owner)));
+ return result;
+ }
+
+ private Map<String, String> createColumnTableNameMap() {
+ Map<String, String> result = new HashMap<>();
+ result.put("logic_Table.user_id", "logic_Table");
+ result.put("sub_Logic_Table.user_id", "sub_Logic_Table");
+ result.put("logic_Table.order_id", "logic_Table");
+ result.put("sub_Logic_Table.order_id", "sub_Logic_Table");
+ return result;
+ }
+
+ private ColumnProjection buildColumnProjection(final ColumnSegment
segment) {
+ String owner = segment.getOwner().map(optional ->
optional.getIdentifier().getValue()).orElse(null);
+ return new ColumnProjection(owner, segment.getIdentifier().getValue(),
null);
+ }
}
diff --git
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
index f0b0d22..7995f25 100644
---
a/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
+++
b/shardingsphere-test/shardingsphere-rewrite-test/src/test/resources/scenario/sharding/case/select.xml
@@ -504,7 +504,6 @@
<rewrite-assertion id="select_multi_nested_subquery_with_binding_tables"
db-types="MySQL,PostgreSQL,openGauss,SQLServer,SQL92">
<input sql="SELECT * FROM (SELECT id, content FROM t_user WHERE id = ?
AND content IN (SELECT content FROM t_user_extend WHERE user_id = ?)) AS temp"
parameters="1, 1"/>
- <output sql="SELECT * FROM (SELECT id, content FROM t_user_0 WHERE id
= ? AND content IN (SELECT content FROM t_user_extend_0 WHERE user_id = ?)) AS
temp" parameters="1, 1"/>
<output sql="SELECT * FROM (SELECT id, content FROM t_user_1 WHERE id
= ? AND content IN (SELECT content FROM t_user_extend_1 WHERE user_id = ?)) AS
temp" parameters="1, 1"/>
</rewrite-assertion>