This is an automated email from the ASF dual-hosted git repository.

tuichenchuxin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new eb64e0bd176 Fix wrong rewrite result when part of logical table name 
of the binding table is consistent with the actual table name, and some are 
inconsistent (#22336)
eb64e0bd176 is described below

commit eb64e0bd176d69920b0ee95e11e473a63825efd0
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Tue Nov 22 15:43:06 2022 +0800

    Fix wrong rewrite result when part of logical table name of the binding 
table is consistent with the actual table name, and some are inconsistent 
(#22336)
    
    * Fix wrong rewrite result when part of logical table name of the binding 
table is consistent with the actual table name, and some are inconsistent
    
    * fix unit test
---
 .../token/generator/impl/TableTokenGenerator.java  |  7 ++-
 .../ShardingSQLRewriteContextDecoratorTest.java    |  3 ++
 .../rewrite/token/TableTokenGeneratorTest.java     | 61 ++++++++++++++++++----
 3 files changed, 59 insertions(+), 12 deletions(-)

diff --git 
a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/TableTokenGenerator.java
 
b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/TableTokenGenerator.java
index 234ad4194c9..99cd969b696 100644
--- 
a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/TableTokenGenerator.java
+++ 
b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/TableTokenGenerator.java
@@ -45,7 +45,12 @@ public final class TableTokenGenerator implements 
CollectionSQLTokenGenerator<SQ
     
     @Override
     public boolean isGenerateSQLToken(final SQLStatementContext<?> 
sqlStatementContext) {
-        return routeContext.containsTableSharding();
+        return isAllBindingTables(sqlStatementContext) || 
routeContext.containsTableSharding();
+    }
+    
+    private boolean isAllBindingTables(final SQLStatementContext<?> 
sqlStatementContext) {
+        Collection<String> shardingLogicTableNames = 
shardingRule.getShardingLogicTableNames(sqlStatementContext.getTablesContext().getTableNames());
+        return shardingLogicTableNames.size() > 1 && 
shardingRule.isAllBindingTables(shardingLogicTableNames);
     }
     
     @Override
diff --git 
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
 
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
index 509f7a33eb7..d9db3ca74c0 100644
--- 
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
+++ 
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.shardingsphere.sharding.rewrite.context;
 
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
 import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
 import org.apache.shardingsphere.infra.route.context.RouteContext;
@@ -26,6 +27,7 @@ import org.junit.Test;
 import java.util.Collections;
 
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -35,6 +37,7 @@ public final class ShardingSQLRewriteContextDecoratorTest {
     public void assertDecorate() {
         SQLRewriteContext sqlRewriteContext = mock(SQLRewriteContext.class);
         
when(sqlRewriteContext.getParameters()).thenReturn(Collections.singletonList(new
 Object()));
+        
when(sqlRewriteContext.getSqlStatementContext()).thenReturn(mock(SQLStatementContext.class,
 RETURNS_DEEP_STUBS));
         new 
ShardingSQLRewriteContextDecorator().decorate(mock(ShardingRule.class), 
mock(ConfigurationProperties.class), sqlRewriteContext, 
mock(RouteContext.class));
         assertTrue(sqlRewriteContext.getSqlTokens().isEmpty());
     }
diff --git 
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/TableTokenGeneratorTest.java
 
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/TableTokenGeneratorTest.java
index ee71b9b95f5..44573b7b354 100644
--- 
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/TableTokenGeneratorTest.java
+++ 
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/TableTokenGeneratorTest.java
@@ -17,40 +17,79 @@
 
 package org.apache.shardingsphere.sharding.rewrite.token;
 
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import 
org.apache.shardingsphere.infra.binder.statement.ddl.CreateDatabaseStatementContext;
 import 
org.apache.shardingsphere.infra.binder.statement.ddl.CreateTableStatementContext;
+import 
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
+import org.apache.shardingsphere.infra.route.context.RouteContext;
 import 
org.apache.shardingsphere.sharding.rewrite.token.generator.impl.TableTokenGenerator;
+import org.apache.shardingsphere.sharding.rewrite.token.pojo.TableToken;
 import org.apache.shardingsphere.sharding.rule.ShardingRule;
 import org.apache.shardingsphere.sharding.rule.TableRule;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.ddl.CreateDatabaseStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
 import org.junit.Test;
 
-import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Optional;
 
+import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
 public final class TableTokenGeneratorTest {
     
     @Test
-    public void assertGenerateSQLToken() {
+    public void assertIsGenerateSQLTokenWhenConfigAllBindingTables() {
+        TableTokenGenerator generator = new TableTokenGenerator();
+        ShardingRule shardingRule = mock(ShardingRule.class);
+        Collection<String> logicTableNames = Arrays.asList("t_order", 
"t_order_item");
+        
when(shardingRule.getShardingLogicTableNames(logicTableNames)).thenReturn(logicTableNames);
+        
when(shardingRule.isAllBindingTables(logicTableNames)).thenReturn(true);
+        generator.setShardingRule(shardingRule);
+        SQLStatementContext<SelectStatement> sqlStatementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(logicTableNames);
+        assertTrue(generator.isGenerateSQLToken(sqlStatementContext));
+    }
+    
+    @Test
+    public void assertIsGenerateSQLTokenWhenContainsTableSharding() {
+        TableTokenGenerator generator = new TableTokenGenerator();
+        RouteContext routeContext = mock(RouteContext.class);
+        when(routeContext.containsTableSharding()).thenReturn(true);
+        generator.setShardingRule(mock(ShardingRule.class));
+        generator.setRouteContext(routeContext);
+        SQLStatementContext<SelectStatement> sqlStatementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        assertTrue(generator.isGenerateSQLToken(sqlStatementContext));
+    }
+    
+    @Test
+    public void assertGenerateSQLTokenWhenSQLStatementIsTableAvailable() {
         ShardingRule shardingRule = mock(ShardingRule.class);
         
when(shardingRule.findTableRule(anyString())).thenReturn(Optional.of(mock(TableRule.class)));
-        TableTokenGenerator tableTokenGenerator = new TableTokenGenerator();
-        tableTokenGenerator.setShardingRule(shardingRule);
-        CreateDatabaseStatementContext createDatabaseStatementContext = 
mock(CreateDatabaseStatementContext.class);
-        
assertThat(tableTokenGenerator.generateSQLTokens(createDatabaseStatementContext),
 is(Collections.emptyList()));
-        int testStartIndex = 3;
-        TableNameSegment tableNameSegment = new 
TableNameSegment(testStartIndex, 8, new IdentifierValue("test"));
-        CreateTableStatementContext createTableStatementContext = 
mock(CreateTableStatementContext.class);
-        
when(createTableStatementContext.getAllTables()).thenReturn(Collections.singleton(new
 SimpleTableSegment(tableNameSegment)));
-        assertThat((new 
ArrayList<>(tableTokenGenerator.generateSQLTokens(createTableStatementContext))).get(0).getStartIndex(),
 is(testStartIndex));
+        TableTokenGenerator generator = new TableTokenGenerator();
+        generator.setShardingRule(shardingRule);
+        CreateTableStatementContext sqlStatementContext = 
mock(CreateTableStatementContext.class);
+        
when(sqlStatementContext.getAllTables()).thenReturn(Collections.singletonList(new
 SimpleTableSegment(new TableNameSegment(0, 0, new 
IdentifierValue("t_order")))));
+        Collection<TableToken> actual = 
generator.generateSQLTokens(sqlStatementContext);
+        assertThat(actual.size(), is(1));
+        assertThat(actual.iterator().next(), instanceOf(TableToken.class));
+    }
+    
+    @Test
+    public void assertGenerateSQLTokenWhenSQLStatementIsNotTableAvailable() {
+        TableTokenGenerator generator = new TableTokenGenerator();
+        SQLStatementContext<CreateDatabaseStatement> sqlStatementContext = 
mock(CreateDatabaseStatementContext.class);
+        assertThat(generator.generateSQLTokens(sqlStatementContext), 
is(Collections.emptyList()));
     }
 }

Reply via email to