This is an automated email from the ASF dual-hosted git repository.
tuichenchuxin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new eb64e0bd176 Fix wrong rewrite result when part of logical table name
of the binding table is consistent with the actual table name, and some are
inconsistent (#22336)
eb64e0bd176 is described below
commit eb64e0bd176d69920b0ee95e11e473a63825efd0
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Tue Nov 22 15:43:06 2022 +0800
Fix wrong rewrite result when part of logical table name of the binding
table is consistent with the actual table name, and some are inconsistent
(#22336)
* Fix wrong rewrite result when part of logical table name of the binding
table is consistent with the actual table name, and some are inconsistent
* fix unit test
---
.../token/generator/impl/TableTokenGenerator.java | 7 ++-
.../ShardingSQLRewriteContextDecoratorTest.java | 3 ++
.../rewrite/token/TableTokenGeneratorTest.java | 61 ++++++++++++++++++----
3 files changed, 59 insertions(+), 12 deletions(-)
diff --git
a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/TableTokenGenerator.java
b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/TableTokenGenerator.java
index 234ad4194c9..99cd969b696 100644
---
a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/TableTokenGenerator.java
+++
b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/TableTokenGenerator.java
@@ -45,7 +45,12 @@ public final class TableTokenGenerator implements
CollectionSQLTokenGenerator<SQ
@Override
public boolean isGenerateSQLToken(final SQLStatementContext<?>
sqlStatementContext) {
- return routeContext.containsTableSharding();
+ return isAllBindingTables(sqlStatementContext) ||
routeContext.containsTableSharding();
+ }
+
+ private boolean isAllBindingTables(final SQLStatementContext<?>
sqlStatementContext) {
+ Collection<String> shardingLogicTableNames =
shardingRule.getShardingLogicTableNames(sqlStatementContext.getTablesContext().getTableNames());
+ return shardingLogicTableNames.size() > 1 &&
shardingRule.isAllBindingTables(shardingLogicTableNames);
}
@Override
diff --git
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
index 509f7a33eb7..d9db3ca74c0 100644
---
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
+++
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
@@ -17,6 +17,7 @@
package org.apache.shardingsphere.sharding.rewrite.context;
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.route.context.RouteContext;
@@ -26,6 +27,7 @@ import org.junit.Test;
import java.util.Collections;
import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -35,6 +37,7 @@ public final class ShardingSQLRewriteContextDecoratorTest {
public void assertDecorate() {
SQLRewriteContext sqlRewriteContext = mock(SQLRewriteContext.class);
when(sqlRewriteContext.getParameters()).thenReturn(Collections.singletonList(new
Object()));
+
when(sqlRewriteContext.getSqlStatementContext()).thenReturn(mock(SQLStatementContext.class,
RETURNS_DEEP_STUBS));
new
ShardingSQLRewriteContextDecorator().decorate(mock(ShardingRule.class),
mock(ConfigurationProperties.class), sqlRewriteContext,
mock(RouteContext.class));
assertTrue(sqlRewriteContext.getSqlTokens().isEmpty());
}
diff --git
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/TableTokenGeneratorTest.java
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/TableTokenGeneratorTest.java
index ee71b9b95f5..44573b7b354 100644
---
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/TableTokenGeneratorTest.java
+++
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/TableTokenGeneratorTest.java
@@ -17,40 +17,79 @@
package org.apache.shardingsphere.sharding.rewrite.token;
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import
org.apache.shardingsphere.infra.binder.statement.ddl.CreateDatabaseStatementContext;
import
org.apache.shardingsphere.infra.binder.statement.ddl.CreateTableStatementContext;
+import
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
+import org.apache.shardingsphere.infra.route.context.RouteContext;
import
org.apache.shardingsphere.sharding.rewrite.token.generator.impl.TableTokenGenerator;
+import org.apache.shardingsphere.sharding.rewrite.token.pojo.TableToken;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sharding.rule.TableRule;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.statement.ddl.CreateDatabaseStatement;
+import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
import org.junit.Test;
-import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
import java.util.Collections;
import java.util.Optional;
+import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public final class TableTokenGeneratorTest {
@Test
- public void assertGenerateSQLToken() {
+ public void assertIsGenerateSQLTokenWhenConfigAllBindingTables() {
+ TableTokenGenerator generator = new TableTokenGenerator();
+ ShardingRule shardingRule = mock(ShardingRule.class);
+ Collection<String> logicTableNames = Arrays.asList("t_order",
"t_order_item");
+
when(shardingRule.getShardingLogicTableNames(logicTableNames)).thenReturn(logicTableNames);
+
when(shardingRule.isAllBindingTables(logicTableNames)).thenReturn(true);
+ generator.setShardingRule(shardingRule);
+ SQLStatementContext<SelectStatement> sqlStatementContext =
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(logicTableNames);
+ assertTrue(generator.isGenerateSQLToken(sqlStatementContext));
+ }
+
+ @Test
+ public void assertIsGenerateSQLTokenWhenContainsTableSharding() {
+ TableTokenGenerator generator = new TableTokenGenerator();
+ RouteContext routeContext = mock(RouteContext.class);
+ when(routeContext.containsTableSharding()).thenReturn(true);
+ generator.setShardingRule(mock(ShardingRule.class));
+ generator.setRouteContext(routeContext);
+ SQLStatementContext<SelectStatement> sqlStatementContext =
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+ assertTrue(generator.isGenerateSQLToken(sqlStatementContext));
+ }
+
+ @Test
+ public void assertGenerateSQLTokenWhenSQLStatementIsTableAvailable() {
ShardingRule shardingRule = mock(ShardingRule.class);
when(shardingRule.findTableRule(anyString())).thenReturn(Optional.of(mock(TableRule.class)));
- TableTokenGenerator tableTokenGenerator = new TableTokenGenerator();
- tableTokenGenerator.setShardingRule(shardingRule);
- CreateDatabaseStatementContext createDatabaseStatementContext =
mock(CreateDatabaseStatementContext.class);
-
assertThat(tableTokenGenerator.generateSQLTokens(createDatabaseStatementContext),
is(Collections.emptyList()));
- int testStartIndex = 3;
- TableNameSegment tableNameSegment = new
TableNameSegment(testStartIndex, 8, new IdentifierValue("test"));
- CreateTableStatementContext createTableStatementContext =
mock(CreateTableStatementContext.class);
-
when(createTableStatementContext.getAllTables()).thenReturn(Collections.singleton(new
SimpleTableSegment(tableNameSegment)));
- assertThat((new
ArrayList<>(tableTokenGenerator.generateSQLTokens(createTableStatementContext))).get(0).getStartIndex(),
is(testStartIndex));
+ TableTokenGenerator generator = new TableTokenGenerator();
+ generator.setShardingRule(shardingRule);
+ CreateTableStatementContext sqlStatementContext =
mock(CreateTableStatementContext.class);
+
when(sqlStatementContext.getAllTables()).thenReturn(Collections.singletonList(new
SimpleTableSegment(new TableNameSegment(0, 0, new
IdentifierValue("t_order")))));
+ Collection<TableToken> actual =
generator.generateSQLTokens(sqlStatementContext);
+ assertThat(actual.size(), is(1));
+ assertThat(actual.iterator().next(), instanceOf(TableToken.class));
+ }
+
+ @Test
+ public void assertGenerateSQLTokenWhenSQLStatementIsNotTableAvailable() {
+ TableTokenGenerator generator = new TableTokenGenerator();
+ SQLStatementContext<CreateDatabaseStatement> sqlStatementContext =
mock(CreateDatabaseStatementContext.class);
+ assertThat(generator.generateSQLTokens(sqlStatementContext),
is(Collections.emptyList()));
}
}