This is an automated email from the ASF dual-hosted git repository.

xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 430046040 [#1708] feat(server): support use skip list to store 
shuffleBuffer in memory (#1763)
430046040 is described below

commit 430046040cf611dfed7cfce0ed2ce583032f3b06
Author: xianjingfeng <[email protected]>
AuthorDate: Mon Jun 17 15:17:42 2024 +0800

    [#1708] feat(server): support use skip list to store shuffleBuffer in 
memory (#1763)
    
    ### What changes were proposed in this pull request?
    Support use skip list to store shuffleBuffer in memory.
    
    ### Why are the changes needed?
    If we assign a lot of memory to store shuffle data, it will help to improve 
the performance(The system load of the shuffle server will be reduced.)
    Fix: #1708
    
    ### Does this PR introduce any user-facing change?
    set rss.server.shuffleBuffer.type to SKIP_LIST
    
    ### How was this patch tested?
    UTs and manual testing
---
 .../apache/uniffle/server/ShuffleServerConf.java   |  13 +
 .../server/buffer/AbstractShuffleBuffer.java       | 188 +++++++++++
 .../uniffle/server/buffer/ShuffleBuffer.java       | 343 +--------------------
 .../server/buffer/ShuffleBufferManager.java        |  13 +-
 .../uniffle/server/buffer/ShuffleBufferType.java   |  23 ++
 ...uffer.java => ShuffleBufferWithLinkedList.java} | 159 +---------
 .../server/buffer/ShuffleBufferWithSkipList.java   | 230 ++++++++++++++
 .../uniffle/server/buffer/BufferTestBase.java      |  19 +-
 ...t.java => ShuffleBufferWithLinkedListTest.java} |  34 +-
 .../buffer/ShuffleBufferWithSkipListTest.java      | 208 +++++++++++++
 10 files changed, 736 insertions(+), 494 deletions(-)

diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
index 064a6b28b..ebc4aaccb 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
@@ -25,6 +25,7 @@ import org.apache.uniffle.common.config.ConfigOption;
 import org.apache.uniffle.common.config.ConfigOptions;
 import org.apache.uniffle.common.config.ConfigUtils;
 import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.server.buffer.ShuffleBufferType;
 
 public class ShuffleServerConf extends RssBaseConf {
 
@@ -434,6 +435,18 @@ public class ShuffleServerConf extends RssBaseConf {
               "The interval of trigger shuffle buffer manager to flush data to 
persistent storage. If <= 0"
                   + ", then this flush check would be disabled.");
 
+  public static final ConfigOption<ShuffleBufferType> 
SERVER_SHUFFLE_BUFFER_TYPE =
+      ConfigOptions.key("rss.server.shuffleBuffer.type")
+          .enumType(ShuffleBufferType.class)
+          .defaultValue(ShuffleBufferType.LINKED_LIST)
+          .withDescription(
+              "The type for shuffle buffers. Setting as LINKED_LIST or 
SKIP_LIST."
+                  + " The default value is LINKED_LIST. SKIP_LIST will help to 
improve"
+                  + " the performance when there are a large number of blocks 
in memory"
+                  + " or when the memory occupied by the blocks is very large."
+                  + " The cpu usage of the shuffle server will be reduced."
+                  + " But SKIP_LIST doesn't support the slow-start feature of 
MR.");
+
   public static final ConfigOption<Long> SERVER_SHUFFLE_FLUSH_THRESHOLD =
       ConfigOptions.key("rss.server.shuffle.flush.threshold")
           .longType()
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/AbstractShuffleBuffer.java
 
b/server/src/main/java/org/apache/uniffle/server/buffer/AbstractShuffleBuffer.java
new file mode 100644
index 000000000..15197bc86
--- /dev/null
+++ 
b/server/src/main/java/org/apache/uniffle/server/buffer/AbstractShuffleBuffer.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.buffer;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.function.Supplier;
+
+import com.google.common.collect.Lists;
+import io.netty.buffer.CompositeByteBuf;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.common.util.NettyUtils;
+import org.apache.uniffle.server.ShuffleDataFlushEvent;
+
+public abstract class AbstractShuffleBuffer implements ShuffleBuffer {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(AbstractShuffleBuffer.class);
+
+  private final long capacity;
+  protected long size;
+
+  public AbstractShuffleBuffer(long capacity) {
+    this.capacity = capacity;
+    this.size = 0;
+  }
+
+  /** Only for test */
+  @Override
+  public synchronized ShuffleDataFlushEvent toFlushEvent(
+      String appId,
+      int shuffleId,
+      int startPartition,
+      int endPartition,
+      Supplier<Boolean> isValid) {
+    return toFlushEvent(
+        appId,
+        shuffleId,
+        startPartition,
+        endPartition,
+        isValid,
+        ShuffleDataDistributionType.NORMAL);
+  }
+
+  @Override
+  public long getSize() {
+    return size;
+  }
+
+  @Override
+  public boolean isFull() {
+    return size > capacity;
+  }
+
+  @Override
+  public synchronized ShuffleDataResult getShuffleData(long lastBlockId, int 
readBufferSize) {
+    return getShuffleData(lastBlockId, readBufferSize, null);
+  }
+
+  // 1. generate buffer segments and other info: if blockId exist, start with 
which eventId
+  // 2. according to info from step 1, generate data
+  // todo: if block was flushed, it's possible to get duplicated data
+  @Override
+  public synchronized ShuffleDataResult getShuffleData(
+      long lastBlockId, int readBufferSize, Roaring64NavigableMap 
expectedTaskIds) {
+    try {
+      List<BufferSegment> bufferSegments = Lists.newArrayList();
+      List<ShufflePartitionedBlock> readBlocks = Lists.newArrayList();
+      updateBufferSegmentsAndResultBlocks(
+          lastBlockId, readBufferSize, bufferSegments, readBlocks, 
expectedTaskIds);
+      if (!bufferSegments.isEmpty()) {
+        CompositeByteBuf byteBuf =
+            new CompositeByteBuf(
+                NettyUtils.getNettyBufferAllocator(),
+                true,
+                Constants.COMPOSITE_BYTE_BUF_MAX_COMPONENTS);
+        // copy result data
+        updateShuffleData(readBlocks, byteBuf);
+        return new ShuffleDataResult(byteBuf, bufferSegments);
+      }
+    } catch (Exception e) {
+      LOG.error("Exception happened when getShuffleData in buffer", e);
+    }
+    return new ShuffleDataResult();
+  }
+
+  // here is the rule to read data in memory:
+  // 1. read from inFlushBlockMap order by eventId asc, then from blocks
+  // 2. if can't find lastBlockId, means related data may be flushed to 
storage, repeat step 1
+  protected abstract void updateBufferSegmentsAndResultBlocks(
+      long lastBlockId,
+      long readBufferSize,
+      List<BufferSegment> bufferSegments,
+      List<ShufflePartitionedBlock> resultBlocks,
+      Roaring64NavigableMap expectedTaskIds);
+
+  protected int calculateDataLength(List<BufferSegment> bufferSegments) {
+    BufferSegment bufferSegment = bufferSegments.get(bufferSegments.size() - 
1);
+    return bufferSegment.getOffset() + bufferSegment.getLength();
+  }
+
+  private void updateShuffleData(List<ShufflePartitionedBlock> readBlocks, 
CompositeByteBuf data) {
+    int offset = 0;
+    for (ShufflePartitionedBlock block : readBlocks) {
+      // fill shuffle data
+      try {
+        data.addComponent(true, block.getData().retain());
+      } catch (Exception e) {
+        LOG.error(
+            "Unexpected exception for System.arraycopy, length["
+                + block.getLength()
+                + "], offset["
+                + offset
+                + "], dataLength["
+                + data.capacity()
+                + "]",
+            e);
+        throw e;
+      }
+      offset += block.getLength();
+    }
+  }
+
+  protected List<Long> sortFlushingEventId(List<Long> eventIdList) {
+    eventIdList.sort(
+        (id1, id2) -> {
+          if (id1 > id2) {
+            return 1;
+          }
+          return -1;
+        });
+    return eventIdList;
+  }
+
+  protected void updateSegmentsWithoutBlockId(
+      int offset,
+      Collection<ShufflePartitionedBlock> cachedBlocks,
+      long readBufferSize,
+      List<BufferSegment> bufferSegments,
+      List<ShufflePartitionedBlock> readBlocks,
+      Roaring64NavigableMap expectedTaskIds) {
+    int currentOffset = offset;
+    // read from first block
+    for (ShufflePartitionedBlock block : cachedBlocks) {
+      if (expectedTaskIds != null && 
!expectedTaskIds.contains(block.getTaskAttemptId())) {
+        continue;
+      }
+      // add bufferSegment with block
+      bufferSegments.add(
+          new BufferSegment(
+              block.getBlockId(),
+              currentOffset,
+              block.getLength(),
+              block.getUncompressLength(),
+              block.getCrc(),
+              block.getTaskAttemptId()));
+      readBlocks.add(block);
+      // update offset
+      currentOffset += block.getLength();
+      // check if length >= request buffer size
+      if (currentOffset >= readBufferSize) {
+        break;
+      }
+    }
+  }
+}
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
index 4a9a215a7..cb01bb0a9 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
@@ -17,359 +17,50 @@
 
 package org.apache.uniffle.server.buffer;
 
-import java.util.Comparator;
-import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.function.Supplier;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Lists;
-import io.netty.buffer.CompositeByteBuf;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
-import org.apache.uniffle.common.BufferSegment;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleDataResult;
 import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.ShufflePartitionedData;
-import org.apache.uniffle.common.util.Constants;
-import org.apache.uniffle.common.util.JavaUtils;
-import org.apache.uniffle.common.util.NettyUtils;
 import org.apache.uniffle.server.ShuffleDataFlushEvent;
-import org.apache.uniffle.server.ShuffleFlushManager;
 
-public class ShuffleBuffer {
+public interface ShuffleBuffer {
+  long append(ShufflePartitionedData data);
 
-  private static final Logger LOG = 
LoggerFactory.getLogger(ShuffleBuffer.class);
-
-  private final long capacity;
-  private long size;
-  // blocks will be added to inFlushBlockMap as <eventId, blocks> pair
-  // it will be removed after flush to storage
-  // the strategy ensure that shuffle is in memory or storage
-  private List<ShufflePartitionedBlock> blocks;
-  private Map<Long, List<ShufflePartitionedBlock>> inFlushBlockMap;
-
-  public ShuffleBuffer(long capacity) {
-    this.capacity = capacity;
-    this.size = 0;
-    this.blocks = new LinkedList<>();
-    this.inFlushBlockMap = JavaUtils.newConcurrentMap();
-  }
-
-  public long append(ShufflePartitionedData data) {
-    long mSize = 0;
-
-    synchronized (this) {
-      for (ShufflePartitionedBlock block : data.getBlockList()) {
-        blocks.add(block);
-        mSize += block.getSize();
-      }
-      size += mSize;
-    }
-
-    return mSize;
-  }
-
-  public synchronized ShuffleDataFlushEvent toFlushEvent(
+  ShuffleDataFlushEvent toFlushEvent(
       String appId,
       int shuffleId,
       int startPartition,
       int endPartition,
       Supplier<Boolean> isValid,
-      ShuffleDataDistributionType dataDistributionType) {
-    if (blocks.isEmpty()) {
-      return null;
-    }
-    // buffer will be cleared, and new list must be created for async flush
-    List<ShufflePartitionedBlock> spBlocks = new LinkedList<>(blocks);
-    List<ShufflePartitionedBlock> inFlushedQueueBlocks = spBlocks;
-    if (dataDistributionType == ShuffleDataDistributionType.LOCAL_ORDER) {
-      /**
-       * When reordering the blocks, it will break down the original reads 
sequence to cause the
-       * data lost in some cases. So we should create a reference copy to 
avoid this.
-       */
-      inFlushedQueueBlocks = new LinkedList<>(spBlocks);
-      
spBlocks.sort(Comparator.comparingLong(ShufflePartitionedBlock::getTaskAttemptId));
-    }
-    long eventId = ShuffleFlushManager.ATOMIC_EVENT_ID.getAndIncrement();
-    final ShuffleDataFlushEvent event =
-        new ShuffleDataFlushEvent(
-            eventId, appId, shuffleId, startPartition, endPartition, size, 
spBlocks, isValid, this);
-    event.addCleanupCallback(
-        () -> {
-          this.clearInFlushBuffer(event.getEventId());
-          spBlocks.forEach(spb -> spb.getData().release());
-        });
-    inFlushBlockMap.put(eventId, inFlushedQueueBlocks);
-    blocks.clear();
-    size = 0;
-    return event;
-  }
+      ShuffleDataDistributionType dataDistributionType);
 
   /** Only for test */
-  public synchronized ShuffleDataFlushEvent toFlushEvent(
-      String appId,
-      int shuffleId,
-      int startPartition,
-      int endPartition,
-      Supplier<Boolean> isValid) {
-    return toFlushEvent(
-        appId,
-        shuffleId,
-        startPartition,
-        endPartition,
-        isValid,
-        ShuffleDataDistributionType.NORMAL);
-  }
-
-  public List<ShufflePartitionedBlock> getBlocks() {
-    return blocks;
-  }
-
-  public long getSize() {
-    return size;
-  }
+  ShuffleDataFlushEvent toFlushEvent(
+      String appId, int shuffleId, int startPartition, int endPartition, 
Supplier<Boolean> isValid);
 
-  public boolean isFull() {
-    return size > capacity;
-  }
+  ShuffleDataResult getShuffleData(long lastBlockId, int readBufferSize);
 
-  public synchronized void clearInFlushBuffer(long eventId) {
-    inFlushBlockMap.remove(eventId);
-  }
-
-  @VisibleForTesting
-  public Map<Long, List<ShufflePartitionedBlock>> getInFlushBlockMap() {
-    return inFlushBlockMap;
-  }
+  ShuffleDataResult getShuffleData(
+      long lastBlockId, int readBufferSize, Roaring64NavigableMap 
expectedTaskIds);
 
-  public synchronized ShuffleDataResult getShuffleData(long lastBlockId, int 
readBufferSize) {
-    return getShuffleData(lastBlockId, readBufferSize, null);
-  }
+  long getSize();
 
-  // 1. generate buffer segments and other info: if blockId exist, start with 
which eventId
-  // 2. according to info from step 1, generate data
-  // todo: if block was flushed, it's possible to get duplicated data
-  public synchronized ShuffleDataResult getShuffleData(
-      long lastBlockId, int readBufferSize, Roaring64NavigableMap 
expectedTaskIds) {
-    try {
-      List<BufferSegment> bufferSegments = Lists.newArrayList();
-      List<ShufflePartitionedBlock> readBlocks = Lists.newArrayList();
-      updateBufferSegmentsAndResultBlocks(
-          lastBlockId, readBufferSize, bufferSegments, readBlocks, 
expectedTaskIds);
-      if (!bufferSegments.isEmpty()) {
-        CompositeByteBuf byteBuf =
-            new CompositeByteBuf(
-                NettyUtils.getNettyBufferAllocator(),
-                true,
-                Constants.COMPOSITE_BYTE_BUF_MAX_COMPONENTS);
-        // copy result data
-        updateShuffleData(readBlocks, byteBuf);
-        return new ShuffleDataResult(byteBuf, bufferSegments);
-      }
-    } catch (Exception e) {
-      LOG.error("Exception happened when getShuffleData in buffer", e);
-    }
-    return new ShuffleDataResult();
-  }
+  boolean isFull();
 
-  // here is the rule to read data in memory:
-  // 1. read from inFlushBlockMap order by eventId asc, then from blocks
-  // 2. if can't find lastBlockId, means related data may be flushed to 
storage, repeat step 1
-  private void updateBufferSegmentsAndResultBlocks(
-      long lastBlockId,
-      long readBufferSize,
-      List<BufferSegment> bufferSegments,
-      List<ShufflePartitionedBlock> resultBlocks,
-      Roaring64NavigableMap expectedTaskIds) {
-    long nextBlockId = lastBlockId;
-    List<Long> sortedEventId = sortFlushingEventId();
-    int offset = 0;
-    boolean hasLastBlockId = false;
-    // read from inFlushBlockMap first to make sure the order of
-    // data read is according to the order of data received
-    // The number of events means how many batches are in flushing status,
-    // it should be less than 5, or there has some problem with storage
-    if (!inFlushBlockMap.isEmpty()) {
-      for (Long eventId : sortedEventId) {
-        // update bufferSegments with different strategy according to 
lastBlockId
-        if (nextBlockId == Constants.INVALID_BLOCK_ID) {
-          updateSegmentsWithoutBlockId(
-              offset,
-              inFlushBlockMap.get(eventId),
-              readBufferSize,
-              bufferSegments,
-              resultBlocks,
-              expectedTaskIds);
-          hasLastBlockId = true;
-        } else {
-          hasLastBlockId =
-              updateSegmentsWithBlockId(
-                  offset,
-                  inFlushBlockMap.get(eventId),
-                  readBufferSize,
-                  nextBlockId,
-                  bufferSegments,
-                  resultBlocks,
-                  expectedTaskIds);
-          // if last blockId is found, read from begin with next cached blocks
-          if (hasLastBlockId) {
-            // reset blockId to read from begin in next cached blocks
-            nextBlockId = Constants.INVALID_BLOCK_ID;
-          }
-        }
-        if (!bufferSegments.isEmpty()) {
-          offset = calculateDataLength(bufferSegments);
-        }
-        if (offset >= readBufferSize) {
-          break;
-        }
-      }
-    }
-    // try to read from cached blocks which is not in flush queue
-    if (blocks.size() > 0 && offset < readBufferSize) {
-      if (nextBlockId == Constants.INVALID_BLOCK_ID) {
-        updateSegmentsWithoutBlockId(
-            offset, blocks, readBufferSize, bufferSegments, resultBlocks, 
expectedTaskIds);
-        hasLastBlockId = true;
-      } else {
-        hasLastBlockId =
-            updateSegmentsWithBlockId(
-                offset,
-                blocks,
-                readBufferSize,
-                nextBlockId,
-                bufferSegments,
-                resultBlocks,
-                expectedTaskIds);
-      }
-    }
-    if ((!inFlushBlockMap.isEmpty() || blocks.size() > 0) && offset == 0 && 
!hasLastBlockId) {
-      // can't find lastBlockId, it should be flushed
-      // but there still has data in memory
-      // try read again with blockId = Constants.INVALID_BLOCK_ID
-      updateBufferSegmentsAndResultBlocks(
-          Constants.INVALID_BLOCK_ID,
-          readBufferSize,
-          bufferSegments,
-          resultBlocks,
-          expectedTaskIds);
-    }
-  }
-
-  private int calculateDataLength(List<BufferSegment> bufferSegments) {
-    BufferSegment bufferSegment = bufferSegments.get(bufferSegments.size() - 
1);
-    return bufferSegment.getOffset() + bufferSegment.getLength();
-  }
-
-  private void updateShuffleData(List<ShufflePartitionedBlock> readBlocks, 
CompositeByteBuf data) {
-    int offset = 0;
-    for (ShufflePartitionedBlock block : readBlocks) {
-      // fill shuffle data
-      try {
-        data.addComponent(true, block.getData().retain());
-      } catch (Exception e) {
-        LOG.error(
-            "Unexpected exception for System.arraycopy, length["
-                + block.getLength()
-                + "], offset["
-                + offset
-                + "], dataLength["
-                + data.capacity()
-                + "]",
-            e);
-        throw e;
-      }
-      offset += block.getLength();
-    }
-  }
+  /** Only for test */
+  List<ShufflePartitionedBlock> getBlocks();
 
-  private List<Long> sortFlushingEventId() {
-    List<Long> eventIdList = Lists.newArrayList(inFlushBlockMap.keySet());
-    eventIdList.sort(
-        (id1, id2) -> {
-          if (id1 > id2) {
-            return 1;
-          }
-          return -1;
-        });
-    return eventIdList;
-  }
+  void release();
 
-  private void updateSegmentsWithoutBlockId(
-      int offset,
-      List<ShufflePartitionedBlock> cachedBlocks,
-      long readBufferSize,
-      List<BufferSegment> bufferSegments,
-      List<ShufflePartitionedBlock> readBlocks,
-      Roaring64NavigableMap expectedTaskIds) {
-    int currentOffset = offset;
-    // read from first block
-    for (ShufflePartitionedBlock block : cachedBlocks) {
-      if (expectedTaskIds != null && 
!expectedTaskIds.contains(block.getTaskAttemptId())) {
-        continue;
-      }
-      // add bufferSegment with block
-      bufferSegments.add(
-          new BufferSegment(
-              block.getBlockId(),
-              currentOffset,
-              block.getLength(),
-              block.getUncompressLength(),
-              block.getCrc(),
-              block.getTaskAttemptId()));
-      readBlocks.add(block);
-      // update offset
-      currentOffset += block.getLength();
-      // check if length >= request buffer size
-      if (currentOffset >= readBufferSize) {
-        break;
-      }
-    }
-  }
+  void clearInFlushBuffer(long eventId);
 
-  private boolean updateSegmentsWithBlockId(
-      int offset,
-      List<ShufflePartitionedBlock> cachedBlocks,
-      long readBufferSize,
-      long lastBlockId,
-      List<BufferSegment> bufferSegments,
-      List<ShufflePartitionedBlock> readBlocks,
-      Roaring64NavigableMap expectedTaskIds) {
-    int currentOffset = offset;
-    // find lastBlockId, then read from next block
-    boolean foundBlockId = false;
-    for (ShufflePartitionedBlock block : cachedBlocks) {
-      if (!foundBlockId) {
-        // find lastBlockId
-        if (block.getBlockId() == lastBlockId) {
-          foundBlockId = true;
-        }
-        continue;
-      }
-      if (expectedTaskIds != null && 
!expectedTaskIds.contains(block.getTaskAttemptId())) {
-        continue;
-      }
-      // add bufferSegment with block
-      bufferSegments.add(
-          new BufferSegment(
-              block.getBlockId(),
-              currentOffset,
-              block.getLength(),
-              block.getUncompressLength(),
-              block.getCrc(),
-              block.getTaskAttemptId()));
-      readBlocks.add(block);
-      // update offset
-      currentOffset += block.getLength();
-      if (currentOffset >= readBufferSize) {
-        break;
-      }
-    }
-    return foundBlockId;
-  }
+  @VisibleForTesting
+  Map<Long, List<ShufflePartitionedBlock>> getInFlushBlockMap();
 }
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index e85d2eae4..c3506921d 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -57,6 +57,7 @@ public class ShuffleBufferManager {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(ShuffleBufferManager.class);
 
+  private final ShuffleBufferType shuffleBufferType;
   private ShuffleTaskManager shuffleTaskManager;
   private final ShuffleFlushManager shuffleFlushManager;
   private long capacity;
@@ -137,6 +138,7 @@ public class ShuffleBufferManager {
             capacity * 
conf.get(ShuffleServerConf.HUGE_PARTITION_MEMORY_USAGE_LIMITATION_RATIO));
     appBlockSizeMetricEnabled =
         
conf.getBoolean(ShuffleServerConf.APP_LEVEL_SHUFFLE_BLOCK_SIZE_METRIC_ENABLED);
+    shuffleBufferType = conf.get(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE);
   }
 
   public void setShuffleTaskManager(ShuffleTaskManager taskManager) {
@@ -152,7 +154,13 @@ public class ShuffleBufferManager {
     if (bufferRangeMap.get(startPartition) == null) {
       ShuffleServerMetrics.counterTotalPartitionNum.inc();
       ShuffleServerMetrics.gaugeTotalPartitionNum.inc();
-      bufferRangeMap.put(Range.closed(startPartition, endPartition), new 
ShuffleBuffer(bufferSize));
+      ShuffleBuffer shuffleBuffer;
+      if (shuffleBufferType == ShuffleBufferType.SKIP_LIST) {
+        shuffleBuffer = new ShuffleBufferWithSkipList(bufferSize);
+      } else {
+        shuffleBuffer = new ShuffleBufferWithLinkedList(bufferSize);
+      }
+      bufferRangeMap.put(Range.closed(startPartition, endPartition), 
shuffleBuffer);
     } else {
       LOG.warn(
           "Already register for appId["
@@ -282,7 +290,6 @@ public class ShuffleBufferManager {
             buffer.getBlocks().size());
       }
       flushBuffer(buffer, appId, shuffleId, startPartition, endPartition, 
isHugePartition);
-      return;
     }
   }
 
@@ -712,7 +719,7 @@ public class ShuffleBufferManager {
       Collection<ShuffleBuffer> buffers = 
bufferRangeMap.asMapOfRanges().values();
       if (buffers != null) {
         for (ShuffleBuffer buffer : buffers) {
-          buffer.getBlocks().forEach(spb -> spb.getData().release());
+          buffer.release();
           ShuffleServerMetrics.gaugeTotalPartitionNum.dec();
           size += buffer.getSize();
         }
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferType.java 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferType.java
new file mode 100644
index 000000000..ad1fa04ff
--- /dev/null
+++ 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferType.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.buffer;
+
+public enum ShuffleBufferType {
+  SKIP_LIST,
+  LINKED_LIST
+}
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferWithLinkedList.java
similarity index 63%
copy from 
server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
copy to 
server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferWithLinkedList.java
index 4a9a215a7..c01bb0823 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferWithLinkedList.java
@@ -23,43 +23,32 @@ import java.util.List;
 import java.util.Map;
 import java.util.function.Supplier;
 
-import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
-import io.netty.buffer.CompositeByteBuf;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.BufferSegment;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
-import org.apache.uniffle.common.ShuffleDataResult;
 import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.ShufflePartitionedData;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.JavaUtils;
-import org.apache.uniffle.common.util.NettyUtils;
 import org.apache.uniffle.server.ShuffleDataFlushEvent;
 import org.apache.uniffle.server.ShuffleFlushManager;
 
-public class ShuffleBuffer {
-
-  private static final Logger LOG = 
LoggerFactory.getLogger(ShuffleBuffer.class);
-
-  private final long capacity;
-  private long size;
+public class ShuffleBufferWithLinkedList extends AbstractShuffleBuffer {
   // blocks will be added to inFlushBlockMap as <eventId, blocks> pair
   // it will be removed after flush to storage
   // the strategy ensure that shuffle is in memory or storage
   private List<ShufflePartitionedBlock> blocks;
   private Map<Long, List<ShufflePartitionedBlock>> inFlushBlockMap;
 
-  public ShuffleBuffer(long capacity) {
-    this.capacity = capacity;
-    this.size = 0;
+  public ShuffleBufferWithLinkedList(long capacity) {
+    super(capacity);
     this.blocks = new LinkedList<>();
     this.inFlushBlockMap = JavaUtils.newConcurrentMap();
   }
 
+  @Override
   public long append(ShufflePartitionedData data) {
     long mSize = 0;
 
@@ -74,6 +63,7 @@ public class ShuffleBuffer {
     return mSize;
   }
 
+  @Override
   public synchronized ShuffleDataFlushEvent toFlushEvent(
       String appId,
       int shuffleId,
@@ -110,84 +100,36 @@ public class ShuffleBuffer {
     return event;
   }
 
-  /** Only for test */
-  public synchronized ShuffleDataFlushEvent toFlushEvent(
-      String appId,
-      int shuffleId,
-      int startPartition,
-      int endPartition,
-      Supplier<Boolean> isValid) {
-    return toFlushEvent(
-        appId,
-        shuffleId,
-        startPartition,
-        endPartition,
-        isValid,
-        ShuffleDataDistributionType.NORMAL);
-  }
-
+  @Override
   public List<ShufflePartitionedBlock> getBlocks() {
     return blocks;
   }
 
-  public long getSize() {
-    return size;
-  }
-
-  public boolean isFull() {
-    return size > capacity;
+  @Override
+  public void release() {
+    blocks.forEach(spb -> spb.getData().release());
   }
 
+  @Override
   public synchronized void clearInFlushBuffer(long eventId) {
     inFlushBlockMap.remove(eventId);
   }
 
-  @VisibleForTesting
+  @Override
   public Map<Long, List<ShufflePartitionedBlock>> getInFlushBlockMap() {
     return inFlushBlockMap;
   }
 
-  public synchronized ShuffleDataResult getShuffleData(long lastBlockId, int 
readBufferSize) {
-    return getShuffleData(lastBlockId, readBufferSize, null);
-  }
-
-  // 1. generate buffer segments and other info: if blockId exist, start with 
which eventId
-  // 2. according to info from step 1, generate data
-  // todo: if block was flushed, it's possible to get duplicated data
-  public synchronized ShuffleDataResult getShuffleData(
-      long lastBlockId, int readBufferSize, Roaring64NavigableMap 
expectedTaskIds) {
-    try {
-      List<BufferSegment> bufferSegments = Lists.newArrayList();
-      List<ShufflePartitionedBlock> readBlocks = Lists.newArrayList();
-      updateBufferSegmentsAndResultBlocks(
-          lastBlockId, readBufferSize, bufferSegments, readBlocks, 
expectedTaskIds);
-      if (!bufferSegments.isEmpty()) {
-        CompositeByteBuf byteBuf =
-            new CompositeByteBuf(
-                NettyUtils.getNettyBufferAllocator(),
-                true,
-                Constants.COMPOSITE_BYTE_BUF_MAX_COMPONENTS);
-        // copy result data
-        updateShuffleData(readBlocks, byteBuf);
-        return new ShuffleDataResult(byteBuf, bufferSegments);
-      }
-    } catch (Exception e) {
-      LOG.error("Exception happened when getShuffleData in buffer", e);
-    }
-    return new ShuffleDataResult();
-  }
-
-  // here is the rule to read data in memory:
-  // 1. read from inFlushBlockMap order by eventId asc, then from blocks
-  // 2. if can't find lastBlockId, means related data may be flushed to 
storage, repeat step 1
-  private void updateBufferSegmentsAndResultBlocks(
+  @Override
+  protected void updateBufferSegmentsAndResultBlocks(
       long lastBlockId,
       long readBufferSize,
       List<BufferSegment> bufferSegments,
       List<ShufflePartitionedBlock> resultBlocks,
       Roaring64NavigableMap expectedTaskIds) {
     long nextBlockId = lastBlockId;
-    List<Long> sortedEventId = sortFlushingEventId();
+    List<Long> eventIdList = Lists.newArrayList(inFlushBlockMap.keySet());
+    List<Long> sortedEventId = sortFlushingEventId(eventIdList);
     int offset = 0;
     boolean hasLastBlockId = false;
     // read from inFlushBlockMap first to make sure the order of
@@ -261,77 +203,6 @@ public class ShuffleBuffer {
     }
   }
 
-  private int calculateDataLength(List<BufferSegment> bufferSegments) {
-    BufferSegment bufferSegment = bufferSegments.get(bufferSegments.size() - 
1);
-    return bufferSegment.getOffset() + bufferSegment.getLength();
-  }
-
-  private void updateShuffleData(List<ShufflePartitionedBlock> readBlocks, 
CompositeByteBuf data) {
-    int offset = 0;
-    for (ShufflePartitionedBlock block : readBlocks) {
-      // fill shuffle data
-      try {
-        data.addComponent(true, block.getData().retain());
-      } catch (Exception e) {
-        LOG.error(
-            "Unexpected exception for System.arraycopy, length["
-                + block.getLength()
-                + "], offset["
-                + offset
-                + "], dataLength["
-                + data.capacity()
-                + "]",
-            e);
-        throw e;
-      }
-      offset += block.getLength();
-    }
-  }
-
-  private List<Long> sortFlushingEventId() {
-    List<Long> eventIdList = Lists.newArrayList(inFlushBlockMap.keySet());
-    eventIdList.sort(
-        (id1, id2) -> {
-          if (id1 > id2) {
-            return 1;
-          }
-          return -1;
-        });
-    return eventIdList;
-  }
-
-  private void updateSegmentsWithoutBlockId(
-      int offset,
-      List<ShufflePartitionedBlock> cachedBlocks,
-      long readBufferSize,
-      List<BufferSegment> bufferSegments,
-      List<ShufflePartitionedBlock> readBlocks,
-      Roaring64NavigableMap expectedTaskIds) {
-    int currentOffset = offset;
-    // read from first block
-    for (ShufflePartitionedBlock block : cachedBlocks) {
-      if (expectedTaskIds != null && 
!expectedTaskIds.contains(block.getTaskAttemptId())) {
-        continue;
-      }
-      // add bufferSegment with block
-      bufferSegments.add(
-          new BufferSegment(
-              block.getBlockId(),
-              currentOffset,
-              block.getLength(),
-              block.getUncompressLength(),
-              block.getCrc(),
-              block.getTaskAttemptId()));
-      readBlocks.add(block);
-      // update offset
-      currentOffset += block.getLength();
-      // check if length >= request buffer size
-      if (currentOffset >= readBufferSize) {
-        break;
-      }
-    }
-  }
-
   private boolean updateSegmentsWithBlockId(
       int offset,
       List<ShufflePartitionedBlock> cachedBlocks,
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferWithSkipList.java
 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferWithSkipList.java
new file mode 100644
index 000000000..4783b4f56
--- /dev/null
+++ 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferWithSkipList.java
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.buffer;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentNavigableMap;
+import java.util.concurrent.ConcurrentSkipListMap;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+import com.google.common.collect.Lists;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.server.ShuffleDataFlushEvent;
+import org.apache.uniffle.server.ShuffleFlushManager;
+
+public class ShuffleBufferWithSkipList extends AbstractShuffleBuffer {
+  private ConcurrentSkipListMap<Long, ShufflePartitionedBlock> blocksMap;
+  private final Map<Long, ConcurrentSkipListMap<Long, 
ShufflePartitionedBlock>> inFlushBlockMap;
+
+  public ShuffleBufferWithSkipList(long capacity) {
+    super(capacity);
+    this.blocksMap = newConcurrentSkipListMap();
+    this.inFlushBlockMap = JavaUtils.newConcurrentMap();
+  }
+
+  private ConcurrentSkipListMap<Long, ShufflePartitionedBlock> 
newConcurrentSkipListMap() {
+    // We just need to ensure the order of taskAttemptId here for we need sort 
blocks when flush.
+    // taskAttemptId is in the lowest bits of blockId, so we should reverse it 
when making
+    // comparisons.
+    return new 
ConcurrentSkipListMap<>(Comparator.comparingLong(Long::reverse));
+  }
+
+  @Override
+  public long append(ShufflePartitionedData data) {
+    long mSize = 0;
+
+    synchronized (this) {
+      for (ShufflePartitionedBlock block : data.getBlockList()) {
+        blocksMap.put(block.getBlockId(), block);
+        mSize += block.getSize();
+      }
+      size += mSize;
+    }
+
+    return mSize;
+  }
+
+  @Override
+  public synchronized ShuffleDataFlushEvent toFlushEvent(
+      String appId,
+      int shuffleId,
+      int startPartition,
+      int endPartition,
+      Supplier<Boolean> isValid,
+      ShuffleDataDistributionType dataDistributionType) {
+    if (blocksMap.isEmpty()) {
+      return null;
+    }
+    List<ShufflePartitionedBlock> spBlocks = new 
LinkedList<>(blocksMap.values());
+    long eventId = ShuffleFlushManager.ATOMIC_EVENT_ID.getAndIncrement();
+    final ShuffleDataFlushEvent event =
+        new ShuffleDataFlushEvent(
+            eventId, appId, shuffleId, startPartition, endPartition, size, 
spBlocks, isValid, this);
+    event.addCleanupCallback(
+        () -> {
+          this.clearInFlushBuffer(event.getEventId());
+          spBlocks.forEach(spb -> spb.getData().release());
+        });
+    inFlushBlockMap.put(eventId, blocksMap);
+    blocksMap = newConcurrentSkipListMap();
+    size = 0;
+    return event;
+  }
+
+  @Override
+  public List<ShufflePartitionedBlock> getBlocks() {
+    return new LinkedList<>(blocksMap.values());
+  }
+
+  @Override
+  public void release() {
+    blocksMap.values().forEach(spb -> spb.getData().release());
+  }
+
+  @Override
+  public synchronized void clearInFlushBuffer(long eventId) {
+    inFlushBlockMap.remove(eventId);
+  }
+
+  @Override
+  public Map<Long, List<ShufflePartitionedBlock>> getInFlushBlockMap() {
+    return inFlushBlockMap.entrySet().stream()
+        .collect(Collectors.toMap(Map.Entry::getKey, e -> new 
ArrayList<>(e.getValue().values())));
+  }
+
+  @Override
+  protected void updateBufferSegmentsAndResultBlocks(
+      long lastBlockId,
+      long readBufferSize,
+      List<BufferSegment> bufferSegments,
+      List<ShufflePartitionedBlock> resultBlocks,
+      Roaring64NavigableMap expectedTaskIds) {
+    long nextBlockId = lastBlockId;
+    List<Long> eventIdList = Lists.newArrayList(inFlushBlockMap.keySet());
+    List<Long> sortedEventId = sortFlushingEventId(eventIdList);
+    int offset = 0;
+    boolean hasLastBlockId = false;
+    // read from inFlushBlockMap first to make sure the order of
+    // data read is according to the order of data received
+    // The number of events means how many batches are in flushing status,
+    // it should be less than 5, or there has some problem with storage
+    if (!inFlushBlockMap.isEmpty()) {
+      for (Long eventId : sortedEventId) {
+        hasLastBlockId =
+            updateSegments(
+                offset,
+                inFlushBlockMap.get(eventId),
+                readBufferSize,
+                nextBlockId,
+                bufferSegments,
+                resultBlocks,
+                expectedTaskIds);
+        // if last blockId is found, read from begin with next cached blocks
+        if (hasLastBlockId) {
+          // reset blockId to read from begin in next cached blocks
+          nextBlockId = Constants.INVALID_BLOCK_ID;
+        }
+        if (!bufferSegments.isEmpty()) {
+          offset = calculateDataLength(bufferSegments);
+        }
+        if (offset >= readBufferSize) {
+          break;
+        }
+      }
+    }
+    // try to read from cached blocks which is not in flush queue
+    if (!blocksMap.isEmpty() && offset < readBufferSize) {
+      hasLastBlockId =
+          updateSegments(
+              offset,
+              blocksMap,
+              readBufferSize,
+              nextBlockId,
+              bufferSegments,
+              resultBlocks,
+              expectedTaskIds);
+    }
+    if ((!inFlushBlockMap.isEmpty() || !blocksMap.isEmpty()) && offset == 0 && 
!hasLastBlockId) {
+      // can't find lastBlockId, it should be flushed
+      // but there still has data in memory
+      // try read again with blockId = Constants.INVALID_BLOCK_ID
+      updateBufferSegmentsAndResultBlocks(
+          Constants.INVALID_BLOCK_ID,
+          readBufferSize,
+          bufferSegments,
+          resultBlocks,
+          expectedTaskIds);
+    }
+  }
+
+  private boolean updateSegments(
+      int offset,
+      ConcurrentSkipListMap<Long, ShufflePartitionedBlock> cachedBlocks,
+      long readBufferSize,
+      long lastBlockId,
+      List<BufferSegment> bufferSegments,
+      List<ShufflePartitionedBlock> readBlocks,
+      Roaring64NavigableMap expectedTaskIds) {
+    int currentOffset = offset;
+    ConcurrentNavigableMap<Long, ShufflePartitionedBlock> remainingBlocks;
+    boolean hasLastBlockId;
+    if (lastBlockId == Constants.INVALID_BLOCK_ID) {
+      remainingBlocks = cachedBlocks;
+    } else {
+      if (cachedBlocks.get(lastBlockId) == null) {
+        return false;
+      }
+      remainingBlocks = cachedBlocks.tailMap(lastBlockId, false);
+    }
+
+    hasLastBlockId = true;
+    for (ShufflePartitionedBlock block : remainingBlocks.values()) {
+      if (expectedTaskIds != null && 
!expectedTaskIds.contains(block.getTaskAttemptId())) {
+        continue;
+      }
+      // add bufferSegment with block
+      bufferSegments.add(
+          new BufferSegment(
+              block.getBlockId(),
+              currentOffset,
+              block.getLength(),
+              block.getUncompressLength(),
+              block.getCrc(),
+              block.getTaskAttemptId()));
+      readBlocks.add(block);
+      // update offset
+      currentOffset += block.getLength();
+      if (currentOffset >= readBufferSize) {
+        break;
+      }
+    }
+    return hasLastBlockId;
+  }
+}
diff --git 
a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java 
b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
index 5e53a80b2..2314a1b6c 100644
--- a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
+++ b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
@@ -18,13 +18,14 @@
 package org.apache.uniffle.server.buffer;
 
 import java.util.Random;
-import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.BeforeAll;
 
 import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.server.ShuffleServerMetrics;
 
@@ -40,7 +41,7 @@ public abstract class BufferTestBase {
     ShuffleServerMetrics.clear();
   }
 
-  private static AtomicLong atomBlockId = new AtomicLong(0);
+  private static AtomicInteger atomSequenceNo = new AtomicInteger(0);
 
   protected ShufflePartitionedData createData(int len) {
     return createData(1, len);
@@ -53,16 +54,18 @@ public abstract class BufferTestBase {
   protected ShufflePartitionedData createData(int partitionId, int 
taskAttemptId, int len) {
     byte[] buf = new byte[len];
     new Random().nextBytes(buf);
+    long blockId =
+        BlockIdLayout.DEFAULT.getBlockId(
+            getAtomSequenceNo().incrementAndGet(), partitionId, taskAttemptId);
     ShufflePartitionedBlock block =
         new ShufflePartitionedBlock(
-            len,
-            len,
-            ChecksumUtils.getCrc32(buf),
-            atomBlockId.incrementAndGet(),
-            taskAttemptId,
-            buf);
+            len, len, ChecksumUtils.getCrc32(buf), blockId, taskAttemptId, 
buf);
     ShufflePartitionedData data =
         new ShufflePartitionedData(partitionId, new ShufflePartitionedBlock[] 
{block});
     return data;
   }
+
+  protected AtomicInteger getAtomSequenceNo() {
+    return atomSequenceNo;
+  }
 }
diff --git 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferWithLinkedListTest.java
similarity index 96%
rename from 
server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
rename to 
server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferWithLinkedListTest.java
index 5eb1bcfa7..90ac86c9f 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferWithLinkedListTest.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.server.buffer;
 
 import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import com.google.common.collect.Lists;
 import org.junit.jupiter.api.Test;
@@ -38,11 +39,13 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
-public class ShuffleBufferTest extends BufferTestBase {
+public class ShuffleBufferWithLinkedListTest extends BufferTestBase {
+
+  private static AtomicInteger atomSequenceNo = new AtomicInteger(0);
 
   @Test
   public void appendTest() {
-    ShuffleBuffer shuffleBuffer = new ShuffleBuffer(100);
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithLinkedList(100);
     shuffleBuffer.append(createData(10));
     // ShufflePartitionedBlock has constant 32 bytes overhead
     assertEquals(42, shuffleBuffer.getSize());
@@ -59,7 +62,7 @@ public class ShuffleBufferTest extends BufferTestBase {
 
   @Test
   public void appendMultiBlocksTest() {
-    ShuffleBuffer shuffleBuffer = new ShuffleBuffer(100);
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithLinkedList(100);
     ShufflePartitionedData data1 = createData(10);
     ShufflePartitionedData data2 = createData(10);
     ShufflePartitionedBlock[] dataCombine = new ShufflePartitionedBlock[2];
@@ -71,7 +74,7 @@ public class ShuffleBufferTest extends BufferTestBase {
 
   @Test
   public void toFlushEventTest() {
-    ShuffleBuffer shuffleBuffer = new ShuffleBuffer(100);
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithLinkedList(100);
     ShuffleDataFlushEvent event = shuffleBuffer.toFlushEvent("appId", 0, 0, 1, 
null);
     assertNull(event);
     shuffleBuffer.append(createData(10));
@@ -85,7 +88,7 @@ public class ShuffleBufferTest extends BufferTestBase {
   @Test
   public void getShuffleDataWithExpectedTaskIdsFilterTest() {
     /** case1: all blocks in cached(or in flushed map) and size < 
readBufferSize */
-    ShuffleBuffer shuffleBuffer = new ShuffleBuffer(100);
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithLinkedList(100);
     ShufflePartitionedData spd1 = createData(1, 1, 15);
     ShufflePartitionedData spd2 = createData(1, 0, 15);
     ShufflePartitionedData spd3 = createData(1, 2, 55);
@@ -197,7 +200,7 @@ public class ShuffleBufferTest extends BufferTestBase {
 
   @Test
   public void getShuffleDataWithLocalOrderTest() {
-    ShuffleBuffer shuffleBuffer = new ShuffleBuffer(200);
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithLinkedList(200);
     ShufflePartitionedData spd1 = createData(1, 1, 15);
     ShufflePartitionedData spd2 = createData(1, 0, 15);
     ShufflePartitionedData spd3 = createData(1, 2, 15);
@@ -235,7 +238,7 @@ public class ShuffleBufferTest extends BufferTestBase {
 
   @Test
   public void getShuffleDataTest() {
-    ShuffleBuffer shuffleBuffer = new ShuffleBuffer(200);
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithLinkedList(200);
     // case1: cached data only, blockId = -1, readBufferSize > buffer size
     ShufflePartitionedData spd1 = createData(10);
     ShufflePartitionedData spd2 = createData(20);
@@ -247,7 +250,7 @@ public class ShuffleBufferTest extends BufferTestBase {
     assertArrayEquals(expectedData, sdr.getData());
 
     // case2: cached data only, blockId = -1, readBufferSize = buffer size
-    shuffleBuffer = new ShuffleBuffer(200);
+    shuffleBuffer = new ShuffleBufferWithLinkedList(200);
     spd1 = createData(20);
     spd2 = createData(20);
     shuffleBuffer.append(spd1);
@@ -258,7 +261,7 @@ public class ShuffleBufferTest extends BufferTestBase {
     assertArrayEquals(expectedData, sdr.getData());
 
     // case3-1: cached data only, blockId = -1, readBufferSize < buffer size
-    shuffleBuffer = new ShuffleBuffer(200);
+    shuffleBuffer = new ShuffleBufferWithLinkedList(200);
     spd1 = createData(20);
     spd2 = createData(21);
     shuffleBuffer.append(spd1);
@@ -269,7 +272,7 @@ public class ShuffleBufferTest extends BufferTestBase {
     assertArrayEquals(expectedData, sdr.getData());
 
     // case3-2: cached data only, blockId = -1, readBufferSize < buffer size
-    shuffleBuffer = new ShuffleBuffer(200);
+    shuffleBuffer = new ShuffleBufferWithLinkedList(200);
     spd1 = createData(15);
     spd2 = createData(15);
     ShufflePartitionedData spd3 = createData(15);
@@ -289,7 +292,7 @@ public class ShuffleBufferTest extends BufferTestBase {
     assertArrayEquals(expectedData, sdr.getData());
 
     // case5: flush data only, blockId = -1, readBufferSize < buffer size
-    shuffleBuffer = new ShuffleBuffer(200);
+    shuffleBuffer = new ShuffleBufferWithLinkedList(200);
     spd1 = createData(15);
     spd2 = createData(15);
     shuffleBuffer.append(spd1);
@@ -307,13 +310,13 @@ public class ShuffleBufferTest extends BufferTestBase {
     assertEquals(0, sdr.getBufferSegments().size());
 
     // case6: no data in buffer & flush buffer
-    shuffleBuffer = new ShuffleBuffer(200);
+    shuffleBuffer = new ShuffleBufferWithLinkedList(200);
     sdr = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 10);
     assertEquals(0, sdr.getBufferSegments().size());
     assertEquals(0, sdr.getDataLength());
 
     // case7: get data with multiple flush buffer and cached buffer
-    shuffleBuffer = new ShuffleBuffer(200);
+    shuffleBuffer = new ShuffleBufferWithLinkedList(200);
     spd1 = createData(15);
     spd2 = createData(15);
     spd3 = createData(15);
@@ -596,4 +599,9 @@ public class ShuffleBufferTest extends BufferTestBase {
       segmentIndex++;
     }
   }
+
+  @Override
+  protected AtomicInteger getAtomSequenceNo() {
+    return atomSequenceNo;
+  }
 }
diff --git 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferWithSkipListTest.java
 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferWithSkipListTest.java
new file mode 100644
index 000000000..bf5040304
--- /dev/null
+++ 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferWithSkipListTest.java
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.buffer;
+
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.server.ShuffleDataFlushEvent;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ShuffleBufferWithSkipListTest extends BufferTestBase {
+  private static AtomicInteger atomSequenceNo = new AtomicInteger(0);
+
+  @Test
+  public void appendTest() {
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithSkipList(100);
+    shuffleBuffer.append(createData(10));
+    // ShufflePartitionedBlock has constant 32 bytes overhead
+    assertEquals(42, shuffleBuffer.getSize());
+    assertFalse(shuffleBuffer.isFull());
+
+    shuffleBuffer.append(createData(26));
+    assertEquals(100, shuffleBuffer.getSize());
+    assertFalse(shuffleBuffer.isFull());
+
+    shuffleBuffer.append(createData(1));
+    assertEquals(133, shuffleBuffer.getSize());
+    assertTrue(shuffleBuffer.isFull());
+  }
+
+  @Test
+  public void appendMultiBlocksTest() {
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithSkipList(100);
+    ShufflePartitionedData data1 = createData(10);
+    ShufflePartitionedData data2 = createData(10);
+    ShufflePartitionedBlock[] dataCombine = new ShufflePartitionedBlock[2];
+    dataCombine[0] = data1.getBlockList()[0];
+    dataCombine[1] = data2.getBlockList()[0];
+    shuffleBuffer.append(new ShufflePartitionedData(1, dataCombine));
+    assertEquals(84, shuffleBuffer.getSize());
+  }
+
+  @Test
+  public void toFlushEventTest() {
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithSkipList(100);
+    ShuffleDataFlushEvent event = shuffleBuffer.toFlushEvent("appId", 0, 0, 1, 
null);
+    assertNull(event);
+    shuffleBuffer.append(createData(10));
+    assertEquals(42, shuffleBuffer.getSize());
+    event = shuffleBuffer.toFlushEvent("appId", 0, 0, 1, null);
+    assertEquals(42, event.getSize());
+    assertEquals(0, shuffleBuffer.getSize());
+    assertEquals(0, shuffleBuffer.getBlocks().size());
+  }
+
+  @Test
+  public void getShuffleDataWithExpectedTaskIdsFilterTest() {
+    /** case1: all blocks in cached(or in flushed map) and size < 
readBufferSize */
+    ShuffleBuffer shuffleBuffer = new ShuffleBufferWithSkipList(100);
+    ShufflePartitionedData spd1 = createData(1, 1, 15);
+    ShufflePartitionedData spd2 = createData(1, 0, 15);
+    ShufflePartitionedData spd3 = createData(1, 2, 55);
+    ShufflePartitionedData spd4 = createData(1, 1, 45);
+    shuffleBuffer.append(spd1);
+    shuffleBuffer.append(spd2);
+    shuffleBuffer.append(spd3);
+    shuffleBuffer.append(spd4);
+
+    Roaring64NavigableMap expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    ShuffleDataResult result =
+        shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 1000, 
expectedTasks);
+    assertEquals(3, result.getBufferSegments().size());
+    for (BufferSegment segment : result.getBufferSegments()) {
+      assertTrue(expectedTasks.contains(segment.getTaskAttemptId()));
+    }
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    // Currently, if we use skip_list, we can't guarantee that the reading 
order is same as
+    // writing order. So only check the total segment size of taskAttempt 1.
+    assertEquals(
+        60,
+        result.getBufferSegments().get(0).getLength()
+            + result.getBufferSegments().get(1).getLength());
+    assertEquals(60, result.getBufferSegments().get(2).getOffset());
+    assertEquals(55, result.getBufferSegments().get(2).getLength());
+
+    expectedTasks = Roaring64NavigableMap.bitmapOf(0);
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 1000, 
expectedTasks);
+    assertEquals(1, result.getBufferSegments().size());
+    assertEquals(15, result.getBufferSegments().get(0).getLength());
+
+    /**
+     * case2: all blocks in cached(or in flushed map) and size > 
readBufferSize, so it will read
+     * multiple times.
+     *
+     * <p>required blocks size list: 15+45, 55
+     */
+    expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 60, 
expectedTasks);
+    assertEquals(2, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(
+        60,
+        result.getBufferSegments().get(0).getLength()
+            + result.getBufferSegments().get(1).getLength());
+
+    // 2nd read
+    long lastBlockId = result.getBufferSegments().get(1).getBlockId();
+    result = shuffleBuffer.getShuffleData(lastBlockId, 60, expectedTasks);
+    assertEquals(1, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(55, result.getBufferSegments().get(0).getLength());
+
+    /** case3: all blocks in flushed map and size < readBufferSize */
+    expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    ShuffleDataFlushEvent event1 =
+        shuffleBuffer.toFlushEvent("appId", 0, 0, 1, null, 
ShuffleDataDistributionType.LOCAL_ORDER);
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 1000, 
expectedTasks);
+    assertEquals(3, result.getBufferSegments().size());
+    for (BufferSegment segment : result.getBufferSegments()) {
+      assertTrue(expectedTasks.contains(segment.getTaskAttemptId()));
+    }
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    // Currently, if we use skip_list, we can't guarantee that the reading 
order is same as
+    // writing order. So only check the total segment size of taskAttempt 1.
+    assertEquals(
+        60,
+        result.getBufferSegments().get(0).getLength()
+            + result.getBufferSegments().get(1).getLength());
+    assertEquals(60, result.getBufferSegments().get(2).getOffset());
+    assertEquals(55, result.getBufferSegments().get(2).getLength());
+
+    /** case4: all blocks in flushed map and size > readBufferSize, it will 
read multiple times */
+    expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 60, 
expectedTasks);
+    assertEquals(2, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(
+        60,
+        result.getBufferSegments().get(0).getLength()
+            + result.getBufferSegments().get(1).getLength());
+
+    // 2nd read
+    lastBlockId = result.getBufferSegments().get(1).getBlockId();
+    result = shuffleBuffer.getShuffleData(lastBlockId, 60, expectedTasks);
+    assertEquals(1, result.getBufferSegments().size());
+    assertEquals(0, result.getBufferSegments().get(0).getOffset());
+    assertEquals(55, result.getBufferSegments().get(0).getLength());
+
+    /**
+     * case5: partial blocks in cache and another in flushedMap, and it will 
read multiple times.
+     *
+     * <p>required size: 15, 55, 45 (in flushed map) 55, 45, 5, 25(in cached)
+     */
+    ShufflePartitionedData spd5 = createData(1, 2, 55);
+    ShufflePartitionedData spd6 = createData(1, 1, 45);
+    ShufflePartitionedData spd7 = createData(1, 1, 5);
+    ShufflePartitionedData spd8 = createData(1, 1, 25);
+    shuffleBuffer.append(spd5);
+    shuffleBuffer.append(spd6);
+    shuffleBuffer.append(spd7);
+    shuffleBuffer.append(spd8);
+
+    expectedTasks = Roaring64NavigableMap.bitmapOf(1, 2);
+    result = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 60, 
expectedTasks);
+    assertEquals(2, result.getBufferSegments().size());
+
+    // 2nd read
+    lastBlockId = result.getBufferSegments().get(1).getBlockId();
+    result = shuffleBuffer.getShuffleData(lastBlockId, 60, expectedTasks);
+    assertEquals(2, result.getBufferSegments().size());
+    // 3rd read
+    lastBlockId = result.getBufferSegments().get(1).getBlockId();
+    result = shuffleBuffer.getShuffleData(lastBlockId, 60, expectedTasks);
+    assertEquals(3, result.getBufferSegments().size());
+  }
+
+  @Override
+  protected AtomicInteger getAtomSequenceNo() {
+    return atomSequenceNo;
+  }
+}

Reply via email to