This is an automated email from the ASF dual-hosted git repository.
zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 8bc38e3ca3c Add more test cases on MySQLAuthenticationEngineTest
(#37910)
8bc38e3ca3c is described below
commit 8bc38e3ca3cf84e8fe0918c5b59517e74d4bf97b
Author: Liang Zhang <[email protected]>
AuthorDate: Fri Jan 30 20:51:36 2026 +0800
Add more test cases on MySQLAuthenticationEngineTest (#37910)
* Refactor MySQLAuthenticationEngine
* Add more test cases on MySQLAuthenticationEngineTest
* Add more test cases on MySQLAuthenticationEngineTest
---
.../MySQLAuthenticationEngineTest.java | 77 +++++++++++++++++++++-
1 file changed, 76 insertions(+), 1 deletion(-)
diff --git
a/proxy/frontend/dialect/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
b/proxy/frontend/dialect/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
index 3fe2dc9b0d7..ee9ec009991 100644
---
a/proxy/frontend/dialect/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
+++
b/proxy/frontend/dialect/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngineTest.java
@@ -21,11 +21,14 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
+import io.netty.channel.epoll.EpollDomainSocketChannel;
import io.netty.util.Attribute;
import lombok.SneakyThrows;
import org.apache.shardingsphere.authentication.Authenticator;
import org.apache.shardingsphere.authentication.AuthenticatorFactory;
+import org.apache.shardingsphere.authentication.result.AuthenticationResult;
import
org.apache.shardingsphere.authentication.result.AuthenticationResultBuilder;
+import org.apache.shardingsphere.authority.checker.AuthorityChecker;
import org.apache.shardingsphere.authority.model.ShardingSpherePrivileges;
import org.apache.shardingsphere.authority.rule.AuthorityRule;
import org.apache.shardingsphere.database.connector.core.type.DatabaseType;
@@ -35,6 +38,7 @@ import
org.apache.shardingsphere.database.exception.mysql.exception.DatabaseAcce
import
org.apache.shardingsphere.database.exception.mysql.exception.HandshakeException;
import
org.apache.shardingsphere.database.exception.mysql.vendor.MySQLVendorError;
import org.apache.shardingsphere.database.protocol.constant.CommonConstants;
+import
org.apache.shardingsphere.database.protocol.mysql.constant.MySQLAuthenticationMethod;
import
org.apache.shardingsphere.database.protocol.mysql.constant.MySQLCapabilityFlag;
import
org.apache.shardingsphere.database.protocol.mysql.constant.MySQLConnectionPhase;
import
org.apache.shardingsphere.database.protocol.mysql.constant.MySQLConstants;
@@ -74,6 +78,7 @@ import java.util.Optional;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
@@ -155,6 +160,25 @@ class MySQLAuthenticationEngineTest {
assertThat(getConnectionPhase(),
is(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH));
}
+ @Test
+ void assertAuthenticationMethodMismatchWithEmptyAuthResponse() {
+ setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
+ AuthorityRule rule = mock(AuthorityRule.class);
+ when(rule.getAuthenticatorType(any())).thenReturn("");
+ when(rule.findUser(any(Grantee.class))).thenReturn(Optional.empty());
+ ContextManager contextManager = mockContextManager(rule);
+
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+ ChannelHandlerContext context = mockChannelHandlerContext();
+ MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
+ when(payload.readInt4()).thenReturn(0);
+ when(payload.readInt1()).thenReturn(1);
+ when(payload.readStringNul()).thenReturn("root");
+ when(payload.readStringNulByBytes()).thenReturn(new byte[0]);
+ AuthenticationResult actual =
authenticationEngine.authenticate(context, payload);
+ assertFalse(actual.isFinished());
+ assertThat(getConnectionPhase(),
is(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH));
+ }
+
@Test
void assertAuthenticationSwitchResponse() {
setConnectionPhase(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH);
@@ -173,7 +197,14 @@ class MySQLAuthenticationEngineTest {
when(rule.getAuthenticatorType(any())).thenReturn("");
ContextManager contextManager = mockContextManager(rule);
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
- authenticationEngine.authenticate(channelHandlerContext, payload);
+ Authenticator authenticator = mock(Authenticator.class);
+ when(authenticator.authenticate(eq(user),
any(Object[].class))).thenReturn(true);
+
when(authenticator.getAuthenticationMethodName()).thenReturn(MySQLAuthenticationMethod.NATIVE.getMethodName());
+ try (
+ MockedConstruction<AuthenticatorFactory> ignored =
mockConstruction(AuthenticatorFactory.class,
+ (mock, mockContext) ->
when(mock.newInstance(user)).thenReturn(authenticator))) {
+ authenticationEngine.authenticate(channelHandlerContext, payload);
+ }
assertThat(getAuthResponse(), is(authResponse));
}
@@ -278,6 +309,33 @@ class MySQLAuthenticationEngineTest {
verify(context).writeAndFlush(any(MySQLOKPacket.class));
}
+ @Test
+ void assertAuthenticateWithMismatchedPhaseOnDomainSocket() {
+
setConnectionPhase(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH);
+ setAuthenticationResult();
+ MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
+ when(payload.readStringEOFByBytes()).thenReturn(authResponse);
+ ChannelHandlerContext context =
mockDomainSocketChannelHandlerContext();
+ AuthorityRule rule = mock(AuthorityRule.class);
+ ShardingSphereUser user = new ShardingSphereUser("root", "",
"local_host");
+ when(rule.findUser(user.getGrantee())).thenReturn(Optional.of(user));
+ when(rule.getAuthenticatorType(any())).thenReturn("");
+ ContextManager contextManager = mockContextManager(rule);
+
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+ Authenticator authenticator = mock(Authenticator.class);
+ when(authenticator.authenticate(eq(user),
any(Object[].class))).thenReturn(true);
+
when(authenticator.getAuthenticationMethodName()).thenReturn(MySQLAuthenticationMethod.NATIVE.getMethodName());
+ try (
+ MockedConstruction<AuthenticatorFactory> ignoredFactory =
mockConstruction(AuthenticatorFactory.class,
+ (mock, mockContext) ->
when(mock.newInstance(user)).thenReturn(authenticator));
+ MockedConstruction<AuthorityChecker> ignoredChecker =
mockConstruction(AuthorityChecker.class,
+ (mock, mockContext) ->
when(mock.isAuthorized("foo_db")).thenReturn(true))) {
+ AuthenticationResult actual =
authenticationEngine.authenticate(context, payload);
+ assertTrue(actual.isFinished());
+ assertThat(actual.getHostname(), is("local_host"));
+ }
+ }
+
private ContextManager mockContextManager(final AuthorityRule rule) {
ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS);
ShardingSphereDatabase database = new ShardingSphereDatabase("foo_db",
TypedSPILoader.getService(DatabaseType.class, "MySQL"), mock(), mock(),
Collections.emptyList());
@@ -302,6 +360,23 @@ class MySQLAuthenticationEngineTest {
return result;
}
+ @SuppressWarnings("unchecked")
+ private ChannelHandlerContext mockDomainSocketChannelHandlerContext() {
+ ChannelHandlerContext result = mock(ChannelHandlerContext.class);
+ EpollDomainSocketChannel channel =
mock(EpollDomainSocketChannel.class, RETURNS_DEEP_STUBS);
+ Channel parentChannel = mock(Channel.class);
+ SocketAddress socketAddress = mock(SocketAddress.class);
+ when(socketAddress.toString()).thenReturn("local_host");
+ when(parentChannel.localAddress()).thenReturn(socketAddress);
+ when(channel.parent()).thenReturn(parentChannel);
+
when(channel.attr(CommonConstants.CHARSET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
+
when(channel.attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
+
when(channel.attr(MySQLConstants.SEQUENCE_ID_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
+
when(channel.attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)).thenReturn(mock(Attribute.class));
+ doReturn(channel).when(result).channel();
+ return result;
+ }
+
@SuppressWarnings("unchecked")
private Channel getChannel() {
Channel result = mock(Channel.class);