This is an automated email from the ASF dual-hosted git repository.

ckj pushed a commit to branch branch-0.6
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git

commit 00860784dffa393854b1f6fb445b5003ded7d6e9
Author: Zhen Wang <[email protected]>
AuthorDate: Thu Nov 3 19:31:30 2022 +0800

    Fix NPE in WriteBufferManager.addRecord (#296)
    
    ### What changes were proposed in this pull request?
    To handle the NPE of record value to be consistent with Spark
    
    ### Why are the changes needed?
    Fix NPE:
    
    ```
    22/11/03 03:36:08 ERROR Executor: Exception in task 4.0 in stage 1.0 (TID 5)
    java.lang.NullPointerException
            at 
org.apache.spark.shuffle.writer.WriteBufferManager.addRecord(WriteBufferManager.java:116)
            at 
org.apache.spark.shuffle.writer.RssShuffleWriter.write(RssShuffleWriter.java:152)
            at 
org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
            at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
            at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
            at org.apache.spark.scheduler.Task.run(Task.scala:131)
            at 
org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
            at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
            at 
org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
            at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
            at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
            at java.lang.Thread.run(Thread.java:748)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    UTs
---
 .../apache/spark/shuffle/writer/WriteBufferManager.java    |  7 ++++++-
 .../spark/shuffle/writer/WriteBufferManagerTest.java       | 14 ++++++++++++++
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index ffb6000b..e7aef914 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -35,6 +35,7 @@ import org.apache.spark.serializer.SerializerInstance;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.reflect.ClassTag$;
+import scala.reflect.ManifestFactory$;
 
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.RssShuffleUtils;
@@ -108,7 +109,11 @@ public class WriteBufferManager extends MemoryConsumer {
     final long start = System.currentTimeMillis();
     arrayOutputStream.reset();
     serializeStream.writeKey(key, ClassTag$.MODULE$.apply(key.getClass()));
-    serializeStream.writeValue(value, 
ClassTag$.MODULE$.apply(value.getClass()));
+    if (value != null) {
+      serializeStream.writeValue(value, 
ClassTag$.MODULE$.apply(value.getClass()));
+    } else {
+      serializeStream.writeValue(null, ManifestFactory$.MODULE$.Null());
+    }
     serializeStream.flush();
     serializeTime += System.currentTimeMillis() - start;
     byte[] serializedData = arrayOutputStream.getBuf();
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 665f5d2d..47d2b2a5 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -138,6 +138,20 @@ public class WriteBufferManagerTest {
     assertEquals(1, wbm.getBuffers().size());
   }
 
+  @Test
+  public void addNullValueRecordTest() {
+    SparkConf conf = getConf();
+    WriteBufferManager wbm = createManager(conf);
+    String testKey = "key";
+    String testValue = null;
+    List<ShuffleBlockInfo> result = wbm.addRecord(0, testKey, testValue);
+    assertEquals(0, result.size());
+    assertEquals(512, wbm.getAllocatedBytes());
+    assertEquals(32, wbm.getUsedBytes());
+    assertEquals(0, wbm.getInSendListBytes());
+    assertEquals(1, wbm.getBuffers().size());
+  }
+
   @Test
   public void createBlockIdTest() {
     SparkConf conf = getConf();

Reply via email to