This is an automated email from the ASF dual-hosted git repository.

manikumar pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new b4e1deb43a7 MINOR: Few cleanups
b4e1deb43a7 is described below

commit b4e1deb43a75ca84262d877c5f47bbf2b0dbc6c4
Author: Vikas Singh <[email protected]>
AuthorDate: Mon Sep 9 18:43:33 2024 +0530

    MINOR: Few cleanups
    
    Reviewers: Manikumar Reddy <[email protected]>
---
 .../security/scram/internals/ScramSaslServer.java  |  6 +-
 .../scram/internals/ScramSaslServerTest.java       | 68 ++++++++++++++++++++++
 2 files changed, 73 insertions(+), 1 deletion(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
 
b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
index cea3ddf71fd..e3a300f9a7b 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
@@ -149,6 +149,9 @@ public class ScramSaslServer implements SaslServer {
                 case RECEIVE_CLIENT_FINAL_MESSAGE:
                     try {
                         ClientFinalMessage clientFinalMessage = new 
ClientFinalMessage(response);
+                        if 
(!clientFinalMessage.nonce().endsWith(serverFirstMessage.nonce())) {
+                            throw new SaslException("Invalid client nonce in 
the final client message.");
+                        }
                         verifyClientProof(clientFinalMessage);
                         byte[] serverKey = scramCredential.serverKey();
                         byte[] serverSignature = 
formatter.serverSignature(serverKey, clientFirstMessage, serverFirstMessage, 
clientFinalMessage);
@@ -222,7 +225,8 @@ public class ScramSaslServer implements SaslServer {
         this.state = state;
     }
 
-    private void verifyClientProof(ClientFinalMessage clientFinalMessage) 
throws SaslException {
+    // Visible for testing
+    void verifyClientProof(ClientFinalMessage clientFinalMessage) throws 
SaslException {
         try {
             byte[] expectedStoredKey = scramCredential.storedKey();
             byte[] clientSignature = 
formatter.clientSignature(expectedStoredKey, clientFirstMessage, 
serverFirstMessage, clientFinalMessage);
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java
index 1393b26f87a..94b95b0cfdf 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java
@@ -20,14 +20,23 @@ package org.apache.kafka.common.security.scram.internals;
 import org.apache.kafka.common.errors.SaslAuthenticationException;
 import org.apache.kafka.common.security.authenticator.CredentialCache;
 import org.apache.kafka.common.security.scram.ScramCredential;
+import 
org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFirstMessage;
+import 
org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFinalMessage;
+import 
org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage;
 import 
org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
 
 import java.nio.charset.StandardCharsets;
+import java.util.Base64;
 import java.util.HashMap;
 
+import javax.security.sasl.SaslException;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
@@ -67,10 +76,69 @@ public class ScramSaslServerTest {
         assertThrows(SaslAuthenticationException.class, () -> 
saslServer.evaluateResponse(clientFirstMessage(USER_A, USER_B)));
     }
 
+    /**
+     * Validate that server responds with client's nonce as prefix of its 
nonce in the
+     * server first message.
+     * <br>
+     * In addition, it checks that the client final message has nonce that it 
sent in its
+     * first message.
+     */
+    @Test
+    public void validateNonceExchange() throws SaslException {
+        ScramSaslServer spySaslServer = Mockito.spy(saslServer);
+        byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A);
+        ClientFirstMessage clientFirstMessage = new 
ClientFirstMessage(clientFirstMsgBytes);
+
+        byte[] serverFirstMsgBytes = 
spySaslServer.evaluateResponse(clientFirstMsgBytes);
+        ServerFirstMessage serverFirstMessage = new 
ServerFirstMessage(serverFirstMsgBytes);
+        
assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()),
+            "Nonce in server message should start with client first message's 
nonce");
+
+        byte[] clientFinalMessage = 
clientFinalMessage(serverFirstMessage.nonce());
+        Mockito.doNothing()
+            
.when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class));
+        byte[] serverFinalMsgBytes = 
spySaslServer.evaluateResponse(clientFinalMessage);
+        ServerFinalMessage serverFinalMessage = new 
ServerFinalMessage(serverFinalMsgBytes);
+        assertNull(serverFinalMessage.error(), "Server final message should 
not contain error");
+    }
+
+    @Test
+    public void validateFailedNonceExchange() throws SaslException {
+        ScramSaslServer spySaslServer = Mockito.spy(saslServer);
+        byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A);
+        ClientFirstMessage clientFirstMessage = new 
ClientFirstMessage(clientFirstMsgBytes);
+
+        byte[] serverFirstMsgBytes = 
spySaslServer.evaluateResponse(clientFirstMsgBytes);
+        ServerFirstMessage serverFirstMessage = new 
ServerFirstMessage(serverFirstMsgBytes);
+        
assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()),
+            "Nonce in server message should start with client first message's 
nonce");
+
+        byte[] clientFinalMessage = 
clientFinalMessage(formatter.secureRandomString());
+        Mockito.doNothing()
+            
.when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class));
+        SaslException saslException = assertThrows(SaslException.class,
+            () -> spySaslServer.evaluateResponse(clientFinalMessage));
+        assertEquals("Invalid client nonce in the final client message.",
+            saslException.getMessage(),
+            "Failure message: " + saslException.getMessage());
+    }
+
     private byte[] clientFirstMessage(String userName, String authorizationId) 
{
         String nonce = formatter.secureRandomString();
         String authorizationField = authorizationId != null ? "a=" + 
authorizationId : "";
         String firstMessage = String.format("n,%s,n=%s,r=%s", 
authorizationField, userName, nonce);
         return firstMessage.getBytes(StandardCharsets.UTF_8);
     }
+
+    private byte[] clientFinalMessage(String nonce) {
+        String channelBinding = randomBytesAsString();
+        String proof = randomBytesAsString();
+
+        String message = String.format("c=%s,r=%s,p=%s", channelBinding, 
nonce, proof);
+        return message.getBytes(StandardCharsets.UTF_8);
+    }
+
+    private String randomBytesAsString() {
+        return 
Base64.getEncoder().encodeToString(formatter.secureRandomBytes());
+    }
 }

Reply via email to