This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 03b04b90 [#808] feat(spark): ensure thread safe and data consistency 
when spill (#848)
03b04b90 is described below

commit 03b04b903c46132359dc4b8f5e47d1ac645b3828
Author: Junfan Zhang <[email protected]>
AuthorDate: Sat Jul 22 21:10:22 2023 +0800

    [#808] feat(spark): ensure thread safe and data consistency when spill 
(#848)
    
    ### What changes were proposed in this pull request?
    
    1. Guarantees thread safe by only allowing spills to be triggered by the 
current thread
    2. Using  the same logic of processing blocks in the `RssShuffleWriter` and 
`WriteBufferManager` to ensure the data consistency
    
    ### Why are the changes needed?
    
    Fix: #808
    
    In this PR, we use the two ways to solve the concurrent problem for 
`addRecord` and `spill` function
    1. For the same thread, the spill will be invoked when adding  records and 
unsuffcient memory. This case could ensure
    thread safe. So it will do the spill sync.
    2. When spill is invoked by other consumers, it will do nothing in this 
thread and just set a signal to let owner to release when adding record.
    
    After this, we could avoid lock(may cause performance regression, like #811 
did) to keep thread safe
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    1. UTs
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |   7 +
 .../spark/shuffle/writer/WriteBufferManager.java   | 102 ++++++++--
 .../shuffle/writer/WriteBufferManagerTest.java     | 216 ++++++++++++++++++---
 .../apache/spark/shuffle/RssShuffleManager.java    |  21 +-
 .../spark/shuffle/writer/RssShuffleWriter.java     |  70 +++++--
 .../apache/spark/shuffle/RssShuffleManager.java    |  21 +-
 .../spark/shuffle/writer/RssShuffleWriter.java     |  81 ++++++--
 .../spark/shuffle/writer/RssShuffleWriterTest.java | 103 +++++++++-
 8 files changed, 502 insertions(+), 119 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index ea5b629b..283122f0 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -57,6 +57,13 @@ public class RssSparkConfig {
           .defaultValue(true)
           .withDescription("indicates row based shuffle, set false when use in 
columnar shuffle");
 
+  public static final ConfigOption<Boolean> RSS_MEMORY_SPILL_ENABLED =
+      ConfigOptions.key("rss.client.memory.spill.enabled")
+          .booleanType()
+          .defaultValue(false)
+          .withDescription(
+              "The memory spill switch triggered by Spark TaskMemoryManager, 
default value is false.");
+
   public static final String SPARK_RSS_CONFIG_PREFIX = "spark.";
 
   public static final ConfigEntry<Integer> RSS_PARTITION_NUM_PER_RANGE =
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index c4356510..49795555 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -23,6 +23,8 @@ import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Function;
 
@@ -87,8 +89,9 @@ public class WriteBufferManager extends MemoryConsumer {
   private long requireMemoryInterval;
   private int requireMemoryRetryMax;
   private Codec codec;
-  private Function<AddBlockEvent, CompletableFuture<Long>> spillFunc;
+  private Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> 
spillFunc;
   private long sendSizeLimit;
+  private boolean memorySpillEnabled;
   private int memorySpillTimeoutSec;
   private boolean isRowBased;
 
@@ -124,7 +127,7 @@ public class WriteBufferManager extends MemoryConsumer {
       TaskMemoryManager taskMemoryManager,
       ShuffleWriteMetrics shuffleWriteMetrics,
       RssConf rssConf,
-      Function<AddBlockEvent, CompletableFuture<Long>> spillFunc) {
+      Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> 
spillFunc) {
     super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), 
MemoryMode.ON_HEAP);
     this.bufferSize = bufferManagerOptions.getBufferSize();
     this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
@@ -155,6 +158,7 @@ public class WriteBufferManager extends MemoryConsumer {
     this.spillFunc = spillFunc;
     this.sendSizeLimit = 
rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
     this.memorySpillTimeoutSec = 
rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
+    this.memorySpillEnabled = 
rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_ENABLED);
   }
 
   /** add serialized columnar data directly when integrate with gluten */
@@ -165,15 +169,51 @@ public class WriteBufferManager extends MemoryConsumer {
 
   public List<ShuffleBlockInfo> addPartitionData(
       int partitionId, byte[] serializedData, int serializedDataLength, long 
start) {
-    List<ShuffleBlockInfo> result = Lists.newArrayList();
+    List<ShuffleBlockInfo> candidateSendingBlocks =
+        insertIntoBuffer(partitionId, serializedData, serializedDataLength);
+
+    // check buffer size > spill threshold
+    if (usedBytes.get() - inSendListBytes.get() > spillSize) {
+      candidateSendingBlocks.addAll(clear());
+    }
+    writeTime += System.currentTimeMillis() - start;
+    return candidateSendingBlocks;
+  }
+
+  /**
+   * Before inserting a record into its corresponding buffer, the system 
should check if there is
+   * sufficient buffer memory available. If there isn't enough memory, it will 
request additional
+   * memory from the {@link TaskMemoryManager}. In the event that the JVM is 
low on memory, a spill
+   * operation will be triggered. If any memory consumer managed by the {@link 
TaskMemoryManager}
+   * fails to meet its memory requirements, it will also be triggered one by 
one.
+   *
+   * <p>If the current buffer manager requests memory and triggers a spill 
operation, the buffer
+   * that is currently being held should be dropped, and then re-inserted.
+   */
+  private List<ShuffleBlockInfo> insertIntoBuffer(
+      int partitionId, byte[] serializedData, int serializedDataLength) {
+    List<ShuffleBlockInfo> sentBlocks = new ArrayList<>();
+    long required = Math.max(bufferSegmentSize, serializedDataLength);
+    // Asking memory from task memory manager for the existing writer buffer,
+    // this may trigger current WriteBufferManager spill method, which will
+    // make the current write buffer discard. So we have to recheck the buffer 
existence.
+    boolean hasRequested = false;
     if (buffers.containsKey(partitionId)) {
       WriterBuffer wb = buffers.get(partitionId);
       if (wb.askForMemory(serializedDataLength)) {
-        requestMemory(Math.max(bufferSegmentSize, serializedDataLength));
+        requestMemory(required);
+        hasRequested = true;
+      }
+    }
+
+    if (buffers.containsKey(partitionId)) {
+      if (hasRequested) {
+        usedBytes.addAndGet(required);
       }
+      WriterBuffer wb = buffers.get(partitionId);
       wb.addRecord(serializedData, serializedDataLength);
       if (wb.getMemoryUsed() > bufferSize) {
-        result.add(createShuffleBlock(partitionId, wb));
+        sentBlocks.add(createShuffleBlock(partitionId, wb));
         copyTime += wb.getCopyTime();
         buffers.remove(partitionId);
         LOG.debug(
@@ -188,18 +228,19 @@ public class WriteBufferManager extends MemoryConsumer {
                 + "]");
       }
     } else {
-      requestMemory(Math.max(bufferSegmentSize, serializedDataLength));
+      // The true of hasRequested means the former partitioned buffer has been 
flushed, that is
+      // triggered by the spill operation caused by asking for memory. So it 
needn't to re-request
+      // the memory.
+      if (!hasRequested) {
+        requestMemory(required);
+      }
+      usedBytes.addAndGet(required);
+
       WriterBuffer wb = new WriterBuffer(bufferSegmentSize);
       wb.addRecord(serializedData, serializedDataLength);
       buffers.put(partitionId, wb);
     }
-
-    // check buffer size > spill threshold
-    if (usedBytes.get() - inSendListBytes.get() > spillSize) {
-      result.addAll(clear());
-    }
-    writeTime += System.currentTimeMillis() - start;
-    return result;
+    return sentBlocks;
   }
 
   public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object 
value) {
@@ -304,7 +345,6 @@ public class WriteBufferManager extends MemoryConsumer {
     if (allocatedBytes.get() - usedBytes.get() < requiredMem) {
       requestExecutorMemory(requiredMem);
     }
-    usedBytes.addAndGet(requiredMem);
     requireMemoryTime += System.currentTimeMillis() - start;
   }
 
@@ -395,7 +435,36 @@ public class WriteBufferManager extends MemoryConsumer {
 
   @Override
   public long spill(long size, MemoryConsumer trigger) {
-    return 0L;
+    // Only for the MemoryConsumer of this instance, it will flush buffer
+    if (!memorySpillEnabled || trigger != this) {
+      return 0L;
+    }
+
+    List<CompletableFuture<Long>> futures = spillFunc.apply(clear());
+    CompletableFuture<Void> allOfFutures =
+        CompletableFuture.allOf(futures.toArray(new 
CompletableFuture[futures.size()]));
+    try {
+      allOfFutures.get(memorySpillTimeoutSec, TimeUnit.SECONDS);
+    } catch (TimeoutException timeoutException) {
+      // A best effort strategy to wait.
+      // If timeout exception occurs, the underlying tasks won't be cancelled.
+    } finally {
+      long releasedSize =
+          futures.stream()
+              .filter(x -> x.isDone())
+              .mapToLong(
+                  x -> {
+                    try {
+                      return x.get();
+                    } catch (Exception e) {
+                      return 0;
+                    }
+                  })
+              .sum();
+      LOG.info(
+          "[taskId: {}] Spill triggered by own, released memory size: {}", 
taskId, releasedSize);
+      return releasedSize;
+    }
   }
 
   @VisibleForTesting
@@ -470,7 +539,8 @@ public class WriteBufferManager extends MemoryConsumer {
   }
 
   @VisibleForTesting
-  public void setSpillFunc(Function<AddBlockEvent, CompletableFuture<Long>> 
spillFunc) {
+  public void setSpillFunc(
+      Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> 
spillFunc) {
     this.spillFunc = spillFunc;
   }
 
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index e5a9c650..0e1f5e3f 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.shuffle.writer;
 
+import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
@@ -27,6 +29,8 @@ import com.google.common.collect.Maps;
 import org.apache.commons.lang3.reflect.FieldUtils;
 import org.apache.spark.SparkConf;
 import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.MemoryManager;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.KryoSerializer;
 import org.apache.spark.serializer.Serializer;
@@ -34,6 +38,8 @@ import org.apache.spark.shuffle.RssSparkConfig;
 import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.config.RssConf;
@@ -266,18 +272,45 @@ public class WriteBufferManagerTest {
     assertEquals(3, events.size());
   }
 
-  public void spillTest() {
+  @Test
+  public void spillByOthersTest() {
     SparkConf conf = getConf();
-    conf.set("spark.rss.client.send.size.limit", "1000");
+    conf.set("spark.rss.client.memory.spill.enabled", "true");
     TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
 
-    Function<AddBlockEvent, CompletableFuture<Long>> spillFunc =
-        event -> {
-          event.getProcessedCallbackChain().stream().forEach(x -> x.run());
-          return CompletableFuture.completedFuture(
-              event.getShuffleDataInfoList().stream().mapToLong(x -> 
x.getFreeMemory()).sum());
-        };
+    WriteBufferManager wbm =
+        new WriteBufferManager(
+            0,
+            "taskId_spillByOthersTest",
+            0,
+            bufferOptions,
+            new KryoSerializer(conf),
+            Maps.newHashMap(),
+            mockTaskMemoryManager,
+            new ShuffleWriteMetrics(),
+            RssSparkConfig.toRssConf(conf),
+            null);
+
+    WriteBufferManager spyManager = spy(wbm);
+    doReturn(512L).when(spyManager).acquireMemory(anyLong());
+
+    String testKey = "Key";
+    String testValue = "Value";
+    spyManager.addRecord(0, testKey, testValue);
+    spyManager.addRecord(1, testKey, testValue);
+
+    // case1. if one thread wants to spill other consumers data, it will 
return 0
+    assertEquals(0, spyManager.spill(1000, mock(WriteBufferManager.class)));
+  }
+
+  @Test
+  public void spillByOwnTest() {
+    SparkConf conf = getConf();
+    conf.set("spark.rss.client.send.size.limit", "1000");
+    conf.set("spark.rss.client.memory.spill.enabled", "true");
+    TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
 
     WriteBufferManager wbm =
         new WriteBufferManager(
@@ -290,7 +323,20 @@ public class WriteBufferManagerTest {
             mockTaskMemoryManager,
             new ShuffleWriteMetrics(),
             RssSparkConfig.toRssConf(conf),
-            spillFunc);
+            null);
+
+    Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
+        blocks -> {
+          long sum = 0L;
+          List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
+          for (AddBlockEvent event : events) {
+            event.getProcessedCallbackChain().stream().forEach(x -> x.run());
+            sum += event.getShuffleDataInfoList().stream().mapToLong(x -> 
x.getFreeMemory()).sum();
+          }
+          return Arrays.asList(CompletableFuture.completedFuture(sum));
+        };
+    wbm.setSpillFunc(spillFunc);
+
     WriteBufferManager spyManager = spy(wbm);
     doReturn(512L).when(spyManager).acquireMemory(anyLong());
 
@@ -300,8 +346,9 @@ public class WriteBufferManagerTest {
     spyManager.addRecord(1, testKey, testValue);
 
     // case1. all events are flushed within normal time.
-    long releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class));
+    long releasedSize = spyManager.spill(1000, spyManager);
     assertEquals(64, releasedSize);
+    assertEquals(0, spyManager.getUsedBytes());
 
     // case2. partial events are not flushed within normal time.
     // when calling spill func, 2 events should be spilled. But
@@ -309,26 +356,133 @@ public class WriteBufferManagerTest {
     spyManager.setSendSizeLimit(30);
     spyManager.addRecord(0, testKey, testValue);
     spyManager.addRecord(1, testKey, testValue);
-    spyManager.setSpillFunc(
-        event ->
-            CompletableFuture.supplyAsync(
-                () -> {
-                  int partitionId = 
event.getShuffleDataInfoList().get(0).getPartitionId();
-                  if (partitionId == 1) {
-                    try {
-                      Thread.sleep(2000);
-                    } catch (InterruptedException interruptedException) {
-                      // ignore.
-                    }
-                  }
-                  event.getProcessedCallbackChain().stream().forEach(x -> 
x.run());
-                  return event.getShuffleDataInfoList().stream()
-                      .mapToLong(x -> x.getFreeMemory())
-                      .sum();
-                }));
-    releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class));
-    assertEquals(32, releasedSize);
-    assertEquals(32, spyManager.getUsedBytes());
-    Awaitility.await().timeout(3, TimeUnit.SECONDS).until(() -> 
spyManager.getUsedBytes() == 0);
+    spillFunc =
+        shuffleBlockInfos ->
+            Arrays.asList(
+                CompletableFuture.supplyAsync(
+                    () -> {
+                      List<AddBlockEvent> events = 
spyManager.buildBlockEvents(shuffleBlockInfos);
+                      long sum = 0L;
+                      for (AddBlockEvent event : events) {
+                        int partitionId = 
event.getShuffleDataInfoList().get(0).getPartitionId();
+                        if (partitionId == 1) {
+                          try {
+                            Thread.sleep(2000);
+                          } catch (InterruptedException interruptedException) {
+                            // ignore.
+                          }
+                        }
+                        event.getProcessedCallbackChain().stream().forEach(x 
-> x.run());
+                        sum +=
+                            event.getShuffleDataInfoList().stream()
+                                .mapToLong(x -> x.getFreeMemory())
+                                .sum();
+                      }
+                      return sum;
+                    }));
+    spyManager.setSpillFunc(spillFunc);
+    releasedSize = spyManager.spill(1000, spyManager);
+    assertEquals(0, releasedSize);
+    Awaitility.await().timeout(5, TimeUnit.SECONDS).until(() -> 
spyManager.getUsedBytes() == 0);
+  }
+
+  public static class FakedTaskMemoryManager extends TaskMemoryManager {
+    private static final Logger LOGGER = 
LoggerFactory.getLogger(FakedTaskMemoryManager.class);
+    private int invokedCnt = 0;
+    private int spilledCnt = 0;
+    private int bytesReturnFirstTime = 32;
+
+    public FakedTaskMemoryManager() {
+      super(mock(MemoryManager.class), 1);
+    }
+
+    public FakedTaskMemoryManager(int bytesReturnFirstTime) {
+      this();
+      this.bytesReturnFirstTime = bytesReturnFirstTime;
+    }
+
+    public long acquireExecutionMemory(long required, MemoryConsumer consumer) 
{
+      if (invokedCnt++ == 0) {
+        LOGGER.info("Return existing memory: {}", bytesReturnFirstTime);
+        return bytesReturnFirstTime;
+      }
+      try {
+        spilledCnt++;
+        long size = consumer.spill(required, consumer);
+        LOGGER.info("Return spilled memory: {}", size);
+        return size;
+      } catch (IOException e) {
+        return 0L;
+      }
+    }
+
+    public int getInvokedCnt() {
+      return invokedCnt;
+    }
+
+    public int getSpilledCnt() {
+      return spilledCnt;
+    }
+  }
+
+  @Test
+  public void spillByOwnWithSparkTaskMemoryManagerTest() {
+    SparkConf conf = getConf();
+    conf.set(RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key(), "32");
+    conf.set("spark.rss.client.send.size.limit", "1000");
+    conf.set("spark.rss.client.memory.spill.enabled", "true");
+    FakedTaskMemoryManager fakedTaskMemoryManager = new 
FakedTaskMemoryManager();
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+
+    WriteBufferManager wbm =
+        new WriteBufferManager(
+            0,
+            "taskId_spillTest",
+            0,
+            bufferOptions,
+            new KryoSerializer(conf),
+            Maps.newHashMap(),
+            fakedTaskMemoryManager,
+            new ShuffleWriteMetrics(),
+            RssSparkConfig.toRssConf(conf),
+            null);
+
+    List<ShuffleBlockInfo> blockList = new ArrayList<>();
+
+    Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
+        blocks -> {
+          blockList.addAll(blocks);
+          long sum = 0L;
+          List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
+          for (AddBlockEvent event : events) {
+            event.getProcessedCallbackChain().stream().forEach(x -> x.run());
+            sum += event.getShuffleDataInfoList().stream().mapToLong(x -> 
x.getFreeMemory()).sum();
+          }
+          return Arrays.asList(CompletableFuture.completedFuture(sum));
+        };
+    wbm.setSpillFunc(spillFunc);
+
+    WriteBufferManager spyManager = spy(wbm);
+
+    String testKey = "Key";
+    String testValue = "Value";
+
+    // First time, it request 32 bytes and then insert the record. It will not 
flush buffer.
+    spyManager.addRecord(0, testKey, testValue);
+    assertEquals(0, blockList.size());
+
+    // Second time, the memory manager trigger the spill, so it will flush 
buffer and then insert
+    // the record
+    spyManager.addRecord(1, testKey, testValue);
+    assertEquals(1, blockList.size());
+    assertEquals(32, blockList.stream().mapToLong(x -> 
x.getFreeMemory()).sum());
+
+    // Third time, it will still do above.
+    spyManager.addRecord(2, testKey, testValue);
+    assertEquals(2, blockList.size());
+    assertEquals(64, blockList.stream().mapToLong(x -> 
x.getFreeMemory()).sum());
+
+    assertEquals(3, fakedTaskMemoryManager.getInvokedCnt());
+    assertEquals(2, fakedTaskMemoryManager.getSpilledCnt());
   }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 90b7a8e3..a2c7edbb 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -25,7 +25,6 @@ import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
-import java.util.function.Function;
 
 import scala.Option;
 import scala.Tuple2;
@@ -44,10 +43,8 @@ import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
 import org.apache.spark.shuffle.writer.AddBlockEvent;
-import org.apache.spark.shuffle.writer.BufferManagerOptions;
 import org.apache.spark.shuffle.writer.DataPusher;
 import org.apache.spark.shuffle.writer.RssShuffleWriter;
-import org.apache.spark.shuffle.writer.WriteBufferManager;
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.storage.BlockManagerId;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -419,33 +416,19 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
       int shuffleId = rssHandle.getShuffleId();
       String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
-      BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
       ShuffleWriteMetrics writeMetrics = 
context.taskMetrics().shuffleWriteMetrics();
-      WriteBufferManager bufferManager =
-          new WriteBufferManager(
-              shuffleId,
-              taskId,
-              context.taskAttemptId(),
-              bufferOptions,
-              rssHandle.getDependency().serializer(),
-              rssHandle.getPartitionToServers(),
-              context.taskMemoryManager(),
-              writeMetrics,
-              RssSparkConfig.toRssConf(sparkConf),
-              this::sendData);
-
       return new RssShuffleWriter<>(
           rssHandle.getAppId(),
           shuffleId,
           taskId,
           context.taskAttemptId(),
-          bufferManager,
           writeMetrics,
           this,
           sparkConf,
           shuffleWriteClient,
           rssHandle,
-          (Function<String, Boolean>) this::markFailedTask);
+          this::markFailedTask,
+          context);
     } else {
       throw new RssException("Unexpected ShuffleHandle:" + 
handle.getClass().getName());
     }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index c0d5a54b..11f2dd3b 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -17,10 +17,13 @@
 
 package org.apache.spark.shuffle.writer;
 
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -40,6 +43,7 @@ import com.google.common.util.concurrent.Uninterruptibles;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
@@ -84,6 +88,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private long sendCheckInterval;
   private boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
+  private final Set<Long> blockIds = Sets.newConcurrentHashSet();
 
   public RssShuffleWriter(
       String appId,
@@ -101,21 +106,20 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleId,
         taskId,
         taskAttemptId,
-        bufferManager,
         shuffleWriteMetrics,
         shuffleManager,
         sparkConf,
         shuffleWriteClient,
         rssHandle,
         (tid) -> true);
+    this.bufferManager = bufferManager;
   }
 
-  public RssShuffleWriter(
+  private RssShuffleWriter(
       String appId,
       int shuffleId,
       String taskId,
       long taskAttemptId,
-      WriteBufferManager bufferManager,
       ShuffleWriteMetrics shuffleWriteMetrics,
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
@@ -123,7 +127,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback) {
     this.appId = appId;
-    this.bufferManager = bufferManager;
     this.shuffleId = shuffleId;
     this.taskId = taskId;
     this.taskAttemptId = taskAttemptId;
@@ -145,6 +148,45 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.taskFailureCallback = taskFailureCallback;
   }
 
+  public RssShuffleWriter(
+      String appId,
+      int shuffleId,
+      String taskId,
+      long taskAttemptId,
+      ShuffleWriteMetrics shuffleWriteMetrics,
+      RssShuffleManager shuffleManager,
+      SparkConf sparkConf,
+      ShuffleWriteClient shuffleWriteClient,
+      RssShuffleHandle<K, V, C> rssHandle,
+      Function<String, Boolean> taskFailureCallback,
+      TaskContext context) {
+    this(
+        appId,
+        shuffleId,
+        taskId,
+        taskAttemptId,
+        shuffleWriteMetrics,
+        shuffleManager,
+        sparkConf,
+        shuffleWriteClient,
+        rssHandle,
+        taskFailureCallback);
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
+    final WriteBufferManager bufferManager =
+        new WriteBufferManager(
+            shuffleId,
+            taskId,
+            taskAttemptId,
+            bufferOptions,
+            rssHandle.getDependency().serializer(),
+            rssHandle.getPartitionToServers(),
+            context.taskMemoryManager(),
+            shuffleWriteMetrics,
+            RssSparkConfig.toRssConf(sparkConf),
+            this::processShuffleBlockInfos);
+    this.bufferManager = bufferManager;
+  }
+
   private boolean isMemoryShuffleEnabled(String storageType) {
     return StorageType.withMemory(StorageType.valueOf(storageType));
   }
@@ -169,7 +211,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private void writeImpl(Iterator<Product2<K, V>> records) {
     List<ShuffleBlockInfo> shuffleBlockInfos;
-    Set<Long> blockIds = Sets.newHashSet();
     while (records.hasNext()) {
       Product2<K, V> record = records.next();
       int partition = getPartition(record._1());
@@ -180,12 +221,12 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       } else {
         shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), 
record._2());
       }
-      processShuffleBlockInfos(shuffleBlockInfos, blockIds);
+      processShuffleBlockInfos(shuffleBlockInfos);
     }
 
     final long start = System.currentTimeMillis();
     shuffleBlockInfos = bufferManager.clear();
-    processShuffleBlockInfos(shuffleBlockInfos, blockIds);
+    processShuffleBlockInfos(shuffleBlockInfos);
     long s = System.currentTimeMillis();
     checkBlockSendResult(blockIds);
     final long checkDuration = System.currentTimeMillis() - s;
@@ -221,10 +262,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
    * and shuffle reader will do the integration check with them
    *
    * @param shuffleBlockInfoList
-   * @param blockIds
    */
-  private void processShuffleBlockInfos(
-      List<ShuffleBlockInfo> shuffleBlockInfoList, Set<Long> blockIds) {
+  private List<CompletableFuture<Long>> processShuffleBlockInfos(
+      List<ShuffleBlockInfo> shuffleBlockInfoList) {
     if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) {
       shuffleBlockInfoList.stream()
           .forEach(
@@ -238,16 +278,20 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                     .computeIfAbsent(partitionId, k -> Sets.newHashSet())
                     .add(blockId);
               });
-      postBlockEvent(shuffleBlockInfoList);
+      return postBlockEvent(shuffleBlockInfoList);
     }
+    return Collections.emptyList();
   }
 
   // don't send huge block to shuffle server, or there will be OOM if shuffle 
sever receives data
   // more than expected
-  protected void postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
+  protected List<CompletableFuture<Long>> postBlockEvent(
+      List<ShuffleBlockInfo> shuffleBlockInfoList) {
+    List<CompletableFuture<Long>> futures = new ArrayList<>();
     for (AddBlockEvent event : 
bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
-      shuffleManager.sendData(event);
+      futures.add(shuffleManager.sendData(event));
     }
+    return futures;
   }
 
   @VisibleForTesting
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 674d6777..e1d535d3 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -27,7 +27,6 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import scala.Tuple2;
@@ -50,10 +49,8 @@ import org.apache.spark.executor.ShuffleReadMetrics;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
 import org.apache.spark.shuffle.writer.AddBlockEvent;
-import org.apache.spark.shuffle.writer.BufferManagerOptions;
 import org.apache.spark.shuffle.writer.DataPusher;
 import org.apache.spark.shuffle.writer.RssShuffleWriter;
-import org.apache.spark.shuffle.writer.WriteBufferManager;
 import org.apache.spark.sql.internal.SQLConf;
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.storage.BlockManagerId;
@@ -449,38 +446,26 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     setPusherAppId(rssHandle);
     int shuffleId = rssHandle.getShuffleId();
     String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
-    BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
+
     ShuffleWriteMetrics writeMetrics;
     if (metrics != null) {
       writeMetrics = new WriteMetrics(metrics);
     } else {
       writeMetrics = context.taskMetrics().shuffleWriteMetrics();
     }
-    WriteBufferManager bufferManager =
-        new WriteBufferManager(
-            shuffleId,
-            taskId,
-            context.taskAttemptId(),
-            bufferOptions,
-            rssHandle.getDependency().serializer(),
-            rssHandle.getPartitionToServers(),
-            context.taskMemoryManager(),
-            writeMetrics,
-            RssSparkConfig.toRssConf(sparkConf),
-            this::sendData);
     LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), 
rssHandle.getShuffleId());
     return new RssShuffleWriter<>(
         rssHandle.getAppId(),
         shuffleId,
         taskId,
         context.taskAttemptId(),
-        bufferManager,
         writeMetrics,
         this,
         sparkConf,
         shuffleWriteClient,
         rssHandle,
-        (Function<String, Boolean>) this::markFailedTask);
+        this::markFailedTask,
+        context);
   }
 
   public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 428157c3..fb0c7850 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -18,10 +18,13 @@
 package org.apache.spark.shuffle.writer;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -41,6 +44,7 @@ import com.google.common.util.concurrent.Uninterruptibles;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.shuffle.RssShuffleHandle;
@@ -65,7 +69,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private final String appId;
   private final int shuffleId;
-  private final WriteBufferManager bufferManager;
+  private WriteBufferManager bufferManager;
   private final String taskId;
   private final int numMaps;
   private final ShuffleDependency<K, V, C> shuffleDependency;
@@ -80,14 +84,17 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
   private final Set<ShuffleServerInfo> shuffleServersForData;
   private final long[] partitionLengths;
-  private boolean isMemoryShuffleEnabled;
+  private final boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
+  private final Set<Long> blockIds = Sets.newConcurrentHashSet();
 
   /** used by columnar rss shuffle writer implementation */
   protected final long taskAttemptId;
 
   protected final ShuffleWriteMetrics shuffleWriteMetrics;
 
+  // Only for tests
+  @VisibleForTesting
   public RssShuffleWriter(
       String appId,
       int shuffleId,
@@ -104,21 +111,20 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleId,
         taskId,
         taskAttemptId,
-        bufferManager,
         shuffleWriteMetrics,
         shuffleManager,
         sparkConf,
         shuffleWriteClient,
         rssHandle,
         (tid) -> true);
+    this.bufferManager = bufferManager;
   }
 
-  public RssShuffleWriter(
+  private RssShuffleWriter(
       String appId,
       int shuffleId,
       String taskId,
       long taskAttemptId,
-      WriteBufferManager bufferManager,
       ShuffleWriteMetrics shuffleWriteMetrics,
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
@@ -128,7 +134,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     LOG.warn("RssShuffle start write taskAttemptId data" + taskAttemptId);
     this.shuffleManager = shuffleManager;
     this.appId = appId;
-    this.bufferManager = bufferManager;
     this.shuffleId = shuffleId;
     this.taskId = taskId;
     this.taskAttemptId = taskAttemptId;
@@ -151,6 +156,45 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.taskFailureCallback = taskFailureCallback;
   }
 
+  public RssShuffleWriter(
+      String appId,
+      int shuffleId,
+      String taskId,
+      long taskAttemptId,
+      ShuffleWriteMetrics shuffleWriteMetrics,
+      RssShuffleManager shuffleManager,
+      SparkConf sparkConf,
+      ShuffleWriteClient shuffleWriteClient,
+      RssShuffleHandle<K, V, C> rssHandle,
+      Function<String, Boolean> taskFailureCallback,
+      TaskContext context) {
+    this(
+        appId,
+        shuffleId,
+        taskId,
+        taskAttemptId,
+        shuffleWriteMetrics,
+        shuffleManager,
+        sparkConf,
+        shuffleWriteClient,
+        rssHandle,
+        taskFailureCallback);
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
+    final WriteBufferManager bufferManager =
+        new WriteBufferManager(
+            shuffleId,
+            taskId,
+            taskAttemptId,
+            bufferOptions,
+            rssHandle.getDependency().serializer(),
+            rssHandle.getPartitionToServers(),
+            context.taskMemoryManager(),
+            shuffleWriteMetrics,
+            RssSparkConfig.toRssConf(sparkConf),
+            this::processShuffleBlockInfos);
+    this.bufferManager = bufferManager;
+  }
+
   private boolean isMemoryShuffleEnabled(String storageType) {
     return StorageType.withMemory(StorageType.valueOf(storageType));
   }
@@ -167,7 +211,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private void writeImpl(Iterator<Product2<K, V>> records) {
     List<ShuffleBlockInfo> shuffleBlockInfos;
-    Set<Long> blockIds = Sets.newHashSet();
     boolean isCombine = shuffleDependency.mapSideCombine();
     Function1<V, C> createCombiner = null;
     if (isCombine) {
@@ -187,13 +230,13 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), 
record._2());
       }
       if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
-        processShuffleBlockInfos(shuffleBlockInfos, blockIds);
+        processShuffleBlockInfos(shuffleBlockInfos);
       }
     }
     final long start = System.currentTimeMillis();
     shuffleBlockInfos = bufferManager.clear();
     if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
-      processShuffleBlockInfos(shuffleBlockInfos, blockIds);
+      processShuffleBlockInfos(shuffleBlockInfos);
     }
     long checkStartTs = System.currentTimeMillis();
     checkBlockSendResult(blockIds);
@@ -227,8 +270,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     return new long[0];
   }
 
-  protected void processShuffleBlockInfos(
-      List<ShuffleBlockInfo> shuffleBlockInfoList, Set<Long> blockIds) {
+  @VisibleForTesting
+  protected List<CompletableFuture<Long>> processShuffleBlockInfos(
+      List<ShuffleBlockInfo> shuffleBlockInfoList) {
     if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) {
       shuffleBlockInfoList.forEach(
           sbi -> {
@@ -240,14 +284,18 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             partitionToBlockIds.computeIfAbsent(partitionId, k -> 
Sets.newHashSet()).add(blockId);
             partitionLengths[partitionId] += sbi.getLength();
           });
-      postBlockEvent(shuffleBlockInfoList);
+      return postBlockEvent(shuffleBlockInfoList);
     }
+    return Collections.emptyList();
   }
 
-  protected void postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
+  protected List<CompletableFuture<Long>> postBlockEvent(
+      List<ShuffleBlockInfo> shuffleBlockInfoList) {
+    List<CompletableFuture<Long>> futures = new ArrayList<>();
     for (AddBlockEvent event : 
bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
-      shuffleManager.sendData(event);
+      futures.add(shuffleManager.sendData(event));
     }
+    return futures;
   }
 
   @VisibleForTesting
@@ -377,4 +425,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   Map<Integer, Set<Long>> getPartitionToBlockIds() {
     return partitionToBlockIds;
   }
+
+  @VisibleForTesting
+  public WriteBufferManager getBufferManager() {
+    return bufferManager;
+  }
 }
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 38895246..1be8a2a6 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -18,7 +18,9 @@
 package org.apache.spark.shuffle.writer;
 
 import java.time.Duration;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -36,7 +38,6 @@ import com.google.common.collect.Sets;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
-import org.apache.spark.SparkContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.memory.TaskMemoryManager;
@@ -79,8 +80,6 @@ public class RssShuffleWriterTest {
         .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
         .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.LOCALFILE.name())
         .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), 
"127.0.0.1:12345,127.0.0.1:12346");
-    // init SparkContext
-    final SparkContext sc = SparkContext.getOrCreate(conf);
     Map<String, Set<Long>> failBlocks = JavaUtils.newConcurrentMap();
     Map<String, Set<Long>> successBlocks = JavaUtils.newConcurrentMap();
     Serializer kryoSerializer = new KryoSerializer(conf);
@@ -149,8 +148,6 @@ public class RssShuffleWriterTest {
     assertTrue(e3.getMessage().startsWith("Send failed:"));
     successBlocks.clear();
     failBlocks.clear();
-
-    sc.stop();
   }
 
   static class FakedDataPusher extends DataPusher {
@@ -184,6 +181,99 @@ public class RssShuffleWriterTest {
     }
   }
 
+  @Test
+  public void dataConsistencyWhenSpillTriggeredTest() throws Exception {
+    SparkConf conf = new SparkConf();
+    conf.set("spark.rss.client.memory.spill.enabled", "true");
+    conf.setAppName("dataConsistencyWhenSpillTriggeredTest_app")
+        .setMaster("local[2]")
+        .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
+        .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "100000")
+        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY.name())
+        .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), 
"127.0.0.1:12345,127.0.0.1:12346");
+
+    Map<String, Set<Long>> successBlockIds = Maps.newConcurrentMap();
+
+    List<Long> freeMemoryList = new ArrayList<>();
+    FakedDataPusher dataPusher =
+        new FakedDataPusher(
+            event -> {
+              event.getProcessedCallbackChain().stream().forEach(x -> x.run());
+              long sum =
+                  event.getShuffleDataInfoList().stream().mapToLong(x -> 
x.getFreeMemory()).sum();
+              freeMemoryList.add(sum);
+              successBlockIds.putIfAbsent(event.getTaskId(), new HashSet<>());
+              successBlockIds
+                  .get(event.getTaskId())
+                  .add(event.getShuffleDataInfoList().get(0).getBlockId());
+              return CompletableFuture.completedFuture(sum);
+            });
+
+    final RssShuffleManager manager =
+        TestUtils.createShuffleManager(
+            conf, false, dataPusher, successBlockIds, 
JavaUtils.newConcurrentMap());
+
+    WriteBufferManagerTest.FakedTaskMemoryManager fakedTaskMemoryManager =
+        new WriteBufferManagerTest.FakedTaskMemoryManager();
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+    WriteBufferManager bufferManager =
+        new WriteBufferManager(
+            0,
+            "taskId",
+            0,
+            bufferOptions,
+            new KryoSerializer(conf),
+            Maps.newHashMap(),
+            fakedTaskMemoryManager,
+            new ShuffleWriteMetrics(),
+            RssSparkConfig.toRssConf(conf),
+            null);
+
+    Serializer kryoSerializer = new KryoSerializer(conf);
+    Partitioner mockPartitioner = mock(Partitioner.class);
+    final ShuffleWriteClient mockShuffleWriteClient = 
mock(ShuffleWriteClient.class);
+    ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
+    RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
+    when(mockHandle.getDependency()).thenReturn(mockDependency);
+    when(mockDependency.serializer()).thenReturn(kryoSerializer);
+    when(mockDependency.partitioner()).thenReturn(mockPartitioner);
+    when(mockPartitioner.numPartitions()).thenReturn(1);
+
+    RssShuffleWriter<String, String, String> rssShuffleWriter =
+        new RssShuffleWriter<>(
+            "appId",
+            0,
+            "taskId",
+            1L,
+            bufferManager,
+            new ShuffleWriteMetrics(),
+            manager,
+            conf,
+            mockShuffleWriteClient,
+            mockHandle);
+    
rssShuffleWriter.getBufferManager().setSpillFunc(rssShuffleWriter::processShuffleBlockInfos);
+
+    MutableList<Product2<String, String>> data = new MutableList<>();
+    // One record is 26 bytes
+    data.appendElem(new Tuple2<>("Key", "Value11111111111111"));
+    data.appendElem(new Tuple2<>("Key", "Value11111111111111"));
+    data.appendElem(new Tuple2<>("Key", "Value11111111111111"));
+    data.appendElem(new Tuple2<>("Key", "Value11111111111111"));
+
+    // case1: all blocks are sent and pass the blocks check when spill is 
triggered
+    rssShuffleWriter.write(data.iterator());
+    assertEquals(4, successBlockIds.get("taskId").size());
+    for (int i = 0; i < 4; i++) {
+      assertEquals(32, freeMemoryList.get(i));
+    }
+  }
+
   @Test
   public void writeTest() throws Exception {
     SparkConf conf = new SparkConf();
@@ -198,9 +288,7 @@ public class RssShuffleWriterTest {
         .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128")
         .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.LOCALFILE.name())
         .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), 
"127.0.0.1:12345,127.0.0.1:12346");
-    // init SparkContext
     List<ShuffleBlockInfo> shuffleBlockInfos = Lists.newArrayList();
-    final SparkContext sc = SparkContext.getOrCreate(conf);
     Map<String, Set<Long>> successBlockIds = Maps.newConcurrentMap();
 
     FakedDataPusher dataPusher =
@@ -331,7 +419,6 @@ public class RssShuffleWriterTest {
     assertEquals(2, partitionToBlockIds.get(0).size());
     assertEquals(2, partitionToBlockIds.get(2).size());
     partitionToBlockIds.clear();
-    sc.stop();
   }
 
   @Test

Reply via email to