This is an automated email from the ASF dual-hosted git repository. penghui pushed a commit to branch branch-2.8 in repository https://gitbox.apache.org/repos/asf/pulsar.git
commit 9ccfe96bae46fa4cc17130c0a2038b3b57c0cb59 Author: Lari Hotari <[email protected]> AuthorDate: Tue Sep 21 12:39:04 2021 +0300 [Client] Fix ConcurrentModificationException in sendAsync (#11884) (cherry picked from commit a1c10288f2fb011443b6edb98def1841a310157d) --- .../pulsar/client/impl/PulsarTestClient.java | 11 ++- .../apache/pulsar/client/impl/ProducerImpl.java | 78 ++++++++++++++++++-- .../pulsar/client/impl/OpSendMsgQueueTest.java | 85 ++++++++++++++++++++++ 3 files changed, 161 insertions(+), 13 deletions(-) diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/client/impl/PulsarTestClient.java b/pulsar-broker/src/test/java/org/apache/pulsar/client/impl/PulsarTestClient.java index 8fede95..eebcf5b 100644 --- a/pulsar-broker/src/test/java/org/apache/pulsar/client/impl/PulsarTestClient.java +++ b/pulsar-broker/src/test/java/org/apache/pulsar/client/impl/PulsarTestClient.java @@ -22,8 +22,6 @@ import static org.testng.Assert.assertEquals; import io.netty.channel.EventLoopGroup; import io.netty.util.concurrent.DefaultThreadFactory; import java.io.IOException; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -139,14 +137,15 @@ public class PulsarTestClient extends PulsarClientImpl { return new ProducerImpl<T>(this, topic, conf, producerCreatedFuture, partitionIndex, schema, interceptors) { @Override - protected BlockingQueue<OpSendMsg> createPendingMessagesQueue() { - return new ArrayBlockingQueue<OpSendMsg>(conf.getMaxPendingMessages()) { + protected OpSendMsgQueue createPendingMessagesQueue() { + return new OpSendMsgQueue() { @Override - public void put(OpSendMsg opSendMsg) throws InterruptedException { - super.put(opSendMsg); + public boolean add(OpSendMsg opSendMsg) { + boolean added = super.add(opSendMsg); if (pendingMessageCallback != null) { pendingMessageCallback.accept(opSendMsg); } + return added; } }; } diff --git a/pulsar-client/src/main/java/org/apache/pulsar/client/impl/ProducerImpl.java b/pulsar-client/src/main/java/org/apache/pulsar/client/impl/ProducerImpl.java index 84062fb..8ba2694 100644 --- a/pulsar-client/src/main/java/org/apache/pulsar/client/impl/ProducerImpl.java +++ b/pulsar-client/src/main/java/org/apache/pulsar/client/impl/ProducerImpl.java @@ -29,9 +29,7 @@ import static org.apache.pulsar.client.impl.ProducerBase.MultiSchemaMode.Auto; import static org.apache.pulsar.client.impl.ProducerBase.MultiSchemaMode.Enabled; import static org.apache.pulsar.common.protocol.Commands.hasChecksum; import static org.apache.pulsar.common.protocol.Commands.readChecksum; - import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Queues; import io.netty.buffer.ByteBuf; import io.netty.util.Recycler; import io.netty.util.Recycler.Handle; @@ -39,10 +37,10 @@ import io.netty.util.ReferenceCountUtil; import io.netty.util.Timeout; import io.netty.util.TimerTask; import io.netty.util.concurrent.ScheduledFuture; - import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -50,12 +48,12 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Queue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.Consumer; import org.apache.commons.lang3.StringUtils; import org.apache.pulsar.client.api.BatcherBuilder; import org.apache.pulsar.client.api.CompressionType; @@ -99,7 +97,7 @@ public class ProducerImpl<T> extends ProducerBase<T> implements TimerTask, Conne // Variable is used through the atomic updater private volatile long msgIdGenerator; - private final Queue<OpSendMsg> pendingMessages; + private final OpSendMsgQueue pendingMessages; private final Optional<Semaphore> semaphore; private volatile Timeout sendTimeout = null; private long createProducerTimeout; @@ -251,8 +249,8 @@ public class ProducerImpl<T> extends ProducerBase<T> implements TimerTask, Conne grabCnx(); } - protected Queue<OpSendMsg> createPendingMessagesQueue() { - return new ArrayDeque<>(); + protected OpSendMsgQueue createPendingMessagesQueue() { + return new OpSendMsgQueue(); } public ConnectionHandler getConnectionHandler() { @@ -1281,6 +1279,72 @@ public class ProducerImpl<T> extends ProducerBase<T> implements TimerTask, Conne }; } + /** + * Queue implementation that is used as the pending messages queue. + * + * This implementation postpones adding of new OpSendMsg entries that happen + * while the forEach call is in progress. This is needed for preventing + * ConcurrentModificationExceptions that would occur when the forEach action + * calls the add method via a callback in user code. + * + * This queue is not thread safe. + */ + protected static class OpSendMsgQueue implements Iterable<OpSendMsg> { + private final Queue<OpSendMsg> delegate = new ArrayDeque<>(); + private int forEachDepth = 0; + private List<OpSendMsg> postponedOpSendMgs; + + @Override + public void forEach(Consumer<? super OpSendMsg> action) { + try { + // track any forEach call that is in progress in the current call stack + // so that adding a new item while iterating doesn't cause ConcurrentModificationException + forEachDepth++; + delegate.forEach(action); + } finally { + forEachDepth--; + // if this is the top-most forEach call and there are postponed items, add them + if (forEachDepth == 0 && postponedOpSendMgs != null && !postponedOpSendMgs.isEmpty()) { + delegate.addAll(postponedOpSendMgs); + postponedOpSendMgs.clear(); + } + } + } + + public boolean add(OpSendMsg o) { + // postpone adding to the queue while forEach iteration is in progress + if (forEachDepth > 0) { + if (postponedOpSendMgs == null) { + postponedOpSendMgs = new ArrayList<>(); + } + return postponedOpSendMgs.add(o); + } else { + return delegate.add(o); + } + } + + public void clear() { + delegate.clear(); + } + + public void remove() { + delegate.remove(); + } + + public OpSendMsg peek() { + return delegate.peek(); + } + + public int size() { + return delegate.size(); + } + + @Override + public Iterator<OpSendMsg> iterator() { + return delegate.iterator(); + } + } + @Override public void connectionOpened(final ClientCnx cnx) { // we set the cnx reference before registering the producer on the cnx, so if the cnx breaks before creating the diff --git a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/OpSendMsgQueueTest.java b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/OpSendMsgQueueTest.java new file mode 100644 index 0000000..bf45e87 --- /dev/null +++ b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/OpSendMsgQueueTest.java @@ -0,0 +1,85 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pulsar.client.impl; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import com.google.common.collect.Lists; +import java.util.Arrays; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +/** + * Contains unit tests for ProducerImpl.OpSendMsgQueue inner class. + */ +public class OpSendMsgQueueTest { + MessageImpl<?> message; + + @BeforeClass + public void createMockMessage() { + message = mock(MessageImpl.class); + when(message.getUncompressedSize()).thenReturn(0); + } + + private ProducerImpl.OpSendMsg createDummyOpSendMsg() { + return ProducerImpl.OpSendMsg.create(message, null, 0L, null); + } + + @Test + public void shouldPostponeAddsToPreventConcurrentModificationException() { + // given + ProducerImpl.OpSendMsgQueue queue = new ProducerImpl.OpSendMsgQueue(); + ProducerImpl.OpSendMsg opSendMsg = createDummyOpSendMsg(); + ProducerImpl.OpSendMsg opSendMsg2 = createDummyOpSendMsg(); + queue.add(opSendMsg); + + // when + queue.forEach(item -> { + queue.add(opSendMsg2); + }); + + // then + assertEquals(Lists.newArrayList(queue), Arrays.asList(opSendMsg, opSendMsg2)); + } + + @Test + public void shouldPostponeAddsAlsoInRecursiveCalls() { + // given + ProducerImpl.OpSendMsgQueue queue = new ProducerImpl.OpSendMsgQueue(); + ProducerImpl.OpSendMsg opSendMsg = createDummyOpSendMsg(); + ProducerImpl.OpSendMsg opSendMsg2 = createDummyOpSendMsg(); + ProducerImpl.OpSendMsg opSendMsg3 = createDummyOpSendMsg(); + ProducerImpl.OpSendMsg opSendMsg4 = createDummyOpSendMsg(); + queue.add(opSendMsg); + + // when + queue.forEach(item -> { + queue.add(opSendMsg2); + // recursive forEach + queue.forEach(item2 -> { + queue.add(opSendMsg3); + }); + queue.add(opSendMsg4); + }); + + // then + assertEquals(Lists.newArrayList(queue), Arrays.asList(opSendMsg, opSendMsg2, opSendMsg3, opSendMsg4)); + } +}
