This is an automated email from the ASF dual-hosted git repository.

kenhuuu pushed a commit to branch sasl-initial-authn
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git

commit 543eee70b32bf5778f31cc8138caee0165540703
Author: Tiến Nguyễn Khắc <[email protected]>
AuthorDate: Sun Mar 17 00:00:07 2024 +1300

    fix: failing authentication when multiple initially requests are executed 
concurrently
---
 CHANGELOG.asciidoc                                 |   1 +
 .../gremlin/driver/simple/AbstractClient.java      |  13 +-
 .../test/integration/sasl-authentication-tests.js  |  13 ++
 .../server/handler/SaslAuthenticationHandler.java  | 259 +++++++++++++--------
 .../tinkerpop/gremlin/server/handler/StateKey.java |   9 +
 .../server/GremlinServerAuthIntegrateTest.java     |  72 +++++-
 6 files changed, 252 insertions(+), 115 deletions(-)

diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc
index ec23b0fc92..652a0d7215 100644
--- a/CHANGELOG.asciidoc
+++ b/CHANGELOG.asciidoc
@@ -25,6 +25,7 @@ 
image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima
 
 * Deprecated `ltrim()` and `rTrim()` in favor of `l_trim()` and `r_trim` in 
Python.
 * Fixed bug in `onCreate` for `mergeV()` where use of the `Cardinality` 
functions was not properly handled.
+* Fixed multiple concurrent initially requests caused authentication to fail.
 
 [[release-3-7-1]]
 === TinkerPop 3.7.1 (November 20, 2023)
diff --git 
a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java
 
b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java
index 00a67b42eb..dfbbcef60a 100644
--- 
a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java
+++ 
b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java
@@ -27,8 +27,7 @@ import 
org.apache.tinkerpop.gremlin.util.message.RequestMessage;
 import org.apache.tinkerpop.gremlin.util.message.ResponseMessage;
 import org.apache.tinkerpop.gremlin.util.message.ResponseStatusCode;
 
-import java.util.ArrayList;
-import java.util.List;
+import java.util.*;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
@@ -49,7 +48,7 @@ public abstract class AbstractClient implements SimpleClient {
 
     @Override
     public void submit(final RequestMessage requestMessage, final 
Consumer<ResponseMessage> callback) throws Exception {
-        callbackResponseHandler.callback = callback;
+        callbackResponseHandler.callback.put(requestMessage.getRequestId(), 
callback);
         writeAndFlush(requestMessage);
     }
 
@@ -65,7 +64,7 @@ public abstract class AbstractClient implements SimpleClient {
     public CompletableFuture<List<ResponseMessage>> submitAsync(final 
RequestMessage requestMessage) throws Exception {
         final List<ResponseMessage> results = new ArrayList<>();
         final CompletableFuture<List<ResponseMessage>> f = new 
CompletableFuture<>();
-        callbackResponseHandler.callback = response -> {
+        callbackResponseHandler.callback.put(requestMessage.getRequestId(), 
response -> {
             if (f.isDone())
                 throw new RuntimeException("A terminating message was already 
encountered - no more messages should have been received");
 
@@ -75,7 +74,7 @@ public abstract class AbstractClient implements SimpleClient {
             if (response.getStatus().getCode().isFinalResponse()) {
                 f.complete(results);
             }
-        };
+        });
 
         writeAndFlush(requestMessage);
 
@@ -83,11 +82,11 @@ public abstract class AbstractClient implements 
SimpleClient {
     }
 
     static class CallbackResponseHandler extends 
SimpleChannelInboundHandler<ResponseMessage> {
-        public Consumer<ResponseMessage> callback;
+        public Map<UUID, Consumer<ResponseMessage>> callback = new HashMap<>();
 
         @Override
         protected void channelRead0(final ChannelHandlerContext 
channelHandlerContext, final ResponseMessage response) throws Exception {
-            callback.accept(response);
+            callback.get(response.getRequestId()).accept(response);
         }
     }
 }
diff --git 
a/gremlin-javascript/src/main/javascript/gremlin-javascript/test/integration/sasl-authentication-tests.js
 
b/gremlin-javascript/src/main/javascript/gremlin-javascript/test/integration/sasl-authentication-tests.js
index 1fa5850cc2..3d2b937666 100644
--- 
a/gremlin-javascript/src/main/javascript/gremlin-javascript/test/integration/sasl-authentication-tests.js
+++ 
b/gremlin-javascript/src/main/javascript/gremlin-javascript/test/integration/sasl-authentication-tests.js
@@ -54,6 +54,19 @@ describe('DriverRemoteConnection', function () {
           });
       });
 
+      it('should be able to send multiple requests concurrently with valid 
credentials and parse the response', async function () {
+        connection = 
helper.getSecureConnectionWithPlainTextSaslAuthenticator(null, 'stephen', 
'password');
+
+        const submissions = await Promise.all(
+          Array.from({ length: 10 }).map(() => connection.submit(new 
Bytecode().addStep('V', []).addStep('tail', []))),
+        );
+
+        submissions.forEach((response) => {
+          assert.ok(response);
+          assert.ok(response.traversers);
+        });
+      });
+
       it('should send the request with invalid credentials and parse the 
response error', function () {
         connection = 
helper.getSecureConnectionWithPlainTextSaslAuthenticator(null, 'Bob', 
'password');
 
diff --git 
a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java
 
b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java
index 07faaa6f51..4e4fdd8af6 100644
--- 
a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java
+++ 
b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java
@@ -22,40 +22,44 @@ import io.netty.channel.Channel;
 import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.util.Attribute;
-
-import java.net.InetAddress;
-import java.net.InetSocketAddress;
-import java.net.SocketAddress;
-import java.util.Base64;
-import java.util.HashMap;
-import java.util.Map;
-
-import io.netty.util.AttributeMap;
-import org.apache.tinkerpop.gremlin.util.Tokens;
-import org.apache.tinkerpop.gremlin.util.message.RequestMessage;
-import org.apache.tinkerpop.gremlin.util.message.ResponseMessage;
-import org.apache.tinkerpop.gremlin.util.message.ResponseStatusCode;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.tinkerpop.gremlin.server.GremlinServer;
 import org.apache.tinkerpop.gremlin.server.Settings;
-import org.apache.tinkerpop.gremlin.server.auth.AuthenticatedUser;
 import org.apache.tinkerpop.gremlin.server.auth.AuthenticationException;
 import org.apache.tinkerpop.gremlin.server.auth.Authenticator;
 import org.apache.tinkerpop.gremlin.server.authz.Authorizer;
 import org.apache.tinkerpop.gremlin.server.channel.WebSocketChannelizer;
+import org.apache.tinkerpop.gremlin.util.Tokens;
+import org.apache.tinkerpop.gremlin.util.message.RequestMessage;
+import org.apache.tinkerpop.gremlin.util.message.ResponseMessage;
+import org.apache.tinkerpop.gremlin.util.message.ResponseStatusCode;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.time.Duration;
+import java.time.LocalDateTime;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.HashMap;
+import java.util.List;
+import java.util.function.Function;
+
 /**
  * A SASL authentication handler that allows the {@link Authenticator} to be 
plugged into it. This handler is meant
  * to be used with protocols that process a {@link RequestMessage} such as the 
{@link WebSocketChannelizer}
  *
- * @author Stephen Mallette (http://stephen.genoprime.com)
+ * @author Stephen Mallette (<a 
href="http://stephen.genoprime.com";>http://stephen.genoprime.com</a>)
  */
 @ChannelHandler.Sharable
 public class SaslAuthenticationHandler extends AbstractAuthenticationHandler {
     private static final Logger logger = 
LoggerFactory.getLogger(SaslAuthenticationHandler.class);
     private static final Base64.Decoder BASE64_DECODER = Base64.getDecoder();
     private static final Base64.Encoder BASE64_ENCODER = Base64.getEncoder();
+    public static final Duration MAX_REQUEST_DEFERRABLE_DURATION = 
Duration.ofSeconds(5);
     private static final Logger auditLogger = 
LoggerFactory.getLogger(GremlinServer.AUDIT_LOGGER_NAME);
 
     protected final Settings settings;
@@ -75,96 +79,149 @@ public class SaslAuthenticationHandler extends 
AbstractAuthenticationHandler {
 
     @Override
     public void channelRead(final ChannelHandlerContext ctx, final Object msg) 
throws Exception {
-        if (msg instanceof RequestMessage){
-            final RequestMessage requestMessage = (RequestMessage) msg;
-
-            final Attribute<Authenticator.SaslNegotiator> negotiator = 
((AttributeMap) ctx).attr(StateKey.NEGOTIATOR);
-            final Attribute<RequestMessage> request = ((AttributeMap) 
ctx).attr(StateKey.REQUEST_MESSAGE);
-            if (negotiator.get() == null) {
-                try {
-                    // First time through so save the request and send an 
AUTHENTICATE challenge with no data
-                    
negotiator.set(authenticator.newSaslNegotiator(getRemoteInetAddress(ctx)));
-                    request.set(requestMessage);
-                    final ResponseMessage authenticate = 
ResponseMessage.build(requestMessage)
-                            .code(ResponseStatusCode.AUTHENTICATE).create();
-                    ctx.writeAndFlush(authenticate);
-                } catch (Exception ex) {
-                    // newSaslNegotiator can cause troubles - if we don't 
catch and respond nicely the driver seems
-                    // to hang until timeout which isn't so nice. treating 
this like a server error as it means that
-                    // the Authenticator isn't really ready to deal with 
requests for some reason.
-                    logger.error(String.format("%s is not ready to handle 
requests - check its configuration or related services",
-                            authenticator.getClass().getSimpleName()), ex);
-
-                    final ResponseMessage error = 
ResponseMessage.build(requestMessage)
-                            .statusMessage("Authenticator is not ready to 
handle requests")
-                            .code(ResponseStatusCode.SERVER_ERROR).create();
-                    ctx.writeAndFlush(error);
-                }
-            } else {
-                if (requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION) 
&& requestMessage.getArgs().containsKey(Tokens.ARGS_SASL)) {
-                    
-                    final Object saslObject = 
requestMessage.getArgs().get(Tokens.ARGS_SASL);
-                    final byte[] saslResponse;
-                    
-                    if(saslObject instanceof String) {
-                        saslResponse = BASE64_DECODER.decode((String) 
saslObject);
-                    } else {
-                        final ResponseMessage error = 
ResponseMessage.build(request.get())
-                                .statusMessage("Incorrect type for : " + 
Tokens.ARGS_SASL + " - base64 encoded String is expected")
-                                
.code(ResponseStatusCode.REQUEST_ERROR_MALFORMED_REQUEST).create();
-                        ctx.writeAndFlush(error);
-                        return;
-                    }
-
-                    try {
-                        final byte[] saslMessage = 
negotiator.get().evaluateResponse(saslResponse);
-                        if (negotiator.get().isComplete()) {
-                            final AuthenticatedUser user = 
negotiator.get().getAuthenticatedUser();
-                            
ctx.channel().attr(StateKey.AUTHENTICATED_USER).set(user);
-                            // User name logged with the remote socket address 
and authenticator classname for audit logging
-                            if (settings.enableAuditLog) {
-                                String address = 
ctx.channel().remoteAddress().toString();
-                                if (address.startsWith("/") && 
address.length() > 1) address = address.substring(1);
-                                final String[] authClassParts = 
authenticator.getClass().toString().split("[.]");
-                                auditLogger.info("User {} with address {} 
authenticated by {}",
-                                        user.getName(), address, 
authClassParts[authClassParts.length - 1]);
-                            }
-                            // If we have got here we are authenticated so 
remove the handler and pass
-                            // the original message down the pipeline for 
processing
-                            ctx.pipeline().remove(this);
-                            final RequestMessage original = request.get();
-                            ctx.fireChannelRead(original);
-                        } else {
-                            // not done here - send back the sasl message for 
next challenge.
-                            final Map<String,Object> metadata = new 
HashMap<>();
-                            metadata.put(Tokens.ARGS_SASL, 
BASE64_ENCODER.encodeToString(saslMessage));
-                            final ResponseMessage authenticate = 
ResponseMessage.build(requestMessage)
-                                    .statusAttributes(metadata)
-                                    
.code(ResponseStatusCode.AUTHENTICATE).create();
-                            ctx.writeAndFlush(authenticate);
-                        }
-                    } catch (AuthenticationException ae) {
-                        final ResponseMessage error = 
ResponseMessage.build(request.get())
-                                .statusMessage(ae.getMessage())
-                                
.code(ResponseStatusCode.UNAUTHORIZED).create();
-                        ctx.writeAndFlush(error);
-                    }
-                } else {
-                    final ResponseMessage error = 
ResponseMessage.build(requestMessage)
-                            .statusMessage("Failed to authenticate")
-                            .code(ResponseStatusCode.UNAUTHORIZED).create();
-                    ctx.writeAndFlush(error);
-                }
-            }
-        } else {
+        if (!(msg instanceof RequestMessage)) {
             logger.warn("{} only processes RequestMessage instances - received 
{} - channel closing",
                     this.getClass().getSimpleName(), msg.getClass());
             ctx.close();
+            return;
+        }
+
+        final RequestMessage requestMessage = (RequestMessage) msg;
+
+        final Attribute<Authenticator.SaslNegotiator> negotiator = 
ctx.channel().attr(StateKey.NEGOTIATOR);
+        final Attribute<RequestMessage> request = 
ctx.channel().attr(StateKey.REQUEST_MESSAGE);
+        final Attribute<Pair<LocalDateTime, List<RequestMessage>>> 
deferredRequests = ctx.channel().attr(StateKey.DEFERRED_REQUEST_MESSAGES);
+
+        if (negotiator.get() == null) {
+            try {
+                // First time through so save the request and send an 
AUTHENTICATE challenge with no data
+                
negotiator.set(authenticator.newSaslNegotiator(getRemoteInetAddress(ctx)));
+                request.set(requestMessage);
+                final ResponseMessage authenticate = 
ResponseMessage.build(requestMessage)
+                        .code(ResponseStatusCode.AUTHENTICATE).create();
+                ctx.writeAndFlush(authenticate);
+            } catch (Exception ex) {
+                // newSaslNegotiator can cause troubles - if we don't catch 
and respond nicely the driver seems
+                // to hang until timeout which isn't so nice. treating this 
like a server error as it means that
+                // the Authenticator isn't really ready to deal with requests 
for some reason.
+                logger.error(String.format("%s is not ready to handle requests 
- check its configuration or related services",
+                        authenticator.getClass().getSimpleName()), ex);
+
+                respondWithError(
+                    requestMessage,
+                    builder -> builder.statusMessage("Authenticator is not 
ready to handle requests").code(ResponseStatusCode.SERVER_ERROR),
+                    ctx);
+            }
+
+            return;
         }
+
+        // If authentication negotiation is pending, store subsequent 
non-authentication requests for later processing
+        if (negotiator.get() != null && 
!requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION)) {
+            deferredRequests.setIfAbsent(new 
ImmutablePair<>(LocalDateTime.now(), new ArrayList<>()));
+            deferredRequests.get().getValue().add(requestMessage);
+
+            final Duration deferredDuration = 
Duration.between(deferredRequests.get().getKey(), LocalDateTime.now());
+
+            if (deferredDuration.compareTo(MAX_REQUEST_DEFERRABLE_DURATION) > 
0) {
+                respondWithError(
+                    requestMessage,
+                    builder -> builder.statusMessage("Too many unauthenticated 
requests").code(ResponseStatusCode.TOO_MANY_REQUESTS),
+                    ctx);
+                return;
+            }
+
+            return;
+        }
+
+        if (!requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION) || 
!requestMessage.getArgs().containsKey(Tokens.ARGS_SASL)) {
+            respondWithError(
+                requestMessage,
+                builder -> builder.statusMessage("Failed to 
authenticate").code(ResponseStatusCode.UNAUTHORIZED),
+                ctx);
+            return;
+        }
+
+        final Object saslObject = 
requestMessage.getArgs().get(Tokens.ARGS_SASL);
+
+        if (!(saslObject instanceof String)) {
+            respondWithError(
+                requestMessage,
+                builder -> builder
+                    .statusMessage("Incorrect type for : " + Tokens.ARGS_SASL 
+ " - base64 encoded String is expected")
+                    .code(ResponseStatusCode.REQUEST_ERROR_MALFORMED_REQUEST),
+                ctx);
+            return;
+        }
+
+        try {
+            final byte[] saslResponse = BASE64_DECODER.decode((String) 
saslObject);
+            final byte[] saslMessage = 
negotiator.get().evaluateResponse(saslResponse);
+
+            if (!negotiator.get().isComplete()) {
+                // not done here - send back the sasl message for next 
challenge.
+                final HashMap<String, Object> metadata = new HashMap<>();
+                metadata.put(Tokens.ARGS_SASL, 
BASE64_ENCODER.encodeToString(saslMessage));
+                final ResponseMessage authenticate = 
ResponseMessage.build(requestMessage)
+                        .statusAttributes(metadata)
+                        .code(ResponseStatusCode.AUTHENTICATE).create();
+                ctx.writeAndFlush(authenticate);
+                return;
+            }
+
+            final org.apache.tinkerpop.gremlin.server.auth.AuthenticatedUser 
user = negotiator.get().getAuthenticatedUser();
+            ctx.channel().attr(StateKey.AUTHENTICATED_USER).set(user);
+            // User name logged with the remote socket address and 
authenticator classname for audit logging
+            if (settings.enableAuditLog) {
+                String address = ctx.channel().remoteAddress().toString();
+                if (address.startsWith("/") && address.length() > 1) address = 
address.substring(1);
+                final String[] authClassParts = 
authenticator.getClass().toString().split("[.]");
+                auditLogger.info("User {} with address {} authenticated by {}",
+                        user.getName(), address, 
authClassParts[authClassParts.length - 1]);
+            }
+            // If we have got here we are authenticated so remove the handler 
and pass
+            // the original message down the pipeline for processing
+            ctx.pipeline().remove(this);
+            final RequestMessage original = request.get();
+            ctx.fireChannelRead(original);
+
+            // Also send deferred requests if there are any down the pipeline 
for processing
+            if (deferredRequests.get() != null) {
+                
deferredRequests.getAndSet(null).getValue().forEach(ctx::fireChannelRead);
+            }
+        } catch (AuthenticationException ae) {
+            respondWithError(
+                requestMessage,
+                builder -> 
builder.statusMessage(ae.getMessage()).code(ResponseStatusCode.UNAUTHORIZED),
+                ctx);
+        }
+    }
+
+    private void respondWithError(final RequestMessage requestMessage, final 
Function<ResponseMessage.Builder, ResponseMessage.Builder> buildResponse, final 
ChannelHandlerContext ctx) {
+        final Attribute<RequestMessage> originalRequest = 
ctx.channel().attr(StateKey.REQUEST_MESSAGE);
+        final Attribute<Pair<LocalDateTime, List<RequestMessage>>> 
deferredRequests = ctx.channel().attr(StateKey.DEFERRED_REQUEST_MESSAGES);
+
+        if (!requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION)) {
+            
ctx.write(buildResponse.apply(ResponseMessage.build(requestMessage)).create());
+        }
+
+        if (originalRequest.get() != null) {
+            
ctx.write(buildResponse.apply(ResponseMessage.build(originalRequest.get())).create());
+        }
+
+        if (deferredRequests.get() != null) {
+            deferredRequests
+                .getAndSet(null).getValue().stream()
+                .map(ResponseMessage::build)
+                .map(buildResponse)
+                .map(ResponseMessage.Builder::create)
+                .forEach(ctx::write);
+        }
+
+        ctx.flush();
     }
 
-    private InetAddress getRemoteInetAddress(final ChannelHandlerContext ctx)
-    {
+    private InetAddress getRemoteInetAddress(final ChannelHandlerContext ctx) {
         final Channel channel = ctx.channel();
 
         if (null == channel)
@@ -172,9 +229,9 @@ public class SaslAuthenticationHandler extends 
AbstractAuthenticationHandler {
 
         final SocketAddress genericSocketAddr = channel.remoteAddress();
 
-        if (null == genericSocketAddr || !(genericSocketAddr instanceof 
InetSocketAddress))
+        if (!(genericSocketAddr instanceof InetSocketAddress))
             return null;
 
-        return ((InetSocketAddress)genericSocketAddr).getAddress();
+        return ((InetSocketAddress) genericSocketAddr).getAddress();
     }
 }
diff --git 
a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/StateKey.java
 
b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/StateKey.java
index cfb2e320ad..4f693942e1 100644
--- 
a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/StateKey.java
+++ 
b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/StateKey.java
@@ -18,6 +18,7 @@
  */
 package org.apache.tinkerpop.gremlin.server.handler;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.tinkerpop.gremlin.util.MessageSerializer;
 import org.apache.tinkerpop.gremlin.util.message.RequestMessage;
 import org.apache.tinkerpop.gremlin.server.auth.AuthenticatedUser;
@@ -25,6 +26,9 @@ import org.apache.tinkerpop.gremlin.server.auth.Authenticator;
 import org.apache.tinkerpop.gremlin.server.op.session.Session;
 import io.netty.util.AttributeKey;
 
+import java.time.LocalDateTime;
+import java.util.List;
+
 /**
  * Keys used in the various handlers to store state in the pipeline.
  *
@@ -59,6 +63,11 @@ public final class StateKey {
      */
     public static final AttributeKey<RequestMessage> REQUEST_MESSAGE = 
AttributeKey.valueOf("request");
 
+    /**
+     * The key for the deferred requests.
+     */
+    public static final AttributeKey<Pair<LocalDateTime, 
List<RequestMessage>>> DEFERRED_REQUEST_MESSAGES = 
AttributeKey.valueOf("deferredRequests");
+
     /**
      * The key for the current {@link AuthenticatedUser}.
      */
diff --git 
a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
 
b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
index e39d8c2b44..15dcd041f7 100644
--- 
a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
+++ 
b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
@@ -18,24 +18,32 @@
  */
 package org.apache.tinkerpop.gremlin.server;
 
-import org.apache.tinkerpop.gremlin.util.ExceptionHelper;
 import org.apache.tinkerpop.gremlin.driver.Client;
 import org.apache.tinkerpop.gremlin.driver.Cluster;
+import org.apache.tinkerpop.gremlin.driver.Result;
+import org.apache.tinkerpop.gremlin.driver.ResultSet;
 import org.apache.tinkerpop.gremlin.driver.exception.NoHostAvailableException;
 import org.apache.tinkerpop.gremlin.driver.exception.ResponseException;
+import org.apache.tinkerpop.gremlin.driver.simple.WebSocketClient;
 import org.apache.tinkerpop.gremlin.server.auth.SimpleAuthenticator;
-import org.ietf.jgss.GSSException;
+import org.apache.tinkerpop.gremlin.server.handler.SaslAuthenticationHandler;
 import org.apache.tinkerpop.gremlin.structure.Property;
 import org.apache.tinkerpop.gremlin.structure.Vertex;
+import org.apache.tinkerpop.gremlin.util.ExceptionHelper;
+import org.apache.tinkerpop.gremlin.util.message.ResponseMessage;
+import org.apache.tinkerpop.gremlin.util.message.ResponseStatusCode;
+import org.apache.tinkerpop.gremlin.util.ser.Serializers;
+import org.ietf.jgss.GSSException;
 import org.junit.Test;
 
+import java.time.Duration;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 
-import org.apache.tinkerpop.gremlin.util.ser.Serializers;
-
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.AnyOf.anyOf;
 import static org.hamcrest.core.IsInstanceOf.instanceOf;
@@ -163,6 +171,46 @@ public class GremlinServerAuthIntegrateTest extends 
AbstractGremlinServerIntegra
         }
     }
 
+    @Test
+    public void 
shouldFailAuthenticateWithUnAuthenticatedRequestAfterMaxDeferrableDuration() 
throws Exception {
+        try (WebSocketClient client = 
TestClientFactory.createWebSocketClient()) {
+            // First request will initiate the authentication handshake
+            // Subsequent requests will be deferred
+            CompletableFuture<List<ResponseMessage>> future1 = 
client.submitAsync("");
+            CompletableFuture<List<ResponseMessage>> future2 = 
client.submitAsync("");
+            CompletableFuture<List<ResponseMessage>> future3 = 
client.submitAsync("");
+
+            // After the maximum allowed deferred request duration,
+            // any non-authenticated request will invalidate all requests with 
429 error
+            CompletableFuture<List<ResponseMessage>> future4 = 
CompletableFuture.runAsync(() -> {
+                try {
+                    
Thread.sleep(SaslAuthenticationHandler.MAX_REQUEST_DEFERRABLE_DURATION.plus(Duration.ofSeconds(1)).toMillis());
+                } catch (InterruptedException e) {
+                    throw new RuntimeException(e);
+                }
+            }).thenCompose((__) -> {
+                try {
+                    return client.submitAsync("");
+                } catch (Exception e) {
+                    throw new RuntimeException(e);
+                }
+            });
+
+            List<ResponseMessage> responses = new ArrayList<>();
+
+            responses.addAll(future1.join());
+            responses.addAll(future2.join());
+            responses.addAll(future3.join());
+            responses.addAll(future4.join());
+
+            for (ResponseMessage response : responses) {
+                if (response.getStatus().getCode() != 
ResponseStatusCode.AUTHENTICATE) {
+                    assertEquals(ResponseStatusCode.TOO_MANY_REQUESTS, 
response.getStatus().getCode());
+                }
+            }
+        }
+    }
+
     @Test
     public void shouldAuthenticateWithPlainTextOverDefaultJSONSerialization() 
throws Exception {
         final Cluster cluster = 
TestClientFactory.build().serializer(Serializers.GRAPHSON)
@@ -218,9 +266,19 @@ public class GremlinServerAuthIntegrateTest extends 
AbstractGremlinServerIntegra
 
     private static void assertConnection(final Cluster cluster, final Client 
client) throws InterruptedException, ExecutionException {
         try {
-            assertEquals(2, client.submit("1+1").all().get().get(0).getInt());
-            assertEquals(3, client.submit("1+2").all().get().get(0).getInt());
-            assertEquals(4, client.submit("1+3").all().get().get(0).getInt());
+            CompletableFuture<List<Result>> future1 = 
client.submitAsync("1+1").thenCompose(ResultSet::all);
+            CompletableFuture<List<Result>> future2 = 
client.submitAsync("1+2").thenCompose(ResultSet::all);
+            CompletableFuture<List<Result>> future3 = 
client.submitAsync("1+3").thenCompose(ResultSet::all);
+            CompletableFuture<List<Result>> future4 = 
client.submitAsync("1+4").thenCompose(ResultSet::all);
+            CompletableFuture<List<Result>> future5 = 
client.submitAsync("1+5").thenCompose(ResultSet::all);
+            CompletableFuture<List<Result>> future6 = 
client.submitAsync("1+6").thenCompose(ResultSet::all);
+
+            assertEquals(2, future1.join().get(0).getInt());
+            assertEquals(3, future2.join().get(0).getInt());
+            assertEquals(4, future3.join().get(0).getInt());
+            assertEquals(5, future4.join().get(0).getInt());
+            assertEquals(6, future5.join().get(0).getInt());
+            assertEquals(7, future6.join().get(0).getInt());
         } finally {
             cluster.close();
         }

Reply via email to