This is an automated email from the ASF dual-hosted git repository. rsivaram pushed a commit to branch 2.5 in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/2.5 by this push: new 9ee465f KAFKA-10029; Don't update completedReceives when channels are closed to avoid ConcurrentModificationException (#8705) 9ee465f is described below commit 9ee465f8ccc5933ea10d6cb211c5045010840fb5 Author: Rajini Sivaram <rajinisiva...@googlemail.com> AuthorDate: Fri May 29 09:32:57 2020 +0100 KAFKA-10029; Don't update completedReceives when channels are closed to avoid ConcurrentModificationException (#8705) Reviewers: Ismael Juma <ism...@juma.me.uk>, Chia-Ping Tsai <chia7...@gmail.com> --- .../apache/kafka/common/network/KafkaChannel.java | 5 +- .../org/apache/kafka/common/network/Selector.java | 35 +++++++-- .../apache/kafka/common/network/SelectorTest.java | 85 +++++++++++++++++++++ .../main/scala/kafka/network/SocketServer.scala | 4 +- .../unit/kafka/network/SocketServerTest.scala | 87 +++++++++++++++++----- 5 files changed, 188 insertions(+), 28 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java index 4e4edd4..0ed9ee0 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java +++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java @@ -28,7 +28,6 @@ import java.net.Socket; import java.net.SocketAddress; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; -import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; @@ -471,12 +470,12 @@ public class KafkaChannel implements AutoCloseable { return false; } KafkaChannel that = (KafkaChannel) o; - return Objects.equals(id, that.id); + return id.equals(that.id); } @Override public int hashCode() { - return Objects.hash(id); + return id.hashCode(); } @Override diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index cb91cad..06f7048 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -107,7 +107,7 @@ public class Selector implements Selectable, AutoCloseable { private final Set<KafkaChannel> explicitlyMutedChannels; private boolean outOfMemory; private final List<Send> completedSends; - private final LinkedHashMap<KafkaChannel, NetworkReceive> completedReceives; + private final LinkedHashMap<String, NetworkReceive> completedReceives; private final Set<SelectionKey> immediatelyConnectedKeys; private final Map<String, KafkaChannel> closingChannels; private Set<SelectionKey> keysWithBufferedRead; @@ -804,7 +804,33 @@ public class Selector implements Selectable, AutoCloseable { } /** - * Clear the results from the prior poll + * Clears completed receives. This is used by SocketServer to remove references to + * receive buffers after processing completed receives, without waiting for the next + * poll(). + */ + public void clearCompletedReceives() { + this.completedReceives.clear(); + } + + /** + * Clears completed sends. This is used by SocketServer to remove references to + * send buffers after processing completed sends, without waiting for the next + * poll(). + */ + public void clearCompletedSends() { + this.completedSends.clear(); + } + + /** + * Clears all the results from the previous poll. This is invoked by Selector at the start of + * a poll() when all the results from the previous poll are expected to have been handled. + * <p> + * SocketServer uses {@link #clearCompletedSends()} and {@link #clearCompletedSends()} to + * clear `completedSends` and `completedReceives` as soon as they are processed to avoid + * holding onto large request/response buffers from multiple connections longer than necessary. + * Clients rely on Selector invoking {@link #clear()} at the start of each poll() since memory usage + * is less critical and clearing once-per-poll provides the flexibility to process these results in + * any order before the next poll. */ private void clear() { this.completedSends.clear(); @@ -935,7 +961,6 @@ public class Selector implements Selectable, AutoCloseable { } this.sensors.connectionClosed.record(); - this.completedReceives.remove(channel); this.explicitlyMutedChannels.remove(channel); if (notifyDisconnect) this.disconnected.put(channel.id(), channel.state()); @@ -1015,7 +1040,7 @@ public class Selector implements Selectable, AutoCloseable { * Check if given channel has a completed receive */ private boolean hasCompletedReceive(KafkaChannel channel) { - return completedReceives.containsKey(channel); + return completedReceives.containsKey(channel.id()); } /** @@ -1025,7 +1050,7 @@ public class Selector implements Selectable, AutoCloseable { if (hasCompletedReceive(channel)) throw new IllegalStateException("Attempting to add second completed receive to channel " + channel.id()); - this.completedReceives.put(channel, networkReceive); + this.completedReceives.put(channel.id(), networkReceive); sensors.recordCompletedReceive(channel.id(), networkReceive.size(), currentTimeMs); } diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java index 57b0153..ac773ee 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java @@ -48,6 +48,7 @@ import java.nio.channels.SocketChannel; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -341,6 +342,36 @@ public class SelectorTest { assertEquals("", blockingRequest(node, "")); } + @Test + public void testClearCompletedSendsAndReceives() throws Exception { + int bufferSize = 1024; + String node = "0"; + InetSocketAddress addr = new InetSocketAddress("localhost", server.port); + connect(node, addr); + String request = TestUtils.randomString(bufferSize); + selector.send(createSend(node, request)); + boolean sent = false; + boolean received = false; + while (!sent || !received) { + selector.poll(1000L); + assertEquals("No disconnects should have occurred.", 0, selector.disconnected().size()); + if (!selector.completedSends().isEmpty()) { + assertEquals(1, selector.completedSends().size()); + selector.clearCompletedSends(); + assertEquals(0, selector.completedSends().size()); + sent = true; + } + + if (!selector.completedReceives().isEmpty()) { + assertEquals(1, selector.completedReceives().size()); + assertEquals(request, asString(selector.completedReceives().iterator().next())); + selector.clearCompletedReceives(); + assertEquals(0, selector.completedReceives().size()); + received = true; + } + } + } + @Test(expected = IllegalStateException.class) public void testExistingConnectionId() throws IOException { blockingConnect("0"); @@ -904,6 +935,60 @@ public class SelectorTest { assertEquals(asList(send), selector.completedSends()); } + /** + * Ensure that no errors are thrown if channels are closed while processing multiple completed receives + */ + @Test + public void testChannelCloseWhileProcessingReceives() throws Exception { + int numChannels = 4; + Map<String, KafkaChannel> channels = TestUtils.fieldValue(selector, Selector.class, "channels"); + Set<SelectionKey> selectionKeys = new HashSet<>(); + for (int i = 0; i < numChannels; i++) { + String id = String.valueOf(i); + KafkaChannel channel = mock(KafkaChannel.class); + channels.put(id, channel); + when(channel.id()).thenReturn(id); + when(channel.state()).thenReturn(ChannelState.READY); + when(channel.isConnected()).thenReturn(true); + when(channel.ready()).thenReturn(true); + when(channel.read()).thenReturn(1L); + + SelectionKey selectionKey = mock(SelectionKey.class); + when(channel.selectionKey()).thenReturn(selectionKey); + when(selectionKey.isValid()).thenReturn(true); + when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_READ); + selectionKey.attach(channel); + selectionKeys.add(selectionKey); + + NetworkReceive receive = mock(NetworkReceive.class); + when(receive.source()).thenReturn(id); + when(receive.size()).thenReturn(10); + when(receive.bytesRead()).thenReturn(1); + when(receive.payload()).thenReturn(ByteBuffer.allocate(10)); + when(channel.maybeCompleteReceive()).thenReturn(receive); + } + + selector.pollSelectionKeys(selectionKeys, false, System.nanoTime()); + assertEquals(numChannels, selector.completedReceives().size()); + Set<KafkaChannel> closed = new HashSet<>(); + Set<KafkaChannel> notClosed = new HashSet<>(); + for (NetworkReceive receive : selector.completedReceives()) { + KafkaChannel channel = selector.channel(receive.source()); + assertNotNull(channel); + if (closed.size() < 2) { + selector.close(channel.id()); + closed.add(channel); + } else + notClosed.add(channel); + } + assertEquals(notClosed, new HashSet<>(selector.channels())); + closed.forEach(channel -> assertNull(selector.channel(channel.id()))); + + selector.poll(0); + assertEquals(0, selector.completedReceives().size()); + } + + private String blockingRequest(String node, String s) throws IOException { selector.send(createSend(node, s)); selector.poll(1000L); diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 35d9d7c..fb69ab9 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -779,7 +779,7 @@ private[kafka] class Processor(val id: Int, } } - private def processException(errorMessage: String, throwable: Throwable): Unit = { + private[network] def processException(errorMessage: String, throwable: Throwable): Unit = { throwable match { case e: ControlThrowable => throw e case e => error(errorMessage, e) @@ -915,6 +915,7 @@ private[kafka] class Processor(val id: Int, processChannelException(receive.source, s"Exception while processing request from ${receive.source}", e) } } + selector.clearCompletedReceives() } private def processCompletedSends(): Unit = { @@ -938,6 +939,7 @@ private[kafka] class Processor(val id: Int, s"Exception while processing completed send to ${send.destination}", e) } } + selector.clearCompletedSends() } private def updateRequestMetrics(response: RequestChannel.Response): Unit = { diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 91fce5b..b40c763 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -1570,6 +1570,8 @@ class SocketServerTest { testableSelector.waitForOperations(SelectorOperation.Poll, 1) testableSelector.waitForOperations(SelectorOperation.CloseSelector, 1) + assertEquals(1, testableServer.uncaughtExceptions) + testableServer.uncaughtExceptions = 0 }) } @@ -1648,6 +1650,7 @@ class SocketServerTest { testWithServer(testableServer) } finally { shutdownServerAndMetrics(testableServer) + assertEquals(0, testableServer.uncaughtExceptions) } } @@ -1702,6 +1705,7 @@ class SocketServerTest { new Metrics, time, credentialProvider) { @volatile var selector: Option[TestableSelector] = None + @volatile var uncaughtExceptions = 0 override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = { @@ -1714,6 +1718,12 @@ class SocketServerTest { selector = Some(testableSelector) testableSelector } + + override private[network] def processException(errorMessage: String, throwable: Throwable): Unit = { + if (errorMessage.contains("uncaught exception")) + uncaughtExceptions += 1 + super.processException(errorMessage, throwable) + } } } @@ -1766,12 +1776,19 @@ class SocketServerTest { // Enable data from `Selector.poll()` to be deferred to a subsequent poll() until // the number of elements of that type reaches `minPerPoll`. This enables tests to verify // that failed processing doesn't impact subsequent processing within the same iteration. - class PollData[T] { + abstract class PollData[T] { var minPerPoll = 1 val deferredValues = mutable.Buffer[T]() - val currentPollValues = mutable.Buffer[T]() - def update(newValues: mutable.Buffer[T]): Unit = { - if (currentPollValues.nonEmpty || deferredValues.size + newValues.size >= minPerPoll) { + + /** + * Process new results and return the results for the current poll if at least + * `minPerPoll` results are available including any deferred results. Otherwise + * add the provided values to the deferred set and return an empty buffer. This allows + * tests to process `minPerPoll` elements as the results of a single poll iteration. + */ + protected def update(newValues: mutable.Buffer[T]): mutable.Buffer[T] = { + val currentPollValues = mutable.Buffer[T]() + if (deferredValues.size + newValues.size >= minPerPoll) { if (deferredValues.nonEmpty) { currentPollValues ++= deferredValues deferredValues.clear() @@ -1779,14 +1796,49 @@ class SocketServerTest { currentPollValues ++= newValues } else deferredValues ++= newValues + + currentPollValues } - def reset(): Unit = { - currentPollValues.clear() + + /** + * Process results from the appropriate buffer in Selector and update the buffer to either + * defer and return nothing or return all results including previously deferred values. + */ + def updateResults(): Unit + } + + class CompletedReceivesPollData(selector: TestableSelector) extends PollData[NetworkReceive] { + val completedReceivesMap: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(selector, classOf[Selector], "completedReceives") + + override def updateResults(): Unit = { + val currentReceives = update(selector.completedReceives.asScala.toBuffer) + completedReceivesMap.clear() + currentReceives.foreach { receive => + val channelOpt = Option(selector.channel(receive.source)).orElse(Option(selector.closingChannel(receive.source))) + channelOpt.foreach { channel => completedReceivesMap.put(channel.id, receive) } + } } } - val cachedCompletedReceives = new PollData[NetworkReceive]() - val cachedCompletedSends = new PollData[Send]() - val cachedDisconnected = new PollData[(String, ChannelState)]() + + class CompletedSendsPollData(selector: TestableSelector) extends PollData[Send] { + override def updateResults(): Unit = { + val currentSends = update(selector.completedSends.asScala) + selector.completedSends.clear() + currentSends.foreach { selector.completedSends.add } + } + } + + class DisconnectedPollData(selector: TestableSelector) extends PollData[(String, ChannelState)] { + override def updateResults(): Unit = { + val currentDisconnected = update(selector.disconnected.asScala.toBuffer) + selector.disconnected.clear() + currentDisconnected.foreach { case (channelId, state) => selector.disconnected.put(channelId, state) } + } + } + + val cachedCompletedReceives = new CompletedReceivesPollData(this) + val cachedCompletedSends = new CompletedSendsPollData(this) + val cachedDisconnected = new DisconnectedPollData(this) val allCachedPollData = Seq(cachedCompletedReceives, cachedCompletedSends, cachedDisconnected) val pendingClosingChannels = new ConcurrentLinkedQueue[KafkaChannel]() @volatile var minWakeupCount = 0 @@ -1833,20 +1885,23 @@ class SocketServerTest { override def poll(timeout: Long): Unit = { try { + assertEquals(0, super.completedReceives().size) + assertEquals(0, super.completedSends().size) + pollCallback.apply() while (!pendingClosingChannels.isEmpty) { makeClosing(pendingClosingChannels.poll()) } - allCachedPollData.foreach(_.reset) runOp(SelectorOperation.Poll, None) { super.poll(pollTimeoutOverride.getOrElse(timeout)) } } finally { super.channels.asScala.foreach(allChannels += _.id) allDisconnectedChannels ++= super.disconnected.asScala.keys - cachedCompletedReceives.update(super.completedReceives.asScala.toBuffer) - cachedCompletedSends.update(super.completedSends.asScala) - cachedDisconnected.update(super.disconnected.asScala.toBuffer) + + cachedCompletedReceives.updateResults() + cachedCompletedSends.updateResults() + cachedDisconnected.updateResults() } } @@ -1871,12 +1926,6 @@ class SocketServerTest { } } - override def disconnected: java.util.Map[String, ChannelState] = cachedDisconnected.currentPollValues.toMap.asJava - - override def completedSends: java.util.List[Send] = cachedCompletedSends.currentPollValues.asJava - - override def completedReceives: java.util.List[NetworkReceive] = cachedCompletedReceives.currentPollValues.asJava - override def close(id: String): Unit = { runOp(SelectorOperation.Close, Some(id)) { super.close(id)