[ https://issues.apache.org/jira/browse/KAFKA-6950?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16598593#comment-16598593 ]
ASF GitHub Bot commented on KAFKA-6950: --------------------------------------- rajinisivaram closed pull request #5082: KAFKA-6950: Delay response to failed client authentication to prevent potential DoS issues (KIP-306) URL: https://github.com/apache/kafka/pull/5082 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java index f6458c6f22d..7a05eba03f2 100644 --- a/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java +++ b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java @@ -40,6 +40,10 @@ public AuthenticationException(String message) { super(message); } + public AuthenticationException(Throwable cause) { + super(cause); + } + public AuthenticationException(String message, Throwable cause) { super(message, cause); } diff --git a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java index 4e2e7273a68..33c2e908516 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java @@ -37,6 +37,14 @@ */ void authenticate() throws AuthenticationException, IOException; + /** + * Perform any processing related to authentication failure. This is invoked when the channel is about to be closed + * because of an {@link AuthenticationException} thrown from a prior {@link #authenticate()} call. + * @throws IOException if read/write fails due to an I/O error + */ + default void handleAuthenticationFailure() throws IOException { + } + /** * Returns Principal using PrincipalBuilder */ @@ -46,5 +54,4 @@ * returns true if authentication is complete otherwise returns false; */ boolean complete(); - } diff --git a/clients/src/main/java/org/apache/kafka/common/network/DelayedResponseAuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/network/DelayedResponseAuthenticationException.java new file mode 100644 index 00000000000..8474426c609 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/DelayedResponseAuthenticationException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.errors.AuthenticationException; + +public class DelayedResponseAuthenticationException extends AuthenticationException { + private static final long serialVersionUID = 1L; + + public DelayedResponseAuthenticationException(Throwable cause) { + super(cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java index 1839729f2e7..17dc6a33ef2 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java +++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java @@ -120,15 +120,22 @@ public KafkaPrincipal principal() { * authentication. For SASL, authentication is performed by {@link Authenticator#authenticate()}. */ public void prepare() throws AuthenticationException, IOException { + boolean authenticating = false; try { if (!transportLayer.ready()) transportLayer.handshake(); - if (transportLayer.ready() && !authenticator.complete()) + if (transportLayer.ready() && !authenticator.complete()) { + authenticating = true; authenticator.authenticate(); + } } catch (AuthenticationException e) { // Clients are notified of authentication exceptions to enable operations to be terminated // without retries. Other errors are handled as network exceptions in Selector. state = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, e); + if (authenticating) { + delayCloseOnAuthenticationFailure(); + throw new DelayedResponseAuthenticationException(e); + } throw e; } if (ready()) @@ -236,6 +243,24 @@ public ChannelMuteState muteState() { return muteState; } + /** + * Delay channel close on authentication failure. This will remove all read/write operations from the channel until + * {@link #completeCloseOnAuthenticationFailure()} is called to finish up the channel close. + */ + private void delayCloseOnAuthenticationFailure() { + transportLayer.removeInterestOps(SelectionKey.OP_WRITE); + } + + /** + * Finish up any processing on {@link #prepare()} failure. + * @throws IOException + */ + void completeCloseOnAuthenticationFailure() throws IOException { + transportLayer.addInterestOps(SelectionKey.OP_WRITE); + // Invoke the underlying handler to finish up any processing on authentication failure + authenticator.handleAuthenticationFailure(); + } + /** * Returns true if this channel has been explicitly muted using {@link KafkaChannel#mute()} */ diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index 7e32509933e..806bda700e3 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -86,6 +86,7 @@ public class Selector implements Selectable, AutoCloseable { public static final long NO_IDLE_TIMEOUT_MS = -1; + public static final int NO_FAILED_AUTHENTICATION_DELAY = 0; private enum CloseMode { GRACEFUL(true), // process outstanding staged receives, notify disconnect @@ -119,8 +120,11 @@ private final int maxReceiveSize; private final boolean recordTimePerConnection; private final IdleExpiryManager idleExpiryManager; + private final LinkedHashMap<String, DelayedAuthenticationFailureClose> delayedClosingChannels; private final MemoryPool memoryPool; private final long lowMemThreshold; + private final int failedAuthenticationDelayMs; + //indicates if the previous call to poll was able to make progress in reading already-buffered data. //this is used to prevent tight loops when memory is not available to read any more data private boolean madeReadProgressLastPoll = true; @@ -129,6 +133,8 @@ * Create a new nioSelector * @param maxReceiveSize Max size in bytes of a single network receive (use {@link NetworkReceive#UNLIMITED} for no limit) * @param connectionMaxIdleMs Max idle connection time (use {@link #NO_IDLE_TIMEOUT_MS} to disable idle timeout) + * @param failedAuthenticationDelayMs Minimum time by which failed authentication response and channel close should be delayed by. + * Use {@link #NO_FAILED_AUTHENTICATION_DELAY} to disable this delay. * @param metrics Registry for Selector metrics * @param time Time implementation * @param metricGrpPrefix Prefix for the group of metrics registered by Selector @@ -139,6 +145,7 @@ */ public Selector(int maxReceiveSize, long connectionMaxIdleMs, + int failedAuthenticationDelayMs, Metrics metrics, Time time, String metricGrpPrefix, @@ -174,8 +181,39 @@ public Selector(int maxReceiveSize, this.memoryPool = memoryPool; this.lowMemThreshold = (long) (0.1 * this.memoryPool.size()); this.log = logContext.logger(Selector.class); + this.failedAuthenticationDelayMs = failedAuthenticationDelayMs; + this.delayedClosingChannels = (failedAuthenticationDelayMs > NO_FAILED_AUTHENTICATION_DELAY) ? new LinkedHashMap<String, DelayedAuthenticationFailureClose>() : null; + } + + public Selector(int maxReceiveSize, + long connectionMaxIdleMs, + Metrics metrics, + Time time, + String metricGrpPrefix, + Map<String, String> metricTags, + boolean metricsPerConnection, + boolean recordTimePerConnection, + ChannelBuilder channelBuilder, + MemoryPool memoryPool, + LogContext logContext) { + this(maxReceiveSize, connectionMaxIdleMs, NO_FAILED_AUTHENTICATION_DELAY, metrics, time, metricGrpPrefix, metricTags, + metricsPerConnection, recordTimePerConnection, channelBuilder, memoryPool, logContext); } + public Selector(int maxReceiveSize, + long connectionMaxIdleMs, + int failedAuthenticationDelayMs, + Metrics metrics, + Time time, + String metricGrpPrefix, + Map<String, String> metricTags, + boolean metricsPerConnection, + ChannelBuilder channelBuilder, + LogContext logContext) { + this(maxReceiveSize, connectionMaxIdleMs, failedAuthenticationDelayMs, metrics, time, metricGrpPrefix, metricTags, metricsPerConnection, false, channelBuilder, MemoryPool.NONE, logContext); + } + + public Selector(int maxReceiveSize, long connectionMaxIdleMs, Metrics metrics, @@ -185,13 +223,17 @@ public Selector(int maxReceiveSize, boolean metricsPerConnection, ChannelBuilder channelBuilder, LogContext logContext) { - this(maxReceiveSize, connectionMaxIdleMs, metrics, time, metricGrpPrefix, metricTags, metricsPerConnection, false, channelBuilder, MemoryPool.NONE, logContext); + this(maxReceiveSize, connectionMaxIdleMs, NO_FAILED_AUTHENTICATION_DELAY, metrics, time, metricGrpPrefix, metricTags, metricsPerConnection, channelBuilder, logContext); } public Selector(long connectionMaxIdleMS, Metrics metrics, Time time, String metricGrpPrefix, ChannelBuilder channelBuilder, LogContext logContext) { this(NetworkReceive.UNLIMITED, connectionMaxIdleMS, metrics, time, metricGrpPrefix, Collections.<String, String>emptyMap(), true, channelBuilder, logContext); } + public Selector(long connectionMaxIdleMS, int failedAuthenticationDelayMs, Metrics metrics, Time time, String metricGrpPrefix, ChannelBuilder channelBuilder, LogContext logContext) { + this(NetworkReceive.UNLIMITED, connectionMaxIdleMS, failedAuthenticationDelayMs, metrics, time, metricGrpPrefix, Collections.<String, String>emptyMap(), true, channelBuilder, logContext); + } + /** * Begin connecting to the given address and add the connection to this nioSelector associated with the given id * number. @@ -435,6 +477,9 @@ public void poll(long timeout) throws IOException { long endIo = time.nanoseconds(); this.sensors.ioTime.record(endIo - endSelect, time.milliseconds()); + // Close channels that were delayed and are now ready to be closed + completeDelayedChannelClose(endIo); + // we use the time at the end of select to ensure that we don't close any connections that // have just been processed in pollSelectionKeys maybeCloseOldestConnection(endSelect); @@ -457,15 +502,14 @@ void pollSelectionKeys(Set<SelectionKey> selectionKeys, for (SelectionKey key : determineHandlingOrder(selectionKeys)) { KafkaChannel channel = channel(key); long channelStartTimeNanos = recordTimePerConnection ? time.nanoseconds() : 0; + boolean sendFailed = false; // register all per-connection metrics at once sensors.maybeRegisterConnectionMetrics(channel.id()); if (idleExpiryManager != null) idleExpiryManager.update(channel.id(), currentTimeNanos); - boolean sendFailed = false; try { - /* complete any connections that have finished their handshake (either normally or immediately) */ if (isImmediatelyConnected || key.isConnectable()) { if (channel.finishConnect()) { @@ -477,8 +521,9 @@ void pollSelectionKeys(Set<SelectionKey> selectionKeys, socketChannel.socket().getSendBufferSize(), socketChannel.socket().getSoTimeout(), channel.id()); - } else + } else { continue; + } } /* if channel is not ready finish prepare */ @@ -532,7 +577,11 @@ else if (e instanceof AuthenticationException) // will be logged later as error log.debug("Connection with {} disconnected due to authentication exception", desc, e); else log.warn("Unexpected error from {}; closing connection", desc, e); - close(channel, sendFailed ? CloseMode.NOTIFY_ONLY : CloseMode.GRACEFUL); + + if (e instanceof DelayedResponseAuthenticationException) + maybeDelayCloseOnAuthenticationFailure(channel); + else + close(channel, sendFailed ? CloseMode.NOTIFY_ONLY : CloseMode.GRACEFUL); } finally { maybeRecordTimePerConnection(channel, channelStartTimeNanos); } @@ -631,6 +680,18 @@ public void unmuteAll() { unmute(channel); } + // package-private for testing + void completeDelayedChannelClose(long currentTimeNanos) { + if (delayedClosingChannels == null) + return; + + while (!delayedClosingChannels.isEmpty()) { + DelayedAuthenticationFailureClose delayedClose = delayedClosingChannels.values().iterator().next(); + if (!delayedClose.tryClose(currentTimeNanos)) + break; + } + } + private void maybeCloseOldestConnection(long currentTimeNanos) { if (idleExpiryManager == null) return; @@ -657,6 +718,7 @@ private void clear() { this.completedReceives.clear(); this.connected.clear(); this.disconnected.clear(); + // Remove closed channels after all their staged receives have been processed or if a send was requested for (Iterator<Map.Entry<String, KafkaChannel>> it = closingChannels.entrySet().iterator(); it.hasNext(); ) { KafkaChannel channel = it.next().getValue(); @@ -667,6 +729,7 @@ private void clear() { it.remove(); } } + for (String channel : this.failedSends) this.disconnected.put(channel, ChannelState.FAILED_SEND); this.failedSends.clear(); @@ -707,6 +770,24 @@ public void close(String id) { } } + private void maybeDelayCloseOnAuthenticationFailure(KafkaChannel channel) { + DelayedAuthenticationFailureClose delayedClose = new DelayedAuthenticationFailureClose(channel, failedAuthenticationDelayMs); + if (delayedClosingChannels != null) + delayedClosingChannels.put(channel.id(), delayedClose); + else + delayedClose.closeNow(); + } + + private void handleCloseOnAuthenticationFailure(KafkaChannel channel) { + try { + channel.completeCloseOnAuthenticationFailure(); + } catch (Exception e) { + log.error("Exception handling close on authentication failure node {}", channel.id(), e); + } finally { + close(channel, CloseMode.GRACEFUL); + } + } + /** * Begin closing this connection. * If 'closeMode' is `CloseMode.GRACEFUL`, the channel is disconnected here, but staged receives @@ -735,10 +816,14 @@ private void close(KafkaChannel channel, CloseMode closeMode) { // stagedReceives will be moved to completedReceives later along with receives from other channels closingChannels.put(channel.id(), channel); log.debug("Tracking closing connection {} to process outstanding requests", channel.id()); - } else + } else { doClose(channel, closeMode.notifyDisconnect); + } this.channels.remove(channel.id()); + if (delayedClosingChannels != null) + delayedClosingChannels.remove(channel.id()); + if (idleExpiryManager != null) idleExpiryManager.remove(channel.id()); } @@ -1064,6 +1149,46 @@ public void close() { } } + /** + * Encapsulate a channel that must be closed after a specific delay has elapsed due to authentication failure. + */ + private class DelayedAuthenticationFailureClose { + private final KafkaChannel channel; + private final long endTimeNanos; + private boolean closed; + + /** + * @param channel The channel whose close is being delayed + * @param delayMs The amount of time by which the operation should be delayed + */ + public DelayedAuthenticationFailureClose(KafkaChannel channel, int delayMs) { + this.channel = channel; + this.endTimeNanos = time.nanoseconds() + (delayMs * 1000L * 1000L); + this.closed = false; + } + + /** + * Try to close this channel if the delay has expired. + * @param currentTimeNanos The current time + * @return True if the delay has expired and the channel was closed; false otherwise + */ + public final boolean tryClose(long currentTimeNanos) { + if (endTimeNanos <= currentTimeNanos) + closeNow(); + return closed; + } + + /** + * Close the channel now, regardless of whether the delay has expired or not. + */ + public final void closeNow() { + if (closed) + throw new IllegalStateException("Attempt to close a channel that has already been closed"); + handleCloseOnAuthenticationFailure(channel); + closed = true; + } + } + // helper class for tracking least recently used connections to enable idle connection closing private static class IdleExpiryManager { private final Map<String, Long> lruConnections; @@ -1114,4 +1239,9 @@ boolean isOutOfMemory() { boolean isMadeReadProgressLastPoll() { return madeReadProgressLastPoll; } + + // package-private for testing + Map<?, ?> delayedClosingChannels() { + return delayedClosingChannels; + } } diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java index e8f77a53e22..43cb0a44870 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java @@ -118,6 +118,7 @@ // buffers used in `authenticate` private NetworkReceive netInBuffer; private Send netOutBuffer; + private Send authenticationFailureSend = null; // flag indicating if sasl tokens are sent as Kafka SaslAuthenticate request/responses private boolean enableKafkaSaslAuthenticateHeaders; @@ -294,6 +295,11 @@ public boolean complete() { return saslState == SaslState.COMPLETE; } + @Override + public void handleAuthenticationFailure() throws IOException { + sendAuthenticationFailureResponse(); + } + @Override public void close() throws IOException { if (principalBuilder instanceof Closeable) @@ -362,7 +368,7 @@ private void handleSaslToken(byte[] clientToken) throws IOException { RequestAndSize requestAndSize = requestContext.parseRequest(requestBuffer); if (apiKey != ApiKeys.SASL_AUTHENTICATE) { IllegalSaslStateException e = new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL authentication."); - sendKafkaResponse(requestContext, requestAndSize.request.getErrorResponse(e)); + buildResponseOnAuthenticateFailure(requestContext, requestAndSize.request.getErrorResponse(e)); throw e; } if (!apiKey.isVersionSupported(version)) { @@ -378,7 +384,8 @@ private void handleSaslToken(byte[] clientToken) throws IOException { ByteBuffer responseBuf = responseToken == null ? EMPTY_BUFFER : ByteBuffer.wrap(responseToken); sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.NONE, null, responseBuf)); } catch (SaslAuthenticationException e) { - sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.SASL_AUTHENTICATION_FAILED, e.getMessage())); + buildResponseOnAuthenticateFailure(requestContext, + new SaslAuthenticateResponse(Errors.SASL_AUTHENTICATION_FAILED, e.getMessage())); throw e; } catch (SaslException e) { KerberosError kerberosError = KerberosError.fromException(e); @@ -464,7 +471,7 @@ private String handleHandshakeRequest(RequestContext context, SaslHandshakeReque return clientMechanism; } else { LOG.debug("SASL mechanism '{}' requested by client is not supported", clientMechanism); - sendKafkaResponse(context, new SaslHandshakeResponse(Errors.UNSUPPORTED_SASL_MECHANISM, enabledMechanisms)); + buildResponseOnAuthenticateFailure(context, new SaslHandshakeResponse(Errors.UNSUPPORTED_SASL_MECHANISM, enabledMechanisms)); throw new UnsupportedSaslMechanismException("Unsupported SASL mechanism " + clientMechanism); } } @@ -491,6 +498,24 @@ private void handleApiVersionsRequest(RequestContext context, ApiVersionsRequest } } + /** + * Build a {@link Send} response on {@link #authenticate()} failure. The actual response is sent out when + * {@link #sendAuthenticationFailureResponse()} is called. + */ + private void buildResponseOnAuthenticateFailure(RequestContext context, AbstractResponse response) { + authenticationFailureSend = context.buildResponse(response); + } + + /** + * Send any authentication failure response that may have been previously built. + */ + private void sendAuthenticationFailureResponse() throws IOException { + if (authenticationFailureSend == null) + return; + sendKafkaResponse(authenticationFailureSend); + authenticationFailureSend = null; + } + private void sendKafkaResponse(RequestContext context, AbstractResponse response) throws IOException { sendKafkaResponse(context.buildResponse(response)); } diff --git a/clients/src/test/java/org/apache/kafka/common/network/NetworkTestUtils.java b/clients/src/test/java/org/apache/kafka/common/network/NetworkTestUtils.java index 59980490e68..b08c8c19a87 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/NetworkTestUtils.java +++ b/clients/src/test/java/org/apache/kafka/common/network/NetworkTestUtils.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -28,6 +29,7 @@ import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.security.authenticator.CredentialCache; import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Utils; import org.apache.kafka.test.TestUtils; @@ -35,16 +37,22 @@ * Common utility functions used by transport layer and authenticator tests. */ public class NetworkTestUtils { + public static NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, + AbstractConfig serverConfig, CredentialCache credentialCache, Time time) throws Exception { + return createEchoServer(listenerName, securityProtocol, serverConfig, credentialCache, 100, time); + } public static NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, - AbstractConfig serverConfig, CredentialCache credentialCache) throws Exception { - NioEchoServer server = new NioEchoServer(listenerName, securityProtocol, serverConfig, "localhost", null, credentialCache); + AbstractConfig serverConfig, CredentialCache credentialCache, + int failedAuthenticationDelayMs, Time time) throws Exception { + NioEchoServer server = new NioEchoServer(listenerName, securityProtocol, serverConfig, "localhost", + null, credentialCache, failedAuthenticationDelayMs, time); server.start(); return server; } - public static Selector createSelector(ChannelBuilder channelBuilder) { - return new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext()); + public static Selector createSelector(ChannelBuilder channelBuilder, Time time) { + return new Selector(5000, new Metrics(), time, "MetricGroup", channelBuilder, new LogContext()); } public static void checkClientConnection(Selector selector, String node, int minMessageSize, int messageCount) throws Exception { @@ -79,19 +87,33 @@ public static void waitForChannelReady(Selector selector, String node) throws IO assertTrue(selector.isChannelReady(node)); } - public static ChannelState waitForChannelClose(Selector selector, String node, ChannelState.State channelState) + public static ChannelState waitForChannelClose(Selector selector, String node, ChannelState.State channelState, MockTime mockTime) throws IOException { boolean closed = false; - for (int i = 0; i < 30; i++) { - selector.poll(1000L); + for (int i = 0; i < 300; i++) { + selector.poll(100L); if (selector.channel(node) == null && selector.closingChannel(node) == null) { closed = true; break; } + if (mockTime != null) + mockTime.setCurrentTimeMs(mockTime.milliseconds() + 150); } assertTrue("Channel was not closed by timeout", closed); ChannelState finalState = selector.disconnected().get(node); assertEquals(channelState, finalState.state()); return finalState; } + + public static ChannelState waitForChannelClose(Selector selector, String node, ChannelState.State channelState) throws IOException { + return waitForChannelClose(selector, node, channelState, null); + } + + public static void completeDelayedChannelClose(Selector selector, long currentTimeNanos) { + selector.completeDelayedChannelClose(currentTimeNanos); + } + + public static Map<?, ?> delayedClosingChannels(Selector selector) { + return selector.delayedClosingChannels(); + } } diff --git a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java index 64b7e4e6792..bd212fdc172 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java +++ b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java @@ -27,7 +27,7 @@ import org.apache.kafka.common.security.scram.ScramCredential; import org.apache.kafka.common.security.scram.internals.ScramMechanism; import org.apache.kafka.common.utils.LogContext; -import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; import org.apache.kafka.test.TestCondition; import org.apache.kafka.test.TestUtils; @@ -69,7 +69,13 @@ private final DelegationTokenCache tokenCache; public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config, - String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache) throws Exception { + String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache, Time time) throws Exception { + this(listenerName, securityProtocol, config, serverHost, channelBuilder, credentialCache, 100, time); + } + + public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config, + String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache, + int failedAuthenticationDelayMs, Time time) throws Exception { super("echoserver"); setDaemon(true); serverSocketChannel = ServerSocketChannel.open(); @@ -89,7 +95,7 @@ public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtoco if (channelBuilder == null) channelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, false, securityProtocol, config, credentialCache, tokenCache); this.metrics = new Metrics(); - this.selector = new Selector(5000, metrics, new MockTime(), "MetricGroup", channelBuilder, new LogContext()); + this.selector = new Selector(5000, failedAuthenticationDelayMs, metrics, time, "MetricGroup", channelBuilder, new LogContext()); acceptorThread = new AcceptorThread(); } diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java index d70a448df22..efddd469a1a 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java @@ -27,7 +27,6 @@ import org.apache.kafka.common.security.TestSecurityConfig; import org.apache.kafka.common.security.ssl.SslFactory; import org.apache.kafka.common.utils.LogContext; -import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Utils; import org.apache.kafka.test.TestCondition; @@ -64,6 +63,7 @@ public class SslTransportLayerTest { private static final int BUFFER_SIZE = 4 * 1024; + private static Time time = Time.SYSTEM; private NioEchoServer server; private Selector selector; @@ -82,7 +82,7 @@ public void setup() throws Exception { sslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores); this.channelBuilder = new SslChannelBuilder(Mode.CLIENT, null, false); this.channelBuilder.configure(sslClientConfigs); - this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext()); + this.selector = new Selector(5000, new Metrics(), time, "MetricGroup", channelBuilder, new LogContext()); } @After @@ -204,7 +204,7 @@ protected TestSslTransportLayer newTransportLayer(String id, SelectionKey key, S }; serverChannelBuilder.configure(sslServerConfigs); server = new NioEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), SecurityProtocol.SSL, - new TestSecurityConfig(sslServerConfigs), "localhost", serverChannelBuilder, null); + new TestSecurityConfig(sslServerConfigs), "localhost", serverChannelBuilder, null, time); server.start(); createSelector(sslClientConfigs); @@ -784,7 +784,7 @@ private void testIOExceptionsDuringHandshake(FailureAction readFailureAction, channelBuilder.flushFailureAction = flushFailureAction; channelBuilder.failureIndex = i; channelBuilder.configure(sslClientConfigs); - this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext()); + this.selector = new Selector(5000, new Metrics(), time, "MetricGroup", channelBuilder, new LogContext()); InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); @@ -827,7 +827,7 @@ public void testPeerNotifiedOfHandshakeFailure() throws Exception { serverChannelBuilder.flushDelayCount = i; server = new NioEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), SecurityProtocol.SSL, new TestSecurityConfig(sslServerConfigs), - "localhost", serverChannelBuilder, null); + "localhost", serverChannelBuilder, null, time); server.start(); createSelector(sslClientConfigs); InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); @@ -853,7 +853,7 @@ private void testClose(SecurityProtocol securityProtocol, ChannelBuilder clientC String node = "0"; server = createEchoServer(securityProtocol); clientChannelBuilder.configure(sslClientConfigs); - this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", clientChannelBuilder, new LogContext()); + this.selector = new Selector(5000, new Metrics(), time, "MetricGroup", clientChannelBuilder, new LogContext()); InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); @@ -893,7 +893,7 @@ public void testServerKeystoreDynamicUpdate() throws Exception { ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, false, securityProtocol, config, null, null); server = new NioEchoServer(listenerName, securityProtocol, config, - "localhost", serverChannelBuilder, null); + "localhost", serverChannelBuilder, null, time); server.start(); InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); @@ -953,7 +953,7 @@ public void testServerTruststoreDynamicUpdate() throws Exception { ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, false, securityProtocol, config, null, null); server = new NioEchoServer(listenerName, securityProtocol, config, - "localhost", serverChannelBuilder, null); + "localhost", serverChannelBuilder, null, time); server.start(); InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); @@ -1025,12 +1025,12 @@ private Selector createSelector(Map<String, Object> sslClientConfigs, final Inte channelBuilder.configureBufferSizes(netReadBufSize, netWriteBufSize, appBufSize); this.channelBuilder = channelBuilder; this.channelBuilder.configure(sslClientConfigs); - this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext()); + this.selector = new Selector(5000, new Metrics(), time, "MetricGroup", channelBuilder, new LogContext()); return selector; } private NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol) throws Exception { - return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, new TestSecurityConfig(sslServerConfigs), null); + return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, new TestSecurityConfig(sslServerConfigs), null, time); } private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception { diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/ClientAuthenticationFailureTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/ClientAuthenticationFailureTest.java index 413997f2931..2fce4c56373 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/ClientAuthenticationFailureTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/ClientAuthenticationFailureTest.java @@ -34,6 +34,7 @@ import org.apache.kafka.common.security.auth.SecurityProtocol; import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.MockTime; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,6 +49,7 @@ import static org.junit.Assert.fail; public class ClientAuthenticationFailureTest { + private static MockTime time = new MockTime(50); private NioEchoServer server; private Map<String, Object> saslServerConfigs; @@ -147,6 +149,6 @@ private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws private NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol) throws Exception { return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, - new TestSecurityConfig(saslServerConfigs), new CredentialCache()); + new TestSecurityConfig(saslServerConfigs), new CredentialCache(), time); } } diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailureDelayTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailureDelayTest.java new file mode 100644 index 00000000000..b0dfc7a123b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailureDelayTest.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.network.CertStores; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.ChannelBuilders; +import org.apache.kafka.common.network.ChannelState; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.network.NetworkTestUtils; +import org.apache.kafka.common.network.NioEchoServer; +import org.apache.kafka.common.network.Selector; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@RunWith(value = Parameterized.class) +public class SaslAuthenticatorFailureDelayTest { + private static final int BUFFER_SIZE = 4 * 1024; + private static MockTime time = new MockTime(50); + + private NioEchoServer server; + private Selector selector; + private ChannelBuilder channelBuilder; + private CertStores serverCertStores; + private CertStores clientCertStores; + private Map<String, Object> saslClientConfigs; + private Map<String, Object> saslServerConfigs; + private CredentialCache credentialCache; + private long startTimeMs; + private final int failedAuthenticationDelayMs; + + public SaslAuthenticatorFailureDelayTest(int failedAuthenticationDelayMs) { + this.failedAuthenticationDelayMs = failedAuthenticationDelayMs; + } + + @Parameterized.Parameters(name = "failedAuthenticationDelayMs={0}") + public static Collection<Object[]> data() { + List<Object[]> values = new ArrayList<>(); + values.add(new Object[]{0}); + values.add(new Object[]{200}); + return values; + } + + @Before + public void setup() throws Exception { + LoginManager.closeAll(); + serverCertStores = new CertStores(true, "localhost"); + clientCertStores = new CertStores(false, "localhost"); + saslServerConfigs = serverCertStores.getTrustingConfig(clientCertStores); + saslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores); + credentialCache = new CredentialCache(); + SaslAuthenticatorTest.TestLogin.loginCount.set(0); + startTimeMs = time.milliseconds(); + } + + @After + public void teardown() throws Exception { + long now = time.milliseconds(); + if (server != null) + this.server.close(); + if (selector != null) + this.selector.close(); + if (failedAuthenticationDelayMs != -1) + assertTrue("timeSpent: " + (now - startTimeMs), now - startTimeMs >= failedAuthenticationDelayMs); + } + + /** + * Tests that SASL/PLAIN clients with invalid password fail authentication. + */ + @Test + public void testInvalidPasswordSaslPlain() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + jaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "invalidpassword"); + + server = createEchoServer(securityProtocol); + createAndCheckClientAuthenticationFailure(securityProtocol, node, "PLAIN", + "Authentication failed: Invalid username or password"); + server.verifyAuthenticationMetrics(0, 1); + } + + /** + * Tests client connection close before response for authentication failure is sent. + */ + @Test + public void testClientConnectionClose() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + jaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "invalidpassword"); + + server = createEchoServer(securityProtocol); + createClientConnection(securityProtocol, node); + + Map<?, ?> delayedClosingChannels = NetworkTestUtils.delayedClosingChannels(server.selector()); + + // Wait until server has established connection with client and has processed the auth failure + TestUtils.waitForCondition(() -> { + poll(selector); + return !server.selector().channels().isEmpty(); + }, "Timeout waiting for connection"); + TestUtils.waitForCondition(() -> { + poll(selector); + return failedAuthenticationDelayMs == 0 || !delayedClosingChannels.isEmpty(); + }, "Timeout waiting for auth failure"); + + selector.close(); + selector = null; + + // Now that client connection is closed, wait until server notices the disconnection and removes it from the + // list of connected channels and from delayed response for auth failure + TestUtils.waitForCondition(() -> failedAuthenticationDelayMs == 0 || delayedClosingChannels.isEmpty(), + "Timeout waiting for delayed response remove"); + TestUtils.waitForCondition(() -> server.selector().channels().isEmpty(), + "Timeout waiting for connection close"); + + // Try forcing completion of delayed channel close + TestUtils.waitForCondition(() -> time.milliseconds() > startTimeMs + failedAuthenticationDelayMs + 1, + "Timeout when waiting for auth failure response timeout to elapse"); + NetworkTestUtils.completeDelayedChannelClose(server.selector(), time.nanoseconds()); + } + + private void poll(Selector selector) { + try { + selector.poll(50); + } catch (IOException e) { + Assert.fail("Caught unexpected exception " + e); + } + } + + private TestJaasConfig configureMechanisms(String clientMechanism, List<String> serverMechanisms) { + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, clientMechanism); + saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, serverMechanisms); + if (serverMechanisms.contains("DIGEST-MD5")) { + saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + TestDigestLoginModule.DigestServerCallbackHandler.class.getName()); + } + return TestJaasConfig.createConfiguration(clientMechanism, serverMechanisms); + } + + private void createSelector(SecurityProtocol securityProtocol, Map<String, Object> clientConfigs) { + if (selector != null) { + selector.close(); + selector = null; + } + + String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM); + this.channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, + new TestSecurityConfig(clientConfigs), null, saslMechanism, true); + this.selector = NetworkTestUtils.createSelector(channelBuilder, time); + } + + private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception { + return createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol); + } + + private NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol) throws Exception { + if (failedAuthenticationDelayMs != -1) + return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, + new TestSecurityConfig(saslServerConfigs), credentialCache, failedAuthenticationDelayMs, time); + else + return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, + new TestSecurityConfig(saslServerConfigs), credentialCache, time); + } + + private void createClientConnection(SecurityProtocol securityProtocol, String node) throws Exception { + createSelector(securityProtocol, saslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("127.0.0.1", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + } + + private void createAndCheckClientAuthenticationFailure(SecurityProtocol securityProtocol, String node, + String mechanism, String expectedErrorMessage) throws Exception { + ChannelState finalState = createAndCheckClientConnectionFailure(securityProtocol, node); + Exception exception = finalState.exception(); + assertTrue("Invalid exception class " + exception.getClass(), exception instanceof SaslAuthenticationException); + if (expectedErrorMessage == null) + expectedErrorMessage = "Authentication failed due to invalid credentials with SASL mechanism " + mechanism; + assertEquals(expectedErrorMessage, exception.getMessage()); + } + + private ChannelState createAndCheckClientConnectionFailure(SecurityProtocol securityProtocol, String node) + throws Exception { + createClientConnection(securityProtocol, node); + ChannelState finalState = NetworkTestUtils.waitForChannelClose(selector, node, + ChannelState.State.AUTHENTICATION_FAILED, time); + selector.close(); + selector = null; + return finalState; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java index b8894f1370f..74058eb1cd5 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java @@ -92,6 +92,7 @@ import org.apache.kafka.common.security.authenticator.TestDigestLoginModule.DigestServerCallbackHandler; import org.apache.kafka.common.security.plain.internals.PlainServerCallbackHandler; +import org.apache.kafka.common.utils.Time; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -107,6 +108,7 @@ public class SaslAuthenticatorTest { private static final int BUFFER_SIZE = 4 * 1024; + private static Time time = Time.SYSTEM; private NioEchoServer server; private Selector selector; @@ -1308,7 +1310,7 @@ protected void enableKafkaSaslAuthenticateHeaders(boolean flag) { }; serverChannelBuilder.configure(saslServerConfigs); server = new NioEchoServer(listenerName, securityProtocol, new TestSecurityConfig(saslServerConfigs), - "localhost", serverChannelBuilder, credentialCache); + "localhost", serverChannelBuilder, credentialCache, time); server.start(); return server; } @@ -1347,7 +1349,7 @@ protected void saslAuthenticateVersion(short version) { } }; clientChannelBuilder.configure(saslClientConfigs); - this.selector = NetworkTestUtils.createSelector(clientChannelBuilder); + this.selector = NetworkTestUtils.createSelector(clientChannelBuilder, time); InetSocketAddress addr = new InetSocketAddress("127.0.0.1", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); } @@ -1452,7 +1454,7 @@ private void createSelector(SecurityProtocol securityProtocol, Map<String, Objec String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM); this.channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, new TestSecurityConfig(clientConfigs), null, saslMechanism, true); - this.selector = NetworkTestUtils.createSelector(channelBuilder); + this.selector = NetworkTestUtils.createSelector(channelBuilder, time); } private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception { @@ -1461,7 +1463,7 @@ private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws private NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol) throws Exception { return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, - new TestSecurityConfig(saslServerConfigs), credentialCache); + new TestSecurityConfig(saslServerConfigs), credentialCache, time); } private void createClientConnection(SecurityProtocol securityProtocol, String node) throws Exception { @@ -1490,8 +1492,7 @@ private void createAndCheckClientAuthenticationFailure(SecurityProtocol security private ChannelState createAndCheckClientConnectionFailure(SecurityProtocol securityProtocol, String node) throws Exception { createClientConnection(securityProtocol, node); - ChannelState finalState = NetworkTestUtils.waitForChannelClose(selector, node, - ChannelState.State.AUTHENTICATION_FAILED); + ChannelState finalState = NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); selector.close(); selector = null; return finalState; diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 749c921ee02..34eb2662714 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -245,6 +245,7 @@ class SocketServer(val config: KafkaConfig, val metrics: Metrics, val time: Time requestChannel, connectionQuotas, config.connectionsMaxIdleMs, + config.failedAuthenticationDelayMs, listenerName, securityProtocol, config, @@ -502,6 +503,7 @@ private[kafka] class Processor(val id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, connectionsMaxIdleMs: Long, + failedAuthenticationDelayMs: Int, listenerName: ListenerName, securityProtocol: SecurityProtocol, config: KafkaConfig, @@ -562,6 +564,7 @@ private[kafka] class Processor(val id: Int, new KSelector( maxRequestSize, connectionsMaxIdleMs, + failedAuthenticationDelayMs, metrics, time, "socket-server", diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala index 334d496a6f4..753a5b91aa0 100755 --- a/core/src/main/scala/kafka/server/KafkaConfig.scala +++ b/core/src/main/scala/kafka/server/KafkaConfig.scala @@ -77,6 +77,7 @@ object Defaults { val MaxConnectionsPerIpOverrides: String = "" val ConnectionsMaxIdleMs = 10 * 60 * 1000L val RequestTimeoutMs = 30000 + val FailedAuthenticationDelayMs = 100 /** ********* Log Configuration ***********/ val NumPartitions = 1 @@ -282,6 +283,7 @@ object KafkaConfig { val MaxConnectionsPerIpProp = "max.connections.per.ip" val MaxConnectionsPerIpOverridesProp = "max.connections.per.ip.overrides" val ConnectionsMaxIdleMsProp = "connections.max.idle.ms" + val FailedAuthenticationDelayMsProp = "connection.failed.authentication.delay.ms" /***************** rack configuration *************/ val RackProp = "broker.rack" /** ********* Log Configuration ***********/ @@ -537,6 +539,8 @@ object KafkaConfig { "configured using " + MaxConnectionsPerIpOverridesProp + " property" val MaxConnectionsPerIpOverridesDoc = "A comma-separated list of per-ip or hostname overrides to the default maximum number of connections. An example value is \"hostName:100,127.0.0.1:200\"" val ConnectionsMaxIdleMsDoc = "Idle connections timeout: the server socket processor threads close the connections that idle more than this" + val FailedAuthenticationDelayMsDoc = "Connection close delay on failed authentication: this is the time (in milliseconds) by which connection close will be delayed on authentication failure. " + + s"This must be configured to be less than $ConnectionsMaxIdleMsProp to prevent connection timeout." /************* Rack Configuration **************/ val RackDoc = "Rack of the broker. This will be used in rack aware replication assignment for fault tolerance. Examples: `RACK1`, `us-east-1d`" /** ********* Log Configuration ***********/ @@ -820,6 +824,7 @@ object KafkaConfig { .define(MaxConnectionsPerIpProp, INT, Defaults.MaxConnectionsPerIp, atLeast(0), MEDIUM, MaxConnectionsPerIpDoc) .define(MaxConnectionsPerIpOverridesProp, STRING, Defaults.MaxConnectionsPerIpOverrides, MEDIUM, MaxConnectionsPerIpOverridesDoc) .define(ConnectionsMaxIdleMsProp, LONG, Defaults.ConnectionsMaxIdleMs, MEDIUM, ConnectionsMaxIdleMsDoc) + .define(FailedAuthenticationDelayMsProp, INT, Defaults.FailedAuthenticationDelayMs, atLeast(0), LOW, FailedAuthenticationDelayMsDoc) /************ Rack Configuration ******************/ .define(RackProp, STRING, null, MEDIUM, RackDoc) @@ -1101,6 +1106,7 @@ class KafkaConfig(val props: java.util.Map[_, _], doLog: Boolean, dynamicConfigO val maxConnectionsPerIpOverrides: Map[String, Int] = getMap(KafkaConfig.MaxConnectionsPerIpOverridesProp, getString(KafkaConfig.MaxConnectionsPerIpOverridesProp)).map { case (k, v) => (k, v.toInt)} val connectionsMaxIdleMs = getLong(KafkaConfig.ConnectionsMaxIdleMsProp) + val failedAuthenticationDelayMs = getInt(KafkaConfig.FailedAuthenticationDelayMsProp) /***************** rack configuration **************/ val rack = Option(getString(KafkaConfig.RackProp)) @@ -1397,5 +1403,11 @@ class KafkaConfig(val props: java.util.Map[_, _], doLog: Boolean, dynamicConfigO val invalidAddresses = maxConnectionsPerIpOverrides.keys.filterNot(address => Utils.validHostPattern(address)) if (!invalidAddresses.isEmpty) throw new IllegalArgumentException(s"${KafkaConfig.MaxConnectionsPerIpOverridesProp} contains invalid addresses : ${invalidAddresses.mkString(",")}") + + if (connectionsMaxIdleMs >= 0) + require(failedAuthenticationDelayMs < connectionsMaxIdleMs, + s"${KafkaConfig.FailedAuthenticationDelayMsProp}=$failedAuthenticationDelayMs should always be less than" + + s" ${KafkaConfig.ConnectionsMaxIdleMsProp}=$connectionsMaxIdleMs to prevent failed" + + " authentication responses from timing out") } } diff --git a/core/src/test/scala/integration/kafka/api/SaslGssapiSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslGssapiSslEndToEndAuthorizationTest.scala index a9b4a60ac72..a630293bb66 100644 --- a/core/src/test/scala/integration/kafka/api/SaslGssapiSslEndToEndAuthorizationTest.scala +++ b/core/src/test/scala/integration/kafka/api/SaslGssapiSslEndToEndAuthorizationTest.scala @@ -24,10 +24,10 @@ import scala.collection.immutable.List class SaslGssapiSslEndToEndAuthorizationTest extends SaslEndToEndAuthorizationTest { override val clientPrincipal = JaasTestUtils.KafkaClientPrincipalUnqualifiedName override val kafkaPrincipal = JaasTestUtils.KafkaServerPrincipalUnqualifiedName - + override protected def kafkaClientSaslMechanism = "GSSAPI" override protected def kafkaServerSaslMechanisms = List("GSSAPI") - + // Configure brokers to require SSL client authentication in order to verify that SASL_SSL works correctly even if the // client doesn't have a keystore. We want to cover the scenario where a broker requires either SSL client // authentication or SASL authentication with SSL as the transport layer (but not both). diff --git a/core/src/test/scala/integration/kafka/api/SaslSetup.scala b/core/src/test/scala/integration/kafka/api/SaslSetup.scala index 391321227bd..81de1059068 100644 --- a/core/src/test/scala/integration/kafka/api/SaslSetup.scala +++ b/core/src/test/scala/integration/kafka/api/SaslSetup.scala @@ -138,8 +138,12 @@ trait SaslSetup { props } - def jaasClientLoginModule(clientSaslMechanism: String): String = - JaasTestUtils.clientLoginModule(clientSaslMechanism, clientKeytabFile) + def jaasClientLoginModule(clientSaslMechanism: String, serviceName: Option[String] = None): String = { + if (serviceName.isDefined) + JaasTestUtils.clientLoginModule(clientSaslMechanism, clientKeytabFile, serviceName.get) + else + JaasTestUtils.clientLoginModule(clientSaslMechanism, clientKeytabFile) + } def createScramCredentials(zkConnect: String, userName: String, password: String): Unit = { val credentials = ScramMechanism.values.map(m => s"${m.mechanismName}=[iterations=4096,password=$password]") diff --git a/core/src/test/scala/integration/kafka/server/GssapiAuthenticationTest.scala b/core/src/test/scala/integration/kafka/server/GssapiAuthenticationTest.scala index 74b2a152e23..cbe8462b3fd 100644 --- a/core/src/test/scala/integration/kafka/server/GssapiAuthenticationTest.scala +++ b/core/src/test/scala/integration/kafka/server/GssapiAuthenticationTest.scala @@ -19,19 +19,25 @@ package kafka.server import java.net.InetSocketAddress +import java.time.Duration import java.util.Properties import java.util.concurrent.{Executors, TimeUnit} import kafka.api.{Both, IntegrationTestHarness, SaslSetup} import kafka.utils.TestUtils import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.errors.SaslAuthenticationException import org.apache.kafka.common.network._ import org.apache.kafka.common.security.{JaasContext, TestSecurityConfig} import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.MockTime import org.junit.Assert._ import org.junit.{After, Before, Test} +import scala.collection.JavaConverters._ + class GssapiAuthenticationTest extends IntegrationTestHarness with SaslSetup { override val serverCount = 1 override protected def securityProtocol = SecurityProtocol.SASL_PLAINTEXT @@ -43,10 +49,17 @@ class GssapiAuthenticationTest extends IntegrationTestHarness with SaslSetup { private val executor = Executors.newFixedThreadPool(numThreads) private val clientConfig: Properties = new Properties private var serverAddr: InetSocketAddress = _ + private val time = new MockTime(10) + val topic = "topic" + val part = 0 + val tp = new TopicPartition(topic, part) + private val failedAuthenticationDelayMs = 2000 @Before override def setUp() { startSasl(jaasSections(kafkaServerSaslMechanisms, Option(kafkaClientSaslMechanism), Both)) + serverConfig.put(KafkaConfig.SslClientAuthProp, "required") + serverConfig.put(KafkaConfig.FailedAuthenticationDelayMsProp, failedAuthenticationDelayMs.toString) super.setUp() serverAddr = new InetSocketAddress("localhost", servers.head.boundPort(ListenerName.forSecurityProtocol(SecurityProtocol.SASL_PLAINTEXT))) @@ -55,6 +68,9 @@ class GssapiAuthenticationTest extends IntegrationTestHarness with SaslSetup { clientConfig.put(SaslConfigs.SASL_MECHANISM, kafkaClientSaslMechanism) clientConfig.put(SaslConfigs.SASL_JAAS_CONFIG, jaasClientLoginModule(kafkaClientSaslMechanism)) clientConfig.put(CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG, "5000") + + // create the test topic with all the brokers as replicas + createTopic(topic, 2, serverCount) } @After @@ -93,6 +109,31 @@ class GssapiAuthenticationTest extends IntegrationTestHarness with SaslSetup { verifyNonRetriableAuthenticationFailure() } + /** + * Test that when client fails to verify authenticity of the server, the resulting failed authentication exception + * is thrown immediately, and is not affected by <code>connection.failed.authentication.delay.ms</code>. + */ + @Test + def testServerAuthenticationFailure(): Unit = { + // Setup client with a non-existent service principal, so that server authentication fails on the client + val clientLoginContext = jaasClientLoginModule(kafkaClientSaslMechanism, Some("another-kafka-service")) + val configOverrides = new Properties() + configOverrides.setProperty(SaslConfigs.SASL_JAAS_CONFIG, clientLoginContext) + val consumer = createConsumer(configOverrides = configOverrides) + consumer.assign(List(tp).asJava) + + val startMs = System.currentTimeMillis() + try { + consumer.poll(Duration.ofMillis(50)) + fail() + } catch { + case _: SaslAuthenticationException => + } + val endMs = System.currentTimeMillis() + require(endMs - startMs < failedAuthenticationDelayMs, "Failed authentication must not be delayed on the client") + consumer.close() + } + /** * Verifies that any exceptions during authentication with the current `clientConfig` are * notified with disconnect state `AUTHENTICATE` (and not `AUTHENTICATION_FAILED`). This @@ -148,6 +189,6 @@ class GssapiAuthenticationTest extends IntegrationTestHarness with SaslSetup { private def createSelector(): Selector = { val channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, new TestSecurityConfig(clientConfig), null, kafkaClientSaslMechanism, true) - NetworkTestUtils.createSelector(channelBuilder) + NetworkTestUtils.createSelector(channelBuilder, time) } } diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index da611490740..b5983377b77 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -307,13 +307,14 @@ class SocketServerTest extends JUnitSuite { override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = { new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas, - config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool, new LogContext()) { - override protected[network] def connectionId(socket: Socket): String = overrideConnectionId - override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = { - val testableSelector = new TestableSelector(config, channelBuilder, time, metrics) - selector = testableSelector - testableSelector - } + config.connectionsMaxIdleMs, config.failedAuthenticationDelayMs, listenerName, protocol, config, metrics, + credentialProvider, memoryPool, new LogContext()) { + override protected[network] def connectionId(socket: Socket): String = overrideConnectionId + override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = { + val testableSelector = new TestableSelector(config, channelBuilder, time, metrics) + selector = testableSelector + testableSelector + } } } } @@ -652,7 +653,8 @@ class SocketServerTest extends JUnitSuite { override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = { new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas, - config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, MemoryPool.NONE, new LogContext()) { + config.connectionsMaxIdleMs, config.failedAuthenticationDelayMs, listenerName, protocol, config, metrics, + credentialProvider, MemoryPool.NONE, new LogContext()) { override protected[network] def sendResponse(response: RequestChannel.Response, responseSend: Send) { conn.close() super.sendResponse(response, responseSend) @@ -697,7 +699,8 @@ class SocketServerTest extends JUnitSuite { override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = { new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas, - config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool, new LogContext()) { + config.connectionsMaxIdleMs, config.failedAuthenticationDelayMs, listenerName, protocol, config, metrics, + credentialProvider, memoryPool, new LogContext()) { override protected[network] def connectionId(socket: Socket): String = overrideConnectionId override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = { val testableSelector = new TestableSelector(config, channelBuilder, time, metrics) @@ -743,7 +746,7 @@ class SocketServerTest extends JUnitSuite { @Test def testBrokerSendAfterChannelClosedUpdatesRequestMetrics() { val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) - props.setProperty(KafkaConfig.ConnectionsMaxIdleMsProp, "100") + props.setProperty(KafkaConfig.ConnectionsMaxIdleMsProp, "110") val serverMetrics = new Metrics var conn: Socket = null val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, Time.SYSTEM, credentialProvider) @@ -1100,10 +1103,8 @@ class SocketServerTest extends JUnitSuite { override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = { - new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas, - config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool, new - LogContext()) { - + new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas, config.connectionsMaxIdleMs, + config.failedAuthenticationDelayMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool, new LogContext()) { override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = { val testableSelector = new TestableSelector(config, channelBuilder, time, metrics) assertEquals(None, selector) @@ -1149,7 +1150,7 @@ class SocketServerTest extends JUnitSuite { } class TestableSelector(config: KafkaConfig, channelBuilder: ChannelBuilder, time: Time, metrics: Metrics) - extends Selector(config.socketRequestMaxBytes, config.connectionsMaxIdleMs, + extends Selector(config.socketRequestMaxBytes, config.connectionsMaxIdleMs, config.failedAuthenticationDelayMs, metrics, time, "socket-server", new HashMap, false, true, channelBuilder, MemoryPool.NONE, new LogContext()) { val failures = mutable.Map[SelectorOperation, Exception]() diff --git a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala index 927dd1c203b..b7a8951ecc0 100755 --- a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala @@ -587,6 +587,7 @@ class KafkaConfigTest { case KafkaConfig.MaxConnectionsPerIpOverridesProp => assertPropertyInvalid(getBaseProperties(), name, "127.0.0.1:not_a_number") case KafkaConfig.ConnectionsMaxIdleMsProp => assertPropertyInvalid(getBaseProperties(), name, "not_a_number") + case KafkaConfig.FailedAuthenticationDelayMsProp => assertPropertyInvalid(getBaseProperties(), name, "not_a_number", "-1") case KafkaConfig.NumPartitionsProp => assertPropertyInvalid(getBaseProperties(), name, "not_a_number", "0") case KafkaConfig.LogDirsProp => // ignore string diff --git a/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala b/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala index e8d9c30a383..1870a4996fe 100644 --- a/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala +++ b/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala @@ -170,8 +170,8 @@ object JaasTestUtils { } // Returns the dynamic configuration, using credentials for user #1 - def clientLoginModule(mechanism: String, keytabLocation: Option[File]): String = - kafkaClientModule(mechanism, keytabLocation, KafkaClientPrincipal, KafkaPlainUser, KafkaPlainPassword, KafkaScramUser, KafkaScramPassword, KafkaOAuthBearerUser).toString + def clientLoginModule(mechanism: String, keytabLocation: Option[File], serviceName: String = serviceName): String = + kafkaClientModule(mechanism, keytabLocation, KafkaClientPrincipal, KafkaPlainUser, KafkaPlainPassword, KafkaScramUser, KafkaScramPassword, KafkaOAuthBearerUser, serviceName).toString def tokenClientLoginModule(tokenId: String, password: String): String = { ScramLoginModule( @@ -223,10 +223,11 @@ object JaasTestUtils { } // consider refactoring if more mechanisms are added - private def kafkaClientModule(mechanism: String, + private def kafkaClientModule(mechanism: String, keytabLocation: Option[File], clientPrincipal: String, plainUser: String, plainPassword: String, - scramUser: String, scramPassword: String, oauthBearerUser: String): JaasModule = { + scramUser: String, scramPassword: String, + oauthBearerUser: String, serviceName: String = serviceName): JaasModule = { mechanism match { case "GSSAPI" => Krb5LoginModule( ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Add mechanism to delay response to failed client authentication > --------------------------------------------------------------- > > Key: KAFKA-6950 > URL: https://issues.apache.org/jira/browse/KAFKA-6950 > Project: Kafka > Issue Type: Improvement > Components: core > Reporter: Dhruvil Shah > Assignee: Dhruvil Shah > Priority: Major > Fix For: 2.1.0 > > > This Jira is for tracking the implementation for > [KIP-306|https://cwiki.apache.org/confluence/display/KAFKA/KIP-306%3A+Configuration+for+Delaying+Response+to+Failed+Client+Authentication]. -- This message was sent by Atlassian JIRA (v7.6.3#76005)