This is an automated email from the ASF dual-hosted git repository.

tuichenchuxin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new e5c93a9299c Merge MySQLAuthenticationHandler and 
MySQLAuthenticationEngine (#24148)
e5c93a9299c is described below

commit e5c93a9299c554eeb5d0a8bbf53a3a71aa557484
Author: Liang Zhang <[email protected]>
AuthorDate: Tue Feb 14 10:12:09 2023 +0800

    Merge MySQLAuthenticationHandler and MySQLAuthenticationEngine (#24148)
    
    * Merge MySQLAuthenticationHandler and MySQLAuthenticationEngine
    
    * Fix checkstyle
    
    * Fix checkstyle
    
    * Fix checkstyle
    
    * Fix checkstyle
---
 .../authentication/MySQLAuthenticationEngine.java  |  59 +++++---
 .../authentication/MySQLAuthenticationHandler.java |  60 ---------
 .../MySQLAuthenticationEngineTest.java             | 140 ++++++++++++-------
 .../MySQLAuthenticationHandlerTest.java            | 150 ---------------------
 4 files changed, 131 insertions(+), 278 deletions(-)

diff --git 
a/proxy/frontend/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.java
 
b/proxy/frontend/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.java
index 9233f4d6d17..fe596de4365 100644
--- 
a/proxy/frontend/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.java
+++ 
b/proxy/frontend/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.java
@@ -19,6 +19,7 @@ package 
org.apache.shardingsphere.proxy.frontend.mysql.authentication;
 
 import com.google.common.base.Strings;
 import io.netty.channel.ChannelHandlerContext;
+import org.apache.shardingsphere.authority.checker.AuthorityChecker;
 import org.apache.shardingsphere.authority.rule.AuthorityRule;
 import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
 import 
org.apache.shardingsphere.db.protocol.mysql.constant.MySQLCapabilityFlag;
@@ -28,6 +29,7 @@ import 
org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
 import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLStatusFlag;
 import 
org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLErrPacket;
 import 
org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLOKPacket;
+import 
org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthPluginData;
 import 
org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthSwitchRequestPacket;
 import 
org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthSwitchResponsePacket;
 import 
org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandshakePacket;
@@ -35,6 +37,7 @@ import 
org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandsha
 import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
 import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
 import org.apache.shardingsphere.dialect.mysql.vendor.MySQLVendorError;
+import org.apache.shardingsphere.infra.metadata.user.Grantee;
 import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
 import 
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
@@ -57,7 +60,7 @@ public final class MySQLAuthenticationEngine implements 
AuthenticationEngine {
     
     private static final int DEFAULT_STATUS_FLAG = 
MySQLStatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue();
     
-    private final MySQLAuthenticationHandler authenticationHandler = new 
MySQLAuthenticationHandler();
+    private final MySQLAuthPluginData authPluginData = new 
MySQLAuthPluginData();
     
     private MySQLConnectionPhase connectionPhase = 
MySQLConnectionPhase.INITIAL_HANDSHAKE;
     
@@ -69,47 +72,51 @@ public final class MySQLAuthenticationEngine implements 
AuthenticationEngine {
     public int handshake(final ChannelHandlerContext context) {
         int result = ConnectionIdGenerator.getInstance().nextId();
         connectionPhase = MySQLConnectionPhase.AUTH_PHASE_FAST_PATH;
-        context.writeAndFlush(new MySQLHandshakePacket(result, 
authenticationHandler.getAuthPluginData()));
+        context.writeAndFlush(new MySQLHandshakePacket(result, 
authPluginData));
         MySQLStatementIDGenerator.getInstance().registerConnection(result);
         return result;
     }
     
     @Override
     public AuthenticationResult authenticate(final ChannelHandlerContext 
context, final PacketPayload payload) {
+        AuthorityRule rule = 
ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(AuthorityRule.class);
         if (MySQLConnectionPhase.AUTH_PHASE_FAST_PATH == connectionPhase) {
-            currentAuthResult = authPhaseFastPath(context, payload);
+            currentAuthResult = authPhaseFastPath(context, payload, rule);
             if (!currentAuthResult.isFinished()) {
                 return currentAuthResult;
             }
         } else if (MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH == 
connectionPhase) {
             authenticationMethodMismatch((MySQLPacketPayload) payload);
         }
-        Optional<MySQLVendorError> vendorError = 
authenticationHandler.login(currentAuthResult.getUsername(), 
getHostAddress(context), authResponse, currentAuthResult.getDatabase());
-        if (vendorError.isPresent()) {
-            context.writeAndFlush(createErrorPacket(vendorError.get(), 
context));
-            context.close();
+        Grantee grantee = new Grantee(currentAuthResult.getUsername(), 
getHostAddress(context));
+        if (!login(rule, grantee, authResponse)) {
+            writeErrorPacket(context,
+                    new 
MySQLErrPacket(MySQLVendorError.ER_ACCESS_DENIED_ERROR, 
currentAuthResult.getUsername(), getHostAddress(context), 0 == 
authResponse.length ? "NO" : "YES"));
             return AuthenticationResultBuilder.continued();
         }
-        context.writeAndFlush(new MySQLOKPacket(DEFAULT_STATUS_FLAG));
-        return 
AuthenticationResultBuilder.finished(currentAuthResult.getUsername(), 
getHostAddress(context), currentAuthResult.getDatabase());
+        if (!authorizeDatabase(rule, grantee, 
currentAuthResult.getDatabase())) {
+            writeErrorPacket(context,
+                    new 
MySQLErrPacket(MySQLVendorError.ER_DBACCESS_DENIED_ERROR, 
currentAuthResult.getUsername(), getHostAddress(context), 
currentAuthResult.getDatabase()));
+            return AuthenticationResultBuilder.continued();
+        }
+        writeOKPacket(context);
+        return AuthenticationResultBuilder.finished(grantee.getUsername(), 
grantee.getHostname(), currentAuthResult.getDatabase());
     }
     
-    private AuthenticationResult authPhaseFastPath(final ChannelHandlerContext 
context, final PacketPayload payload) {
+    private AuthenticationResult authPhaseFastPath(final ChannelHandlerContext 
context, final PacketPayload payload, final AuthorityRule rule) {
         MySQLHandshakeResponse41Packet packet = new 
MySQLHandshakeResponse41Packet((MySQLPacketPayload) payload);
         authResponse = packet.getAuthResponse();
         MySQLCharacterSet characterSet = 
MySQLCharacterSet.findById(packet.getCharacterSet());
         
context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).set(characterSet.getCharset());
         
context.channel().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).set(characterSet);
         if (!Strings.isNullOrEmpty(packet.getDatabase()) && 
!ProxyContext.getInstance().databaseExists(packet.getDatabase())) {
-            context.writeAndFlush(new 
MySQLErrPacket(MySQLVendorError.ER_BAD_DB_ERROR, packet.getDatabase()));
-            context.close();
+            writeErrorPacket(context, new 
MySQLErrPacket(MySQLVendorError.ER_BAD_DB_ERROR, packet.getDatabase()));
             return AuthenticationResultBuilder.continued();
         }
-        AuthorityRule rule = 
ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(AuthorityRule.class);
         Authenticator authenticator = new 
AuthenticatorFactory<>(MySQLAuthenticatorType.class, rule).newInstance(new 
ShardingSphereUser(packet.getUsername(), "", getHostAddress(context)));
         if (isClientPluginAuth(packet) && 
!authenticator.getAuthenticationMethodName().equals(packet.getAuthPluginName()))
 {
             connectionPhase = 
MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH;
-            context.writeAndFlush(new 
MySQLAuthSwitchRequestPacket(authenticator.getAuthenticationMethodName(), 
authenticationHandler.getAuthPluginData()));
+            context.writeAndFlush(new 
MySQLAuthSwitchRequestPacket(authenticator.getAuthenticationMethodName(), 
authPluginData));
             return AuthenticationResultBuilder.continued(packet.getUsername(), 
getHostAddress(context), packet.getDatabase());
         }
         return AuthenticationResultBuilder.finished(packet.getUsername(), 
getHostAddress(context), packet.getDatabase());
@@ -120,22 +127,30 @@ public final class MySQLAuthenticationEngine implements 
AuthenticationEngine {
     }
     
     private void authenticationMethodMismatch(final MySQLPacketPayload 
payload) {
-        MySQLAuthSwitchResponsePacket packet = new 
MySQLAuthSwitchResponsePacket(payload);
-        authResponse = packet.getAuthPluginResponse();
+        authResponse = new 
MySQLAuthSwitchResponsePacket(payload).getAuthPluginResponse();
     }
     
-    private MySQLErrPacket createErrorPacket(final MySQLVendorError 
vendorError, final ChannelHandlerContext context) {
-        return MySQLVendorError.ER_DBACCESS_DENIED_ERROR == vendorError
-                ? new 
MySQLErrPacket(MySQLVendorError.ER_DBACCESS_DENIED_ERROR, 
currentAuthResult.getUsername(), getHostAddress(context), 
currentAuthResult.getDatabase())
-                : new MySQLErrPacket(MySQLVendorError.ER_ACCESS_DENIED_ERROR, 
currentAuthResult.getUsername(), getHostAddress(context), getErrorMessage());
+    private boolean login(final AuthorityRule rule, final Grantee grantee, 
final byte[] authenticationResponse) {
+        Optional<ShardingSphereUser> user = rule.findUser(grantee);
+        return user.isPresent()
+                && new AuthenticatorFactory<>(MySQLAuthenticatorType.class, 
rule).newInstance(user.get()).authenticate(user.get(), new 
Object[]{authenticationResponse, authPluginData});
     }
     
-    private String getErrorMessage() {
-        return 0 == authResponse.length ? "NO" : "YES";
+    private boolean authorizeDatabase(final AuthorityRule rule, final Grantee 
grantee, final String databaseName) {
+        return null == databaseName || new AuthorityChecker(rule, 
grantee).isAuthorized(databaseName);
     }
     
     private String getHostAddress(final ChannelHandlerContext context) {
         SocketAddress socketAddress = context.channel().remoteAddress();
         return socketAddress instanceof InetSocketAddress ? 
((InetSocketAddress) socketAddress).getAddress().getHostAddress() : 
socketAddress.toString();
     }
+    
+    private void writeErrorPacket(final ChannelHandlerContext context, final 
MySQLErrPacket errPacket) {
+        context.writeAndFlush(errPacket);
+        context.close();
+    }
+    
+    private void writeOKPacket(final ChannelHandlerContext context) {
+        context.writeAndFlush(new MySQLOKPacket(DEFAULT_STATUS_FLAG));
+    }
 }
diff --git 
a/proxy/frontend/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationHandler.java
 
b/proxy/frontend/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationHandler.java
deleted file mode 100644
index 16a322ddb30..00000000000
--- 
a/proxy/frontend/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationHandler.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.shardingsphere.proxy.frontend.mysql.authentication;
-
-import lombok.Getter;
-import org.apache.shardingsphere.authority.checker.AuthorityChecker;
-import org.apache.shardingsphere.authority.rule.AuthorityRule;
-import 
org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthPluginData;
-import org.apache.shardingsphere.dialect.mysql.vendor.MySQLVendorError;
-import org.apache.shardingsphere.infra.metadata.user.Grantee;
-import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
-import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
-import 
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticatorFactory;
-import 
org.apache.shardingsphere.proxy.frontend.mysql.authentication.authenticator.MySQLAuthenticatorType;
-
-import java.util.Optional;
-
-/**
- * Authentication handler for MySQL.
- */
-@Getter
-public final class MySQLAuthenticationHandler {
-    
-    private final MySQLAuthPluginData authPluginData = new 
MySQLAuthPluginData();
-    
-    /**
-     * Login.
-     *
-     * @param username username
-     * @param hostname hostname
-     * @param authenticationResponse authentication response
-     * @param databaseName database name
-     * @return login success or failure
-     */
-    public Optional<MySQLVendorError> login(final String username, final 
String hostname, final byte[] authenticationResponse, final String 
databaseName) {
-        AuthorityRule rule = 
ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(AuthorityRule.class);
-        Grantee grantee = new Grantee(username, hostname);
-        Optional<ShardingSphereUser> user = rule.findUser(grantee);
-        if (!user.isPresent()
-                || !new AuthenticatorFactory<>(MySQLAuthenticatorType.class, 
rule).newInstance(user.get()).authenticate(user.get(), new 
Object[]{authenticationResponse, authPluginData})) {
-            return Optional.of(MySQLVendorError.ER_ACCESS_DENIED_ERROR);
-        }
-        return null == databaseName || new AuthorityChecker(rule, 
grantee).isAuthorized(databaseName) ? Optional.empty() : 
Optional.of(MySQLVendorError.ER_DBACCESS_DENIED_ERROR);
-    }
-}
diff --git 
a/proxy/frontend/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
 
b/proxy/frontend/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
index 83d4f2d2025..b8edbf32c20 100644
--- 
a/proxy/frontend/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
+++ 
b/proxy/frontend/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
@@ -35,20 +35,25 @@ import 
org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
 import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import 
org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
+import org.apache.shardingsphere.infra.metadata.user.Grantee;
+import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
 import org.apache.shardingsphere.mode.manager.ContextManager;
 import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
 import org.apache.shardingsphere.mode.metadata.persist.MetaDataPersistService;
 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
 import 
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResultBuilder;
+import org.apache.shardingsphere.proxy.frontend.authentication.Authenticator;
+import 
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticatorFactory;
 import org.apache.shardingsphere.proxy.frontend.mysql.ProxyContextRestorer;
-import org.junit.Before;
 import org.junit.Test;
+import org.mockito.MockedConstruction;
 import org.mockito.internal.configuration.plugins.Plugins;
 
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
 import java.util.Collections;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Properties;
@@ -57,34 +62,22 @@ import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockConstruction;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 public final class MySQLAuthenticationEngineTest extends ProxyContextRestorer {
     
-    private final MySQLAuthenticationHandler authenticationHandler = 
mock(MySQLAuthenticationHandler.class);
-    
     private final MySQLAuthenticationEngine authenticationEngine = new 
MySQLAuthenticationEngine();
     
     private final byte[] authResponse = {-27, 89, -20, -27, 65, -120, -64, 
-101, 86, -100, -108, -100, 6, -125, -37, 117, 14, -43, 95, -113};
     
-    @Before
-    public void setUp() {
-        initAuthenticationHandlerForAuthenticationEngine();
-    }
-    
-    @SneakyThrows(ReflectiveOperationException.class)
-    private void initAuthenticationHandlerForAuthenticationEngine() {
-        
Plugins.getMemberAccessor().set(MySQLAuthenticationEngine.class.getDeclaredField("authenticationHandler"),
 authenticationEngine, authenticationHandler);
-    }
-    
     @Test
     public void assertHandshake() {
-        ChannelHandlerContext context = getContext();
+        ChannelHandlerContext context = mockChannelHandlerContext();
         assertTrue(authenticationEngine.handshake(context) > 0);
         verify(context).writeAndFlush(any(MySQLHandshakePacket.class));
     }
@@ -92,7 +85,9 @@ public final class MySQLAuthenticationEngineTest extends 
ProxyContextRestorer {
     @SuppressWarnings("unchecked")
     @Test
     public void assertAuthenticationMethodMismatch() {
-        setMetaDataContexts();
+        AuthorityRule rule = mock(AuthorityRule.class);
+        when(rule.getAuthenticatorType(any())).thenReturn("");
+        setMetaDataContexts(rule);
         setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
         MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
         ChannelHandlerContext channelHandlerContext = 
mock(ChannelHandlerContext.class);
@@ -110,7 +105,7 @@ public final class MySQLAuthenticationEngineTest extends 
ProxyContextRestorer {
     }
     
     @Test
-    public void assertAuthSwitchResponse() {
+    public void assertAuthenticationSwitchResponse() {
         
setConnectionPhase(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH);
         MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
         Channel channel = mock(Channel.class);
@@ -119,6 +114,9 @@ public final class MySQLAuthenticationEngineTest extends 
ProxyContextRestorer {
         when(channel.remoteAddress()).thenReturn(new 
InetSocketAddress("localhost", 3307));
         when(channelHandlerContext.channel()).thenReturn(channel);
         setAuthenticationResult();
+        AuthorityRule rule = mock(AuthorityRule.class);
+        when(rule.getAuthenticatorType(any())).thenReturn("");
+        setMetaDataContexts(rule);
         authenticationEngine.authenticate(channelHandlerContext, payload);
         assertThat(getAuthResponse(), is(authResponse));
     }
@@ -129,53 +127,103 @@ public final class MySQLAuthenticationEngineTest extends 
ProxyContextRestorer {
     }
     
     @Test
-    public void assertAuthWithLoginFail() {
+    public void assertAuthenticateFailedWithAbsentUser() {
+        setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
+        AuthorityRule rule = mock(AuthorityRule.class);
+        when(rule.getAuthenticatorType(any())).thenReturn("");
+        when(rule.findUser(new Grantee("root", 
"127.0.0.1"))).thenReturn(Optional.empty());
+        setMetaDataContexts(rule);
+        ChannelHandlerContext context = mockChannelHandlerContext();
+        try (MockedConstruction<MySQLErrPacket> ignored = 
mockConstruction(MySQLErrPacket.class, (mock, mockContext) -> 
assertAuthenticationErrorPacket(mockContext.arguments()))) {
+            authenticationEngine.authenticate(context, getPayload("root", 
"sharding_db", authResponse));
+            verify(context).writeAndFlush(any(MySQLErrPacket.class));
+            verify(context).close();
+        }
+    }
+    
+    @SuppressWarnings({"rawtypes", "unused"})
+    @Test
+    public void assertAuthenticateFailedWithUnAuthenticatedUser() {
         setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
-        ChannelHandlerContext context = getContext();
-        setMetaDataContexts();
-        when(authenticationHandler.login(anyString(), any(), any(), 
anyString())).thenReturn(Optional.of(MySQLVendorError.ER_ACCESS_DENIED_ERROR));
-        authenticationEngine.authenticate(context, getPayload("root", 
"sharding_db", authResponse));
-        verify(context).writeAndFlush(any(MySQLErrPacket.class));
-        verify(context).close();
+        AuthorityRule rule = mock(AuthorityRule.class);
+        when(rule.getAuthenticatorType(any())).thenReturn("");
+        ShardingSphereUser user = new ShardingSphereUser("root", "", 
"127.0.0.1");
+        when(rule.findUser(user.getGrantee())).thenReturn(Optional.of(user));
+        setMetaDataContexts(rule);
+        ChannelHandlerContext context = mockChannelHandlerContext();
+        try (
+                MockedConstruction<AuthenticatorFactory> 
mockedAuthenticatorFactory = mockConstruction(AuthenticatorFactory.class,
+                        (mock, mockContext) -> 
when(mock.newInstance(user)).thenReturn(mock(Authenticator.class))); 
+                MockedConstruction<MySQLErrPacket> mockedErrPacket = 
mockConstruction(MySQLErrPacket.class, (mock, mockContext) -> 
assertAuthenticationErrorPacket(mockContext.arguments()))
+        ) {
+            authenticationEngine.authenticate(context, getPayload("root", 
"sharding_db", authResponse));
+            verify(context).writeAndFlush(any(MySQLErrPacket.class));
+            verify(context).close();
+        }
+    }
+    
+    private void assertAuthenticationErrorPacket(final List<?> arguments) {
+        assertThat(arguments.get(0), 
is(MySQLVendorError.ER_ACCESS_DENIED_ERROR));
+        assertThat(arguments.get(1), is(new String[] {"root", "127.0.0.1", 
"YES"}));
     }
     
     @Test
-    public void assertAuthWithDatabaseAccessDenied() {
+    public void assertAuthenticateFailedWithDatabaseAccessDenied() {
         setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
-        ChannelHandlerContext context = getContext();
-        setMetaDataContexts();
-        when(authenticationHandler.login(anyString(), any(), any(), 
anyString())).thenReturn(Optional.of(MySQLVendorError.ER_DBACCESS_DENIED_ERROR));
-        authenticationEngine.authenticate(context, getPayload("root", 
"sharding_db", authResponse));
-        verify(context).writeAndFlush(any(MySQLErrPacket.class));
-        verify(context).close();
+        AuthorityRule rule = mock(AuthorityRule.class);
+        when(rule.getAuthenticatorType(any())).thenReturn("");
+        ShardingSphereUser user = new ShardingSphereUser("root", "", 
"127.0.0.1");
+        when(rule.findUser(user.getGrantee())).thenReturn(Optional.of(user));
+        setMetaDataContexts(rule);
+        ChannelHandlerContext context = mockChannelHandlerContext();
+        try (MockedConstruction<MySQLErrPacket> ignored = 
mockConstruction(MySQLErrPacket.class, (mock, mockContext) -> 
assertDatabaseAccessDeniedErrorPacket(mockContext.arguments()))) {
+            authenticationEngine.authenticate(context, getPayload("root", 
"sharding_db", authResponse));
+            verify(context).writeAndFlush(any(MySQLErrPacket.class));
+            verify(context).close();
+        }
+    }
+    
+    private void assertDatabaseAccessDeniedErrorPacket(final List<?> 
arguments) {
+        assertThat(arguments.get(0), 
is(MySQLVendorError.ER_DBACCESS_DENIED_ERROR));
+        assertThat(arguments.get(1), is(new String[] {"root", "127.0.0.1", 
"sharding_db"}));
     }
     
     @Test
-    public void assertAuthWithAbsentDatabase() {
-        ChannelHandlerContext context = getContext();
-        setMetaDataContexts();
+    public void assertAuthenticateFailedWithInvalidDatabase() {
+        AuthorityRule rule = mock(AuthorityRule.class);
+        when(rule.getAuthenticatorType(any())).thenReturn("");
+        setMetaDataContexts(rule);
         setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
-        authenticationEngine.authenticate(context, getPayload("root", "ABSENT 
DATABASE", authResponse));
-        verify(context).writeAndFlush(any(MySQLErrPacket.class));
-        verify(context).close();
+        ChannelHandlerContext context = mockChannelHandlerContext();
+        try (MockedConstruction<MySQLErrPacket> ignored = 
mockConstruction(MySQLErrPacket.class, (mock, mockContext) -> 
assertInvalidDatabaseErrorPacket(mockContext.arguments()))) {
+            authenticationEngine.authenticate(context, getPayload("root", 
"invalid_db", authResponse));
+            verify(context).writeAndFlush(any(MySQLErrPacket.class));
+            verify(context).close();
+        }
+    }
+    
+    private void assertInvalidDatabaseErrorPacket(final List<?> arguments) {
+        assertThat(arguments.get(0), is(MySQLVendorError.ER_BAD_DB_ERROR));
+        assertThat(arguments.get(1), is(new String[] {"invalid_db"}));
     }
     
     @Test
-    public void assertAuth() {
+    public void assertAuthenticateSuccess() {
         setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
-        ChannelHandlerContext context = getContext();
-        when(authenticationHandler.login(anyString(), any(), any(), 
anyString())).thenReturn(Optional.empty());
-        setMetaDataContexts();
-        authenticationEngine.authenticate(context, getPayload("root", 
"sharding_db", authResponse));
+        AuthorityRule rule = mock(AuthorityRule.class);
+        when(rule.getAuthenticatorType(any())).thenReturn("");
+        ShardingSphereUser user = new ShardingSphereUser("root", "", 
"127.0.0.1");
+        when(rule.findUser(user.getGrantee())).thenReturn(Optional.of(user));
+        setMetaDataContexts(rule);
+        ChannelHandlerContext context = mockChannelHandlerContext();
+        authenticationEngine.authenticate(context, getPayload("root", null, 
authResponse));
         verify(context).writeAndFlush(any(MySQLOKPacket.class));
     }
     
-    private void setMetaDataContexts() {
+    private void setMetaDataContexts(final AuthorityRule rule) {
         ContextManager contextManager = mock(ContextManager.class, 
RETURNS_DEEP_STUBS);
         Map<String, ShardingSphereDatabase> databases = new LinkedHashMap<>(1, 
1);
         databases.put("sharding_db", mock(ShardingSphereDatabase.class));
-        AuthorityRule rule = mock(AuthorityRule.class);
-        when(rule.getAuthenticatorType(any())).thenReturn("");
         MetaDataContexts metaDataContexts = new 
MetaDataContexts(mock(MetaDataPersistService.class), new 
ShardingSphereMetaData(databases,
                 new ShardingSphereRuleMetaData(Collections.singleton(rule)), 
new ConfigurationProperties(new Properties())));
         
when(contextManager.getMetaDataContexts()).thenReturn(metaDataContexts);
@@ -191,7 +239,7 @@ public final class MySQLAuthenticationEngineTest extends 
ProxyContextRestorer {
         return result;
     }
     
-    private ChannelHandlerContext getContext() {
+    private ChannelHandlerContext mockChannelHandlerContext() {
         ChannelHandlerContext result = mock(ChannelHandlerContext.class);
         doReturn(getChannel()).when(result).channel();
         return result;
diff --git 
a/proxy/frontend/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationHandlerTest.java
 
b/proxy/frontend/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationHandlerTest.java
deleted file mode 100644
index 6c53aa23d14..00000000000
--- 
a/proxy/frontend/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationHandlerTest.java
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.shardingsphere.proxy.frontend.mysql.authentication;
-
-import com.google.common.primitives.Bytes;
-import lombok.SneakyThrows;
-import org.apache.shardingsphere.authority.config.AuthorityRuleConfiguration;
-import org.apache.shardingsphere.authority.model.AuthorityRegistry;
-import org.apache.shardingsphere.authority.rule.AuthorityRule;
-import org.apache.shardingsphere.authority.rule.builder.AuthorityRuleBuilder;
-import 
org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthPluginData;
-import org.apache.shardingsphere.dialect.mysql.vendor.MySQLVendorError;
-import org.apache.shardingsphere.infra.config.algorithm.AlgorithmConfiguration;
-import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
-import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
-import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
-import 
org.apache.shardingsphere.infra.metadata.database.rule.ShardingSphereRuleMetaData;
-import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
-import org.apache.shardingsphere.mode.manager.ContextManager;
-import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
-import org.apache.shardingsphere.mode.metadata.persist.MetaDataPersistService;
-import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
-import org.apache.shardingsphere.proxy.frontend.mysql.ProxyContextRestorer;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.internal.configuration.plugins.Plugins;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Properties;
-
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertFalse;
-import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-public final class MySQLAuthenticationHandlerTest extends ProxyContextRestorer 
{
-    
-    private static final String SCHEMA_PATTERN = "db%s";
-    
-    private final MySQLAuthenticationHandler authenticationHandler = new 
MySQLAuthenticationHandler();
-    
-    private final byte[] part1 = {84, 85, 115, 77, 68, 116, 85, 78};
-    
-    private final byte[] part2 = {83, 121, 75, 81, 87, 56, 120, 112, 73, 109, 
77, 69};
-    
-    @Before
-    public void setUp() {
-        initAuthPluginDataForAuthenticationHandler();
-    }
-    
-    @SneakyThrows(ReflectiveOperationException.class)
-    private void initAuthPluginDataForAuthenticationHandler() {
-        MySQLAuthPluginData authPluginData = new MySQLAuthPluginData(part1, 
part2);
-        
Plugins.getMemberAccessor().set(MySQLAuthenticationHandler.class.getDeclaredField("authPluginData"),
 authenticationHandler, authPluginData);
-    }
-    
-    @Test
-    public void assertLoginWithPassword() {
-        initProxyContext(new ShardingSphereUser("root", "root", ""), true);
-        byte[] authResponse = {-27, 89, -20, -27, 65, -120, -64, -101, 86, 
-100, -108, -100, 6, -125, -37, 117, 14, -43, 95, -113};
-        assertFalse(authenticationHandler.login("root", "", authResponse, 
"db1").isPresent());
-    }
-    
-    @Test
-    public void assertLoginWithAbsentUser() {
-        initProxyContext(new ShardingSphereUser("root", "root", ""), true);
-        byte[] authResponse = {-27, 89, -20, -27, 65, -120, -64, -101, 86, 
-100, -108, -100, 6, -125, -37, 117, 14, -43, 95, -113};
-        assertThat(authenticationHandler.login("root1", "", authResponse, 
"db1").orElse(null), is(MySQLVendorError.ER_ACCESS_DENIED_ERROR));
-    }
-    
-    @Test
-    public void assertLoginWithIncorrectPassword() {
-        initProxyContext(new ShardingSphereUser("root", "root", ""), true);
-        byte[] authResponse = {0, 89, -20, -27, 65, -120, -64, -101, 86, -100, 
-108, -100, 6, -125, -37, 117, 14, -43, 95, -113};
-        assertThat(authenticationHandler.login("root", "", authResponse, 
"db1").orElse(null), is(MySQLVendorError.ER_ACCESS_DENIED_ERROR));
-    }
-    
-    @Test
-    public void assertLoginWithoutPassword() {
-        initProxyContext(new ShardingSphereUser("root", null, ""), true);
-        byte[] authResponse = {};
-        assertFalse(authenticationHandler.login("root", "", authResponse, 
"db1").isPresent());
-    }
-    
-    @Test
-    public void assertLoginWithUnauthorizedSchema() {
-        initProxyContext(new ShardingSphereUser("root", "root", ""), false);
-        byte[] authResponse = {-27, 89, -20, -27, 65, -120, -64, -101, 86, 
-100, -108, -100, 6, -125, -37, 117, 14, -43, 95, -113};
-        assertThat(authenticationHandler.login("root", "", authResponse, 
"db11").orElse(null), is(MySQLVendorError.ER_DBACCESS_DENIED_ERROR));
-    }
-    
-    @Test
-    public void assertGetAuthPluginData() {
-        
assertThat(authenticationHandler.getAuthPluginData().getAuthenticationPluginData(),
 is(Bytes.concat(part1, part2)));
-    }
-    
-    @SneakyThrows(ReflectiveOperationException.class)
-    private void initProxyContext(final ShardingSphereUser user, final boolean 
isNeedSuper) {
-        ContextManager contextManager = mock(ContextManager.class, 
RETURNS_DEEP_STUBS);
-        MetaDataContexts metaDataContexts = getMetaDataContexts(user, 
isNeedSuper);
-        
when(contextManager.getMetaDataContexts()).thenReturn(metaDataContexts);
-        ProxyContext.init(contextManager);
-    }
-    
-    private MetaDataContexts getMetaDataContexts(final ShardingSphereUser 
user, final boolean isNeedSuper) throws ReflectiveOperationException {
-        return new MetaDataContexts(mock(MetaDataPersistService.class),
-                new ShardingSphereMetaData(getDatabases(), 
buildGlobalRuleMetaData(user, isNeedSuper), new ConfigurationProperties(new 
Properties())));
-    }
-    
-    private Map<String, ShardingSphereDatabase> getDatabases() {
-        Map<String, ShardingSphereDatabase> result = new HashMap<>(10, 1);
-        for (int i = 0; i < 10; i++) {
-            ShardingSphereDatabase database = 
mock(ShardingSphereDatabase.class);
-            when(database.getRuleMetaData()).thenReturn(new 
ShardingSphereRuleMetaData(Collections.emptyList()));
-            result.put(String.format(SCHEMA_PATTERN, i), database);
-        }
-        return result;
-    }
-    
-    private ShardingSphereRuleMetaData buildGlobalRuleMetaData(final 
ShardingSphereUser user, final boolean isNeedSuper) throws 
ReflectiveOperationException {
-        AuthorityRuleConfiguration ruleConfig = new 
AuthorityRuleConfiguration(Collections.singletonList(user), new 
AlgorithmConfiguration("ALL_PERMITTED", new Properties()), null);
-        AuthorityRule rule = new AuthorityRuleBuilder().build(ruleConfig, 
Collections.emptyMap(), mock(ConfigurationProperties.class));
-        if (!isNeedSuper) {
-            AuthorityRegistry authorityRegistry = 
mock(AuthorityRegistry.class);
-            
when(authorityRegistry.findPrivileges(user.getGrantee())).thenReturn(Optional.empty());
-            
Plugins.getMemberAccessor().set(AuthorityRule.class.getDeclaredField("authorityRegistry"),
 rule, authorityRegistry);
-        }
-        return new ShardingSphereRuleMetaData(Collections.singletonList(rule));
-    }
-}


Reply via email to