This is an automated email from the ASF dual-hosted git repository. kenhuuu pushed a commit to branch sasl-initial-authn in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit f1f39dca56964d03d3cfef0c21cd0ab7bcb4c407 Author: Ken Hu <[email protected]> AuthorDate: Mon Apr 8 14:28:11 2024 -0700 Fix nits in PR for TINKERPOP-3061 regarding sasl authentication. --- .../gremlin/driver/simple/AbstractClient.java | 15 +++-- .../server/handler/SaslAuthenticationHandler.java | 56 +++++++++--------- .../server/GremlinServerAuthIntegrateTest.java | 69 +++++++++++++--------- 3 files changed, 78 insertions(+), 62 deletions(-) diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java index dfbbcef60a..c353f45c07 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/AbstractClient.java @@ -25,9 +25,12 @@ import io.netty.channel.nio.NioEventLoopGroup; import org.apache.commons.lang3.concurrent.BasicThreadFactory; import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import org.apache.tinkerpop.gremlin.util.message.ResponseMessage; -import org.apache.tinkerpop.gremlin.util.message.ResponseStatusCode; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; @@ -48,7 +51,7 @@ public abstract class AbstractClient implements SimpleClient { @Override public void submit(final RequestMessage requestMessage, final Consumer<ResponseMessage> callback) throws Exception { - callbackResponseHandler.callback.put(requestMessage.getRequestId(), callback); + callbackResponseHandler.callbackByRequestId.put(requestMessage.getRequestId(), callback); writeAndFlush(requestMessage); } @@ -64,7 +67,7 @@ public abstract class AbstractClient implements SimpleClient { public CompletableFuture<List<ResponseMessage>> submitAsync(final RequestMessage requestMessage) throws Exception { final List<ResponseMessage> results = new ArrayList<>(); final CompletableFuture<List<ResponseMessage>> f = new CompletableFuture<>(); - callbackResponseHandler.callback.put(requestMessage.getRequestId(), response -> { + callbackResponseHandler.callbackByRequestId.put(requestMessage.getRequestId(), response -> { if (f.isDone()) throw new RuntimeException("A terminating message was already encountered - no more messages should have been received"); @@ -82,11 +85,11 @@ public abstract class AbstractClient implements SimpleClient { } static class CallbackResponseHandler extends SimpleChannelInboundHandler<ResponseMessage> { - public Map<UUID, Consumer<ResponseMessage>> callback = new HashMap<>(); + public Map<UUID, Consumer<ResponseMessage>> callbackByRequestId = new HashMap<>(); @Override protected void channelRead0(final ChannelHandlerContext channelHandlerContext, final ResponseMessage response) throws Exception { - callback.get(response.getRequestId()).accept(response); + callbackByRequestId.get(response.getRequestId()).accept(response); } } } diff --git a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java index 4e4fdd8af6..34345ecf5e 100644 --- a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java +++ b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/SaslAuthenticationHandler.java @@ -108,16 +108,14 @@ public class SaslAuthenticationHandler extends AbstractAuthenticationHandler { authenticator.getClass().getSimpleName()), ex); respondWithError( - requestMessage, - builder -> builder.statusMessage("Authenticator is not ready to handle requests").code(ResponseStatusCode.SERVER_ERROR), - ctx); + requestMessage, + builder -> builder.statusMessage("Authenticator is not ready to handle requests").code(ResponseStatusCode.SERVER_ERROR), + ctx); } return; - } - - // If authentication negotiation is pending, store subsequent non-authentication requests for later processing - if (negotiator.get() != null && !requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION)) { + } else if (!requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION)) { + // If authentication negotiation is pending, store subsequent non-authentication requests for later processing deferredRequests.setIfAbsent(new ImmutablePair<>(LocalDateTime.now(), new ArrayList<>())); deferredRequests.get().getValue().add(requestMessage); @@ -125,20 +123,20 @@ public class SaslAuthenticationHandler extends AbstractAuthenticationHandler { if (deferredDuration.compareTo(MAX_REQUEST_DEFERRABLE_DURATION) > 0) { respondWithError( - requestMessage, - builder -> builder.statusMessage("Too many unauthenticated requests").code(ResponseStatusCode.TOO_MANY_REQUESTS), - ctx); + requestMessage, + builder -> builder.statusMessage("Authentication did not finish in the allowed duration (" + MAX_REQUEST_DEFERRABLE_DURATION + "s).") + .code(ResponseStatusCode.UNAUTHORIZED), + ctx); return; } return; - } - - if (!requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION) || !requestMessage.getArgs().containsKey(Tokens.ARGS_SASL)) { + } else if (!requestMessage.getArgs().containsKey(Tokens.ARGS_SASL)) { + // This is an authentication request that is missing a "sasl" argument. respondWithError( - requestMessage, - builder -> builder.statusMessage("Failed to authenticate").code(ResponseStatusCode.UNAUTHORIZED), - ctx); + requestMessage, + builder -> builder.statusMessage("Failed to authenticate").code(ResponseStatusCode.UNAUTHORIZED), + ctx); return; } @@ -146,11 +144,11 @@ public class SaslAuthenticationHandler extends AbstractAuthenticationHandler { if (!(saslObject instanceof String)) { respondWithError( - requestMessage, - builder -> builder - .statusMessage("Incorrect type for : " + Tokens.ARGS_SASL + " - base64 encoded String is expected") - .code(ResponseStatusCode.REQUEST_ERROR_MALFORMED_REQUEST), - ctx); + requestMessage, + builder -> builder + .statusMessage("Incorrect type for : " + Tokens.ARGS_SASL + " - base64 encoded String is expected") + .code(ResponseStatusCode.REQUEST_ERROR_MALFORMED_REQUEST), + ctx); return; } @@ -191,9 +189,9 @@ public class SaslAuthenticationHandler extends AbstractAuthenticationHandler { } } catch (AuthenticationException ae) { respondWithError( - requestMessage, - builder -> builder.statusMessage(ae.getMessage()).code(ResponseStatusCode.UNAUTHORIZED), - ctx); + requestMessage, + builder -> builder.statusMessage(ae.getMessage()).code(ResponseStatusCode.UNAUTHORIZED), + ctx); } } @@ -211,11 +209,11 @@ public class SaslAuthenticationHandler extends AbstractAuthenticationHandler { if (deferredRequests.get() != null) { deferredRequests - .getAndSet(null).getValue().stream() - .map(ResponseMessage::build) - .map(buildResponse) - .map(ResponseMessage.Builder::create) - .forEach(ctx::write); + .getAndSet(null).getValue().stream() + .map(ResponseMessage::build) + .map(buildResponse) + .map(ResponseMessage.Builder::create) + .forEach(ctx::write); } ctx.flush(); diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java index 15dcd041f7..01367ec49c 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java @@ -30,6 +30,8 @@ import org.apache.tinkerpop.gremlin.server.handler.SaslAuthenticationHandler; import org.apache.tinkerpop.gremlin.structure.Property; import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.util.ExceptionHelper; +import org.apache.tinkerpop.gremlin.util.Tokens; +import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import org.apache.tinkerpop.gremlin.util.message.ResponseMessage; import org.apache.tinkerpop.gremlin.util.message.ResponseStatusCode; import org.apache.tinkerpop.gremlin.util.ser.Serializers; @@ -43,6 +45,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.AnyOf.anyOf; @@ -176,13 +179,13 @@ public class GremlinServerAuthIntegrateTest extends AbstractGremlinServerIntegra try (WebSocketClient client = TestClientFactory.createWebSocketClient()) { // First request will initiate the authentication handshake // Subsequent requests will be deferred - CompletableFuture<List<ResponseMessage>> future1 = client.submitAsync(""); - CompletableFuture<List<ResponseMessage>> future2 = client.submitAsync(""); - CompletableFuture<List<ResponseMessage>> future3 = client.submitAsync(""); + CompletableFuture<List<ResponseMessage>> futureOfRequestWithinAuthDuration1 = client.submitAsync(""); + CompletableFuture<List<ResponseMessage>> futureOfRequestWithinAuthDuration2 = client.submitAsync(""); + CompletableFuture<List<ResponseMessage>> futureOfRequestWithinAuthDuration3 = client.submitAsync(""); // After the maximum allowed deferred request duration, // any non-authenticated request will invalidate all requests with 429 error - CompletableFuture<List<ResponseMessage>> future4 = CompletableFuture.runAsync(() -> { + CompletableFuture<List<ResponseMessage>> futureOfRequestSubmittedTooLate = CompletableFuture.runAsync(() -> { try { Thread.sleep(SaslAuthenticationHandler.MAX_REQUEST_DEFERRABLE_DURATION.plus(Duration.ofSeconds(1)).toMillis()); } catch (InterruptedException e) { @@ -196,18 +199,40 @@ public class GremlinServerAuthIntegrateTest extends AbstractGremlinServerIntegra } }); - List<ResponseMessage> responses = new ArrayList<>(); + assertEquals(2, futureOfRequestWithinAuthDuration1.get().size()); + assertEquals(1, futureOfRequestWithinAuthDuration2.get().size()); + assertEquals(1, futureOfRequestWithinAuthDuration3.get().size()); + assertEquals(1, futureOfRequestSubmittedTooLate.get().size()); - responses.addAll(future1.join()); - responses.addAll(future2.join()); - responses.addAll(future3.join()); - responses.addAll(future4.join()); + assertEquals(ResponseStatusCode.AUTHENTICATE, futureOfRequestWithinAuthDuration1.get().get(0).getStatus().getCode()); + assertEquals(ResponseStatusCode.UNAUTHORIZED, futureOfRequestWithinAuthDuration1.get().get(1).getStatus().getCode()); + assertEquals(ResponseStatusCode.UNAUTHORIZED, futureOfRequestWithinAuthDuration2.get().get(0).getStatus().getCode()); + assertEquals(ResponseStatusCode.UNAUTHORIZED, futureOfRequestWithinAuthDuration3.get().get(0).getStatus().getCode()); + assertEquals(ResponseStatusCode.UNAUTHORIZED, futureOfRequestSubmittedTooLate.get().get(0).getStatus().getCode()); + } + } - for (ResponseMessage response : responses) { - if (response.getStatus().getCode() != ResponseStatusCode.AUTHENTICATE) { - assertEquals(ResponseStatusCode.TOO_MANY_REQUESTS, response.getStatus().getCode()); - } - } + @Test + public void shouldFailAuthenticateWithIncorrectParallelRequests() throws Exception { + try (WebSocketClient client = TestClientFactory.createWebSocketClient()) { + + CompletableFuture<List<ResponseMessage>> firstRequest = client.submitAsync("1"); + CompletableFuture<List<ResponseMessage>> secondRequest = client.submitAsync("2"); + CompletableFuture<List<ResponseMessage>> thirdRequest = client.submitAsync("3"); + + Thread.sleep(500); + + // send some incorrect value for username password which should cause all requests to fail. + client.submitAsync(RequestMessage.build(Tokens.OPS_AUTHENTICATION).addArg(Tokens.ARGS_SASL, "someincorrectvalue").create()); + + assertEquals(2, firstRequest.get().size()); + assertEquals(1, secondRequest.get().size()); + assertEquals(1, thirdRequest.get().size()); + + assertEquals(ResponseStatusCode.AUTHENTICATE, firstRequest.get().get(0).getStatus().getCode()); + assertEquals(ResponseStatusCode.UNAUTHORIZED, firstRequest.get().get(1).getStatus().getCode()); + assertEquals(ResponseStatusCode.UNAUTHORIZED, secondRequest.get().get(0).getStatus().getCode()); + assertEquals(ResponseStatusCode.UNAUTHORIZED, thirdRequest.get().get(0).getStatus().getCode()); } } @@ -266,19 +291,9 @@ public class GremlinServerAuthIntegrateTest extends AbstractGremlinServerIntegra private static void assertConnection(final Cluster cluster, final Client client) throws InterruptedException, ExecutionException { try { - CompletableFuture<List<Result>> future1 = client.submitAsync("1+1").thenCompose(ResultSet::all); - CompletableFuture<List<Result>> future2 = client.submitAsync("1+2").thenCompose(ResultSet::all); - CompletableFuture<List<Result>> future3 = client.submitAsync("1+3").thenCompose(ResultSet::all); - CompletableFuture<List<Result>> future4 = client.submitAsync("1+4").thenCompose(ResultSet::all); - CompletableFuture<List<Result>> future5 = client.submitAsync("1+5").thenCompose(ResultSet::all); - CompletableFuture<List<Result>> future6 = client.submitAsync("1+6").thenCompose(ResultSet::all); - - assertEquals(2, future1.join().get(0).getInt()); - assertEquals(3, future2.join().get(0).getInt()); - assertEquals(4, future3.join().get(0).getInt()); - assertEquals(5, future4.join().get(0).getInt()); - assertEquals(6, future5.join().get(0).getInt()); - assertEquals(7, future6.join().get(0).getInt()); + assertEquals(2, client.submit("1+1").all().get().get(0).getInt()); + assertEquals(3, client.submit("1+2").all().get().get(0).getInt()); + assertEquals(4, client.submit("1+3").all().get().get(0).getInt()); } finally { cluster.close(); }
