This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new e47ec10ce [CELEBORN-783] Revise the conditions for the 
`SortBasedPusher#insertRecord` method
e47ec10ce is described below

commit e47ec10cefacf4fc5e59630359350bf437a8e020
Author: Fu Chen <[email protected]>
AuthorDate: Tue Jul 11 11:36:29 2023 +0800

    [CELEBORN-783] Revise the conditions for the `SortBasedPusher#insertRecord` 
method
    
    ### What changes were proposed in this pull request?
    
    As title
    
    ### Why are the changes needed?
    
    
[comment](https://github.com/apache/incubator-celeborn/commit/7adf1fca41cd34eb3e9ddf5cc8cd24674ab585b4?notification_referrer_id=NT_kwDOAIJHFbI3MDEyMzc0NzkwOjg1Mzc4Nzc#r121138008)
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New UT
    
    Closes #1699 from cfmcgrady/insert-record.
    
    Authored-by: Fu Chen <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../spark/shuffle/celeborn/SortBasedPusher.java    |  17 ++-
 .../shuffle/celeborn/SortBasedPusherSuiteJ.java    | 139 +++++++++++++++++++++
 2 files changed, 147 insertions(+), 9 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
index 92d5555d3..330231387 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
@@ -57,7 +57,6 @@ public class SortBasedPusher extends MemoryConsumer {
   private final int pushBufferMaxSize;
   private final long pushSortMemoryThreshold;
   final int uaoSize = UnsafeAlignedOffset.getUaoSize();
-  static final long bytes8K = Utils.byteStringAsBytes("8k");
 
   String appId;
   int shuffleId;
@@ -221,9 +220,16 @@ public class SortBasedPusher extends MemoryConsumer {
   public boolean insertRecord(
       Object recordBase, long recordOffset, int recordSize, int partitionId, 
boolean copySize)
       throws IOException {
+    int required;
+    // Need 4 or 8 bytes to store the record recordSize.
+    if (copySize) {
+      required = recordSize + 4 + uaoSize;
+    } else {
+      required = recordSize + uaoSize;
+    }
 
     if (getUsed() > pushSortMemoryThreshold
-        && pageCursor + bytes8K > currentPage.getBaseOffset() + 
currentPage.size()) {
+        && pageCursor + required > currentPage.getBaseOffset() + 
currentPage.size()) {
       logger.info(
           "Memory used {} exceeds threshold {}, need to trigger push. 
currentPage size: {}",
           Utils.bytesToString(getUsed()),
@@ -232,13 +238,6 @@ public class SortBasedPusher extends MemoryConsumer {
       return false;
     }
 
-    int required;
-    // Need 4 or 8 bytes to store the record recordSize.
-    if (copySize) {
-      required = recordSize + 4 + uaoSize;
-    } else {
-      required = recordSize + uaoSize;
-    }
     allocateMemoryForRecordIfNecessary(required);
 
     assert (currentPage != null);
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
new file mode 100644
index 000000000..203ca46fa
--- /dev/null
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.UUID;
+
+import scala.collection.mutable.ListBuffer;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.UnifiedMemoryManager;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.types.BinaryType$;
+import org.apache.spark.sql.types.DataType;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.DummyShuffleClient;
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.util.Utils;
+
+public class SortBasedPusherSuiteJ {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(SortBasedPusherSuiteJ.class);
+  private final SparkConf sparkConf = new 
SparkConf(false).set("spark.buffer.pageSize", "2m");
+
+  private final CelebornConf conf = new CelebornConf();
+
+  private final UnifiedMemoryManager unifiedMemoryManager =
+      UnifiedMemoryManager.apply(sparkConf, 1);
+  private final TaskMemoryManager taskMemoryManager =
+      new TaskMemoryManager(unifiedMemoryManager, 0);
+
+  private final File tempFile = new File(tempDir, 
UUID.randomUUID().toString());
+  private static File tempDir = null;
+
+  @BeforeClass
+  public static void beforeAll() {
+    tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), 
"celeborn_test");
+  }
+
+  @AfterClass
+  public static void afterAll() {
+    try {
+      JavaUtils.deleteRecursively(tempDir);
+    } catch (IOException e) {
+      LOG.error("Failed to delete temporary directory.", e);
+    }
+  }
+
+  @Test
+  public void testMemoryUsage() throws Exception {
+    final ShuffleClient client = new DummyShuffleClient(conf, tempFile);
+    SortBasedPusher pusher =
+        new SortBasedPusher(
+            taskMemoryManager,
+            /*shuffleClient=*/ client,
+            /*appId=*/ null,
+            /*shuffleId=*/ 0,
+            /*mapId=*/ 0,
+            /*attemptNumber=*/ 0,
+            /*taskAttemptId=*/ 0,
+            /*numMappers=*/ 0,
+            /*numPartitions=*/ 0,
+            conf,
+            /*afterPush=*/ null,
+            /*mapStatusLengths=*/ null,
+            /*pushSortMemoryThreshold=*/ Utils.byteStringAsBytes("1m"),
+            /*sharedPushLock=*/ null,
+            /*executorService=*/ null);
+
+    // default page size == 2 MiB
+    assertEquals(unifiedMemoryManager.pageSizeBytes(), 
Utils.byteStringAsBytes("2m"));
+
+    UnsafeRow row9k = genUnsafeRow(1024 * 9); // 9232 B
+    assertEquals(row9k.getSizeInBytes(), 9232);
+
+    // if uao = 4, total write size = (9232 B + 4 + 4) * 226 = 2088240 B = 
2039.3 KiB
+    // if uao = 8, total write size = (9232 B + 4 + 8) * 226 = 2089144 B = 
2040.2 KiB
+    for (int i = 0; i < 226; i++) {
+      assertTrue(
+          pusher.insertRecord(
+              row9k.getBaseObject(), row9k.getBaseOffset(), 
row9k.getSizeInBytes(), 0, true));
+    }
+    // total used memory: sum(pusher.allocatedPages.size()) + 
pusher.inMemSorter = 2m + 1m = 3m
+    assertEquals(pusher.getUsed(), Utils.byteStringAsBytes("3m"));
+    // there is not enough space to write a new 9k row
+    assertTrue(
+        !pusher.insertRecord(
+            row9k.getBaseObject(), row9k.getBaseOffset(), 
row9k.getSizeInBytes(), 0, true));
+
+    UnsafeRow row5k = genUnsafeRow(1024 * 5);
+    assertTrue(
+        pusher.insertRecord(
+            row5k.getBaseObject(), row5k.getBaseOffset(), 
row5k.getSizeInBytes(), 0, true));
+    assertTrue(
+        !pusher.insertRecord(
+            row5k.getBaseObject(), row5k.getBaseOffset(), 
row5k.getSizeInBytes(), 0, true));
+
+    pusher.close();
+  }
+
+  private static UnsafeRow genUnsafeRow(int size) {
+    ListBuffer<Object> values = new ListBuffer<>();
+    byte[] bytes = new byte[size];
+    values.$plus$eq(bytes);
+    InternalRow row = InternalRow.apply(values.toSeq());
+    DataType[] types = new DataType[1];
+    types[0] = BinaryType$.MODULE$;
+    return UnsafeProjection.create(types).apply(row);
+  }
+}

Reply via email to