This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 917e45e17 [#2369] fix(spark): Potential race condition on reading 
prefetch (#2475)
917e45e17 is described below

commit 917e45e1734dac04d01303895bacee7db6e86a9a
Author: Junfan Zhang <zus...@apache.org>
AuthorDate: Thu May 8 15:34:21 2025 +0800

    [#2369] fix(spark): Potential race condition on reading prefetch (#2475)
    
    ### What changes were proposed in this pull request?
    
    fix potential race condition on reading prefetch
    
    ### Why are the changes needed?
    
    fix #2369
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Exsiting tests
---
 .../uniffle/client/impl/ShuffleReadClientImpl.java | 10 +++--
 .../client/impl/ShuffleReadClientImplTest.java     |  8 ++--
 .../org/apache/uniffle/common/util/RssUtils.java   | 45 ++++++++++++++++++++++
 .../test/ShuffleServerFaultToleranceTest.java      |  4 +-
 .../test/ShuffleServerWithMemLocalHadoopTest.java  | 20 +++++-----
 .../uniffle/test/ShuffleServerWithMemoryTest.java  | 30 ++++++++-------
 .../uniffle/server/ShuffleFlushManagerTest.java    |  3 +-
 .../uniffle/server/ShuffleTaskManagerTest.java     |  3 +-
 .../storage/factory/ShuffleHandlerFactory.java     |  2 +-
 .../handler/impl/DataSkippableReadHandler.java     | 14 ++++---
 .../handler/impl/HadoopClientReadHandler.java      |  7 ++--
 .../handler/impl/HadoopShuffleReadHandler.java     |  5 ++-
 .../handler/impl/LocalFileClientReadHandler.java   |  5 ++-
 .../impl/MultiReplicaClientReadHandler.java        |  5 ++-
 .../request/CreateShuffleReadHandlerRequest.java   |  7 ++--
 .../handler/impl/HadoopClientReadHandlerTest.java  |  3 +-
 .../storage/handler/impl/HadoopHandlerTest.java    |  5 ++-
 .../handler/impl/HadoopShuffleReadHandlerTest.java |  5 ++-
 .../impl/LocalFileServerReadHandlerTest.java       |  4 +-
 19 files changed, 126 insertions(+), 59 deletions(-)

diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 8383d5b66..44e8a3f67 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -20,6 +20,8 @@ package org.apache.uniffle.client.impl;
 import java.nio.ByteBuffer;
 import java.util.List;
 import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -62,7 +64,7 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
   private Roaring64NavigableMap blockIdBitmap;
   private Roaring64NavigableMap taskIdBitmap;
   private Roaring64NavigableMap pendingBlockIds;
-  private Roaring64NavigableMap processedBlockIds = 
Roaring64NavigableMap.bitmapOf();
+  private Set<Long> processedBlockIds = ConcurrentHashMap.newKeySet();
   private Queue<BufferSegment> bufferSegmentQueue = 
Queues.newLinkedBlockingQueue();
   private AtomicLong readDataTime = new AtomicLong(0);
   private AtomicLong copyTime = new AtomicLong(0);
@@ -270,7 +272,7 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
           }
 
           // mark block as processed
-          processedBlockIds.addLong(bs.getBlockId());
+          processedBlockIds.add(bs.getBlockId());
           pendingBlockIds.removeLong(bs.getBlockId());
           // only update the statistics of necessary blocks
           clientReadHandler.updateConsumedBlockInfo(bs, false);
@@ -278,7 +280,7 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
         }
         clientReadHandler.updateConsumedBlockInfo(bs, true);
         // mark block as processed
-        processedBlockIds.addLong(bs.getBlockId());
+        processedBlockIds.add(bs.getBlockId());
         pendingBlockIds.removeLong(bs.getBlockId());
       }
 
@@ -293,7 +295,7 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
   }
 
   @VisibleForTesting
-  protected Roaring64NavigableMap getProcessedBlockIds() {
+  protected Set<Long> getProcessedBlockIds() {
     return processedBlockIds;
   }
 
diff --git 
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java
 
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java
index 5d0d4b410..396657f7a 100644
--- 
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java
+++ 
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java
@@ -512,7 +512,7 @@ public class ShuffleReadClientImplTest extends 
HadoopTestBase {
             .taskIdBitmap(taskIdBitmap)
             .build();
     TestUtils.validateResult(readClient, expectedData);
-    assertEquals(15, readClient.getProcessedBlockIds().getLongCardinality());
+    assertEquals(15, readClient.getProcessedBlockIds().size());
     readClient.checkProcessedBlockIds();
     readClient.close();
   }
@@ -569,7 +569,7 @@ public class ShuffleReadClientImplTest extends 
HadoopTestBase {
     // note that skipped block ids in blockIdBitmap will be removed by 
`build()`
     assertEquals(10, blockIdBitmap.getIntCardinality());
     TestUtils.validateResult(readClient, expectedData);
-    assertEquals(20, readClient.getProcessedBlockIds().getLongCardinality());
+    assertEquals(20, readClient.getProcessedBlockIds().size());
     readClient.checkProcessedBlockIds();
     readClient.close();
 
@@ -612,7 +612,7 @@ public class ShuffleReadClientImplTest extends 
HadoopTestBase {
             .taskIdBitmap(taskIdBitmap)
             .build();
     TestUtils.validateResult(readClient, expectedData);
-    assertEquals(15, readClient.getProcessedBlockIds().getLongCardinality());
+    assertEquals(15, readClient.getProcessedBlockIds().size());
     readClient.checkProcessedBlockIds();
     readClient.close();
   }
@@ -641,7 +641,7 @@ public class ShuffleReadClientImplTest extends 
HadoopTestBase {
             .taskIdBitmap(taskIdBitmap)
             .build();
     TestUtils.validateResult(readClient, expectedData);
-    assertEquals(25, readClient.getProcessedBlockIds().getLongCardinality());
+    assertEquals(25, readClient.getProcessedBlockIds().size());
     readClient.checkProcessedBlockIds();
     readClient.close();
   }
diff --git a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java 
b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
index eb0bc4a4c..925f4601c 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
@@ -36,6 +36,7 @@ import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.Enumeration;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -272,6 +273,29 @@ public class RssUtils {
     return clone;
   }
 
+  public static Roaring64NavigableMap toBitmap(Set<Long> sets) {
+    Roaring64NavigableMap bitmap = new Roaring64NavigableMap();
+
+    for (Long value : sets) {
+      if (value != null) {
+        bitmap.addLong(value);
+      }
+    }
+
+    return bitmap;
+  }
+
+  public static Set<Long> toSet(Roaring64NavigableMap bitmap) {
+    Set<Long> result = new HashSet<>();
+    Iterator<Long> it = bitmap.iterator();
+
+    while (it.hasNext()) {
+      result.add(it.next());
+    }
+
+    return result;
+  }
+
   public static String generateShuffleKey(String appId, int shuffleId) {
     return String.join(Constants.KEY_SPLIT_CHAR, appId, 
String.valueOf(shuffleId));
   }
@@ -379,6 +403,27 @@ public class RssUtils {
     return serverToPartitions;
   }
 
+  public static void checkProcessedBlockIds(
+      Roaring64NavigableMap exceptedBlockIds, Set<Long> processedBlockIds) {
+    Iterator<Long> it = exceptedBlockIds.iterator();
+    int expectedCount = 0;
+    int actualCount = 0;
+    while (it.hasNext()) {
+      expectedCount++;
+      if (processedBlockIds.contains(it.next())) {
+        actualCount++;
+      }
+    }
+    if (expectedCount != actualCount) {
+      throw new RssException(
+          "Blocks read inconsistent: expected "
+              + expectedCount
+              + " blocks, actual "
+              + actualCount
+              + " blocks");
+    }
+  }
+
   public static void checkProcessedBlockIds(
       Roaring64NavigableMap blockIdBitmap, Roaring64NavigableMap 
processedBlockIds) {
     // processedBlockIds can be a superset of blockIdBitmap,
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
index 596dd8dd6..0554e7d51 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
@@ -21,6 +21,8 @@ import java.io.File;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Stream;
 
 import com.google.common.collect.Lists;
@@ -265,7 +267,7 @@ public class ShuffleServerFaultToleranceTest extends 
ShuffleReadWriteBase {
     request.setShuffleServerInfoList(shuffleServerInfoList);
     request.setHadoopConf(conf);
     request.setExpectBlockIds(expectBlockIds);
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     request.setProcessBlockIds(processBlockIds);
     request.setDistributionType(ShuffleDataDistributionType.NORMAL);
     Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0);
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHadoopTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHadoopTest.java
index 065dc0af2..c0fa5a222 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHadoopTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHadoopTest.java
@@ -20,6 +20,8 @@ package org.apache.uniffle.test;
 import java.io.File;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Stream;
 
 import com.google.common.collect.Lists;
@@ -169,7 +171,7 @@ public class ShuffleServerWithMemLocalHadoopTest extends 
ShuffleReadWriteBase {
     RssSendShuffleDataResponse response = 
shuffleServerClient.sendShuffleData(rssdr);
     assertSame(StatusCode.SUCCESS, response.getStatusCode());
 
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     Roaring64NavigableMap exceptTaskIds = Roaring64NavigableMap.bitmapOf(0);
     // read the 1-th segment from memory
     MemoryClientReadHandler memoryClientReadHandler =
@@ -220,9 +222,9 @@ public class ShuffleServerWithMemLocalHadoopTest extends 
ShuffleReadWriteBase {
     expectedData.put(blocks.get(2).getBlockId(), 
ByteBufUtils.readBytes(blocks.get(2).getData()));
     ShuffleDataResult sdr = composedClientReadHandler.readShuffleData();
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks.get(0).getBlockId());
-    processBlockIds.addLong(blocks.get(1).getBlockId());
-    processBlockIds.addLong(blocks.get(2).getBlockId());
+    processBlockIds.add(blocks.get(0).getBlockId());
+    processBlockIds.add(blocks.get(1).getBlockId());
+    processBlockIds.add(blocks.get(2).getBlockId());
     sdr.getBufferSegments()
         .forEach(bs -> composedClientReadHandler.updateConsumedBlockInfo(bs, 
checkSkippedMetrics));
 
@@ -245,8 +247,8 @@ public class ShuffleServerWithMemLocalHadoopTest extends 
ShuffleReadWriteBase {
     expectedData.put(blocks2.get(0).getBlockId(), 
ByteBufUtils.readBytes(blocks2.get(0).getData()));
     expectedData.put(blocks2.get(1).getBlockId(), 
ByteBufUtils.readBytes(blocks2.get(1).getData()));
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks2.get(0).getBlockId());
-    processBlockIds.addLong(blocks2.get(1).getBlockId());
+    processBlockIds.add(blocks2.get(0).getBlockId());
+    processBlockIds.add(blocks2.get(1).getBlockId());
     sdr.getBufferSegments()
         .forEach(bs -> composedClientReadHandler.updateConsumedBlockInfo(bs, 
checkSkippedMetrics));
 
@@ -255,7 +257,7 @@ public class ShuffleServerWithMemLocalHadoopTest extends 
ShuffleReadWriteBase {
     expectedData.clear();
     expectedData.put(blocks2.get(2).getBlockId(), 
ByteBufUtils.readBytes(blocks2.get(2).getData()));
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks2.get(2).getBlockId());
+    processBlockIds.add(blocks2.get(2).getBlockId());
     sdr.getBufferSegments()
         .forEach(bs -> composedClientReadHandler.updateConsumedBlockInfo(bs, 
checkSkippedMetrics));
 
@@ -277,8 +279,8 @@ public class ShuffleServerWithMemLocalHadoopTest extends 
ShuffleReadWriteBase {
     expectedData.put(blocks3.get(0).getBlockId(), 
ByteBufUtils.readBytes(blocks3.get(0).getData()));
     expectedData.put(blocks3.get(1).getBlockId(), 
ByteBufUtils.readBytes(blocks3.get(1).getData()));
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks3.get(0).getBlockId());
-    processBlockIds.addLong(blocks3.get(1).getBlockId());
+    processBlockIds.add(blocks3.get(0).getBlockId());
+    processBlockIds.add(blocks3.get(1).getBlockId());
     sdr.getBufferSegments()
         .forEach(bs -> composedClientReadHandler.updateConsumedBlockInfo(bs, 
checkSkippedMetrics));
 
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
index e86fe1c10..c31d8ab33 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
@@ -20,6 +20,8 @@ package org.apache.uniffle.test;
 import java.io.File;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Stream;
 
 import com.google.common.collect.Lists;
@@ -193,7 +195,7 @@ public class ShuffleServerWithMemoryTest extends 
ShuffleReadWriteBase {
         new MemoryClientReadHandler(
             testAppId, shuffleId, partitionId, 50, shuffleServerClient, 
exceptTaskIds);
 
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     LocalFileClientReadHandler localFileQuorumClientReadHandler =
         new LocalFileClientReadHandler(
             testAppId,
@@ -226,8 +228,8 @@ public class ShuffleServerWithMemoryTest extends 
ShuffleReadWriteBase {
     validateResult(expectedData, sdr);
 
     // send data to shuffle server, flush should happen
-    processBlockIds.addLong(blocks.get(0).getBlockId());
-    processBlockIds.addLong(blocks.get(1).getBlockId());
+    processBlockIds.add(blocks.get(0).getBlockId());
+    processBlockIds.add(blocks.get(1).getBlockId());
 
     List<ShuffleBlockInfo> blocks2 =
         createShuffleBlockList(shuffleId, partitionId, 0, 3, 50, 
expectBlockIds, dataMap, mockSSI);
@@ -260,20 +262,20 @@ public class ShuffleServerWithMemoryTest extends 
ShuffleReadWriteBase {
     expectedData.put(blocks.get(2).getBlockId(), 
ByteBufUtils.readBytes(blocks.get(2).getData()));
     expectedData.put(blocks2.get(0).getBlockId(), 
ByteBufUtils.readBytes(blocks2.get(0).getData()));
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks.get(2).getBlockId());
-    processBlockIds.addLong(blocks2.get(0).getBlockId());
+    processBlockIds.add(blocks.get(2).getBlockId());
+    processBlockIds.add(blocks2.get(0).getBlockId());
 
     sdr = composedClientReadHandler.readShuffleData();
     expectedData.clear();
     expectedData.put(blocks2.get(1).getBlockId(), 
ByteBufUtils.readBytes(blocks2.get(1).getData()));
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks2.get(1).getBlockId());
+    processBlockIds.add(blocks2.get(1).getBlockId());
 
     sdr = composedClientReadHandler.readShuffleData();
     expectedData.clear();
     expectedData.put(blocks2.get(2).getBlockId(), 
ByteBufUtils.readBytes(blocks2.get(2).getData()));
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks2.get(2).getBlockId());
+    processBlockIds.add(blocks2.get(2).getBlockId());
 
     sdr = composedClientReadHandler.readShuffleData();
     assertNull(sdr);
@@ -398,7 +400,7 @@ public class ShuffleServerWithMemoryTest extends 
ShuffleReadWriteBase {
     assertSame(StatusCode.SUCCESS, response.getStatusCode());
 
     // read the 1-th segment from memory
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     Roaring64NavigableMap exceptTaskIds = Roaring64NavigableMap.bitmapOf(0);
     MemoryClientReadHandler memoryClientReadHandler =
         new MemoryClientReadHandler(
@@ -434,9 +436,9 @@ public class ShuffleServerWithMemoryTest extends 
ShuffleReadWriteBase {
     expectedData.put(blocks.get(2).getBlockId(), 
ByteBufUtils.readBytes(blocks.get(2).getData()));
     ShuffleDataResult sdr = composedClientReadHandler.readShuffleData();
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks.get(0).getBlockId());
-    processBlockIds.addLong(blocks.get(1).getBlockId());
-    processBlockIds.addLong(blocks.get(2).getBlockId());
+    processBlockIds.add(blocks.get(0).getBlockId());
+    processBlockIds.add(blocks.get(1).getBlockId());
+    processBlockIds.add(blocks.get(2).getBlockId());
 
     // send data to shuffle server, and wait until flush finish
     List<ShuffleBlockInfo> blocks2 =
@@ -470,15 +472,15 @@ public class ShuffleServerWithMemoryTest extends 
ShuffleReadWriteBase {
     expectedData.put(blocks2.get(0).getBlockId(), 
ByteBufUtils.readBytes(blocks2.get(0).getData()));
     expectedData.put(blocks2.get(1).getBlockId(), 
ByteBufUtils.readBytes(blocks2.get(1).getData()));
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks2.get(0).getBlockId());
-    processBlockIds.addLong(blocks2.get(1).getBlockId());
+    processBlockIds.add(blocks2.get(0).getBlockId());
+    processBlockIds.add(blocks2.get(1).getBlockId());
 
     // read the 3-th segment from localFile
     sdr = composedClientReadHandler.readShuffleData();
     expectedData.clear();
     expectedData.put(blocks2.get(2).getBlockId(), 
ByteBufUtils.readBytes(blocks2.get(2).getData()));
     validateResult(expectedData, sdr);
-    processBlockIds.addLong(blocks2.get(2).getBlockId());
+    processBlockIds.add(blocks2.get(2).getBlockId());
 
     // all segments are processed
     sdr = composedClientReadHandler.readShuffleData();
diff --git 
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java 
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
index ad3faf210..5d4bffb2a 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
@@ -30,6 +30,7 @@ import java.util.EnumSet;
 import java.util.List;
 import java.util.Random;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
@@ -740,7 +741,7 @@ public class ShuffleFlushManagerTest extends HadoopTestBase 
{
       int partitionNumPerRange,
       String basePath) {
     Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf();
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     Set<Long> remainIds = Sets.newHashSet();
     for (ShufflePartitionedBlock spb : blocks) {
       expectBlockIds.addLong(spb.getBlockId());
diff --git 
a/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java 
b/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java
index c6484b175..b9cdf1e24 100644
--- a/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java
@@ -25,6 +25,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Random;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -1134,7 +1135,7 @@ public class ShuffleTaskManagerTest extends 
HadoopTestBase {
       List<ShufflePartitionedBlock> blocks,
       String basePath) {
     Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf();
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     Set<Long> remainIds = Sets.newHashSet();
     for (ShufflePartitionedBlock spb : blocks) {
       expectBlockIds.addLong(spb.getBlockId());
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index 84550ab1a..356cfc1dd 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -126,7 +126,7 @@ public class ShuffleHandlerFactory {
     Roaring64NavigableMap expectTaskIds = null;
     if (request.isExpectedTaskIdsBitmapFilterEnable()) {
       Roaring64NavigableMap realExceptBlockIds = 
RssUtils.cloneBitMap(request.getExpectBlockIds());
-      realExceptBlockIds.xor(request.getProcessBlockIds());
+      realExceptBlockIds.xor(RssUtils.toBitmap(request.getProcessBlockIds()));
       expectTaskIds = RssUtils.generateTaskIdBitMap(realExceptBlockIds, 
request.getIdHelper());
     }
     ClientReadHandler memoryClientReadHandler =
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
index 700e99f93..58288d53c 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
@@ -17,8 +17,10 @@
 
 package org.apache.uniffle.storage.handler.impl;
 
+import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 
 import com.google.common.collect.Lists;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -38,7 +40,7 @@ public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHan
   protected int segmentIndex = 0;
 
   protected Roaring64NavigableMap expectBlockIds;
-  protected Roaring64NavigableMap processBlockIds;
+  protected Set<Long> processBlockIds;
 
   protected ShuffleDataDistributionType distributionType;
   protected Roaring64NavigableMap expectTaskIds;
@@ -49,7 +51,7 @@ public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHan
       int partitionId,
       int readBufferSize,
       Roaring64NavigableMap expectBlockIds,
-      Roaring64NavigableMap processBlockIds,
+      Set<Long> processBlockIds,
       ShuffleDataDistributionType distributionType,
       Roaring64NavigableMap expectTaskIds,
       Optional<PrefetchOption> prefetchOption) {
@@ -90,13 +92,13 @@ public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHan
     ShuffleDataResult result = null;
     while (segmentIndex < shuffleDataSegments.size()) {
       ShuffleDataSegment segment = shuffleDataSegments.get(segmentIndex);
-      Roaring64NavigableMap blocksOfSegment = Roaring64NavigableMap.bitmapOf();
-      segment.getBufferSegments().forEach(block -> 
blocksOfSegment.addLong(block.getBlockId()));
+      Set<Long> blocksOfSegment = new HashSet<>();
+      segment.getBufferSegments().forEach(block -> 
blocksOfSegment.add(block.getBlockId()));
       // skip unexpected blockIds
-      blocksOfSegment.and(expectBlockIds);
+      blocksOfSegment.removeIf(blockId -> !expectBlockIds.contains(blockId));
       if (!blocksOfSegment.isEmpty()) {
         // skip processed blockIds
-        blocksOfSegment.andNot(processBlockIds);
+        blocksOfSegment.removeAll(processBlockIds);
         if (!blocksOfSegment.isEmpty()) {
           result = readShuffleData(segment);
           segmentIndex++;
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java
index a316a3028..161740786 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java
@@ -21,6 +21,7 @@ import java.io.FileNotFoundException;
 import java.util.Collections;
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 import com.google.common.collect.Lists;
@@ -49,7 +50,7 @@ public class HadoopClientReadHandler extends 
AbstractClientReadHandler {
   protected final int readBufferSize;
   private final String shuffleServerId;
   protected Roaring64NavigableMap expectBlockIds;
-  protected Roaring64NavigableMap processBlockIds;
+  protected Set<Long> processBlockIds;
   protected final String storageBasePath;
   protected final Configuration hadoopConf;
   protected final List<HadoopShuffleReadHandler> readHandlers = 
Lists.newArrayList();
@@ -69,7 +70,7 @@ public class HadoopClientReadHandler extends 
AbstractClientReadHandler {
       int partitionNum,
       int readBufferSize,
       Roaring64NavigableMap expectBlockIds,
-      Roaring64NavigableMap processBlockIds,
+      Set<Long> processBlockIds,
       String storageBasePath,
       Configuration hadoopConf,
       ShuffleDataDistributionType distributionType,
@@ -107,7 +108,7 @@ public class HadoopClientReadHandler extends 
AbstractClientReadHandler {
       int partitionNum,
       int readBufferSize,
       Roaring64NavigableMap expectBlockIds,
-      Roaring64NavigableMap processBlockIds,
+      Set<Long> processBlockIds,
       String storageBasePath,
       Configuration hadoopConf) {
     this(
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
index f3ecc16cf..f2153d2be 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
@@ -21,6 +21,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
@@ -54,7 +55,7 @@ public class HadoopShuffleReadHandler extends 
DataSkippableReadHandler {
       String filePrefix,
       int readBufferSize,
       Roaring64NavigableMap expectBlockIds,
-      Roaring64NavigableMap processBlockIds,
+      Set<Long> processBlockIds,
       Configuration conf,
       ShuffleDataDistributionType distributionType,
       Roaring64NavigableMap expectTaskIds,
@@ -87,7 +88,7 @@ public class HadoopShuffleReadHandler extends 
DataSkippableReadHandler {
       String filePrefix,
       int readBufferSize,
       Roaring64NavigableMap expectBlockIds,
-      Roaring64NavigableMap processBlockIds,
+      Set<Long> processBlockIds,
       Configuration conf)
       throws Exception {
     this(
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
index d06675808..cfdf5fdb7 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.storage.handler.impl;
 
 import java.util.Optional;
+import java.util.Set;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -55,7 +56,7 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
       int partitionNum,
       int readBufferSize,
       Roaring64NavigableMap expectBlockIds,
-      Roaring64NavigableMap processBlockIds,
+      Set<Long> processBlockIds,
       ShuffleServerClient shuffleServerClient,
       ShuffleDataDistributionType distributionType,
       Roaring64NavigableMap expectTaskIds,
@@ -91,7 +92,7 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
       int partitionNum,
       int readBufferSize,
       Roaring64NavigableMap expectBlockIds,
-      Roaring64NavigableMap processBlockIds,
+      Set<Long> processBlockIds,
       ShuffleServerClient shuffleServerClient) {
     this(
         appId,
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MultiReplicaClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MultiReplicaClientReadHandler.java
index 262052927..025ea8c82 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MultiReplicaClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MultiReplicaClientReadHandler.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.storage.handler.impl;
 
 import java.util.List;
+import java.util.Set;
 
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
@@ -37,7 +38,7 @@ public class MultiReplicaClientReadHandler extends 
AbstractClientReadHandler {
   private final List<ClientReadHandler> handlers;
   private final List<ShuffleServerInfo> shuffleServerInfos;
   private final Roaring64NavigableMap blockIdBitmap;
-  private final Roaring64NavigableMap processedBlockIds;
+  private final Set<Long> processedBlockIds;
 
   private int readHandlerIndex;
 
@@ -45,7 +46,7 @@ public class MultiReplicaClientReadHandler extends 
AbstractClientReadHandler {
       List<ClientReadHandler> handlers,
       List<ShuffleServerInfo> shuffleServerInfos,
       Roaring64NavigableMap blockIdBitmap,
-      Roaring64NavigableMap processedBlockIds) {
+      Set<Long> processedBlockIds) {
     this.handlers = handlers;
     this.blockIdBitmap = blockIdBitmap;
     this.processedBlockIds = processedBlockIds;
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index 145f93cab..a0a6b6c6e 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -19,6 +19,7 @@ package org.apache.uniffle.storage.request;
 
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 
 import org.apache.hadoop.conf.Configuration;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -50,7 +51,7 @@ public class CreateShuffleReadHandlerRequest {
   private Configuration hadoopConf;
   private List<ShuffleServerInfo> shuffleServerInfoList;
   private Roaring64NavigableMap expectBlockIds;
-  private Roaring64NavigableMap processBlockIds;
+  private Set<Long> processBlockIds;
   private ShuffleDataDistributionType distributionType;
   private Roaring64NavigableMap expectTaskIds;
   private boolean expectedTaskIdsBitmapFilterEnable;
@@ -184,11 +185,11 @@ public class CreateShuffleReadHandlerRequest {
     return expectBlockIds;
   }
 
-  public void setProcessBlockIds(Roaring64NavigableMap processBlockIds) {
+  public void setProcessBlockIds(Set<Long> processBlockIds) {
     this.processBlockIds = processBlockIds;
   }
 
-  public Roaring64NavigableMap getProcessBlockIds() {
+  public Set<Long> getProcessBlockIds() {
     return processBlockIds;
   }
 
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandlerTest.java
 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandlerTest.java
index fa684b840..e4e43f186 100644
--- 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandlerTest.java
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandlerTest.java
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer;
 import java.util.Map;
 import java.util.Random;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
@@ -75,7 +76,7 @@ public class HadoopClientReadHandlerTest extends 
HadoopTestBase {
     indexWriter.writeData(ByteBuffer.allocate(4).putInt(999).array());
     indexWriter.close();
 
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
 
     HadoopShuffleReadHandler indexReader =
         new HadoopShuffleReadHandler(
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopHandlerTest.java
 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopHandlerTest.java
index 94c5f3c33..82f80bbdc 100644
--- 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopHandlerTest.java
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopHandlerTest.java
@@ -22,6 +22,7 @@ import java.util.LinkedList;
 import java.util.List;
 import java.util.Random;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
@@ -98,9 +99,9 @@ public class HadoopHandlerTest extends HadoopTestBase {
       List<Long> expectedBlockId)
       throws IllegalStateException {
     Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf();
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     for (long blockId : expectedBlockId) {
-      expectBlockIds.addLong(blockId);
+      expectBlockIds.add(blockId);
     }
     // read directly and compare
     HadoopClientReadHandler readHandler =
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
index 1b2b7e93b..0d3b5b80e 100644
--- 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
@@ -22,6 +22,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Random;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReentrantLock;
 
@@ -66,7 +67,7 @@ public class HadoopShuffleReadHandlerTest extends 
HadoopTestBase {
         HadoopShuffleHandlerTestBase.calcExpectedSegmentNum(
             expectTotalBlockNum, blockSize, readBufferSize);
     Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf();
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     expectedData.forEach((id, block) -> expectBlockIds.addLong(id));
     String fileNamePrefix =
         ShuffleStorageUtils.getFullShuffleDataFolder(
@@ -130,7 +131,7 @@ public class HadoopShuffleReadHandlerTest extends 
HadoopTestBase {
         HadoopShuffleHandlerTestBase.calcExpectedSegmentNum(
             expectTotalBlockNum, blockSize, readBufferSize);
     Roaring64NavigableMap expectBlockIds = Roaring64NavigableMap.bitmapOf();
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     expectedData.forEach((id, block) -> expectBlockIds.addLong(id));
     String fileNamePrefix =
         ShuffleStorageUtils.getFullShuffleDataFolder(
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
index 5b2db2daf..854c4757d 100644
--- 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
@@ -21,6 +21,8 @@ import java.nio.ByteBuffer;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Collectors;
 
 import com.google.common.collect.Maps;
@@ -122,7 +124,7 @@ public class LocalFileServerReadHandlerTest {
         .when(mockShuffleServerClient)
         .getShuffleData(Mockito.argThat(segment2Match));
 
-    Roaring64NavigableMap processBlockIds = Roaring64NavigableMap.bitmapOf();
+    Set<Long> processBlockIds = ConcurrentHashMap.newKeySet();
     LocalFileClientReadHandler handler =
         new LocalFileClientReadHandler(
             appId,


Reply via email to