This is an automated email from the ASF dual-hosted git repository. zuston pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push: new 917e45e17 [#2369] fix(spark): Potential race condition on reading prefetch (#2475) 917e45e17 is described below commit 917e45e1734dac04d01303895bacee7db6e86a9a Author: Junfan Zhang <zus...@apache.org> AuthorDate: Thu May 8 15:34:21 2025 +0800 [#2369] fix(spark): Potential race condition on reading prefetch (#2475) ### What changes were proposed in this pull request? fix potential race condition on reading prefetch ### Why are the changes needed? fix #2369 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Exsiting tests --- .../uniffle/client/impl/ShuffleReadClientImpl.java | 10 +++-- .../client/impl/ShuffleReadClientImplTest.java | 8 ++-- .../org/apache/uniffle/common/util/RssUtils.java | 45 ++++++++++++++++++++++ .../test/ShuffleServerFaultToleranceTest.java | 4 +- .../test/ShuffleServerWithMemLocalHadoopTest.java | 20 +++++----- .../uniffle/test/ShuffleServerWithMemoryTest.java | 30 ++++++++------- .../uniffle/server/ShuffleFlushManagerTest.java | 3 +- .../uniffle/server/ShuffleTaskManagerTest.java | 3 +- .../storage/factory/ShuffleHandlerFactory.java | 2 +- .../handler/impl/DataSkippableReadHandler.java | 14 ++++--- .../handler/impl/HadoopClientReadHandler.java | 7 ++-- .../handler/impl/HadoopShuffleReadHandler.java | 5 ++- .../handler/impl/LocalFileClientReadHandler.java | 5 ++- .../impl/MultiReplicaClientReadHandler.java | 5 ++- .../request/CreateShuffleReadHandlerRequest.java | 7 ++-- .../handler/impl/HadoopClientReadHandlerTest.java | 3 +- .../storage/handler/impl/HadoopHandlerTest.java | 5 ++- .../handler/impl/HadoopShuffleReadHandlerTest.java | 5 ++- .../impl/LocalFileServerReadHandlerTest.java | 4 +- 19 files changed, 126 insertions(+), 59 deletions(-) diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java index 8383d5b66..44e8a3f67 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java @@ -20,6 +20,8 @@ package org.apache.uniffle.client.impl; import java.nio.ByteBuffer; import java.util.List; import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import com.google.common.annotations.VisibleForTesting; @@ -62,7 +64,7 @@ public class ShuffleReadClientImpl implements ShuffleReadClient { private Roaring64NavigableMap blockIdBitmap; private Roaring64NavigableMap taskIdBitmap; private Roaring64NavigableMap pendingBlockIds; - private Roaring64NavigableMap processedBlockIds = Roaring64NavigableMap.bitmapOf(); + private Set<Long> processedBlockIds = ConcurrentHashMap.newKeySet(); private Queue<BufferSegment> bufferSegmentQueue = Queues.newLinkedBlockingQueue(); private AtomicLong readDataTime = new AtomicLong(0); private AtomicLong copyTime = new AtomicLong(0); @@ -270,7 +272,7 @@ public class ShuffleReadClientImpl implements ShuffleReadClient { } // mark block as processed - processedBlockIds.addLong(bs.getBlockId()); + processedBlockIds.add(bs.getBlockId()); pendingBlockIds.removeLong(bs.getBlockId()); // only update the statistics of necessary blocks clientReadHandler.updateConsumedBlockInfo(bs, false); @@ -278,7 +280,7 @@ public class ShuffleReadClientImpl implements ShuffleReadClient { } clientReadHandler.updateConsumedBlockInfo(bs, true); // mark block as processed - processedBlockIds.addLong(bs.getBlockId()); + processedBlockIds.add(bs.getBlockId()); pendingBlockIds.removeLong(bs.getBlockId()); } @@ -293,7 +295,7 @@ public class ShuffleReadClientImpl implements ShuffleReadClient { } @VisibleForTesting - protected Roaring64NavigableMap getProcessedBlockIds() { + protected Set<Long> getProcessedBlockIds() { return processedBlockIds; } diff --git a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java index 5d0d4b410..396657f7a 100644 --- a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java +++ b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java @@ -512,7 +512,7 @@ public class ShuffleReadClientImplTest extends HadoopTestBase { .taskIdBitmap(taskIdBitmap) .build(); TestUtils.validateResult(readClient, expectedData); - assertEquals(15, readClient.getProcessedBlockIds().getLongCardinality()); + assertEquals(15, readClient.getProcessedBlockIds().size()); readClient.checkProcessedBlockIds(); readClient.close(); } @@ -569,7 +569,7 @@ public class ShuffleReadClientImplTest extends HadoopTestBase { // note that skipped block ids in blockIdBitmap will be removed by `build()` assertEquals(10, blockIdBitmap.getIntCardinality()); TestUtils.validateResult(readClient, expectedData); - assertEquals(20, readClient.getProcessedBlockIds().getLongCardinality()); + assertEquals(20, readClient.getProcessedBlockIds().size()); readClient.checkProcessedBlockIds(); readClient.close(); @@ -612,7 +612,7 @@ public class ShuffleReadClientImplTest extends HadoopTestBase { .taskIdBitmap(taskIdBitmap) .build(); TestUtils.validateResult(readClient, expectedData); - assertEquals(15, readClient.getProcessedBlockIds().getLongCardinality()); + assertEquals(15, readClient.getProcessedBlockIds().size()); readClient.checkProcessedBlockIds(); readClient.close(); } @@ -641,7 +641,7 @@ public class ShuffleReadClientImplTest extends HadoopTestBase { .taskIdBitmap(taskIdBitmap) .build(); TestUtils.validateResult(readClient, expectedData); - assertEquals(25, readClient.getProcessedBlockIds().getLongCardinality()); + assertEquals(25, readClient.getProcessedBlockIds().size()); readClient.checkProcessedBlockIds(); readClient.close(); } diff --git a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java index eb0bc4a4c..925f4601c 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java +++ b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java @@ -36,6 +36,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Enumeration; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -272,6 +273,29 @@ public class RssUtils { return clone; } + public static Roaring64NavigableMap toBitmap(Set<Long> sets) { + Roaring64NavigableMap bitmap = new Roaring64NavigableMap(); + + for (Long value : sets) { + if (value != null) { + bitmap.addLong(value); + } + } + + return bitmap; + } + + public static Set<Long> toSet(Roaring64NavigableMap bitmap) { + Set<Long> result = new HashSet<>(); + Iterator<Long> it = bitmap.iterator(); + + while (it.hasNext()) { + result.add(it.next()); + } + + return result; + } + public static String generateShuffleKey(String appId, int shuffleId) { return String.join(Constants.KEY_SPLIT_CHAR, appId, String.valueOf(shuffleId)); } @@ -379,6 +403,27 @@ public class RssUtils { return serverToPartitions; } + public static void checkProcessedBlockIds( + Roaring64NavigableMap exceptedBlockIds, Set<Long> processedBlockIds) { + Iterator<Long> it = exceptedBlockIds.iterator(); + int expectedCount = 0; + int actualCount = 0; + while (it.hasNext()) { + expectedCount++; + if (processedBlockIds.contains(it.next())) { + actualCount++; + } + } + if (expectedCount != actualCount) { + throw new RssException( + "Blocks read inconsistent: expected " + + expectedCount + + " blocks, actual " + + actualCount + + " blocks"); + } + } + public static void checkProcessedBlockIds( Roaring64NavigableMap blockIdBitmap, Roaring64NavigableMap processedBlockIds) { // processedBlockIds can be a superset of blockIdBitmap, diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java index 596dd8dd6..0554e7d51 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java @@ -21,6 +21,8 @@ import java.io.File; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Stream; import com.google.common.collect.Lists; @@ -265,7 +267,7 @@ public class ShuffleServerFaultToleranceTest extends ShuffleReadWriteBase { request.setShuffleServerInfoList(shuffleServerInfoList); request.setHadoopConf(conf); request.setExpectBlockIds(expectBlockIds); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); request.setProcessBlockIds(processBlockIds); request.setDistributionType(ShuffleDataDistributionType.NORMAL); Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0); diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHadoopTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHadoopTest.java index 065dc0af2..c0fa5a222 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHadoopTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHadoopTest.java @@ -20,6 +20,8 @@ package org.apache.uniffle.test; import java.io.File; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Stream; import com.google.common.collect.Lists; @@ -169,7 +171,7 @@ public class ShuffleServerWithMemLocalHadoopTest extends ShuffleReadWriteBase { RssSendShuffleDataResponse response = shuffleServerClient.sendShuffleData(rssdr); assertSame(StatusCode.SUCCESS, response.getStatusCode()); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); Roaring64NavigableMap exceptTaskIds = Roaring64NavigableMap.bitmapOf(0); // read the 1-th segment from memory MemoryClientReadHandler memoryClientReadHandler = @@ -220,9 +222,9 @@ public class ShuffleServerWithMemLocalHadoopTest extends ShuffleReadWriteBase { expectedData.put(blocks.get(2).getBlockId(), ByteBufUtils.readBytes(blocks.get(2).getData())); ShuffleDataResult sdr = composedClientReadHandler.readShuffleData(); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks.get(0).getBlockId()); - processBlockIds.addLong(blocks.get(1).getBlockId()); - processBlockIds.addLong(blocks.get(2).getBlockId()); + processBlockIds.add(blocks.get(0).getBlockId()); + processBlockIds.add(blocks.get(1).getBlockId()); + processBlockIds.add(blocks.get(2).getBlockId()); sdr.getBufferSegments() .forEach(bs -> composedClientReadHandler.updateConsumedBlockInfo(bs, checkSkippedMetrics)); @@ -245,8 +247,8 @@ public class ShuffleServerWithMemLocalHadoopTest extends ShuffleReadWriteBase { expectedData.put(blocks2.get(0).getBlockId(), ByteBufUtils.readBytes(blocks2.get(0).getData())); expectedData.put(blocks2.get(1).getBlockId(), ByteBufUtils.readBytes(blocks2.get(1).getData())); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks2.get(0).getBlockId()); - processBlockIds.addLong(blocks2.get(1).getBlockId()); + processBlockIds.add(blocks2.get(0).getBlockId()); + processBlockIds.add(blocks2.get(1).getBlockId()); sdr.getBufferSegments() .forEach(bs -> composedClientReadHandler.updateConsumedBlockInfo(bs, checkSkippedMetrics)); @@ -255,7 +257,7 @@ public class ShuffleServerWithMemLocalHadoopTest extends ShuffleReadWriteBase { expectedData.clear(); expectedData.put(blocks2.get(2).getBlockId(), ByteBufUtils.readBytes(blocks2.get(2).getData())); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks2.get(2).getBlockId()); + processBlockIds.add(blocks2.get(2).getBlockId()); sdr.getBufferSegments() .forEach(bs -> composedClientReadHandler.updateConsumedBlockInfo(bs, checkSkippedMetrics)); @@ -277,8 +279,8 @@ public class ShuffleServerWithMemLocalHadoopTest extends ShuffleReadWriteBase { expectedData.put(blocks3.get(0).getBlockId(), ByteBufUtils.readBytes(blocks3.get(0).getData())); expectedData.put(blocks3.get(1).getBlockId(), ByteBufUtils.readBytes(blocks3.get(1).getData())); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks3.get(0).getBlockId()); - processBlockIds.addLong(blocks3.get(1).getBlockId()); + processBlockIds.add(blocks3.get(0).getBlockId()); + processBlockIds.add(blocks3.get(1).getBlockId()); sdr.getBufferSegments() .forEach(bs -> composedClientReadHandler.updateConsumedBlockInfo(bs, checkSkippedMetrics)); diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java index e86fe1c10..c31d8ab33 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java @@ -20,6 +20,8 @@ package org.apache.uniffle.test; import java.io.File; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Stream; import com.google.common.collect.Lists; @@ -193,7 +195,7 @@ public class ShuffleServerWithMemoryTest extends ShuffleReadWriteBase { new MemoryClientReadHandler( testAppId, shuffleId, partitionId, 50, shuffleServerClient, exceptTaskIds); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); LocalFileClientReadHandler localFileQuorumClientReadHandler = new LocalFileClientReadHandler( testAppId, @@ -226,8 +228,8 @@ public class ShuffleServerWithMemoryTest extends ShuffleReadWriteBase { validateResult(expectedData, sdr); // send data to shuffle server, flush should happen - processBlockIds.addLong(blocks.get(0).getBlockId()); - processBlockIds.addLong(blocks.get(1).getBlockId()); + processBlockIds.add(blocks.get(0).getBlockId()); + processBlockIds.add(blocks.get(1).getBlockId()); List<ShuffleBlockInfo> blocks2 = createShuffleBlockList(shuffleId, partitionId, 0, 3, 50, expectBlockIds, dataMap, mockSSI); @@ -260,20 +262,20 @@ public class ShuffleServerWithMemoryTest extends ShuffleReadWriteBase { expectedData.put(blocks.get(2).getBlockId(), ByteBufUtils.readBytes(blocks.get(2).getData())); expectedData.put(blocks2.get(0).getBlockId(), ByteBufUtils.readBytes(blocks2.get(0).getData())); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks.get(2).getBlockId()); - processBlockIds.addLong(blocks2.get(0).getBlockId()); + processBlockIds.add(blocks.get(2).getBlockId()); + processBlockIds.add(blocks2.get(0).getBlockId()); sdr = composedClientReadHandler.readShuffleData(); expectedData.clear(); expectedData.put(blocks2.get(1).getBlockId(), ByteBufUtils.readBytes(blocks2.get(1).getData())); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks2.get(1).getBlockId()); + processBlockIds.add(blocks2.get(1).getBlockId()); sdr = composedClientReadHandler.readShuffleData(); expectedData.clear(); expectedData.put(blocks2.get(2).getBlockId(), ByteBufUtils.readBytes(blocks2.get(2).getData())); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks2.get(2).getBlockId()); + processBlockIds.add(blocks2.get(2).getBlockId()); sdr = composedClientReadHandler.readShuffleData(); assertNull(sdr); @@ -398,7 +400,7 @@ public class ShuffleServerWithMemoryTest extends ShuffleReadWriteBase { assertSame(StatusCode.SUCCESS, response.getStatusCode()); // read the 1-th segment from memory - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); Roaring64NavigableMap exceptTaskIds = Roaring64NavigableMap.bitmapOf(0); MemoryClientReadHandler memoryClientReadHandler = new MemoryClientReadHandler( @@ -434,9 +436,9 @@ public class ShuffleServerWithMemoryTest extends ShuffleReadWriteBase { expectedData.put(blocks.get(2).getBlockId(), ByteBufUtils.readBytes(blocks.get(2).getData())); ShuffleDataResult sdr = composedClientReadHandler.readShuffleData(); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks.get(0).getBlockId()); - processBlockIds.addLong(blocks.get(1).getBlockId()); - processBlockIds.addLong(blocks.get(2).getBlockId()); + processBlockIds.add(blocks.get(0).getBlockId()); + processBlockIds.add(blocks.get(1).getBlockId()); + processBlockIds.add(blocks.get(2).getBlockId()); // send data to shuffle server, and wait until flush finish List<ShuffleBlockInfo> blocks2 = @@ -470,15 +472,15 @@ public class ShuffleServerWithMemoryTest extends ShuffleReadWriteBase { expectedData.put(blocks2.get(0).getBlockId(), ByteBufUtils.readBytes(blocks2.get(0).getData())); expectedData.put(blocks2.get(1).getBlockId(), ByteBufUtils.readBytes(blocks2.get(1).getData())); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks2.get(0).getBlockId()); - processBlockIds.addLong(blocks2.get(1).getBlockId()); + processBlockIds.add(blocks2.get(0).getBlockId()); + processBlockIds.add(blocks2.get(1).getBlockId()); // read the 3-th segment from localFile sdr = composedClientReadHandler.readShuffleData(); expectedData.clear(); expectedData.put(blocks2.get(2).getBlockId(), ByteBufUtils.readBytes(blocks2.get(2).getData())); validateResult(expectedData, sdr); - processBlockIds.addLong(blocks2.get(2).getBlockId()); + processBlockIds.add(blocks2.get(2).getBlockId()); // all segments are processed sdr = composedClientReadHandler.readShuffleData(); diff --git a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java index ad3faf210..5d4bffb2a 100644 --- a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java +++ b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java @@ -30,6 +30,7 @@ import java.util.EnumSet; import java.util.List; import java.util.Random; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.ReentrantReadWriteLock; @@ -740,7 +741,7 @@ public class ShuffleFlushManagerTest extends HadoopTestBase { int partitionNumPerRange, String basePath) { Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf(); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); Set<Long> remainIds = Sets.newHashSet(); for (ShufflePartitionedBlock spb : blocks) { expectBlockIds.addLong(spb.getBlockId()); diff --git a/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java b/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java index c6484b175..b9cdf1e24 100644 --- a/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java +++ b/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -1134,7 +1135,7 @@ public class ShuffleTaskManagerTest extends HadoopTestBase { List<ShufflePartitionedBlock> blocks, String basePath) { Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf(); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); Set<Long> remainIds = Sets.newHashSet(); for (ShufflePartitionedBlock spb : blocks) { expectBlockIds.addLong(spb.getBlockId()); diff --git a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java index 84550ab1a..356cfc1dd 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java +++ b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java @@ -126,7 +126,7 @@ public class ShuffleHandlerFactory { Roaring64NavigableMap expectTaskIds = null; if (request.isExpectedTaskIdsBitmapFilterEnable()) { Roaring64NavigableMap realExceptBlockIds = RssUtils.cloneBitMap(request.getExpectBlockIds()); - realExceptBlockIds.xor(request.getProcessBlockIds()); + realExceptBlockIds.xor(RssUtils.toBitmap(request.getProcessBlockIds())); expectTaskIds = RssUtils.generateTaskIdBitMap(realExceptBlockIds, request.getIdHelper()); } ClientReadHandler memoryClientReadHandler = diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java index 700e99f93..58288d53c 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java @@ -17,8 +17,10 @@ package org.apache.uniffle.storage.handler.impl; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import com.google.common.collect.Lists; import org.roaringbitmap.longlong.Roaring64NavigableMap; @@ -38,7 +40,7 @@ public abstract class DataSkippableReadHandler extends PrefetchableClientReadHan protected int segmentIndex = 0; protected Roaring64NavigableMap expectBlockIds; - protected Roaring64NavigableMap processBlockIds; + protected Set<Long> processBlockIds; protected ShuffleDataDistributionType distributionType; protected Roaring64NavigableMap expectTaskIds; @@ -49,7 +51,7 @@ public abstract class DataSkippableReadHandler extends PrefetchableClientReadHan int partitionId, int readBufferSize, Roaring64NavigableMap expectBlockIds, - Roaring64NavigableMap processBlockIds, + Set<Long> processBlockIds, ShuffleDataDistributionType distributionType, Roaring64NavigableMap expectTaskIds, Optional<PrefetchOption> prefetchOption) { @@ -90,13 +92,13 @@ public abstract class DataSkippableReadHandler extends PrefetchableClientReadHan ShuffleDataResult result = null; while (segmentIndex < shuffleDataSegments.size()) { ShuffleDataSegment segment = shuffleDataSegments.get(segmentIndex); - Roaring64NavigableMap blocksOfSegment = Roaring64NavigableMap.bitmapOf(); - segment.getBufferSegments().forEach(block -> blocksOfSegment.addLong(block.getBlockId())); + Set<Long> blocksOfSegment = new HashSet<>(); + segment.getBufferSegments().forEach(block -> blocksOfSegment.add(block.getBlockId())); // skip unexpected blockIds - blocksOfSegment.and(expectBlockIds); + blocksOfSegment.removeIf(blockId -> !expectBlockIds.contains(blockId)); if (!blocksOfSegment.isEmpty()) { // skip processed blockIds - blocksOfSegment.andNot(processBlockIds); + blocksOfSegment.removeAll(processBlockIds); if (!blocksOfSegment.isEmpty()) { result = readShuffleData(segment); segmentIndex++; diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java index a316a3028..161740786 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java @@ -21,6 +21,7 @@ import java.io.FileNotFoundException; import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import com.google.common.collect.Lists; @@ -49,7 +50,7 @@ public class HadoopClientReadHandler extends AbstractClientReadHandler { protected final int readBufferSize; private final String shuffleServerId; protected Roaring64NavigableMap expectBlockIds; - protected Roaring64NavigableMap processBlockIds; + protected Set<Long> processBlockIds; protected final String storageBasePath; protected final Configuration hadoopConf; protected final List<HadoopShuffleReadHandler> readHandlers = Lists.newArrayList(); @@ -69,7 +70,7 @@ public class HadoopClientReadHandler extends AbstractClientReadHandler { int partitionNum, int readBufferSize, Roaring64NavigableMap expectBlockIds, - Roaring64NavigableMap processBlockIds, + Set<Long> processBlockIds, String storageBasePath, Configuration hadoopConf, ShuffleDataDistributionType distributionType, @@ -107,7 +108,7 @@ public class HadoopClientReadHandler extends AbstractClientReadHandler { int partitionNum, int readBufferSize, Roaring64NavigableMap expectBlockIds, - Roaring64NavigableMap processBlockIds, + Set<Long> processBlockIds, String storageBasePath, Configuration hadoopConf) { this( diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java index f3ecc16cf..f2153d2be 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; import java.util.Optional; +import java.util.Set; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -54,7 +55,7 @@ public class HadoopShuffleReadHandler extends DataSkippableReadHandler { String filePrefix, int readBufferSize, Roaring64NavigableMap expectBlockIds, - Roaring64NavigableMap processBlockIds, + Set<Long> processBlockIds, Configuration conf, ShuffleDataDistributionType distributionType, Roaring64NavigableMap expectTaskIds, @@ -87,7 +88,7 @@ public class HadoopShuffleReadHandler extends DataSkippableReadHandler { String filePrefix, int readBufferSize, Roaring64NavigableMap expectBlockIds, - Roaring64NavigableMap processBlockIds, + Set<Long> processBlockIds, Configuration conf) throws Exception { this( diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java index d06675808..cfdf5fdb7 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java @@ -18,6 +18,7 @@ package org.apache.uniffle.storage.handler.impl; import java.util.Optional; +import java.util.Set; import com.google.common.annotations.VisibleForTesting; import org.roaringbitmap.longlong.Roaring64NavigableMap; @@ -55,7 +56,7 @@ public class LocalFileClientReadHandler extends DataSkippableReadHandler { int partitionNum, int readBufferSize, Roaring64NavigableMap expectBlockIds, - Roaring64NavigableMap processBlockIds, + Set<Long> processBlockIds, ShuffleServerClient shuffleServerClient, ShuffleDataDistributionType distributionType, Roaring64NavigableMap expectTaskIds, @@ -91,7 +92,7 @@ public class LocalFileClientReadHandler extends DataSkippableReadHandler { int partitionNum, int readBufferSize, Roaring64NavigableMap expectBlockIds, - Roaring64NavigableMap processBlockIds, + Set<Long> processBlockIds, ShuffleServerClient shuffleServerClient) { this( appId, diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MultiReplicaClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MultiReplicaClientReadHandler.java index 262052927..025ea8c82 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MultiReplicaClientReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MultiReplicaClientReadHandler.java @@ -18,6 +18,7 @@ package org.apache.uniffle.storage.handler.impl; import java.util.List; +import java.util.Set; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; @@ -37,7 +38,7 @@ public class MultiReplicaClientReadHandler extends AbstractClientReadHandler { private final List<ClientReadHandler> handlers; private final List<ShuffleServerInfo> shuffleServerInfos; private final Roaring64NavigableMap blockIdBitmap; - private final Roaring64NavigableMap processedBlockIds; + private final Set<Long> processedBlockIds; private int readHandlerIndex; @@ -45,7 +46,7 @@ public class MultiReplicaClientReadHandler extends AbstractClientReadHandler { List<ClientReadHandler> handlers, List<ShuffleServerInfo> shuffleServerInfos, Roaring64NavigableMap blockIdBitmap, - Roaring64NavigableMap processedBlockIds) { + Set<Long> processedBlockIds) { this.handlers = handlers; this.blockIdBitmap = blockIdBitmap; this.processedBlockIds = processedBlockIds; diff --git a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java index 145f93cab..a0a6b6c6e 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java +++ b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java @@ -19,6 +19,7 @@ package org.apache.uniffle.storage.request; import java.util.List; import java.util.Optional; +import java.util.Set; import org.apache.hadoop.conf.Configuration; import org.roaringbitmap.longlong.Roaring64NavigableMap; @@ -50,7 +51,7 @@ public class CreateShuffleReadHandlerRequest { private Configuration hadoopConf; private List<ShuffleServerInfo> shuffleServerInfoList; private Roaring64NavigableMap expectBlockIds; - private Roaring64NavigableMap processBlockIds; + private Set<Long> processBlockIds; private ShuffleDataDistributionType distributionType; private Roaring64NavigableMap expectTaskIds; private boolean expectedTaskIdsBitmapFilterEnable; @@ -184,11 +185,11 @@ public class CreateShuffleReadHandlerRequest { return expectBlockIds; } - public void setProcessBlockIds(Roaring64NavigableMap processBlockIds) { + public void setProcessBlockIds(Set<Long> processBlockIds) { this.processBlockIds = processBlockIds; } - public Roaring64NavigableMap getProcessBlockIds() { + public Set<Long> getProcessBlockIds() { return processBlockIds; } diff --git a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandlerTest.java b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandlerTest.java index fa684b840..e4e43f186 100644 --- a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandlerTest.java +++ b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandlerTest.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -75,7 +76,7 @@ public class HadoopClientReadHandlerTest extends HadoopTestBase { indexWriter.writeData(ByteBuffer.allocate(4).putInt(999).array()); indexWriter.close(); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); HadoopShuffleReadHandler indexReader = new HadoopShuffleReadHandler( diff --git a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopHandlerTest.java b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopHandlerTest.java index 94c5f3c33..82f80bbdc 100644 --- a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopHandlerTest.java +++ b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopHandlerTest.java @@ -22,6 +22,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Random; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -98,9 +99,9 @@ public class HadoopHandlerTest extends HadoopTestBase { List<Long> expectedBlockId) throws IllegalStateException { Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf(); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); for (long blockId : expectedBlockId) { - expectBlockIds.addLong(blockId); + expectBlockIds.add(blockId); } // read directly and compare HadoopClientReadHandler readHandler = diff --git a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java index 1b2b7e93b..0d3b5b80e 100644 --- a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java +++ b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -66,7 +67,7 @@ public class HadoopShuffleReadHandlerTest extends HadoopTestBase { HadoopShuffleHandlerTestBase.calcExpectedSegmentNum( expectTotalBlockNum, blockSize, readBufferSize); Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf(); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); expectedData.forEach((id, block) -> expectBlockIds.addLong(id)); String fileNamePrefix = ShuffleStorageUtils.getFullShuffleDataFolder( @@ -130,7 +131,7 @@ public class HadoopShuffleReadHandlerTest extends HadoopTestBase { HadoopShuffleHandlerTestBase.calcExpectedSegmentNum( expectTotalBlockNum, blockSize, readBufferSize); Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf(); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); expectedData.forEach((id, block) -> expectBlockIds.addLong(id)); String fileNamePrefix = ShuffleStorageUtils.getFullShuffleDataFolder( diff --git a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java index 5b2db2daf..854c4757d 100644 --- a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java +++ b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java @@ -21,6 +21,8 @@ import java.nio.ByteBuffer; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import com.google.common.collect.Maps; @@ -122,7 +124,7 @@ public class LocalFileServerReadHandlerTest { .when(mockShuffleServerClient) .getShuffleData(Mockito.argThat(segment2Match)); - Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf(); + Set<Long> processBlockIds = ConcurrentHashMap.newKeySet(); LocalFileClientReadHandler handler = new LocalFileClientReadHandler( appId,