This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 464553e27 [#1497] improvement(spark): flushing buffer if the 
memoryUsed of the first record of `WriterBuffer` larger than bufferSize (#1485)
464553e27 is described below

commit 464553e2736306d39dccab61b755793fd06993a3
Author: xianjingfeng <[email protected]>
AuthorDate: Thu Feb 1 14:16:49 2024 +0800

    [#1497] improvement(spark): flushing buffer if the memoryUsed of the first 
record of `WriterBuffer` larger than bufferSize (#1485)
    
    ### What changes were proposed in this pull request?
    
    Flushing buffer if the memoryUsed of the first record of `WriterBuffer` 
larger than bufferSize.
    
    ### Why are the changes needed?
    
    More accurate. fix #1497
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    UT.
---
 .../spark/shuffle/writer/WriteBufferManager.java   | 55 ++++++++++++----------
 .../shuffle/writer/WriteBufferManagerTest.java     | 13 +++++
 2 files changed, 43 insertions(+), 25 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index a76c04af3..5b9a92841 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -201,39 +201,25 @@ public class WriteBufferManager extends MemoryConsumer {
     // this may trigger current WriteBufferManager spill method, which will
     // make the current write buffer discard. So we have to recheck the buffer 
existence.
     boolean hasRequested = false;
-    if (buffers.containsKey(partitionId)) {
-      WriterBuffer wb = buffers.get(partitionId);
+    WriterBuffer wb = buffers.get(partitionId);
+    if (wb != null) {
       if (wb.askForMemory(serializedDataLength)) {
         requestMemory(required);
         hasRequested = true;
       }
     }
 
-    if (buffers.containsKey(partitionId)) {
+    // hasRequested is not true means spill method was not trigger,
+    // and we don't have to recheck the buffer existence in this case.
+    if (hasRequested) {
+      wb = buffers.get(partitionId);
+    }
+
+    if (wb != null) {
       if (hasRequested) {
         usedBytes.addAndGet(required);
       }
-      WriterBuffer wb = buffers.get(partitionId);
       wb.addRecord(serializedData, serializedDataLength);
-      if (wb.getMemoryUsed() > bufferSize) {
-        List<ShuffleBlockInfo> sentBlocks = new ArrayList<>(1);
-        sentBlocks.add(createShuffleBlock(partitionId, wb));
-        copyTime += wb.getCopyTime();
-        buffers.remove(partitionId);
-        if (LOG.isDebugEnabled()) {
-          LOG.debug(
-              "Single buffer is full for shuffleId["
-                  + shuffleId
-                  + "] partition["
-                  + partitionId
-                  + "] with memoryUsed["
-                  + wb.getMemoryUsed()
-                  + "], dataLength["
-                  + wb.getDataLength()
-                  + "]");
-        }
-        return sentBlocks;
-      }
     } else {
       // The true of hasRequested means the former partitioned buffer has been 
flushed, that is
       // triggered by the spill operation caused by asking for memory. So it 
needn't to re-request
@@ -242,11 +228,30 @@ public class WriteBufferManager extends MemoryConsumer {
         requestMemory(required);
       }
       usedBytes.addAndGet(required);
-
-      WriterBuffer wb = new WriterBuffer(bufferSegmentSize);
+      wb = new WriterBuffer(bufferSegmentSize);
       wb.addRecord(serializedData, serializedDataLength);
       buffers.put(partitionId, wb);
     }
+
+    if (wb.getMemoryUsed() > bufferSize) {
+      List<ShuffleBlockInfo> sentBlocks = new ArrayList<>(1);
+      sentBlocks.add(createShuffleBlock(partitionId, wb));
+      copyTime += wb.getCopyTime();
+      buffers.remove(partitionId);
+      if (LOG.isDebugEnabled()) {
+        LOG.debug(
+            "Single buffer is full for shuffleId["
+                + shuffleId
+                + "] partition["
+                + partitionId
+                + "] with memoryUsed["
+                + wb.getMemoryUsed()
+                + "], dataLength["
+                + wb.getDataLength()
+                + "]");
+      }
+      return sentBlocks;
+    }
     return Collections.emptyList();
   }
 
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 0e1f5e3fb..c0fb191dc 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -485,4 +485,17 @@ public class WriteBufferManagerTest {
     assertEquals(3, fakedTaskMemoryManager.getInvokedCnt());
     assertEquals(2, fakedTaskMemoryManager.getSpilledCnt());
   }
+
+  @Test
+  public void addFirstRecordWithLargeSizeTest() {
+    SparkConf conf = getConf();
+    WriteBufferManager wbm = createManager(conf);
+    String testKey = "key";
+    String testValue = "~~~~~~~~~~~~~~~~~~~~This is a long 
text~~~~~~~~~~~~~~~~~~~~";
+    List<ShuffleBlockInfo> shuffleBlockInfos = wbm.addRecord(0, testKey, 
testValue);
+    assertEquals(1, shuffleBlockInfos.size());
+    String testValue2 = "This is a short text";
+    List<ShuffleBlockInfo> shuffleBlockInfos2 = wbm.addRecord(1, testKey, 
testValue2);
+    assertEquals(0, shuffleBlockInfos2.size());
+  }
 }

Reply via email to