This is an automated email from the ASF dual-hosted git repository.
mmerli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar.git
The following commit(s) were added to refs/heads/master by this push:
new c39f9f82b42 [fix][ml] Fix race conditions in RangeCache (#22789)
c39f9f82b42 is described below
commit c39f9f82b425c66c899f818583714c9c98d3e213
Author: Lari Hotari <[email protected]>
AuthorDate: Fri May 31 03:25:52 2024 +0300
[fix][ml] Fix race conditions in RangeCache (#22789)
---
.../apache/bookkeeper/mledger/impl/EntryImpl.java | 7 +-
.../apache/bookkeeper/mledger/util/RangeCache.java | 278 ++++++++++++++++-----
.../bookkeeper/mledger/util/RangeCacheTest.java | 63 +++--
3 files changed, 254 insertions(+), 94 deletions(-)
diff --git
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java
index 80397931357..48a79a4ac52 100644
---
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java
+++
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java
@@ -27,9 +27,10 @@ import io.netty.util.ReferenceCounted;
import org.apache.bookkeeper.client.api.LedgerEntry;
import org.apache.bookkeeper.mledger.Entry;
import org.apache.bookkeeper.mledger.util.AbstractCASReferenceCounted;
+import org.apache.bookkeeper.mledger.util.RangeCache;
public final class EntryImpl extends AbstractCASReferenceCounted implements
Entry, Comparable<EntryImpl>,
- ReferenceCounted {
+ RangeCache.ValueWithKeyValidation<PositionImpl> {
private static final Recycler<EntryImpl> RECYCLER = new
Recycler<EntryImpl>() {
@Override
@@ -205,4 +206,8 @@ public final class EntryImpl extends
AbstractCASReferenceCounted implements Entr
recyclerHandle.recycle(this);
}
+ @Override
+ public boolean matchesKey(PositionImpl key) {
+ return key.compareTo(ledgerId, entryId) == 0;
+ }
}
diff --git
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
index d34857e5e51..46d03bea1b5 100644
---
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
+++
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
@@ -19,31 +19,134 @@
package org.apache.bookkeeper.mledger.util;
import static com.google.common.base.Preconditions.checkArgument;
+import com.google.common.base.Predicate;
+import io.netty.util.IllegalReferenceCountException;
+import io.netty.util.Recycler;
+import io.netty.util.Recycler.Handle;
import io.netty.util.ReferenceCounted;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.concurrent.ConcurrentNavigableMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.atomic.AtomicLong;
+import org.apache.bookkeeper.mledger.util.RangeCache.ValueWithKeyValidation;
import org.apache.commons.lang3.tuple.Pair;
/**
* Special type of cache where get() and delete() operations can be done over
a range of keys.
+ * The implementation avoids locks and synchronization and relies on
ConcurrentSkipListMap for storing the entries.
+ * Since there is no locks, there is a need to have a way to ensure that a
single entry in the cache is removed
+ * exactly once. Removing an entry multiple times would result in the entries
of the cache getting released too
+ * while they could still be in use.
*
* @param <Key>
* Cache key. Needs to be Comparable
* @param <Value>
* Cache value
*/
-public class RangeCache<Key extends Comparable<Key>, Value extends
ReferenceCounted> {
+public class RangeCache<Key extends Comparable<Key>, Value extends
ValueWithKeyValidation<Key>> {
+ public interface ValueWithKeyValidation<T> extends ReferenceCounted {
+ boolean matchesKey(T key);
+ }
+
// Map from key to nodes inside the linked list
- private final ConcurrentNavigableMap<Key, Value> entries;
+ private final ConcurrentNavigableMap<Key, IdentityWrapper<Key, Value>>
entries;
private AtomicLong size; // Total size of values stored in cache
private final Weighter<Value> weighter; // Weighter object used to extract
the size from values
private final TimestampExtractor<Value> timestampExtractor; // Extract the
timestamp associated with a value
+ /**
+ * Wrapper around the value to store in Map. This is needed to ensure that
a specific instance can be removed from
+ * the map by calling the {@link Map#remove(Object, Object)} method.
Certain race conditions could result in the
+ * wrong value being removed from the map. The instances of this class are
recycled to avoid creating new objects.
+ */
+ private static class IdentityWrapper<K, V> {
+ private final Handle<IdentityWrapper> recyclerHandle;
+ private static final Recycler<IdentityWrapper> RECYCLER = new
Recycler<IdentityWrapper>() {
+ @Override
+ protected IdentityWrapper newObject(Handle<IdentityWrapper>
recyclerHandle) {
+ return new IdentityWrapper(recyclerHandle);
+ }
+ };
+ private K key;
+ private V value;
+
+ private IdentityWrapper(Handle<IdentityWrapper> recyclerHandle) {
+ this.recyclerHandle = recyclerHandle;
+ }
+
+ static <K, V> IdentityWrapper<K, V> create(K key, V value) {
+ IdentityWrapper<K, V> identityWrapper = RECYCLER.get();
+ identityWrapper.key = key;
+ identityWrapper.value = value;
+ return identityWrapper;
+ }
+
+ K getKey() {
+ return key;
+ }
+
+ V getValue() {
+ return value;
+ }
+
+ void recycle() {
+ value = null;
+ recyclerHandle.recycle(this);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ // only match exact identity of the value
+ return this == o;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(key);
+ }
+ }
+
+ /**
+ * Mutable object to store the number of entries and the total size
removed from the cache. The instances
+ * are recycled to avoid creating new instances.
+ */
+ private static class RemovalCounters {
+ private final Handle<RemovalCounters> recyclerHandle;
+ private static final Recycler<RemovalCounters> RECYCLER = new
Recycler<RemovalCounters>() {
+ @Override
+ protected RemovalCounters newObject(Handle<RemovalCounters>
recyclerHandle) {
+ return new RemovalCounters(recyclerHandle);
+ }
+ };
+ int removedEntries;
+ long removedSize;
+ private RemovalCounters(Handle<RemovalCounters> recyclerHandle) {
+ this.recyclerHandle = recyclerHandle;
+ }
+
+ static <T> RemovalCounters create() {
+ RemovalCounters results = RECYCLER.get();
+ results.removedEntries = 0;
+ results.removedSize = 0;
+ return results;
+ }
+
+ void recycle() {
+ removedEntries = 0;
+ removedSize = 0;
+ recyclerHandle.recycle(this);
+ }
+
+ public void entryRemoved(long size) {
+ removedSize += size;
+ removedEntries++;
+ }
+ }
+
/**
* Construct a new RangeLruCache with default Weighter.
*/
@@ -68,18 +171,23 @@ public class RangeCache<Key extends Comparable<Key>, Value
extends ReferenceCoun
* Insert.
*
* @param key
- * @param value
- * ref counted value with at least 1 ref to pass on the cache
+ * @param value ref counted value with at least 1 ref to pass on the cache
* @return whether the entry was inserted in the cache
*/
public boolean put(Key key, Value value) {
// retain value so that it's not released before we put it in the
cache and calculate the weight
value.retain();
try {
- if (entries.putIfAbsent(key, value) == null) {
+ if (!value.matchesKey(key)) {
+ throw new IllegalArgumentException("Value '" + value + "' does
not match key '" + key + "'");
+ }
+ IdentityWrapper<Key, Value> newWrapper =
IdentityWrapper.create(key, value);
+ if (entries.putIfAbsent(key, newWrapper) == null) {
size.addAndGet(weighter.getSize(value));
return true;
} else {
+ // recycle the new wrapper as it was not used
+ newWrapper.recycle();
return false;
}
} finally {
@@ -91,16 +199,37 @@ public class RangeCache<Key extends Comparable<Key>, Value
extends ReferenceCoun
return key != null ? entries.containsKey(key) : true;
}
+ /**
+ * Get the value associated with the key and increment the reference count
of it.
+ * The caller is responsible for releasing the reference.
+ */
public Value get(Key key) {
- Value value = entries.get(key);
- if (value == null) {
+ return getValue(key, entries.get(key));
+ }
+
+ private Value getValue(Key key, IdentityWrapper<Key, Value> valueWrapper)
{
+ if (valueWrapper == null) {
return null;
} else {
+ if (valueWrapper.getKey() != key) {
+ // the wrapper has been recycled and contains another key
+ return null;
+ }
+ Value value = valueWrapper.getValue();
try {
value.retain();
+ } catch (IllegalReferenceCountException e) {
+ // Value was already deallocated
+ return null;
+ }
+ // check that the value matches the key and that there's at least
2 references to it since
+ // the cache should be holding one reference and a new reference
was just added in this method
+ if (value.refCnt() > 1 && value.matchesKey(key)) {
return value;
- } catch (Throwable t) {
- // Value was already destroyed between get() and retain()
+ } else {
+ // Value or IdentityWrapper was recycled and already contains
another value
+ // release the reference added in this method
+ value.release();
return null;
}
}
@@ -118,12 +247,10 @@ public class RangeCache<Key extends Comparable<Key>,
Value extends ReferenceCoun
List<Value> values = new ArrayList();
// Return the values of the entries found in cache
- for (Value value : entries.subMap(first, true, last, true).values()) {
- try {
- value.retain();
+ for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry :
entries.subMap(first, true, last, true).entrySet()) {
+ Value value = getValue(entry.getKey(), entry.getValue());
+ if (value != null) {
values.add(value);
- } catch (Throwable t) {
- // Value was already destroyed between get() and retain()
}
}
@@ -138,25 +265,65 @@ public class RangeCache<Key extends Comparable<Key>,
Value extends ReferenceCoun
* @return an pair of ints, containing the number of removed entries and
the total size
*/
public Pair<Integer, Long> removeRange(Key first, Key last, boolean
lastInclusive) {
- Map<Key, Value> subMap = entries.subMap(first, true, last,
lastInclusive);
+ RemovalCounters counters = RemovalCounters.create();
+ Map<Key, IdentityWrapper<Key, Value>> subMap = entries.subMap(first,
true, last, lastInclusive);
+ for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry :
subMap.entrySet()) {
+ removeEntry(entry, counters);
+ }
+ return handleRemovalResult(counters);
+ }
- int removedEntries = 0;
- long removedSize = 0;
+ enum RemoveEntryResult {
+ ENTRY_REMOVED,
+ CONTINUE_LOOP,
+ BREAK_LOOP;
+ }
- for (Key key : subMap.keySet()) {
- Value value = entries.remove(key);
- if (value == null) {
- continue;
- }
+ private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key,
Value>> entry, RemovalCounters counters) {
+ return removeEntry(entry, counters, (x) -> true);
+ }
- removedSize += weighter.getSize(value);
+ private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key,
Value>> entry, RemovalCounters counters,
+ Predicate<Value> removeCondition) {
+ Key key = entry.getKey();
+ IdentityWrapper<Key, Value> identityWrapper = entry.getValue();
+ if (identityWrapper.getKey() != key) {
+ // the wrapper has been recycled and contains another key
+ return RemoveEntryResult.CONTINUE_LOOP;
+ }
+ Value value = identityWrapper.getValue();
+ try {
+ // add extra retain to avoid value being released while we are
removing it
+ value.retain();
+ } catch (IllegalReferenceCountException e) {
+ // Value was already released
+ return RemoveEntryResult.CONTINUE_LOOP;
+ }
+ try {
+ if (!removeCondition.test(value)) {
+ return RemoveEntryResult.BREAK_LOOP;
+ }
+ // check that the value hasn't been recycled in between
+ // there should be at least 2 references since this method adds
one and the cache should have one
+ // it is valid that the value contains references even after the
key has been removed from the cache
+ if (value.refCnt() > 1 && value.matchesKey(key) &&
entries.remove(key, identityWrapper)) {
+ identityWrapper.recycle();
+ counters.entryRemoved(weighter.getSize(value));
+ // remove the cache reference
+ value.release();
+ }
+ } finally {
+ // remove the extra retain
value.release();
- ++removedEntries;
}
+ return RemoveEntryResult.ENTRY_REMOVED;
+ }
- size.addAndGet(-removedSize);
-
- return Pair.of(removedEntries, removedSize);
+ private Pair<Integer, Long> handleRemovalResult(RemovalCounters counters) {
+ size.addAndGet(-counters.removedSize);
+ Pair<Integer, Long> result = Pair.of(counters.removedEntries,
counters.removedSize);
+ counters.recycle();
+ return result;
}
/**
@@ -166,24 +333,15 @@ public class RangeCache<Key extends Comparable<Key>,
Value extends ReferenceCoun
*/
public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
checkArgument(minSize > 0);
-
- long removedSize = 0;
- int removedEntries = 0;
-
- while (removedSize < minSize) {
- Map.Entry<Key, Value> entry = entries.pollFirstEntry();
+ RemovalCounters counters = RemovalCounters.create();
+ while (counters.removedSize < minSize) {
+ Map.Entry<Key, IdentityWrapper<Key, Value>> entry =
entries.firstEntry();
if (entry == null) {
break;
}
-
- Value value = entry.getValue();
- ++removedEntries;
- removedSize += weighter.getSize(value);
- value.release();
+ removeEntry(entry, counters);
}
-
- size.addAndGet(-removedSize);
- return Pair.of(removedEntries, removedSize);
+ return handleRemovalResult(counters);
}
/**
@@ -192,27 +350,18 @@ public class RangeCache<Key extends Comparable<Key>,
Value extends ReferenceCoun
* @return the tota
*/
public Pair<Integer, Long> evictLEntriesBeforeTimestamp(long maxTimestamp) {
- long removedSize = 0;
- int removedCount = 0;
-
+ RemovalCounters counters = RemovalCounters.create();
while (true) {
- Map.Entry<Key, Value> entry = entries.firstEntry();
- if (entry == null ||
timestampExtractor.getTimestamp(entry.getValue()) > maxTimestamp) {
+ Map.Entry<Key, IdentityWrapper<Key, Value>> entry =
entries.firstEntry();
+ if (entry == null) {
break;
}
- Value value = entry.getValue();
- boolean removeHits = entries.remove(entry.getKey(), value);
- if (!removeHits) {
+ if (removeEntry(entry, counters, value ->
timestampExtractor.getTimestamp(value) <= maxTimestamp)
+ == RemoveEntryResult.BREAK_LOOP) {
break;
}
-
- removedSize += weighter.getSize(value);
- removedCount++;
- value.release();
}
-
- size.addAndGet(-removedSize);
- return Pair.of(removedCount, removedSize);
+ return handleRemovalResult(counters);
}
/**
@@ -231,23 +380,16 @@ public class RangeCache<Key extends Comparable<Key>,
Value extends ReferenceCoun
*
* @return size of removed entries
*/
- public synchronized Pair<Integer, Long> clear() {
- long removedSize = 0;
- int removedCount = 0;
-
+ public Pair<Integer, Long> clear() {
+ RemovalCounters counters = RemovalCounters.create();
while (true) {
- Map.Entry<Key, Value> entry = entries.pollFirstEntry();
+ Map.Entry<Key, IdentityWrapper<Key, Value>> entry =
entries.firstEntry();
if (entry == null) {
break;
}
- Value value = entry.getValue();
- removedSize += weighter.getSize(value);
- removedCount++;
- value.release();
+ removeEntry(entry, counters);
}
-
- size.getAndAdd(-removedSize);
- return Pair.of(removedCount, removedSize);
+ return handleRemovalResult(counters);
}
/**
diff --git
a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
index 8ce0db4ac4c..01b3c67bf11 100644
---
a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
+++
b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
@@ -23,25 +23,30 @@ import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
-
import com.google.common.collect.Lists;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCounted;
-import org.apache.commons.lang3.tuple.Pair;
-import org.testng.annotations.Test;
-import java.util.UUID;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
+import lombok.Cleanup;
+import org.apache.commons.lang3.tuple.Pair;
+import org.testng.annotations.Test;
public class RangeCacheTest {
- class RefString extends AbstractReferenceCounted implements
ReferenceCounted {
+ class RefString extends AbstractReferenceCounted implements
RangeCache.ValueWithKeyValidation<Integer> {
String s;
+ Integer matchingKey;
RefString(String s) {
+ this(s, null);
+ }
+
+ RefString(String s, Integer matchingKey) {
super();
this.s = s;
+ this.matchingKey = matchingKey != null ? matchingKey :
Integer.parseInt(s);
setRefCnt(1);
}
@@ -65,6 +70,11 @@ public class RangeCacheTest {
return false;
}
+
+ @Override
+ public boolean matchesKey(Integer key) {
+ return matchingKey.equals(key);
+ }
}
@Test
@@ -119,8 +129,8 @@ public class RangeCacheTest {
public void customWeighter() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value ->
value.s.length(), x -> 0);
- cache.put(0, new RefString("zero"));
- cache.put(1, new RefString("one"));
+ cache.put(0, new RefString("zero", 0));
+ cache.put(1, new RefString("one", 1));
assertEquals(cache.getSize(), 7);
assertEquals(cache.getNumberOfEntries(), 2);
@@ -132,9 +142,9 @@ public class RangeCacheTest {
RangeCache<Integer, RefString> cache = new RangeCache<>(value ->
value.s.length(), x -> x.s.length());
cache.put(1, new RefString("1"));
- cache.put(2, new RefString("22"));
- cache.put(3, new RefString("333"));
- cache.put(4, new RefString("4444"));
+ cache.put(22, new RefString("22"));
+ cache.put(333, new RefString("333"));
+ cache.put(4444, new RefString("4444"));
assertEquals(cache.getSize(), 10);
assertEquals(cache.getNumberOfEntries(), 4);
@@ -151,12 +161,12 @@ public class RangeCacheTest {
public void doubleInsert() {
RangeCache<Integer, RefString> cache = new RangeCache<>();
- RefString s0 = new RefString("zero");
+ RefString s0 = new RefString("zero", 0);
assertEquals(s0.refCnt(), 1);
assertTrue(cache.put(0, s0));
assertEquals(s0.refCnt(), 1);
- cache.put(1, new RefString("one"));
+ cache.put(1, new RefString("one", 1));
assertEquals(cache.getSize(), 2);
assertEquals(cache.getNumberOfEntries(), 2);
@@ -164,7 +174,7 @@ public class RangeCacheTest {
assertEquals(s.s, "one");
assertEquals(s.refCnt(), 2);
- RefString s1 = new RefString("uno");
+ RefString s1 = new RefString("uno", 1);
assertEquals(s1.refCnt(), 1);
assertFalse(cache.put(1, s1));
assertEquals(s1.refCnt(), 1);
@@ -201,10 +211,10 @@ public class RangeCacheTest {
public void eviction() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value ->
value.s.length(), x -> 0);
- cache.put(0, new RefString("zero"));
- cache.put(1, new RefString("one"));
- cache.put(2, new RefString("two"));
- cache.put(3, new RefString("three"));
+ cache.put(0, new RefString("zero", 0));
+ cache.put(1, new RefString("one", 1));
+ cache.put(2, new RefString("two", 2));
+ cache.put(3, new RefString("three", 3));
// This should remove the LRU entries: 0, 1 whose combined size is 7
assertEquals(cache.evictLeastAccessedEntries(5), Pair.of(2, (long) 7));
@@ -276,20 +286,23 @@ public class RangeCacheTest {
}
@Test
- public void testInParallel() {
- RangeCache<String, RefString> cache = new RangeCache<>(value ->
value.s.length(), x -> 0);
- ScheduledExecutorService executor =
Executors.newSingleThreadScheduledExecutor();
- executor.scheduleWithFixedDelay(cache::clear, 10, 10,
TimeUnit.MILLISECONDS);
- for (int i = 0; i < 1000; i++) {
- cache.put(UUID.randomUUID().toString(), new RefString("zero"));
+ public void testPutWhileClearIsCalledConcurrently() {
+ RangeCache<Integer, RefString> cache = new RangeCache<>(value ->
value.s.length(), x -> 0);
+ int numberOfThreads = 4;
+ @Cleanup("shutdownNow")
+ ScheduledExecutorService executor =
Executors.newScheduledThreadPool(numberOfThreads);
+ for (int i = 0; i < numberOfThreads; i++) {
+ executor.scheduleWithFixedDelay(cache::clear, 0, 1,
TimeUnit.MILLISECONDS);
+ }
+ for (int i = 0; i < 100000; i++) {
+ cache.put(i, new RefString(String.valueOf(i)));
}
- executor.shutdown();
}
@Test
public void testPutSameObj() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value ->
value.s.length(), x -> 0);
- RefString s0 = new RefString("zero");
+ RefString s0 = new RefString("zero", 0);
assertEquals(s0.refCnt(), 1);
assertTrue(cache.put(0, s0));
assertFalse(cache.put(0, s0));