vanzin commented on a change in pull request #25304: 
[SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API.
URL: https://github.com/apache/spark/pull/25304#discussion_r322385002
 
 

 ##########
 File path: 
core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
 ##########
 @@ -281,137 +259,161 @@ void forceSorterToSpill() throws IOException {
    *
    * @return the partition lengths in the merged file.
    */
-  private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws 
IOException {
+  private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+    long[] partitionLengths;
+    if (spills.length == 0) {
+      final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
+          .createMapOutputWriter(
+              shuffleId,
+              mapId,
+              taskContext.taskAttemptId(),
+              partitioner.numPartitions());
+      return mapWriter.commitAllPartitions();
+    } else if (spills.length == 1) {
+      Optional<SingleSpillShuffleMapOutputWriter> maybeSingleFileWriter =
+          shuffleExecutorComponents.createSingleFileMapOutputWriter(
+              shuffleId, mapId, taskContext.taskAttemptId());
+      if (maybeSingleFileWriter.isPresent()) {
+        // Here, we don't need to perform any metrics updates because the 
bytes written to this
+        // output file would have already been counted as shuffle bytes 
written.
+        partitionLengths = spills[0].partitionLengths;
+        maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, 
partitionLengths);
+      } else {
+        partitionLengths = mergeSpillsUsingStandardWriter(spills);
+      }
+    } else {
+      partitionLengths = mergeSpillsUsingStandardWriter(spills);
+    }
+    return partitionLengths;
+  }
+
+  private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws 
IOException {
+    long[] partitionLengths;
     final boolean compressionEnabled = (boolean) 
sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS());
     final CompressionCodec compressionCodec = 
CompressionCodec$.MODULE$.createCodec(sparkConf);
     final boolean fastMergeEnabled =
-      (boolean) 
sparkConf.get(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE());
+        (boolean) 
sparkConf.get(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE());
 
 Review comment:
   You didn't really need to touch these lines, but since you did, there's a 
typo in that constant's name.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to