This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e45f2de [Improvement][AQE] Support getting memory data skip by 
upstream task ids (#358)
0e45f2de is described below

commit 0e45f2dedf22c02c8fb2cbacf68f7b75bb38d8ac
Author: Junfan Zhang <[email protected]>
AuthorDate: Tue Nov 29 14:13:00 2022 +0800

    [Improvement][AQE] Support getting memory data skip by upstream task ids 
(#358)
    
    ### What changes were proposed in this pull request?
    
    Support getting memory data skip by upstream task ids
    
    ### Why are the changes needed?
    
    In current codebase, when the shuffle-server memory is large and
    job is optimized by AQE skew rule, the multiple readers of the same
    partition will get the shuffle data from the same shuffle-server.
    
    To avoid reading unused localfile/HDFS data, the PR of #137 has
    introduce the LOCAL_ORDER mechanism to filter the most of data.
    
    But for the storage of MEMORY, it still suffer from this. So this PR is to 
avoid
    reading unused data for one reader, by expectedTaskIds bitmap to
    filter.
    
    And this optimization is only enabled when AQE skew is applied.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    1. UTs
    
    ### Benchmark
    
    #### Table
    Table1: 100g, dtypes: Array[(String, String)] = Array((v1,StringType), 
(k1,IntegerType)).
    And all columns of k1 have the same value (value = 10)
    
    Table2: 10 records, dtypes: Array[(String, String)] = 
Array((k2,IntegerType), (v2,StringType)).
    And it has the only one record of k2=10
    
    #### Env
    Spark Resource Profile: 10 executors(1core2g)
    Shuffle-server Environment: 10 shuffle servers, 10g for buffer read and 
write.
    Spark Shuffle Client Config: storage type: MEMORY_LOCALFILE with LOCAL_ORDER
    SQL: spark.sql("select * from Table1,Table2 where k1 = 
k2").write.mode("overwrite").parquet("xxxxxx")
    
    #### Result
    __ESS__: cost `3min`
    __Uniffle without patch__: cost `11.6min` (2.1 + 9.5)
    __Uniffle with patch__: cost `3.5min` (2.1 + 1.4)
    
    Co-authored-by: xianjingfeng <[email protected]>
---
 .../spark/shuffle/reader/RssShuffleReader.java     |   6 +-
 .../client/factory/ShuffleClientFactory.java       |  23 +++-
 .../uniffle/client/impl/ShuffleReadClientImpl.java |   8 +-
 .../request/CreateShuffleReadClientRequest.java    |   9 +-
 .../client/impl/grpc/ShuffleServerGrpcClient.java  |  13 +++
 .../request/RssGetInMemoryShuffleDataRequest.java  |  11 +-
 proto/src/main/proto/Rss.proto                     |   1 +
 .../uniffle/server/ShuffleServerGrpcService.java   |  22 +++-
 .../apache/uniffle/server/ShuffleTaskManager.java  |   5 +-
 .../uniffle/server/buffer/ShuffleBuffer.java       |  35 ++++--
 .../server/buffer/ShuffleBufferManager.java        |  16 ++-
 .../server/buffer/ShuffleBufferManagerTest.java    |  44 ++++++++
 .../uniffle/server/buffer/ShuffleBufferTest.java   | 124 +++++++++++++++++++++
 .../storage/factory/ShuffleHandlerFactory.java     |   5 +-
 .../handler/impl/MemoryClientReadHandler.java      |  35 +++++-
 .../request/CreateShuffleReadHandlerRequest.java   |   9 ++
 16 files changed, 339 insertions(+), 27 deletions(-)

diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 353255f8..89e17664 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -78,6 +78,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
   private ShuffleReadMetrics readMetrics;
   private RssConf rssConf;
   private ShuffleDataDistributionType dataDistributionType;
+  private boolean expectedTaskIdsBitmapFilterEnable;
 
   public RssShuffleReader(
       int startPartition,
@@ -119,6 +120,9 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     this.partitionToShuffleServers = rssShuffleHandle.getPartitionToServers();
     this.rssConf = rssConf;
     this.dataDistributionType = dataDistributionType;
+    // This mechanism of expectedTaskIdsBitmap filter is to filter out the 
most of data.
+    // especially for AQE skew optimization
+    this.expectedTaskIdsBitmapFilterEnable = !(mapStartIndex == 0 && 
mapEndIndex == Integer.MAX_VALUE);
   }
 
   @Override
@@ -206,7 +210,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
         CreateShuffleReadClientRequest request = new 
CreateShuffleReadClientRequest(
             appId, shuffleId, partition, storageType, basePath, 
indexReadLimit, readBufferSize,
             1, partitionNum, partitionToExpectBlocks.get(partition), 
taskIdBitmap, shuffleServerInfoList,
-            hadoopConf, dataDistributionType);
+            hadoopConf, dataDistributionType, 
expectedTaskIdsBitmapFilterEnable);
         ShuffleReadClient shuffleReadClient = 
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
         RssShuffleDataIterator iterator = new RssShuffleDataIterator<K, C>(
             shuffleDependency.serializer(), shuffleReadClient,
diff --git 
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
 
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index 7cbbb37f..4229fc77 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -71,10 +71,23 @@ public class ShuffleClientFactory {
   }
 
   public ShuffleReadClient 
createShuffleReadClient(CreateShuffleReadClientRequest request) {
-    return new ShuffleReadClientImpl(request.getStorageType(), 
request.getAppId(), request.getShuffleId(),
-        request.getPartitionId(), request.getIndexReadLimit(), 
request.getPartitionNumPerRange(),
-        request.getPartitionNum(), request.getReadBufferSize(), 
request.getBasePath(),
-        request.getBlockIdBitmap(), request.getTaskIdBitmap(), 
request.getShuffleServerInfoList(),
-        request.getHadoopConf(), request.getIdHelper(), 
request.getShuffleDataDistributionType());
+    return new ShuffleReadClientImpl(
+        request.getStorageType(),
+        request.getAppId(),
+        request.getShuffleId(),
+        request.getPartitionId(),
+        request.getIndexReadLimit(),
+        request.getPartitionNumPerRange(),
+        request.getPartitionNum(),
+        request.getReadBufferSize(),
+        request.getBasePath(),
+        request.getBlockIdBitmap(),
+        request.getTaskIdBitmap(),
+        request.getShuffleServerInfoList(),
+        request.getHadoopConf(),
+        request.getIdHelper(),
+        request.getShuffleDataDistributionType(),
+        request.isExpectedTaskIdsBitmapFilterEnable()
+    );
   }
 }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 8a57cbb0..8b6b0d06 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -76,7 +76,8 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
       List<ShuffleServerInfo> shuffleServerInfoList,
       Configuration hadoopConf,
       IdHelper idHelper,
-      ShuffleDataDistributionType dataDistributionType) {
+      ShuffleDataDistributionType dataDistributionType,
+      boolean expectedTaskIdsBitmapFilterEnable) {
     this.shuffleId = shuffleId;
     this.partitionId = partitionId;
     this.blockIdBitmap = blockIdBitmap;
@@ -99,6 +100,9 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
     request.setProcessBlockIds(processedBlockIds);
     request.setDistributionType(dataDistributionType);
     request.setExpectTaskIds(taskIdBitmap);
+    if (expectedTaskIdsBitmapFilterEnable) {
+      request.useExpectedTaskIdsBitmapFilter();
+    }
 
     List<Long> removeBlockIds = Lists.newArrayList();
     blockIdBitmap.forEach(bid -> {
@@ -135,7 +139,7 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
     this(storageType, appId, shuffleId, partitionId, indexReadLimit,
         partitionNumPerRange, partitionNum, readBufferSize, storageBasePath,
         blockIdBitmap, taskIdBitmap, shuffleServerInfoList, hadoopConf,
-        idHelper, ShuffleDataDistributionType.NORMAL);
+        idHelper, ShuffleDataDistributionType.NORMAL, false);
   }
 
   @Override
diff --git 
a/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
 
b/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
index 2cfd021d..db050304 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
@@ -44,6 +44,7 @@ public class CreateShuffleReadClientRequest {
   private Configuration hadoopConf;
   private IdHelper idHelper;
   private ShuffleDataDistributionType shuffleDataDistributionType = 
ShuffleDataDistributionType.NORMAL;
+  private boolean expectedTaskIdsBitmapFilterEnable = false;
 
   public CreateShuffleReadClientRequest(
       String appId,
@@ -59,11 +60,13 @@ public class CreateShuffleReadClientRequest {
       Roaring64NavigableMap taskIdBitmap,
       List<ShuffleServerInfo> shuffleServerInfoList,
       Configuration hadoopConf,
-      ShuffleDataDistributionType dataDistributionType) {
+      ShuffleDataDistributionType dataDistributionType,
+      boolean expectedTaskIdsBitmapFilterEnable) {
     this(appId, shuffleId, partitionId, storageType, basePath, indexReadLimit, 
readBufferSize,
         partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap, 
shuffleServerInfoList,
         hadoopConf, new DefaultIdHelper());
     this.shuffleDataDistributionType = dataDistributionType;
+    this.expectedTaskIdsBitmapFilterEnable = expectedTaskIdsBitmapFilterEnable;
   }
 
   public CreateShuffleReadClientRequest(
@@ -175,4 +178,8 @@ public class CreateShuffleReadClientRequest {
   public ShuffleDataDistributionType getShuffleDataDistributionType() {
     return shuffleDataDistributionType;
   }
+
+  public boolean isExpectedTaskIdsBitmapFilterEnable() {
+    return expectedTaskIdsBitmapFilterEnable;
+  }
 }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 543ce1f6..74fe2767 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import com.google.protobuf.ByteString;
+import com.google.protobuf.UnsafeByteOperations;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -61,6 +62,7 @@ import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.exception.NotRetryException;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.RetryUtils;
+import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.proto.RssProtos;
 import org.apache.uniffle.proto.RssProtos.AppHeartBeatRequest;
 import org.apache.uniffle.proto.RssProtos.AppHeartBeatResponse;
@@ -599,6 +601,16 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
   public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData(
       RssGetInMemoryShuffleDataRequest request) {
     long start = System.currentTimeMillis();
+    ByteString serializedTaskIdsBytes = ByteString.EMPTY;
+    try {
+      if (request.getExpectedTaskIds() != null) {
+        serializedTaskIdsBytes =
+            
UnsafeByteOperations.unsafeWrap(RssUtils.serializeBitMap(request.getExpectedTaskIds()));
+      }
+    } catch (Exception e) {
+      throw new RssException("Errors on serializing task ids bitmap.", e);
+    }
+
     GetMemoryShuffleDataRequest rpcRequest = GetMemoryShuffleDataRequest
         .newBuilder()
         .setAppId(request.getAppId())
@@ -606,6 +618,7 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
         .setPartitionId(request.getPartitionId())
         .setLastBlockId(request.getLastBlockId())
         .setReadBufferSize(request.getReadBufferSize())
+        .setSerializedExpectedTaskIdsBitmap(serializedTaskIdsBytes)
         .setTimestamp(start)
         .build();
 
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java
index 25f067db..87c3d2f1 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java
@@ -17,20 +17,25 @@
 
 package org.apache.uniffle.client.request;
 
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
 public class RssGetInMemoryShuffleDataRequest {
   private final String appId;
   private final int shuffleId;
   private final int partitionId;
   private final long lastBlockId;
   private final int readBufferSize;
+  private final Roaring64NavigableMap expectedTaskIds;
 
   public RssGetInMemoryShuffleDataRequest(
-      String appId, int shuffleId, int partitionId, long lastBlockId, int 
readBufferSize) {
+      String appId, int shuffleId, int partitionId, long lastBlockId, int 
readBufferSize,
+      Roaring64NavigableMap expectedTaskIds) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionId = partitionId;
     this.lastBlockId = lastBlockId;
     this.readBufferSize = readBufferSize;
+    this.expectedTaskIds = expectedTaskIds;
   }
 
   public String getAppId() {
@@ -52,4 +57,8 @@ public class RssGetInMemoryShuffleDataRequest {
   public int getReadBufferSize() {
     return readBufferSize;
   }
+
+  public Roaring64NavigableMap getExpectedTaskIds() {
+    return expectedTaskIds;
+  }
 }
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 5789a952..db3c6a07 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -92,6 +92,7 @@ message GetMemoryShuffleDataRequest {
   int64 lastBlockId = 4;
   int32 readBufferSize = 5;
   int64 timestamp = 6;
+  optional bytes serializedExpectedTaskIdsBitmap = 7;
 }
 
 message GetMemoryShuffleDataResponse {
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index 055e5d3e..3de560ce 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -30,6 +30,7 @@ import com.google.protobuf.UnsafeByteOperations;
 import io.grpc.Context;
 import io.grpc.Status;
 import io.grpc.stub.StreamObserver;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -43,6 +44,7 @@ import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.ShufflePartitionedData;
 import org.apache.uniffle.common.config.RssBaseConf;
 import org.apache.uniffle.common.exception.FileNotFoundException;
+import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.proto.RssProtos;
 import org.apache.uniffle.proto.RssProtos.AppHeartBeatRequest;
 import org.apache.uniffle.proto.RssProtos.AppHeartBeatResponse;
@@ -635,6 +637,7 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
     long blockId = request.getLastBlockId();
     int readBufferSize = request.getReadBufferSize();
     long timestamp = request.getTimestamp();
+
     if (timestamp > 0) {
       long transportTime = System.currentTimeMillis() - timestamp;
       if (transportTime > 0) {
@@ -652,8 +655,23 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
     // todo: if can get the exact memory size?
     if 
(shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(readBufferSize))
 {
       try {
-        ShuffleDataResult shuffleDataResult = 
shuffleServer.getShuffleTaskManager()
-            .getInMemoryShuffleData(appId, shuffleId, partitionId, blockId, 
readBufferSize);
+        Roaring64NavigableMap expectedTaskIds = null;
+        if (request.getSerializedExpectedTaskIdsBitmap() != null
+            && !request.getSerializedExpectedTaskIdsBitmap().isEmpty()) {
+          expectedTaskIds = RssUtils.deserializeBitMap(
+              request.getSerializedExpectedTaskIdsBitmap().toByteArray()
+          );
+        }
+        ShuffleDataResult shuffleDataResult = shuffleServer
+            .getShuffleTaskManager()
+            .getInMemoryShuffleData(
+                appId,
+                shuffleId,
+                partitionId,
+                blockId,
+                readBufferSize,
+                expectedTaskIds
+            );
         byte[] data = new byte[]{};
         List<BufferSegment> bufferSegments = Lists.newArrayList();
         if (shuffleDataResult != null) {
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index 6fe1e412..9b3c64e2 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -363,9 +363,10 @@ public class ShuffleTaskManager {
   }
 
   public ShuffleDataResult getInMemoryShuffleData(
-      String appId, Integer shuffleId, Integer partitionId, long blockId, int 
readBufferSize) {
+      String appId, Integer shuffleId, Integer partitionId, long blockId, int 
readBufferSize,
+      Roaring64NavigableMap expectedTaskIds) {
     return shuffleBufferManager.getShuffleData(appId,
-        shuffleId, partitionId, blockId, readBufferSize);
+        shuffleId, partitionId, blockId, readBufferSize, expectedTaskIds);
   }
 
   public ShuffleDataResult getShuffleData(
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
index ad0b31af..37e497a8 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
@@ -25,6 +25,7 @@ import java.util.function.Supplier;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -142,16 +143,21 @@ public class ShuffleBuffer {
     return inFlushBlockMap;
   }
 
+  public synchronized ShuffleDataResult getShuffleData(
+      long lastBlockId, int readBufferSize) {
+    return getShuffleData(lastBlockId, readBufferSize, null);
+  }
+
   // 1. generate buffer segments and other info: if blockId exist, start with 
which eventId
   // 2. according to info from step 1, generate data
   // todo: if block was flushed, it's possible to get duplicated data
   public synchronized ShuffleDataResult getShuffleData(
-      long lastBlockId, int readBufferSize) {
+      long lastBlockId, int readBufferSize, Roaring64NavigableMap 
expectedTaskIds) {
     try {
       List<BufferSegment> bufferSegments = Lists.newArrayList();
       List<ShufflePartitionedBlock> readBlocks = Lists.newArrayList();
       updateBufferSegmentsAndResultBlocks(
-          lastBlockId, readBufferSize, bufferSegments, readBlocks);
+          lastBlockId, readBufferSize, bufferSegments, readBlocks, 
expectedTaskIds);
       if (!bufferSegments.isEmpty()) {
         int length = calculateDataLength(bufferSegments);
         byte[] data = new byte[length];
@@ -172,7 +178,8 @@ public class ShuffleBuffer {
       long lastBlockId,
       long readBufferSize,
       List<BufferSegment> bufferSegments,
-      List<ShufflePartitionedBlock> resultBlocks) {
+      List<ShufflePartitionedBlock> resultBlocks,
+      Roaring64NavigableMap expectedTaskIds) {
     long nextBlockId = lastBlockId;
     List<Long> sortedEventId = sortFlushingEventId();
     int offset = 0;
@@ -186,11 +193,11 @@ public class ShuffleBuffer {
         // update bufferSegments with different strategy according to 
lastBlockId
         if (nextBlockId == Constants.INVALID_BLOCK_ID) {
           updateSegmentsWithoutBlockId(offset, inFlushBlockMap.get(eventId), 
readBufferSize,
-              bufferSegments, resultBlocks);
+              bufferSegments, resultBlocks, expectedTaskIds);
           hasLastBlockId = true;
         } else {
           hasLastBlockId = updateSegmentsWithBlockId(offset, 
inFlushBlockMap.get(eventId),
-              readBufferSize, nextBlockId, bufferSegments, resultBlocks);
+              readBufferSize, nextBlockId, bufferSegments, resultBlocks, 
expectedTaskIds);
           // if last blockId is found, read from begin with next cached blocks
           if (hasLastBlockId) {
             // reset blockId to read from begin in next cached blocks
@@ -208,11 +215,11 @@ public class ShuffleBuffer {
     // try to read from cached blocks which is not in flush queue
     if (blocks.size() > 0 && offset < readBufferSize) {
       if (nextBlockId == Constants.INVALID_BLOCK_ID) {
-        updateSegmentsWithoutBlockId(offset, blocks, readBufferSize, 
bufferSegments, resultBlocks);
+        updateSegmentsWithoutBlockId(offset, blocks, readBufferSize, 
bufferSegments, resultBlocks, expectedTaskIds);
         hasLastBlockId = true;
       } else {
         hasLastBlockId = updateSegmentsWithBlockId(offset, blocks,
-            readBufferSize, nextBlockId, bufferSegments, resultBlocks);
+            readBufferSize, nextBlockId, bufferSegments, resultBlocks, 
expectedTaskIds);
       }
     }
     if ((!inFlushBlockMap.isEmpty() || blocks.size() > 0) && offset == 0 && 
!hasLastBlockId) {
@@ -220,7 +227,7 @@ public class ShuffleBuffer {
       // but there still has data in memory
       // try read again with blockId = Constants.INVALID_BLOCK_ID
       updateBufferSegmentsAndResultBlocks(
-          Constants.INVALID_BLOCK_ID, readBufferSize, bufferSegments, 
resultBlocks);
+          Constants.INVALID_BLOCK_ID, readBufferSize, bufferSegments, 
resultBlocks, expectedTaskIds);
     }
   }
 
@@ -261,10 +268,14 @@ public class ShuffleBuffer {
       List<ShufflePartitionedBlock> cachedBlocks,
       long readBufferSize,
       List<BufferSegment> bufferSegments,
-      List<ShufflePartitionedBlock> readBlocks) {
+      List<ShufflePartitionedBlock> readBlocks,
+      Roaring64NavigableMap expectedTaskIds) {
     int currentOffset = offset;
     // read from first block
     for (ShufflePartitionedBlock block : cachedBlocks) {
+      if (expectedTaskIds != null && 
!expectedTaskIds.contains(block.getTaskAttemptId())) {
+        continue;
+      }
       // add bufferSegment with block
       bufferSegments.add(new BufferSegment(block.getBlockId(), currentOffset, 
block.getLength(),
           block.getUncompressLength(), block.getCrc(), 
block.getTaskAttemptId()));
@@ -284,7 +295,8 @@ public class ShuffleBuffer {
       long readBufferSize,
       long lastBlockId,
       List<BufferSegment> bufferSegments,
-      List<ShufflePartitionedBlock> readBlocks) {
+      List<ShufflePartitionedBlock> readBlocks,
+      Roaring64NavigableMap expectedTaskIds) {
     int currentOffset = offset;
     // find lastBlockId, then read from next block
     boolean foundBlockId = false;
@@ -296,6 +308,9 @@ public class ShuffleBuffer {
         }
         continue;
       }
+      if (expectedTaskIds != null && 
!expectedTaskIds.contains(block.getTaskAttemptId())) {
+        continue;
+      }
       // add bufferSegment with block
       bufferSegments.add(new BufferSegment(block.getBlockId(), currentOffset, 
block.getLength(),
           block.getUncompressLength(), block.getCrc(), 
block.getTaskAttemptId()));
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index f87ae896..c3606a08 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -33,6 +33,7 @@ import com.google.common.collect.Range;
 import com.google.common.collect.RangeMap;
 import com.google.common.collect.Sets;
 import com.google.common.collect.TreeRangeMap;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -151,6 +152,19 @@ public class ShuffleBufferManager {
   public ShuffleDataResult getShuffleData(
       String appId, int shuffleId, int partitionId, long blockId,
       int readBufferSize) {
+    return getShuffleData(
+        appId,
+        shuffleId,
+        partitionId,
+        blockId,
+        readBufferSize,
+        null
+    );
+  }
+
+  public ShuffleDataResult getShuffleData(
+      String appId, int shuffleId, int partitionId, long blockId,
+      int readBufferSize, Roaring64NavigableMap expectedTaskIds) {
     Map.Entry<Range<Integer>, ShuffleBuffer> entry = getShuffleBufferEntry(
         appId, shuffleId, partitionId);
     if (entry == null) {
@@ -161,7 +175,7 @@ public class ShuffleBufferManager {
     if (buffer == null) {
       return null;
     }
-    return buffer.getShuffleData(blockId, readBufferSize);
+    return buffer.getShuffleData(blockId, readBufferSize, expectedTaskIds);
   }
 
   void flushSingleBufferIfNecessary(ShuffleBuffer buffer, String appId,
diff --git 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
index 1bee2d9a..1e6d056d 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
@@ -26,6 +26,7 @@ import com.google.common.collect.RangeMap;
 import com.google.common.io.Files;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
 import org.apache.uniffle.common.ShuffleDataResult;
 import org.apache.uniffle.common.ShufflePartitionedData;
@@ -99,6 +100,49 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
     assertEquals(buffer, bufferPool.get(appId).get(shuffleId).get(0));
   }
 
+  @Test
+  public void getShuffleDataWithExpectedTaskIdsTest() {
+    String appId = "getShuffleDataWithExpectedTaskIdsTest";
+    shuffleBufferManager.registerBuffer(appId, 1, 0, 1);
+    ShufflePartitionedData spd1 = createData(0, 1, 68);
+    ShufflePartitionedData spd2 = createData(0, 2, 68);
+    ShufflePartitionedData spd3 = createData(0, 1, 68);
+    ShufflePartitionedData spd4 = createData(0, 3, 68);
+    shuffleBufferManager.cacheShuffleData(appId, 1, false, spd1);
+    shuffleBufferManager.cacheShuffleData(appId, 1, false, spd2);
+    shuffleBufferManager.cacheShuffleData(appId, 1, false, spd3);
+    shuffleBufferManager.cacheShuffleData(appId, 1, false, spd4);
+
+    /**
+     * case1: all blocks in cached and read multiple times
+     */
+    ShuffleDataResult result = shuffleBufferManager.getShuffleData(
+        appId,
+        1,
+        0,
+        Constants.INVALID_BLOCK_ID,
+        60,
+        Roaring64NavigableMap.bitmapOf(1)
+    );
+    assertEquals(1, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(68, result.getBufferSegments().get(0).getLength());
+
+    // 2th read
+    long lastBlockId = result.getBufferSegments().get(0).getBlockId();
+    result = shuffleBufferManager.getShuffleData(
+        appId,
+        1,
+        0,
+        lastBlockId,
+        60,
+        Roaring64NavigableMap.bitmapOf(1)
+    );
+    assertEquals(1, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(68, result.getBufferSegments().get(0).getLength());
+  }
+
   @Test
   public void getShuffleDataTest() {
     String appId = "getShuffleDataTest";
diff --git 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
index d275449a..a45f1a55 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
@@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong;
 
 import com.google.common.collect.Lists;
 import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
 import org.apache.uniffle.common.BufferSegment;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
@@ -82,6 +83,129 @@ public class ShuffleBufferTest extends BufferTestBase {
     assertEquals(0, shuffleBuffer.getBlocks().size());
   }
 
+  @Test
+  public void getShuffleDataWithExpectedTaskIdsFilterTest() {
+    /**
+     * case1: all blocks in cached(or in flushed map) and size < readBufferSize
+     */
+    ShuffleBuffer shuffleBuffer = new ShuffleBuffer(100);
+    ShufflePartitionedData spd1 = createData(1, 1, 15);
+    ShufflePartitionedData spd2 = createData(1, 0, 15);
+    ShufflePartitionedData spd3 = createData(1, 2, 55);
+    ShufflePartitionedData spd4 = createData(1, 1, 45);
+    shuffleBuffer.append(spd1);
+    shuffleBuffer.append(spd2);
+    shuffleBuffer.append(spd3);
+    shuffleBuffer.append(spd4);
+
+    Roaring64NavigableMap expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    ShuffleDataResult result = 
shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 1000, expectedTasks);
+    assertEquals(3, result.getBufferSegments().size());
+    for (BufferSegment segment : result.getBufferSegments()) {
+      assertTrue(expectedTasks.contains(segment.getTaskAttemptId()));
+    }
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(15, result.getBufferSegments().get(0).getLength());
+    assertEquals(15, result.getBufferSegments().get(1).getOffset());
+    assertEquals(55, result.getBufferSegments().get(1).getLength());
+    assertEquals(70, result.getBufferSegments().get(2).getOffset());
+    assertEquals(45, result.getBufferSegments().get(2).getLength());
+
+    expectedTasks = Roaring64NavigableMap.bitmapOf(0);
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 1000, 
expectedTasks);
+    assertEquals(1, result.getBufferSegments().size());
+    assertEquals(15, result.getBufferSegments().get(0).getLength());
+
+    /**
+     * case2: all blocks in cached(or in flushed map) and size > 
readBufferSize, so it will read multiple times.
+     *
+     * required blocks size list: 15, 55, 45
+     */
+    expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 60, 
expectedTasks);
+    assertEquals(2, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(15, result.getBufferSegments().get(0).getLength());
+    assertEquals(15, result.getBufferSegments().get(1).getOffset());
+    assertEquals(55, result.getBufferSegments().get(1).getLength());
+
+    // 2th read
+    long lastBlockId = result.getBufferSegments().get(1).getBlockId();
+    result = shuffleBuffer.getShuffleData(lastBlockId, 60, expectedTasks);
+    assertEquals(1, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(45, result.getBufferSegments().get(0).getLength());
+
+    /**
+     * case3: all blocks in flushed map and size < readBufferSize
+     */
+    expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    ShuffleDataFlushEvent event1 = shuffleBuffer.toFlushEvent(
+        "appId",
+        0,
+        0,
+        1,
+        null,
+        ShuffleDataDistributionType.LOCAL_ORDER
+    );
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 1000, 
expectedTasks);
+    assertEquals(3, result.getBufferSegments().size());
+    for (BufferSegment segment : result.getBufferSegments()) {
+      assertTrue(expectedTasks.contains(segment.getTaskAttemptId()));
+    }
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(15, result.getBufferSegments().get(0).getLength());
+    assertEquals(15, result.getBufferSegments().get(1).getOffset());
+    assertEquals(55, result.getBufferSegments().get(1).getLength());
+    assertEquals(70, result.getBufferSegments().get(2).getOffset());
+    assertEquals(45, result.getBufferSegments().get(2).getLength());
+
+    /**
+     * case4: all blocks in flushed map and size > readBufferSize, it will 
read multiple times
+     */
+    expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 60, 
expectedTasks);
+    assertEquals(2, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(15, result.getBufferSegments().get(0).getLength());
+    assertEquals(15, result.getBufferSegments().get(1).getOffset());
+    assertEquals(55, result.getBufferSegments().get(1).getLength());
+
+    // 2th read
+    lastBlockId = result.getBufferSegments().get(1).getBlockId();
+    result = shuffleBuffer.getShuffleData(lastBlockId, 60, expectedTasks);
+    assertEquals(1, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(45, result.getBufferSegments().get(0).getLength());
+
+    /**
+     * case5: partial blocks in cache and another in flushedMap, and it will 
read multiple times.
+     *
+     * required size: 15, 55, 45 (in flushed map) 55, 45, 5, 25(in cached)
+     */
+    ShufflePartitionedData spd5 = createData(1, 2, 55);
+    ShufflePartitionedData spd6 = createData(1, 1, 45);
+    ShufflePartitionedData spd7 = createData(1, 1, 5);
+    ShufflePartitionedData spd8 = createData(1, 1, 25);
+    shuffleBuffer.append(spd5);
+    shuffleBuffer.append(spd6);
+    shuffleBuffer.append(spd7);
+    shuffleBuffer.append(spd8);
+
+    expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 60, 
expectedTasks);
+    assertEquals(2, result.getBufferSegments().size());
+
+    // 2th read
+    lastBlockId = result.getBufferSegments().get(1).getBlockId();
+    result = shuffleBuffer.getShuffleData(lastBlockId, 60, expectedTasks);
+    assertEquals(2, result.getBufferSegments().size());
+    // 3th read
+    lastBlockId = result.getBufferSegments().get(1).getBlockId();
+    result = shuffleBuffer.getShuffleData(lastBlockId, 60, expectedTasks);
+    assertEquals(3, result.getBufferSegments().size());
+  }
+
   @Test
   public void getShuffleDataWithLocalOrderTest() {
     ShuffleBuffer shuffleBuffer = new ShuffleBuffer(200);
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index 3f4dd135..f3af0d46 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -121,7 +121,10 @@ public class ShuffleHandlerFactory {
         request.getShuffleId(),
         request.getPartitionId(),
         request.getReadBufferSize(),
-        shuffleServerClient);
+        shuffleServerClient,
+        request.getExpectTaskIds(),
+        request.isExpectedTaskIdsBitmapFilterEnable()
+    );
     return memoryClientReadHandler;
   }
 
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
index 73bcc46d..3cc2e6ba 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
@@ -19,6 +19,8 @@ package org.apache.uniffle.storage.handler.impl;
 
 import java.util.List;
 
+import com.google.common.annotations.VisibleForTesting;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -35,18 +37,43 @@ public class MemoryClientReadHandler extends 
AbstractClientReadHandler {
   private static final Logger LOG = 
LoggerFactory.getLogger(MemoryClientReadHandler.class);
   private long lastBlockId = Constants.INVALID_BLOCK_ID;
   private ShuffleServerClient shuffleServerClient;
+  private Roaring64NavigableMap expectTaskIds;
+  private boolean expectedTaskIdsBitmapFilterEnable;
 
+  // Only for tests
+  @VisibleForTesting
   public MemoryClientReadHandler(
       String appId,
       int shuffleId,
       int partitionId,
       int readBufferSize,
       ShuffleServerClient shuffleServerClient) {
+    this(
+        appId,
+        shuffleId,
+        partitionId,
+        readBufferSize,
+        shuffleServerClient,
+        null,
+        false
+    );
+  }
+
+  public MemoryClientReadHandler(
+      String appId,
+      int shuffleId,
+      int partitionId,
+      int readBufferSize,
+      ShuffleServerClient shuffleServerClient,
+      Roaring64NavigableMap expectTaskIds,
+      boolean expectedTaskIdsBitmapFilterEnable) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionId = partitionId;
     this.readBufferSize = readBufferSize;
     this.shuffleServerClient = shuffleServerClient;
+    this.expectTaskIds = expectTaskIds;
+    this.expectedTaskIdsBitmapFilterEnable = expectedTaskIdsBitmapFilterEnable;
   }
 
   @Override
@@ -54,7 +81,13 @@ public class MemoryClientReadHandler extends 
AbstractClientReadHandler {
     ShuffleDataResult result = null;
 
     RssGetInMemoryShuffleDataRequest request = new 
RssGetInMemoryShuffleDataRequest(
-        appId,shuffleId, partitionId, lastBlockId, readBufferSize);
+        appId,
+        shuffleId,
+        partitionId,
+        lastBlockId,
+        readBufferSize,
+        expectedTaskIdsBitmapFilterEnable ? expectTaskIds : null
+    );
 
     try {
       RssGetInMemoryShuffleDataResponse response =
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index 75a1f146..58729485 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -44,6 +44,7 @@ public class CreateShuffleReadHandlerRequest {
   private Roaring64NavigableMap processBlockIds;
   private ShuffleDataDistributionType distributionType;
   private Roaring64NavigableMap expectTaskIds;
+  private boolean expectedTaskIdsBitmapFilterEnable;
 
   public CreateShuffleReadHandlerRequest() {
   }
@@ -175,4 +176,12 @@ public class CreateShuffleReadHandlerRequest {
   public void setExpectTaskIds(Roaring64NavigableMap expectTaskIds) {
     this.expectTaskIds = expectTaskIds;
   }
+
+  public boolean isExpectedTaskIdsBitmapFilterEnable() {
+    return expectedTaskIdsBitmapFilterEnable;
+  }
+
+  public void useExpectedTaskIdsBitmapFilter() {
+    this.expectedTaskIdsBitmapFilterEnable = true;
+  }
 }


Reply via email to