This is an automated email from the ASF dual-hosted git repository.

panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new bee8fa9ab0f Refactor SQLRewriteContext global filed for obtain more 
database metadata (#27334)
bee8fa9ab0f is described below

commit bee8fa9ab0f191765e601ef1f7027beea41b4e71
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Thu Jul 20 19:27:40 2023 +0800

    Refactor SQLRewriteContext global filed for obtain more database metadata 
(#27334)
    
    * Refactor SQLRewriteContext global filed for obtain more database metadata
    
    * fix unit test
---
 .../context/EncryptSQLRewriteContextDecorator.java | 12 ++++-----
 .../ShardingSQLRewriteContextDecorator.java        |  4 +--
 .../ShardingSQLRewriteContextDecoratorTest.java    |  2 ++
 .../infra/rewrite/SQLRewriteEntry.java             |  2 +-
 .../infra/rewrite/context/SQLRewriteContext.java   | 16 +++++------
 .../rewrite/context/SQLRewriteContextTest.java     | 22 +++++++++------
 .../engine/GenericSQLRewriteEngineTest.java        | 21 ++++++++++-----
 .../rewrite/engine/RouteSQLRewriteEngineTest.java  | 31 +++++++++++++---------
 8 files changed, 65 insertions(+), 45 deletions(-)

diff --git 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java
 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java
index fea40fadd4c..9589fab6166 100644
--- 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java
+++ 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java
@@ -49,13 +49,13 @@ public final class EncryptSQLRewriteContextDecorator 
implements SQLRewriteContex
             return;
         }
         Collection<EncryptCondition> encryptConditions = 
createEncryptConditions(encryptRule, sqlRewriteContext);
+        String databaseName = sqlRewriteContext.getDatabase().getName();
         if (!sqlRewriteContext.getParameters().isEmpty()) {
-            Collection<ParameterRewriter> parameterRewriters = new 
EncryptParameterRewriterBuilder(encryptRule,
-                    sqlRewriteContext.getDatabaseName(), 
sqlRewriteContext.getSchemas(), sqlStatementContext, 
encryptConditions).getParameterRewriters();
+            Collection<ParameterRewriter> parameterRewriters =
+                    new EncryptParameterRewriterBuilder(encryptRule, 
databaseName, sqlRewriteContext.getDatabase().getSchemas(), 
sqlStatementContext, encryptConditions).getParameterRewriters();
             rewriteParameters(sqlRewriteContext, parameterRewriters);
         }
-        Collection<SQLTokenGenerator> sqlTokenGenerators = new 
EncryptTokenGenerateBuilder(encryptRule,
-                sqlStatementContext, encryptConditions, 
sqlRewriteContext.getDatabaseName()).getSQLTokenGenerators();
+        Collection<SQLTokenGenerator> sqlTokenGenerators = new 
EncryptTokenGenerateBuilder(encryptRule, sqlStatementContext, 
encryptConditions, databaseName).getSQLTokenGenerators();
         sqlRewriteContext.addSQLTokenGenerators(sqlTokenGenerators);
     }
     
@@ -66,8 +66,8 @@ public final class EncryptSQLRewriteContextDecorator 
implements SQLRewriteContex
         }
         Collection<WhereSegment> whereSegments = ((WhereAvailable) 
sqlStatementContext).getWhereSegments();
         Collection<ColumnSegment> columnSegments = ((WhereAvailable) 
sqlStatementContext).getColumnSegments();
-        return new EncryptConditionEngine(encryptRule, 
sqlRewriteContext.getSchemas())
-                .createEncryptConditions(whereSegments, columnSegments, 
sqlStatementContext, sqlRewriteContext.getDatabaseName());
+        return new EncryptConditionEngine(encryptRule, 
sqlRewriteContext.getDatabase().getSchemas()).createEncryptConditions(whereSegments,
 columnSegments, sqlStatementContext,
+                sqlRewriteContext.getDatabase().getName());
     }
     
     private boolean containsEncryptTable(final EncryptRule encryptRule, final 
SQLStatementContext sqlStatementContext) {
diff --git 
a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecorator.java
 
b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecorator.java
index bf006f191f4..ff116672c93 100644
--- 
a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecorator.java
+++ 
b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecorator.java
@@ -39,8 +39,8 @@ public final class ShardingSQLRewriteContextDecorator 
implements SQLRewriteConte
     @Override
     public void decorate(final ShardingRule shardingRule, final 
ConfigurationProperties props, final SQLRewriteContext sqlRewriteContext, final 
RouteContext routeContext) {
         if (!sqlRewriteContext.getParameters().isEmpty()) {
-            Collection<ParameterRewriter> parameterRewriters = new 
ShardingParameterRewriterBuilder(shardingRule,
-                    routeContext, sqlRewriteContext.getSchemas(), 
sqlRewriteContext.getSqlStatementContext()).getParameterRewriters();
+            Collection<ParameterRewriter> parameterRewriters =
+                    new ShardingParameterRewriterBuilder(shardingRule, 
routeContext, sqlRewriteContext.getDatabase().getSchemas(), 
sqlRewriteContext.getSqlStatementContext()).getParameterRewriters();
             rewriteParameters(sqlRewriteContext, parameterRewriters);
         }
         sqlRewriteContext.addSQLTokenGenerators(new 
ShardingTokenGenerateBuilder(shardingRule, routeContext, 
sqlRewriteContext.getSqlStatementContext()).getSQLTokenGenerators());
diff --git 
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
 
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
index da1d9a38e04..df17e31e854 100644
--- 
a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
+++ 
b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/context/ShardingSQLRewriteContextDecoratorTest.java
@@ -19,6 +19,7 @@ package org.apache.shardingsphere.sharding.rewrite.context;
 
 import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
+import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
 import org.apache.shardingsphere.infra.route.context.RouteContext;
 import org.apache.shardingsphere.sharding.rule.ShardingRule;
@@ -36,6 +37,7 @@ class ShardingSQLRewriteContextDecoratorTest {
     @Test
     void assertDecorate() {
         SQLRewriteContext sqlRewriteContext = mock(SQLRewriteContext.class);
+        
when(sqlRewriteContext.getDatabase()).thenReturn(mock(ShardingSphereDatabase.class));
         
when(sqlRewriteContext.getParameters()).thenReturn(Collections.singletonList(new
 Object()));
         
when(sqlRewriteContext.getSqlStatementContext()).thenReturn(mock(SQLStatementContext.class,
 RETURNS_DEEP_STUBS));
         new 
ShardingSQLRewriteContextDecorator().decorate(mock(ShardingRule.class), 
mock(ConfigurationProperties.class), sqlRewriteContext, 
mock(RouteContext.class));
diff --git 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
index 1a1955df107..bae4bee2cbd 100644
--- 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
+++ 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
@@ -84,7 +84,7 @@ public final class SQLRewriteEntry {
     
     private SQLRewriteContext createSQLRewriteContext(final String sql, final 
List<Object> params, final SQLStatementContext sqlStatementContext,
                                                       final RouteContext 
routeContext, final ConnectionContext connectionContext, final HintValueContext 
hintValueContext) {
-        SQLRewriteContext result = new SQLRewriteContext(database.getName(), 
database.getSchemas(), sqlStatementContext, sql, params, connectionContext, 
hintValueContext);
+        SQLRewriteContext result = new SQLRewriteContext(database, 
sqlStatementContext, sql, params, connectionContext, hintValueContext);
         decorate(decorators, result, routeContext, hintValueContext);
         result.generateSQLTokens();
         return result;
diff --git 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
index af081b4468c..9bca6c5d977 100644
--- 
a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
+++ 
b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java
@@ -22,7 +22,7 @@ import lombok.Getter;
 import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import 
org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
 import org.apache.shardingsphere.infra.hint.HintValueContext;
-import 
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
+import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder;
 import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
 import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.StandardParameterBuilder;
@@ -35,7 +35,6 @@ import 
org.apache.shardingsphere.infra.session.connection.ConnectionContext;
 import java.util.Collection;
 import java.util.LinkedList;
 import java.util.List;
-import java.util.Map;
 
 /**
  * SQL rewrite context.
@@ -43,9 +42,7 @@ import java.util.Map;
 @Getter
 public final class SQLRewriteContext {
     
-    private final String databaseName;
-    
-    private final Map<String, ShardingSphereSchema> schemas;
+    private final ShardingSphereDatabase database;
     
     private final SQLStatementContext sqlStatementContext;
     
@@ -62,10 +59,9 @@ public final class SQLRewriteContext {
     
     private final ConnectionContext connectionContext;
     
-    public SQLRewriteContext(final String databaseName, final Map<String, 
ShardingSphereSchema> schemas, final SQLStatementContext sqlStatementContext,
-                             final String sql, final List<Object> params, 
final ConnectionContext connectionContext, final HintValueContext 
hintValueContext) {
-        this.databaseName = databaseName;
-        this.schemas = schemas;
+    public SQLRewriteContext(final ShardingSphereDatabase database, final 
SQLStatementContext sqlStatementContext, final String sql, final List<Object> 
params,
+                             final ConnectionContext connectionContext, final 
HintValueContext hintValueContext) {
+        this.database = database;
         this.sqlStatementContext = sqlStatementContext;
         this.sql = sql;
         parameters = params;
@@ -92,6 +88,6 @@ public final class SQLRewriteContext {
      * Generate SQL tokens.
      */
     public void generateSQLTokens() {
-        sqlTokens.addAll(sqlTokenGenerators.generateSQLTokens(databaseName, 
schemas, sqlStatementContext, parameters, connectionContext));
+        
sqlTokens.addAll(sqlTokenGenerators.generateSQLTokens(database.getName(), 
database.getSchemas(), sqlStatementContext, parameters, connectionContext));
     }
 }
diff --git 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
index 22bf0d6a6df..b678b7afa6a 100644
--- 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
+++ 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java
@@ -23,6 +23,7 @@ import 
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementConte
 import org.apache.shardingsphere.infra.binder.type.TableAvailable;
 import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
 import org.apache.shardingsphere.infra.hint.HintValueContext;
+import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import 
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
 import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
 import 
org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.StandardParameterBuilder;
@@ -66,11 +67,16 @@ class SQLRewriteContextTest {
     @Mock
     private HintValueContext hintValueContext;
     
+    @Mock
+    private ShardingSphereDatabase database;
+    
     @SuppressWarnings("unchecked")
     @BeforeEach
     void setUp() {
         
when(optionalSQLTokenGenerator.generateSQLToken(sqlStatementContext)).thenReturn(sqlToken);
         
when(collectionSQLTokenGenerator.generateSQLTokens(sqlStatementContext)).thenReturn(Collections.singleton(sqlToken));
+        when(database.getName()).thenReturn(DefaultDatabase.LOGIC_NAME);
+        
when(database.getSchemas()).thenReturn(Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)));
     }
     
     @Test
@@ -78,8 +84,8 @@ class SQLRewriteContextTest {
         InsertStatementContext statementContext = 
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
         when(statementContext.getInsertSelectContext()).thenReturn(null);
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                statementContext, "INSERT INTO tbl VALUES (?)", 
Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
+        SQLRewriteContext sqlRewriteContext =
+                new SQLRewriteContext(database, statementContext, "INSERT INTO 
tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), 
hintValueContext);
         assertThat(sqlRewriteContext.getParameterBuilder(), 
instanceOf(GroupedParameterBuilder.class));
     }
     
@@ -87,15 +93,15 @@ class SQLRewriteContextTest {
     void assertNotInsertStatementContext() {
         SelectStatementContext statementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                statementContext, "SELECT * FROM tbl WHERE id = ?", 
Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
+        SQLRewriteContext sqlRewriteContext =
+                new SQLRewriteContext(database, statementContext, "SELECT * 
FROM tbl WHERE id = ?", Collections.singletonList(1), 
mock(ConnectionContext.class), hintValueContext);
         assertThat(sqlRewriteContext.getParameterBuilder(), 
instanceOf(StandardParameterBuilder.class));
     }
     
     @Test
     void assertGenerateOptionalSQLToken() {
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                sqlStatementContext, "INSERT INTO tbl VALUES (?)", 
Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
+        SQLRewriteContext sqlRewriteContext =
+                new SQLRewriteContext(database, sqlStatementContext, "INSERT 
INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), hintValueContext);
         
sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(optionalSQLTokenGenerator));
         sqlRewriteContext.generateSQLTokens();
         assertFalse(sqlRewriteContext.getSqlTokens().isEmpty());
@@ -104,8 +110,8 @@ class SQLRewriteContextTest {
     
     @Test
     void assertGenerateCollectionSQLToken() {
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                sqlStatementContext, "INSERT INTO tbl VALUES (?)", 
Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
+        SQLRewriteContext sqlRewriteContext =
+                new SQLRewriteContext(database, sqlStatementContext, "INSERT 
INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), hintValueContext);
         
sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(collectionSQLTokenGenerator));
         sqlRewriteContext.generateSQLTokens();
         assertFalse(sqlRewriteContext.getSqlTokens().isEmpty());
diff --git 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
index 3dfbeb8354c..2f8a4ff64c1 100644
--- 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
+++ 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java
@@ -21,6 +21,7 @@ import 
org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContex
 import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
 import org.apache.shardingsphere.infra.database.spi.DatabaseType;
 import org.apache.shardingsphere.infra.hint.HintValueContext;
+import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import 
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
 import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
 import 
org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
@@ -34,6 +35,7 @@ import java.util.Collections;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 class GenericSQLRewriteEngineTest {
     
@@ -41,9 +43,9 @@ class GenericSQLRewriteEngineTest {
     void assertRewrite() {
         DatabaseType databaseType = mock(DatabaseType.class);
         SQLTranslatorRule rule = new SQLTranslatorRule(new 
SQLTranslatorRuleConfiguration());
-        GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, 
databaseType, Collections.singletonMap("ds_0", databaseType)).rewrite(new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
-                Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)), mock(CommonSQLStatementContext.class), 
"SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
-                new HintValueContext()));
+        GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, 
databaseType, Collections.singletonMap("ds_0", databaseType))
+                .rewrite(new SQLRewriteContext(mockDatabase(), 
mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), 
mock(ConnectionContext.class),
+                        new HintValueContext()));
         assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
         assertThat(actual.getSqlRewriteUnit().getParameters(), 
is(Collections.emptyList()));
     }
@@ -51,10 +53,17 @@ class GenericSQLRewriteEngineTest {
     @Test
     void assertRewriteStorageTypeIsEmpty() {
         SQLTranslatorRule rule = new SQLTranslatorRule(new 
SQLTranslatorRuleConfiguration());
-        GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, 
mock(DatabaseType.class), Collections.emptyMap()).rewrite(new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME,
-                Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)), mock(CommonSQLStatementContext.class), 
"SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
-                new HintValueContext()));
+        GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, 
mock(DatabaseType.class), Collections.emptyMap())
+                .rewrite(new SQLRewriteContext(mockDatabase(), 
mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), 
mock(ConnectionContext.class),
+                        new HintValueContext()));
         assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
         assertThat(actual.getSqlRewriteUnit().getParameters(), 
is(Collections.emptyList()));
     }
+    
+    private ShardingSphereDatabase mockDatabase() {
+        ShardingSphereDatabase result = mock(ShardingSphereDatabase.class);
+        when(result.getName()).thenReturn(DefaultDatabase.LOGIC_NAME);
+        when(result.getSchemas()).thenReturn(Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)));
+        return result;
+    }
 }
diff --git 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
index 756051f2157..e58f3c0ebd7 100644
--- 
a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
+++ 
b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java
@@ -25,6 +25,7 @@ import 
org.apache.shardingsphere.infra.database.core.DefaultDatabase;
 import org.apache.shardingsphere.infra.database.spi.DatabaseType;
 import org.apache.shardingsphere.infra.datanode.DataNode;
 import org.apache.shardingsphere.infra.hint.HintValueContext;
+import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import 
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
 import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
 import 
org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
@@ -50,8 +51,8 @@ class RouteSQLRewriteEngineTest {
     
     @Test
     void assertRewriteWithStandardParameterBuilder() {
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                mock(CommonSQLStatementContext.class), "SELECT ?", 
Collections.singletonList(1), mock(ConnectionContext.class), new 
HintValueContext());
+        SQLRewriteContext sqlRewriteContext =
+                new SQLRewriteContext(mockDatabase(), 
mock(CommonSQLStatementContext.class), "SELECT ?", 
Collections.singletonList(1), mock(ConnectionContext.class), new 
HintValueContext());
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
@@ -68,8 +69,7 @@ class RouteSQLRewriteEngineTest {
         SelectStatementContext statementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
         
when(statementContext.getOrderByContext().getItems()).thenReturn(Collections.emptyList());
         
when(statementContext.getPaginationContext().isHasPagination()).thenReturn(false);
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                statementContext, "SELECT ?", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
+        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(mockDatabase(), statementContext, "SELECT ?", 
Collections.singletonList(1), mock(ConnectionContext.class), new 
HintValueContext());
         RouteContext routeContext = new RouteContext();
         RouteUnit firstRouteUnit = new RouteUnit(new RouteMapper("ds", 
"ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteUnit secondRouteUnit = new RouteUnit(new RouteMapper("ds", 
"ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_1")));
@@ -88,8 +88,8 @@ class RouteSQLRewriteEngineTest {
         InsertStatementContext statementContext = 
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
         
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                statementContext, "INSERT INTO tbl VALUES (?)", 
Collections.singletonList(1), mock(ConnectionContext.class), new 
HintValueContext());
+        SQLRewriteContext sqlRewriteContext =
+                new SQLRewriteContext(mockDatabase(), statementContext, 
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
@@ -106,8 +106,8 @@ class RouteSQLRewriteEngineTest {
         InsertStatementContext statementContext = 
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
         
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                statementContext, "INSERT INTO tbl VALUES (?)", 
Collections.singletonList(1), mock(ConnectionContext.class), new 
HintValueContext());
+        SQLRewriteContext sqlRewriteContext =
+                new SQLRewriteContext(mockDatabase(), statementContext, 
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
@@ -126,8 +126,8 @@ class RouteSQLRewriteEngineTest {
         InsertStatementContext statementContext = 
mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
         when(((TableAvailable) 
statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
         
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                statementContext, "INSERT INTO tbl VALUES (?)", 
Collections.singletonList(1), mock(ConnectionContext.class), new 
HintValueContext());
+        SQLRewriteContext sqlRewriteContext =
+                new SQLRewriteContext(mockDatabase(), statementContext, 
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
@@ -147,8 +147,8 @@ class RouteSQLRewriteEngineTest {
         when(statementContext.getInsertSelectContext()).thenReturn(null);
         
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
         
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
-        SQLRewriteContext sqlRewriteContext = new 
SQLRewriteContext(DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)),
-                statementContext, "INSERT INTO tbl VALUES (?)", 
Collections.singletonList(1), mock(ConnectionContext.class), new 
HintValueContext());
+        SQLRewriteContext sqlRewriteContext =
+                new SQLRewriteContext(mockDatabase(), statementContext, 
"INSERT INTO tbl VALUES (?)", Collections.singletonList(1), 
mock(ConnectionContext.class), new HintValueContext());
         RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), 
Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
         RouteContext routeContext = new RouteContext();
         routeContext.getRouteUnits().add(routeUnit);
@@ -160,4 +160,11 @@ class RouteSQLRewriteEngineTest {
         assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), 
is("INSERT INTO tbl VALUES (?)"));
         
assertTrue(actual.getSqlRewriteUnits().get(routeUnit).getParameters().isEmpty());
     }
+    
+    private ShardingSphereDatabase mockDatabase() {
+        ShardingSphereDatabase result = mock(ShardingSphereDatabase.class);
+        when(result.getName()).thenReturn(DefaultDatabase.LOGIC_NAME);
+        when(result.getSchemas()).thenReturn(Collections.singletonMap("test", 
mock(ShardingSphereSchema.class)));
+        return result;
+    }
 }

Reply via email to