This is an automated email from the ASF dual-hosted git repository.

sunnianjun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 3df0fa4cc74 Refactor ReplayedSessionVariablesProvider (#27359)
3df0fa4cc74 is described below

commit 3df0fa4cc746e20681830fed78c5c3ca55d56431
Author: Liang Zhang <[email protected]>
AuthorDate: Fri Jul 21 22:06:21 2023 +0800

    Refactor ReplayedSessionVariablesProvider (#27359)
---
 .../compiler/operator/physical/EnumerableScan.java          |  2 +-
 .../admin/executor/DefaultSessionVariableHandler.java       | 13 ++++++++-----
 .../admin/executor/ReplayedSessionVariablesProvider.java    |  4 ++--
 .../executor/DefaultMySQLSessionVariableHandlerTest.java    |  7 +++++--
 .../admin/PostgreSQLDefaultSessionVariableHandlerTest.java  |  7 +++++--
 5 files changed, 21 insertions(+), 12 deletions(-)

diff --git 
a/kernel/sql-federation/core/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/operator/physical/EnumerableScan.java
 
b/kernel/sql-federation/core/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/operator/physical/EnumerableScan.java
index 38584be7cc7..bcd660d99ca 100644
--- 
a/kernel/sql-federation/core/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/operator/physical/EnumerableScan.java
+++ 
b/kernel/sql-federation/core/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/operator/physical/EnumerableScan.java
@@ -91,7 +91,7 @@ public final class EnumerableScan extends TableScan 
implements EnumerableRel {
     }
     
     private SqlString createSQLString(final RelNode scanContext, final String 
databaseType) {
-        final SqlDialect sqlDialect = 
SQLDialectFactory.getSQLDialect(databaseType);
+        SqlDialect sqlDialect = SQLDialectFactory.getSQLDialect(databaseType);
         return new 
RelToSqlConverter(sqlDialect).visitRoot(scanContext).asStatement().toSqlString(sqlDialect);
     }
     
diff --git 
a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/admin/executor/DefaultSessionVariableHandler.java
 
b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/admin/executor/DefaultSessionVariableHandler.java
index 8b23c519ede..68e7215ad92 100644
--- 
a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/admin/executor/DefaultSessionVariableHandler.java
+++ 
b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/admin/executor/DefaultSessionVariableHandler.java
@@ -17,9 +17,9 @@
 
 package org.apache.shardingsphere.proxy.backend.handler.admin.executor;
 
-import lombok.AccessLevel;
-import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
+import 
org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
+import org.apache.shardingsphere.infra.database.spi.DatabaseType;
 import org.apache.shardingsphere.infra.util.spi.type.typed.TypedSPILoader;
 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
 
@@ -29,14 +29,17 @@ import java.util.Collections;
  * Default session variable handler.
  */
 @Slf4j
-@RequiredArgsConstructor(access = AccessLevel.PROTECTED)
 public abstract class DefaultSessionVariableHandler implements 
SessionVariableHandler {
     
-    private final String databaseType;
+    private final DatabaseType databaseType;
+    
+    protected DefaultSessionVariableHandler(final String databaseType) {
+        this.databaseType = TypedSPILoader.getService(DatabaseType.class, 
databaseType);
+    }
     
     @Override
     public final void handle(final ConnectionSession connectionSession, final 
String variableName, final String assignValue) {
-        if (TypedSPILoader.findService(ReplayedSessionVariablesProvider.class, 
databaseType).map(ReplayedSessionVariablesProvider::getVariables).orElseGet(Collections::emptySet)
+        if 
(DatabaseTypedSPILoader.findService(ReplayedSessionVariablesProvider.class, 
databaseType).map(ReplayedSessionVariablesProvider::getVariables).orElseGet(Collections::emptySet)
                 .contains(variableName) || isNeedHandle(variableName)) {
             
connectionSession.getRequiredSessionVariableRecorder().setVariable(variableName,
 assignValue);
         } else {
diff --git 
a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/admin/executor/ReplayedSessionVariablesProvider.java
 
b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/admin/executor/ReplayedSessionVariablesProvider.java
index da6ca33b9d3..a17f9f43478 100644
--- 
a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/admin/executor/ReplayedSessionVariablesProvider.java
+++ 
b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/admin/executor/ReplayedSessionVariablesProvider.java
@@ -17,8 +17,8 @@
 
 package org.apache.shardingsphere.proxy.backend.handler.admin.executor;
 
+import org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPI;
 import org.apache.shardingsphere.infra.util.spi.annotation.SingletonSPI;
-import org.apache.shardingsphere.infra.util.spi.type.typed.TypedSPI;
 
 import java.util.Collection;
 
@@ -26,7 +26,7 @@ import java.util.Collection;
  * Provide session variables need to be replayed on session connected.
  */
 @SingletonSPI
-public interface ReplayedSessionVariablesProvider extends TypedSPI {
+public interface ReplayedSessionVariablesProvider extends DatabaseTypedSPI {
     
     /**
      * Get need to be replayed session variables.
diff --git 
a/proxy/backend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/executor/DefaultMySQLSessionVariableHandlerTest.java
 
b/proxy/backend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/executor/DefaultMySQLSessionVariableHandlerTest.java
index 1136bc9cbaa..fa82abccef5 100644
--- 
a/proxy/backend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/executor/DefaultMySQLSessionVariableHandlerTest.java
+++ 
b/proxy/backend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/executor/DefaultMySQLSessionVariableHandlerTest.java
@@ -17,6 +17,8 @@
 
 package org.apache.shardingsphere.proxy.backend.mysql.handler.admin.executor;
 
+import 
org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
+import org.apache.shardingsphere.infra.database.spi.DatabaseType;
 import org.apache.shardingsphere.infra.util.spi.type.typed.TypedSPILoader;
 import 
org.apache.shardingsphere.proxy.backend.handler.admin.executor.ReplayedSessionVariablesProvider;
 import 
org.apache.shardingsphere.proxy.backend.mysql.handler.admin.MySQLDefaultSessionVariableHandler;
@@ -48,10 +50,11 @@ class DefaultMySQLSessionVariableHandlerTest {
     void assertHandleRecord() {
         ConnectionSession connectionSession = mock(ConnectionSession.class);
         
when(connectionSession.getRequiredSessionVariableRecorder()).thenReturn(mock(RequiredSessionVariableRecorder.class));
-        try (MockedStatic<TypedSPILoader> typedSPILoader = 
mockStatic(TypedSPILoader.class)) {
+        try (MockedStatic<DatabaseTypedSPILoader> databaseTypedSPILoader = 
mockStatic(DatabaseTypedSPILoader.class)) {
             ReplayedSessionVariablesProvider variablesProvider = 
mock(ReplayedSessionVariablesProvider.class);
             
when(variablesProvider.getVariables()).thenReturn(Collections.singleton("sql_mode"));
-            typedSPILoader.when(() -> 
TypedSPILoader.findService(ReplayedSessionVariablesProvider.class, 
"MySQL")).thenReturn(Optional.of(variablesProvider));
+            databaseTypedSPILoader.when(() -> 
DatabaseTypedSPILoader.findService(
+                    ReplayedSessionVariablesProvider.class, 
TypedSPILoader.getService(DatabaseType.class, 
"MySQL"))).thenReturn(Optional.of(variablesProvider));
             final MySQLDefaultSessionVariableHandler 
defaultSessionVariableHandler = new MySQLDefaultSessionVariableHandler();
             defaultSessionVariableHandler.handle(connectionSession, 
"sql_mode", "''");
             
verify(connectionSession.getRequiredSessionVariableRecorder()).setVariable("sql_mode",
 "''");
diff --git 
a/proxy/backend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/backend/postgresql/handler/admin/PostgreSQLDefaultSessionVariableHandlerTest.java
 
b/proxy/backend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/backend/postgresql/handler/admin/PostgreSQLDefaultSessionVariableHandlerTest.java
index e3467a9cfd4..12ac997fa2c 100644
--- 
a/proxy/backend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/backend/postgresql/handler/admin/PostgreSQLDefaultSessionVariableHandlerTest.java
+++ 
b/proxy/backend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/backend/postgresql/handler/admin/PostgreSQLDefaultSessionVariableHandlerTest.java
@@ -17,6 +17,8 @@
 
 package org.apache.shardingsphere.proxy.backend.postgresql.handler.admin;
 
+import 
org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
+import org.apache.shardingsphere.infra.database.spi.DatabaseType;
 import org.apache.shardingsphere.infra.util.spi.type.typed.TypedSPILoader;
 import 
org.apache.shardingsphere.proxy.backend.handler.admin.executor.ReplayedSessionVariablesProvider;
 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
@@ -46,10 +48,11 @@ class PostgreSQLDefaultSessionVariableHandlerTest {
     void assertHandleRecord() {
         ConnectionSession connectionSession = mock(ConnectionSession.class);
         
when(connectionSession.getRequiredSessionVariableRecorder()).thenReturn(mock(RequiredSessionVariableRecorder.class));
-        try (MockedStatic<TypedSPILoader> typedSPILoader = 
mockStatic(TypedSPILoader.class)) {
+        try (MockedStatic<DatabaseTypedSPILoader> databaseTypedSPILoader = 
mockStatic(DatabaseTypedSPILoader.class)) {
             ReplayedSessionVariablesProvider variablesProvider = 
mock(ReplayedSessionVariablesProvider.class);
             
when(variablesProvider.getVariables()).thenReturn(Collections.singleton("datestyle"));
-            typedSPILoader.when(() -> 
TypedSPILoader.findService(ReplayedSessionVariablesProvider.class, 
"PostgreSQL")).thenReturn(Optional.of(variablesProvider));
+            databaseTypedSPILoader.when(() -> 
DatabaseTypedSPILoader.findService(
+                    ReplayedSessionVariablesProvider.class, 
TypedSPILoader.getService(DatabaseType.class, 
"PostgreSQL"))).thenReturn(Optional.of(variablesProvider));
             new 
PostgreSQLDefaultSessionVariableHandler().handle(connectionSession, 
"datestyle", "postgres");
             
verify(connectionSession.getRequiredSessionVariableRecorder()).setVariable("datestyle",
 "postgres");
         }

Reply via email to