This is an automated email from the ASF dual-hosted git repository.

kenhuuu pushed a commit to branch sasl-initial-authn
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git

commit f1f39dca56964d03d3cfef0c21cd0ab7bcb4c407
Author: Ken Hu <[email protected]>
AuthorDate: Mon Apr 8 14:28:11 2024 -0700

    Fix nits in PR for TINKERPOP-3061 regarding sasl authentication.
---
 .../gremlin/driver/simple/AbstractClient.java      | 15 +++--
 .../server/handler/SaslAuthenticationHandler.java  | 56 +++++++++---------
 .../server/GremlinServerAuthIntegrateTest.java     | 69 +++++++++++++---------
 3 files changed, 78 insertions(+), 62 deletions(-)

diff --git 
a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java
 
b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java
index dfbbcef60a..c353f45c07 100644
--- 
a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java
+++ 
b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java
@@ -25,9 +25,12 @@ import io.netty.channel.nio.NioEventLoopGroup;
 import org.apache.commons.lang3.concurrent.BasicThreadFactory;
 import org.apache.tinkerpop.gremlin.util.message.RequestMessage;
 import org.apache.tinkerpop.gremlin.util.message.ResponseMessage;
-import org.apache.tinkerpop.gremlin.util.message.ResponseStatusCode;
 
-import java.util.*;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
@@ -48,7 +51,7 @@ public abstract class AbstractClient implements SimpleClient {
 
     @Override
     public void submit(final RequestMessage requestMessage, final 
Consumer<ResponseMessage> callback) throws Exception {
-        callbackResponseHandler.callback.put(requestMessage.getRequestId(), 
callback);
+        
callbackResponseHandler.callbackByRequestId.put(requestMessage.getRequestId(), 
callback);
         writeAndFlush(requestMessage);
     }
 
@@ -64,7 +67,7 @@ public abstract class AbstractClient implements SimpleClient {
     public CompletableFuture<List<ResponseMessage>> submitAsync(final 
RequestMessage requestMessage) throws Exception {
         final List<ResponseMessage> results = new ArrayList<>();
         final CompletableFuture<List<ResponseMessage>> f = new 
CompletableFuture<>();
-        callbackResponseHandler.callback.put(requestMessage.getRequestId(), 
response -> {
+        
callbackResponseHandler.callbackByRequestId.put(requestMessage.getRequestId(), 
response -> {
             if (f.isDone())
                 throw new RuntimeException("A terminating message was already 
encountered - no more messages should have been received");
 
@@ -82,11 +85,11 @@ public abstract class AbstractClient implements 
SimpleClient {
     }
 
     static class CallbackResponseHandler extends 
SimpleChannelInboundHandler<ResponseMessage> {
-        public Map<UUID, Consumer<ResponseMessage>> callback = new HashMap<>();
+        public Map<UUID, Consumer<ResponseMessage>> callbackByRequestId = new 
HashMap<>();
 
         @Override
         protected void channelRead0(final ChannelHandlerContext 
channelHandlerContext, final ResponseMessage response) throws Exception {
-            callback.get(response.getRequestId()).accept(response);
+            callbackByRequestId.get(response.getRequestId()).accept(response);
         }
     }
 }
diff --git 
a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java
 
b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java
index 4e4fdd8af6..34345ecf5e 100644
--- 
a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java
+++ 
b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java
@@ -108,16 +108,14 @@ public class SaslAuthenticationHandler extends 
AbstractAuthenticationHandler {
                         authenticator.getClass().getSimpleName()), ex);
 
                 respondWithError(
-                    requestMessage,
-                    builder -> builder.statusMessage("Authenticator is not 
ready to handle requests").code(ResponseStatusCode.SERVER_ERROR),
-                    ctx);
+                        requestMessage,
+                        builder -> builder.statusMessage("Authenticator is not 
ready to handle requests").code(ResponseStatusCode.SERVER_ERROR),
+                        ctx);
             }
 
             return;
-        }
-
-        // If authentication negotiation is pending, store subsequent 
non-authentication requests for later processing
-        if (negotiator.get() != null && 
!requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION)) {
+        } else if (!requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION)) {
+            // If authentication negotiation is pending, store subsequent 
non-authentication requests for later processing
             deferredRequests.setIfAbsent(new 
ImmutablePair<>(LocalDateTime.now(), new ArrayList<>()));
             deferredRequests.get().getValue().add(requestMessage);
 
@@ -125,20 +123,20 @@ public class SaslAuthenticationHandler extends 
AbstractAuthenticationHandler {
 
             if (deferredDuration.compareTo(MAX_REQUEST_DEFERRABLE_DURATION) > 
0) {
                 respondWithError(
-                    requestMessage,
-                    builder -> builder.statusMessage("Too many unauthenticated 
requests").code(ResponseStatusCode.TOO_MANY_REQUESTS),
-                    ctx);
+                        requestMessage,
+                        builder -> builder.statusMessage("Authentication did 
not finish in the allowed duration (" + MAX_REQUEST_DEFERRABLE_DURATION + "s).")
+                                    .code(ResponseStatusCode.UNAUTHORIZED),
+                        ctx);
                 return;
             }
 
             return;
-        }
-
-        if (!requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION) || 
!requestMessage.getArgs().containsKey(Tokens.ARGS_SASL)) {
+        } else if (!requestMessage.getArgs().containsKey(Tokens.ARGS_SASL)) {
+            // This is an authentication request that is missing a "sasl" 
argument.
             respondWithError(
-                requestMessage,
-                builder -> builder.statusMessage("Failed to 
authenticate").code(ResponseStatusCode.UNAUTHORIZED),
-                ctx);
+                    requestMessage,
+                    builder -> builder.statusMessage("Failed to 
authenticate").code(ResponseStatusCode.UNAUTHORIZED),
+                    ctx);
             return;
         }
 
@@ -146,11 +144,11 @@ public class SaslAuthenticationHandler extends 
AbstractAuthenticationHandler {
 
         if (!(saslObject instanceof String)) {
             respondWithError(
-                requestMessage,
-                builder -> builder
-                    .statusMessage("Incorrect type for : " + Tokens.ARGS_SASL 
+ " - base64 encoded String is expected")
-                    .code(ResponseStatusCode.REQUEST_ERROR_MALFORMED_REQUEST),
-                ctx);
+                    requestMessage,
+                    builder -> builder
+                            .statusMessage("Incorrect type for : " + 
Tokens.ARGS_SASL + " - base64 encoded String is expected")
+                            
.code(ResponseStatusCode.REQUEST_ERROR_MALFORMED_REQUEST),
+                    ctx);
             return;
         }
 
@@ -191,9 +189,9 @@ public class SaslAuthenticationHandler extends 
AbstractAuthenticationHandler {
             }
         } catch (AuthenticationException ae) {
             respondWithError(
-                requestMessage,
-                builder -> 
builder.statusMessage(ae.getMessage()).code(ResponseStatusCode.UNAUTHORIZED),
-                ctx);
+                    requestMessage,
+                    builder -> 
builder.statusMessage(ae.getMessage()).code(ResponseStatusCode.UNAUTHORIZED),
+                    ctx);
         }
     }
 
@@ -211,11 +209,11 @@ public class SaslAuthenticationHandler extends 
AbstractAuthenticationHandler {
 
         if (deferredRequests.get() != null) {
             deferredRequests
-                .getAndSet(null).getValue().stream()
-                .map(ResponseMessage::build)
-                .map(buildResponse)
-                .map(ResponseMessage.Builder::create)
-                .forEach(ctx::write);
+                    .getAndSet(null).getValue().stream()
+                    .map(ResponseMessage::build)
+                    .map(buildResponse)
+                    .map(ResponseMessage.Builder::create)
+                    .forEach(ctx::write);
         }
 
         ctx.flush();
diff --git 
a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
 
b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
index 15dcd041f7..01367ec49c 100644
--- 
a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
+++ 
b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
@@ -30,6 +30,8 @@ import 
org.apache.tinkerpop.gremlin.server.handler.SaslAuthenticationHandler;
 import org.apache.tinkerpop.gremlin.structure.Property;
 import org.apache.tinkerpop.gremlin.structure.Vertex;
 import org.apache.tinkerpop.gremlin.util.ExceptionHelper;
+import org.apache.tinkerpop.gremlin.util.Tokens;
+import org.apache.tinkerpop.gremlin.util.message.RequestMessage;
 import org.apache.tinkerpop.gremlin.util.message.ResponseMessage;
 import org.apache.tinkerpop.gremlin.util.message.ResponseStatusCode;
 import org.apache.tinkerpop.gremlin.util.ser.Serializers;
@@ -43,6 +45,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
 
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.AnyOf.anyOf;
@@ -176,13 +179,13 @@ public class GremlinServerAuthIntegrateTest extends 
AbstractGremlinServerIntegra
         try (WebSocketClient client = 
TestClientFactory.createWebSocketClient()) {
             // First request will initiate the authentication handshake
             // Subsequent requests will be deferred
-            CompletableFuture<List<ResponseMessage>> future1 = 
client.submitAsync("");
-            CompletableFuture<List<ResponseMessage>> future2 = 
client.submitAsync("");
-            CompletableFuture<List<ResponseMessage>> future3 = 
client.submitAsync("");
+            CompletableFuture<List<ResponseMessage>> 
futureOfRequestWithinAuthDuration1  = client.submitAsync("");
+            CompletableFuture<List<ResponseMessage>> 
futureOfRequestWithinAuthDuration2  = client.submitAsync("");
+            CompletableFuture<List<ResponseMessage>> 
futureOfRequestWithinAuthDuration3  = client.submitAsync("");
 
             // After the maximum allowed deferred request duration,
             // any non-authenticated request will invalidate all requests with 
429 error
-            CompletableFuture<List<ResponseMessage>> future4 = 
CompletableFuture.runAsync(() -> {
+            CompletableFuture<List<ResponseMessage>> 
futureOfRequestSubmittedTooLate = CompletableFuture.runAsync(() -> {
                 try {
                     
Thread.sleep(SaslAuthenticationHandler.MAX_REQUEST_DEFERRABLE_DURATION.plus(Duration.ofSeconds(1)).toMillis());
                 } catch (InterruptedException e) {
@@ -196,18 +199,40 @@ public class GremlinServerAuthIntegrateTest extends 
AbstractGremlinServerIntegra
                 }
             });
 
-            List<ResponseMessage> responses = new ArrayList<>();
+            assertEquals(2, futureOfRequestWithinAuthDuration1.get().size());
+            assertEquals(1, futureOfRequestWithinAuthDuration2.get().size());
+            assertEquals(1, futureOfRequestWithinAuthDuration3.get().size());
+            assertEquals(1, futureOfRequestSubmittedTooLate.get().size());
 
-            responses.addAll(future1.join());
-            responses.addAll(future2.join());
-            responses.addAll(future3.join());
-            responses.addAll(future4.join());
+            assertEquals(ResponseStatusCode.AUTHENTICATE, 
futureOfRequestWithinAuthDuration1.get().get(0).getStatus().getCode());
+            assertEquals(ResponseStatusCode.UNAUTHORIZED, 
futureOfRequestWithinAuthDuration1.get().get(1).getStatus().getCode());
+            assertEquals(ResponseStatusCode.UNAUTHORIZED, 
futureOfRequestWithinAuthDuration2.get().get(0).getStatus().getCode());
+            assertEquals(ResponseStatusCode.UNAUTHORIZED, 
futureOfRequestWithinAuthDuration3.get().get(0).getStatus().getCode());
+            assertEquals(ResponseStatusCode.UNAUTHORIZED, 
futureOfRequestSubmittedTooLate.get().get(0).getStatus().getCode());
+        }
+    }
 
-            for (ResponseMessage response : responses) {
-                if (response.getStatus().getCode() != 
ResponseStatusCode.AUTHENTICATE) {
-                    assertEquals(ResponseStatusCode.TOO_MANY_REQUESTS, 
response.getStatus().getCode());
-                }
-            }
+    @Test
+    public void shouldFailAuthenticateWithIncorrectParallelRequests() throws 
Exception {
+        try (WebSocketClient client = 
TestClientFactory.createWebSocketClient()) {
+
+            CompletableFuture<List<ResponseMessage>> firstRequest = 
client.submitAsync("1");
+            CompletableFuture<List<ResponseMessage>> secondRequest  = 
client.submitAsync("2");
+            CompletableFuture<List<ResponseMessage>> thirdRequest  = 
client.submitAsync("3");
+
+            Thread.sleep(500);
+
+            // send some incorrect value for username password which should 
cause all requests to fail.
+            
client.submitAsync(RequestMessage.build(Tokens.OPS_AUTHENTICATION).addArg(Tokens.ARGS_SASL,
 "someincorrectvalue").create());
+
+            assertEquals(2, firstRequest.get().size());
+            assertEquals(1, secondRequest.get().size());
+            assertEquals(1, thirdRequest.get().size());
+
+            assertEquals(ResponseStatusCode.AUTHENTICATE, 
firstRequest.get().get(0).getStatus().getCode());
+            assertEquals(ResponseStatusCode.UNAUTHORIZED, 
firstRequest.get().get(1).getStatus().getCode());
+            assertEquals(ResponseStatusCode.UNAUTHORIZED, 
secondRequest.get().get(0).getStatus().getCode());
+            assertEquals(ResponseStatusCode.UNAUTHORIZED, 
thirdRequest.get().get(0).getStatus().getCode());
         }
     }
 
@@ -266,19 +291,9 @@ public class GremlinServerAuthIntegrateTest extends 
AbstractGremlinServerIntegra
 
     private static void assertConnection(final Cluster cluster, final Client 
client) throws InterruptedException, ExecutionException {
         try {
-            CompletableFuture<List<Result>> future1 = 
client.submitAsync("1+1").thenCompose(ResultSet::all);
-            CompletableFuture<List<Result>> future2 = 
client.submitAsync("1+2").thenCompose(ResultSet::all);
-            CompletableFuture<List<Result>> future3 = 
client.submitAsync("1+3").thenCompose(ResultSet::all);
-            CompletableFuture<List<Result>> future4 = 
client.submitAsync("1+4").thenCompose(ResultSet::all);
-            CompletableFuture<List<Result>> future5 = 
client.submitAsync("1+5").thenCompose(ResultSet::all);
-            CompletableFuture<List<Result>> future6 = 
client.submitAsync("1+6").thenCompose(ResultSet::all);
-
-            assertEquals(2, future1.join().get(0).getInt());
-            assertEquals(3, future2.join().get(0).getInt());
-            assertEquals(4, future3.join().get(0).getInt());
-            assertEquals(5, future4.join().get(0).getInt());
-            assertEquals(6, future5.join().get(0).getInt());
-            assertEquals(7, future6.join().get(0).getInt());
+            assertEquals(2, client.submit("1+1").all().get().get(0).getInt());
+            assertEquals(3, client.submit("1+2").all().get().get(0).getInt());
+            assertEquals(4, client.submit("1+3").all().get().get(0).getInt());
         } finally {
             cluster.close();
         }

Reply via email to