This is an automated email from the ASF dual-hosted git repository.

chengzhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ba71b1ff6f Refactor SQLRewriteEntry to remove too much parameters 
(#33462)
2ba71b1ff6f is described below

commit 2ba71b1ff6fc20ad104d2224e1ccc2856afdd575
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Wed Oct 30 11:10:52 2024 +0800

    Refactor SQLRewriteEntry to remove too much parameters (#33462)
---
 .../infra/rewrite/SQLRewriteEntry.java             |  6 +--
 .../infra/rewrite/context/SQLRewriteContext.java   | 17 ++++-----
 .../rewrite/context/SQLRewriteContextTest.java     | 34 ++++++++++++-----
 .../engine/GenericSQLRewriteEngineTest.java        | 23 ++++++-----
 .../rewrite/engine/RouteSQLRewriteEngineTest.java  | 44 +++++++++++-----------
 5 files changed, 70 insertions(+), 54 deletions(-)

diff --git 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
index a445b985165..8d1ee716729 100644
--- 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
+++ 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
@@ -69,16 +69,16 @@ public final class SQLRewriteEntry {
      * @return route unit and SQL rewrite result map
      */
     public SQLRewriteResult rewrite(final QueryContext queryContext, final 
RouteContext routeContext, final ConnectionContext connectionContext) {
-        SQLRewriteContext sqlRewriteContext = 
createSQLRewriteContext(queryContext, routeContext, connectionContext);
+        SQLRewriteContext sqlRewriteContext = 
createSQLRewriteContext(queryContext, routeContext);
         SQLTranslatorRule rule = 
globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
         return routeContext.getRouteUnits().isEmpty()
                 ? new GenericSQLRewriteEngine(rule, database, 
globalRuleMetaData).rewrite(sqlRewriteContext, queryContext)
                 : new RouteSQLRewriteEngine(rule, database, 
globalRuleMetaData).rewrite(sqlRewriteContext, routeContext, queryContext);
     }
     
-    private SQLRewriteContext createSQLRewriteContext(final QueryContext 
queryContext, final RouteContext routeContext, final ConnectionContext 
connectionContext) {
+    private SQLRewriteContext createSQLRewriteContext(final QueryContext 
queryContext, final RouteContext routeContext) {
         HintValueContext hintValueContext = queryContext.getHintValueContext();
-        SQLRewriteContext result = new SQLRewriteContext(database, 
queryContext.getSqlStatementContext(), queryContext.getSql(), 
queryContext.getParameters(), connectionContext, hintValueContext);
+        SQLRewriteContext result = new SQLRewriteContext(database, 
queryContext);
         decorate(result, routeContext, hintValueContext);
         result.generateSQLTokens();
         return result;
diff --git 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
index 6df38915018..f7c7616422d 100644
--- 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
+++ 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
@@ -21,7 +21,6 @@ import lombok.AccessLevel;
 import lombok.Getter;
 import 
org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
 import 
org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
-import org.apache.shardingsphere.infra.hint.HintValueContext;
 import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder;
 import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
@@ -31,6 +30,7 @@ import 
org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.SQLTok
 import 
org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.builder.DefaultTokenGeneratorBuilder;
 import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
 import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
 
 import java.util.Collection;
 import java.util.LinkedList;
@@ -59,19 +59,18 @@ public final class SQLRewriteContext {
     
     private final ConnectionContext connectionContext;
     
-    public SQLRewriteContext(final ShardingSphereDatabase database, final 
SQLStatementContext sqlStatementContext, final String sql, final List<Object> 
params,
-                             final ConnectionContext connectionContext, final 
HintValueContext hintValueContext) {
+    public SQLRewriteContext(final ShardingSphereDatabase database, final 
QueryContext queryContext) {
         this.database = database;
-        this.sqlStatementContext = sqlStatementContext;
-        this.sql = sql;
-        parameters = params;
-        this.connectionContext = connectionContext;
-        if (!hintValueContext.isSkipSQLRewrite()) {
+        sqlStatementContext = queryContext.getSqlStatementContext();
+        sql = queryContext.getSql();
+        parameters = queryContext.getParameters();
+        connectionContext = queryContext.getConnectionContext();
+        if (!queryContext.getHintValueContext().isSkipSQLRewrite()) {
             addSQLTokenGenerators(new 
DefaultTokenGeneratorBuilder(sqlStatementContext).getSQLTokenGenerators());
         }
         parameterBuilder = containsInsertValues(sqlStatementContext)
                 ? new GroupedParameterBuilder(((InsertStatementContext) 
sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) 
sqlStatementContext).getOnDuplicateKeyUpdateParameters())
-                : new StandardParameterBuilder(params);
+                : new StandardParameterBuilder(parameters);
     }
     
     private boolean containsInsertValues(final SQLStatementContext 
sqlStatementContext) {
diff --git 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
index 14e4d6eda3e..2734ef2468c 100644
--- 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
+++ 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
@@ -30,7 +30,7 @@ import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.StandardPa
 import 
org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
 import 
org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.OptionalSQLTokenGenerator;
 import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
@@ -84,8 +84,12 @@ class SQLRewriteContextTest {
         InsertStatementContext statementContext = 
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
         when(statementContext.getInsertSelectContext()).thenReturn(null);
-        SQLRewriteContext sqlRewriteContext =
-                new SQLRewriteContext(database, statementContext, "INSERT INTO 
tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), 
hintValueContext);
+        QueryContext queryContext = mock(QueryContext.class, 
RETURNS_DEEP_STUBS);
+        
when(queryContext.getSqlStatementContext()).thenReturn(statementContext);
+        when(queryContext.getSql()).thenReturn("INSERT INTO tbl VALUES (?)");
+        
when(queryContext.getParameters()).thenReturn(Collections.singletonList(1));
+        when(queryContext.getHintValueContext()).thenReturn(hintValueContext);
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         assertThat(sqlRewriteContext.getParameterBuilder(), 
instanceOf(GroupedParameterBuilder.class));
     }
     
@@ -93,15 +97,23 @@ class SQLRewriteContextTest {
     void assertNotInsertStatementContext() {
         SelectStatementContext statementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
-        SQLRewriteContext sqlRewriteContext =
-                new SQLRewriteContext(database, statementContext, "SELECT * 
FROM tbl WHERE id = ?", Collections.singletonList(1), 
mock(ConnectionContext.class), hintValueContext);
+        QueryContext queryContext = mock(QueryContext.class, 
RETURNS_DEEP_STUBS);
+        
when(queryContext.getSqlStatementContext()).thenReturn(statementContext);
+        when(queryContext.getSql()).thenReturn("SELECT * FROM tbl WHERE id = 
?");
+        
when(queryContext.getParameters()).thenReturn(Collections.singletonList(1));
+        when(queryContext.getHintValueContext()).thenReturn(hintValueContext);
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         assertThat(sqlRewriteContext.getParameterBuilder(), 
instanceOf(StandardParameterBuilder.class));
     }
     
     @Test
     void assertGenerateOptionalSQLToken() {
-        SQLRewriteContext sqlRewriteContext =
-                new SQLRewriteContext(database, sqlStatementContext, "INSERT 
INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), hintValueContext);
+        QueryContext queryContext = mock(QueryContext.class, 
RETURNS_DEEP_STUBS);
+        
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
+        when(queryContext.getSql()).thenReturn("INSERT INTO tbl VALUES (?)");
+        
when(queryContext.getParameters()).thenReturn(Collections.singletonList(1));
+        when(queryContext.getHintValueContext()).thenReturn(hintValueContext);
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         
sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(optionalSQLTokenGenerator));
         sqlRewriteContext.generateSQLTokens();
         assertFalse(sqlRewriteContext.getSqlTokens().isEmpty());
@@ -110,8 +122,12 @@ class SQLRewriteContextTest {
     
     @Test
     void assertGenerateCollectionSQLToken() {
-        SQLRewriteContext sqlRewriteContext =
-                new SQLRewriteContext(database, sqlStatementContext, "INSERT 
INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), hintValueContext);
+        QueryContext queryContext = mock(QueryContext.class, 
RETURNS_DEEP_STUBS);
+        
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
+        when(queryContext.getSql()).thenReturn("INSERT INTO tbl VALUES (?)");
+        
when(queryContext.getParameters()).thenReturn(Collections.singletonList(1));
+        when(queryContext.getHintValueContext()).thenReturn(hintValueContext);
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         
sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(collectionSQLTokenGenerator));
         sqlRewriteContext.generateSQLTokens();
         assertFalse(sqlRewriteContext.getSqlTokens().isEmpty());
diff --git 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
index 4c3f2f437d0..af9d7feae50 100644
--- 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
+++ 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
@@ -27,7 +27,6 @@ import 
org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
 import 
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
 import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
 import 
org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
 import org.apache.shardingsphere.infra.session.query.QueryContext;
 import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
 import 
org.apache.shardingsphere.sqltranslator.rule.builder.DefaultSQLTranslatorRuleConfigurationBuilder;
@@ -54,15 +53,21 @@ class GenericSQLRewriteEngineTest {
         
when(database.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits);
         CommonSQLStatementContext sqlStatementContext = 
mock(CommonSQLStatementContext.class);
         when(sqlStatementContext.getDatabaseType()).thenReturn(databaseType);
-        QueryContext queryContext = mock(QueryContext.class, 
RETURNS_DEEP_STUBS);
-        
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
-        GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, 
database, mock(RuleMetaData.class))
-                .rewrite(new SQLRewriteContext(database, sqlStatementContext, 
"SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
-                        new HintValueContext()), queryContext);
+        QueryContext queryContext = mockQueryContext(sqlStatementContext);
+        GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, 
database, mock(RuleMetaData.class)).rewrite(new SQLRewriteContext(database, 
queryContext), queryContext);
         assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
         assertThat(actual.getSqlRewriteUnit().getParameters(), 
is(Collections.emptyList()));
     }
     
+    private QueryContext mockQueryContext(final CommonSQLStatementContext 
sqlStatementContext) {
+        QueryContext result = mock(QueryContext.class, RETURNS_DEEP_STUBS);
+        when(result.getSqlStatementContext()).thenReturn(sqlStatementContext);
+        when(result.getSql()).thenReturn("SELECT 1");
+        when(result.getParameters()).thenReturn(Collections.emptyList());
+        when(result.getHintValueContext()).thenReturn(new HintValueContext());
+        return result;
+    }
+    
     @Test
     void assertRewriteStorageTypeIsEmpty() {
         SQLTranslatorRule rule = new SQLTranslatorRule(new 
DefaultSQLTranslatorRuleConfigurationBuilder().build());
@@ -73,10 +78,8 @@ class GenericSQLRewriteEngineTest {
         CommonSQLStatementContext sqlStatementContext = 
mock(CommonSQLStatementContext.class);
         DatabaseType databaseType = mock(DatabaseType.class);
         when(sqlStatementContext.getDatabaseType()).thenReturn(databaseType);
-        QueryContext queryContext = mock(QueryContext.class, 
RETURNS_DEEP_STUBS);
-        
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
-        GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, 
database, mock(RuleMetaData.class))
-                .rewrite(new SQLRewriteContext(database, sqlStatementContext, 
"SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), new 
HintValueContext()), queryContext);
+        QueryContext queryContext = mockQueryContext(sqlStatementContext);
+        GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, 
database, mock(RuleMetaData.class)).rewrite(new SQLRewriteContext(database, 
queryContext), queryContext);
         assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
         assertThat(actual.getSqlRewriteUnit().getParameters(), 
is(Collections.emptyList()));
     }
diff --git 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
index 39fed6e18be..07f0121ee2f 100644
--- 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
+++ 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
@@ -34,7 +34,6 @@ import 
org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResu
 import org.apache.shardingsphere.infra.route.context.RouteContext;
 import org.apache.shardingsphere.infra.route.context.RouteMapper;
 import org.apache.shardingsphere.infra.route.context.RouteUnit;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
 import org.apache.shardingsphere.infra.session.query.QueryContext;
 import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
 import 
org.apache.shardingsphere.sqltranslator.rule.builder.DefaultSQLTranslatorRuleConfigurationBuilder;
@@ -59,12 +58,11 @@ class RouteSQLRewriteEngineTest {
         ShardingSphereDatabase database = mockDatabase(databaseType);
         CommonSQLStatementContext sqlStatementContext = 
mock(CommonSQLStatementContext.class);
         when(sqlStatementContext.getDatabaseType()).thenReturn(databaseType);
-        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
sqlStatementContext, "SELECT ?", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
+        QueryContext queryContext = mockQueryContext(sqlStatementContext, 
"SELECT ?");
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
-        QueryContext queryContext = mock(QueryContext.class);
-        
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
         RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
                 new SQLTranslatorRule(new 
DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, 
mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, 
queryContext);
         assertThat(actual.getSqlRewriteUnits().size(), is(1));
@@ -72,6 +70,15 @@ class RouteSQLRewriteEngineTest {
         assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), 
is(Collections.singletonList(1)));
     }
     
+    private QueryContext mockQueryContext(final CommonSQLStatementContext 
sqlStatementContext, final String sql) {
+        QueryContext result = mock(QueryContext.class, RETURNS_DEEP_STUBS);
+        when(result.getSqlStatementContext()).thenReturn(sqlStatementContext);
+        when(result.getSql()).thenReturn(sql);
+        when(result.getParameters()).thenReturn(Collections.singletonList(1));
+        when(result.getHintValueContext()).thenReturn(new HintValueContext());
+        return result;
+    }
+    
     private ShardingSphereDatabase mockDatabase(final DatabaseType 
databaseType) {
         ShardingSphereDatabase result = mock(ShardingSphereDatabase.class, 
RETURNS_DEEP_STUBS);
         when(result.getProtocolType()).thenReturn(databaseType);
@@ -90,14 +97,13 @@ class RouteSQLRewriteEngineTest {
         DatabaseType databaseType = mock(DatabaseType.class);
         when(statementContext.getDatabaseType()).thenReturn(databaseType);
         ShardingSphereDatabase database = mockDatabase(databaseType);
-        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
statementContext, "SELECT ?", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
+        QueryContext queryContext = mockQueryContext(statementContext, "SELECT 
?");
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         RouteContext routeContext = new RouteContext();
         RouteUnit firstRouteUnit = new RouteUnit(new RouteMapper("ds", 
"ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteUnit secondRouteUnit = new RouteUnit(new RouteMapper("ds", 
"ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_1")));
         routeContext.getRouteUnits().add(firstRouteUnit);
         routeContext.getRouteUnits().add(secondRouteUnit);
-        QueryContext queryContext = mock(QueryContext.class);
-        
when(queryContext.getSqlStatementContext()).thenReturn(statementContext);
         RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
                 new SQLTranslatorRule(new 
DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, 
mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, 
queryContext);
         assertThat(actual.getSqlRewriteUnits().size(), is(1));
@@ -115,13 +121,11 @@ class RouteSQLRewriteEngineTest {
         DatabaseType databaseType = mock(DatabaseType.class);
         when(statementContext.getDatabaseType()).thenReturn(databaseType);
         ShardingSphereDatabase database = mockDatabase(databaseType);
-        SQLRewriteContext sqlRewriteContext =
-                new SQLRewriteContext(database, statementContext, "INSERT INTO 
tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), 
new HintValueContext());
+        QueryContext queryContext = mockQueryContext(statementContext, "INSERT 
INTO tbl VALUES (?)");
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
-        QueryContext queryContext = mock(QueryContext.class);
-        
when(queryContext.getSqlStatementContext()).thenReturn(statementContext);
         RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
                 new SQLTranslatorRule(new 
DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, 
mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, 
queryContext);
         assertThat(actual.getSqlRewriteUnits().size(), is(1));
@@ -139,15 +143,13 @@ class RouteSQLRewriteEngineTest {
         DatabaseType databaseType = mock(DatabaseType.class);
         when(statementContext.getDatabaseType()).thenReturn(databaseType);
         ShardingSphereDatabase database = mockDatabase(databaseType);
-        SQLRewriteContext sqlRewriteContext =
-                new SQLRewriteContext(database, statementContext, "INSERT INTO 
tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), 
new HintValueContext());
+        QueryContext queryContext = mockQueryContext(statementContext, "INSERT 
INTO tbl VALUES (?)");
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
         // TODO check why data node is "ds.tbl_0", not "ds_0.tbl_0"
         routeContext.getOriginalDataNodes().add(Collections.singletonList(new 
DataNode("ds.tbl_0")));
-        QueryContext queryContext = mock(QueryContext.class);
-        
when(queryContext.getSqlStatementContext()).thenReturn(statementContext);
         RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
                 new SQLTranslatorRule(new 
DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, 
mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, 
queryContext);
         assertThat(actual.getSqlRewriteUnits().size(), is(1));
@@ -165,14 +167,12 @@ class RouteSQLRewriteEngineTest {
         DatabaseType databaseType = mock(DatabaseType.class);
         when(statementContext.getDatabaseType()).thenReturn(databaseType);
         ShardingSphereDatabase database = mockDatabase(databaseType);
-        SQLRewriteContext sqlRewriteContext =
-                new SQLRewriteContext(database, statementContext, "INSERT INTO 
tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), 
new HintValueContext());
+        QueryContext queryContext = mockQueryContext(statementContext, "INSERT 
INTO tbl VALUES (?)");
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
         routeContext.getOriginalDataNodes().add(Collections.emptyList());
-        QueryContext queryContext = mock(QueryContext.class);
-        
when(queryContext.getSqlStatementContext()).thenReturn(statementContext);
         RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
                 new SQLTranslatorRule(new 
DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, 
mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, 
queryContext);
         assertThat(actual.getSqlRewriteUnits().size(), is(1));
@@ -190,14 +190,12 @@ class RouteSQLRewriteEngineTest {
         DatabaseType databaseType = mock(DatabaseType.class);
         when(statementContext.getDatabaseType()).thenReturn(databaseType);
         ShardingSphereDatabase database = mockDatabase(databaseType);
-        SQLRewriteContext sqlRewriteContext =
-                new SQLRewriteContext(database, statementContext, "INSERT INTO 
tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), 
new HintValueContext());
+        QueryContext queryContext = mockQueryContext(statementContext, "INSERT 
INTO tbl VALUES (?)");
+        SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, 
queryContext);
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
         routeContext.getOriginalDataNodes().add(Collections.singletonList(new 
DataNode("ds_1.tbl_1")));
-        QueryContext queryContext = mock(QueryContext.class);
-        
when(queryContext.getSqlStatementContext()).thenReturn(statementContext);
         RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
                 new SQLTranslatorRule(new 
DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, 
mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, 
queryContext);
         assertThat(actual.getSqlRewriteUnits().size(), is(1));

Reply via email to