FelixYBW commented on code in PR #6480:
URL: https://github.com/apache/incubator-gluten/pull/6480#discussion_r1702488583


##########
shims/spark32/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java:
##########
@@ -0,0 +1,928 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util.collection.unsafe.sort;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.SparkOutOfMemoryError;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TooLargePageException;
+import org.apache.spark.serializer.SerializerManager;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+/** External sorter based on {@link UnsafeInMemorySorter}. */
+public final class UnsafeExternalSorter extends MemoryConsumer {
+
+  private static final Logger logger = 
LoggerFactory.getLogger(UnsafeExternalSorter.class);
+
+  @Nullable private final PrefixComparator prefixComparator;
+
+  /**
+   * {@link RecordComparator} may probably keep the reference to the records 
they compared last
+   * time, so we should not keep a {@link RecordComparator} instance inside 
{@link
+   * UnsafeExternalSorter}, because {@link UnsafeExternalSorter} is referenced 
by {@link
+   * TaskContext} and thus can not be garbage collected until the end of the 
task.
+   */
+  @Nullable private final Supplier<RecordComparator> recordComparatorSupplier;
+
+  private final TaskMemoryManager taskMemoryManager;
+  private final BlockManager blockManager;
+  private final SerializerManager serializerManager;
+  private final TaskContext taskContext;
+
+  /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+  private final int fileBufferSizeBytes;
+
+  /** Force this sorter to spill when there are this many elements in memory. 
*/
+  private final int numElementsForSpillThreshold;
+
+  /**
+   * Memory pages that hold the records being sorted. The pages in this list 
are freed when
+   * spilling, although in principle we could recycle these pages across 
spills (on the other hand,
+   * this might not be necessary if we maintained a pool of re-usable pages in 
the TaskMemoryManager
+   * itself).
+   */
+  private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<>();
+
+  private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new 
LinkedList<>();
+
+  // These variables are reset after spilling:
+  @Nullable private volatile UnsafeInMemorySorter inMemSorter;
+
+  private MemoryBlock currentPage = null;
+  private long pageCursor = -1;
+  private long peakMemoryUsedBytes = 0;
+  private long totalSpillBytes = 0L;
+  private long totalSortTimeNanos = 0L;
+  private volatile SpillableIterator readingIterator = null;
+
+  public static UnsafeExternalSorter createWithExistingInMemorySorter(
+      TaskMemoryManager taskMemoryManager,
+      BlockManager blockManager,
+      SerializerManager serializerManager,
+      TaskContext taskContext,
+      Supplier<RecordComparator> recordComparatorSupplier,
+      PrefixComparator prefixComparator,
+      int initialSize,
+      long pageSizeBytes,
+      int numElementsForSpillThreshold,
+      UnsafeInMemorySorter inMemorySorter,
+      long existingMemoryConsumption)
+      throws IOException {
+    UnsafeExternalSorter sorter =
+        new UnsafeExternalSorter(
+            taskMemoryManager,
+            blockManager,
+            serializerManager,
+            taskContext,
+            recordComparatorSupplier,
+            prefixComparator,
+            initialSize,
+            pageSizeBytes,
+            numElementsForSpillThreshold,
+            inMemorySorter,
+            false /* ignored */);
+    sorter.spill(Long.MAX_VALUE, sorter);
+    taskContext.taskMetrics().incMemoryBytesSpilled(existingMemoryConsumption);
+    sorter.totalSpillBytes += existingMemoryConsumption;
+    // The external sorter will be used to insert records, in-memory sorter is 
not needed.
+    sorter.inMemSorter = null;
+    return sorter;
+  }
+
+  public static UnsafeExternalSorter create(
+      TaskMemoryManager taskMemoryManager,
+      BlockManager blockManager,
+      SerializerManager serializerManager,
+      TaskContext taskContext,
+      Supplier<RecordComparator> recordComparatorSupplier,
+      PrefixComparator prefixComparator,
+      int initialSize,
+      long pageSizeBytes,
+      int numElementsForSpillThreshold,
+      boolean canUseRadixSort) {
+    return new UnsafeExternalSorter(
+        taskMemoryManager,
+        blockManager,
+        serializerManager,
+        taskContext,
+        recordComparatorSupplier,
+        prefixComparator,
+        initialSize,
+        pageSizeBytes,
+        numElementsForSpillThreshold,
+        null,
+        canUseRadixSort);
+  }
+
+  private UnsafeExternalSorter(
+      TaskMemoryManager taskMemoryManager,
+      BlockManager blockManager,
+      SerializerManager serializerManager,
+      TaskContext taskContext,
+      Supplier<RecordComparator> recordComparatorSupplier,
+      PrefixComparator prefixComparator,
+      int initialSize,
+      long pageSizeBytes,
+      int numElementsForSpillThreshold,
+      @Nullable UnsafeInMemorySorter existingInMemorySorter,
+      boolean canUseRadixSort) {
+    super(taskMemoryManager, pageSizeBytes, 
taskMemoryManager.getTungstenMemoryMode());
+    this.taskMemoryManager = taskMemoryManager;
+    this.blockManager = blockManager;
+    this.serializerManager = serializerManager;
+    this.taskContext = taskContext;
+    this.recordComparatorSupplier = recordComparatorSupplier;
+    this.prefixComparator = prefixComparator;
+    // Use getSizeAsKb (not bytes) to maintain backwards compatibility for 
units
+    // this.fileBufferSizeBytes = (int) 
conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024
+    this.fileBufferSizeBytes = 32 * 1024;
+
+    if (existingInMemorySorter == null) {
+      RecordComparator comparator = null;
+      if (recordComparatorSupplier != null) {
+        comparator = recordComparatorSupplier.get();
+      }
+      this.inMemSorter =
+          new UnsafeInMemorySorter(
+              this, taskMemoryManager, comparator, prefixComparator, 
initialSize, canUseRadixSort);
+    } else {
+      this.inMemSorter = existingInMemorySorter;
+    }
+    this.peakMemoryUsedBytes = getMemoryUsage();
+    this.numElementsForSpillThreshold = numElementsForSpillThreshold;
+
+    // Register a cleanup task with TaskContext to ensure that memory is 
guaranteed to be freed at
+    // the end of the task. This is necessary to avoid memory leaks in when 
the downstream operator
+    // does not fully consume the sorter's output (e.g. sort followed by 
limit).
+    taskContext.addTaskCompletionListener(
+        context -> {
+          cleanupResources();
+        });
+  }
+
+  /**
+   * Marks the current page as no-more-space-available, and as a result, 
either allocate a new page
+   * or spill when we see the next record.
+   */
+  @VisibleForTesting
+  public void closeCurrentPage() {
+    if (currentPage != null) {
+      pageCursor = currentPage.getBaseOffset() + currentPage.size();
+    }
+  }
+
+  @Override
+  public long forceSpill(long size, MemoryConsumer trigger) throws IOException 
{
+    if (trigger != this && readingIterator != null) {
+      return readingIterator.spill();
+    }
+
+    if (inMemSorter == null || inMemSorter.numRecords() <= 0) {
+      // There could still be some memory allocated when there are no records 
in the in-memory
+      // sorter. We will not spill it however, to ensure that we can always 
process at least one
+      // record before spilling. See the comments in 
`allocateMemoryForRecordIfNecessary` for why
+      // this is necessary.
+      return 0L;
+    }
+
+    logger.info(
+        "Thread {} force spilling sort data of {} to disk ({} {} so far)",
+        Thread.currentThread().getId(),
+        Utils.bytesToString(getMemoryUsage()),
+        spillWriters.size(),
+        spillWriters.size() > 1 ? " times" : " time");
+
+    ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
+
+    final UnsafeSorterSpillWriter spillWriter =
+        new UnsafeSorterSpillWriter(
+            blockManager, fileBufferSizeBytes, writeMetrics, 
inMemSorter.numRecords());
+    spillWriters.add(spillWriter);
+    spillIterator(inMemSorter.getSortedIterator(), spillWriter);
+
+    final long spillSize = freeMemory();
+    // Note that this is more-or-less going to be a multiple of the page size, 
so wasted space in
+    // pages will currently be counted as memory spilled even though that 
space isn't actually
+    // written to disk. This also counts the space needed to store the 
sorter's pointer array.
+    inMemSorter.freeMemory();
+    // Reset the in-memory sorter's pointer array only after freeing up the 
memory pages holding the
+    // records. Otherwise, if the task is over allocated memory, then without 
freeing the memory
+    // pages, we might not be able to get memory for the pointer array.
+
+    taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+    taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten());
+    totalSpillBytes += spillSize;
+    return spillSize;
+  }
+
+  /** Sort and spill the current records in response to memory pressure. */
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this) {
+      if (readingIterator != null) {
+        return readingIterator.spill();
+      }
+      return 0L; // this should throw exception
+    }
+
+    if (inMemSorter == null || inMemSorter.numRecords() <= 0) {
+      // There could still be some memory allocated when there are no records 
in the in-memory
+      // sorter. We will not spill it however, to ensure that we can always 
process at least one
+      // record before spilling. See the comments in 
`allocateMemoryForRecordIfNecessary` for why
+      // this is necessary.
+      return 0L;
+    }
+
+    logger.info(
+        "Thread {} spilling sort data of {} to disk ({} {} so far)",
+        Thread.currentThread().getId(),
+        Utils.bytesToString(getMemoryUsage()),
+        spillWriters.size(),
+        spillWriters.size() > 1 ? " times" : " time");
+
+    ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
+
+    final UnsafeSorterSpillWriter spillWriter =
+        new UnsafeSorterSpillWriter(
+            blockManager, fileBufferSizeBytes, writeMetrics, 
inMemSorter.numRecords());
+    logger.warn("inMemSorter numRecords: " + inMemSorter.numRecords());
+    spillWriters.add(spillWriter);
+    UnsafeSorterIterator iterator = inMemSorter.getSortedIterator();

Review Comment:
   @jinchengchenghh The error is caused by the mismatch of 
inMemSorter.numRecords() vs. iterator.getNumRecords(). What's the reason?
   
   ```
   24/08/03 05:22:17 INFO [Thread-14] sort.UnsafeExternalSorter: Thread 83 
force spilling sort data of 12.3 GiB to disk (1  time so far)
   24/08/03 05:22:17 WARN [Thread-14] sort.UnsafeExternalSorter: 
UnsafeSorterSpillWriter numRecordsToWrite : 48823146
   24/08/03 05:23:38 WARN [Executor task launch worker for task 421.0 in stage 
0.0 (TID 421)] sort.UnsafeExternalSorter: inMemSorter numRecords: 44739242
   24/08/03 05:23:38 WARN [Executor task launch worker for task 421.0 in stage 
0.0 (TID 421)] sort.UnsafeExternalSorter: iterator numRecords: 44739242
   24/08/03 05:23:38 WARN [Executor task launch worker for task 421.0 in stage 
0.0 (TID 421)] sort.UnsafeExternalSorter: inMemSorter SortedIterator length 
44739242
   24/08/03 05:23:38 WARN [Executor task launch worker for task 421.0 in stage 
0.0 (TID 421)] sort.UnsafeExternalSorter: inMemIterator size: 44739242
   24/08/03 05:23:40 WARN [Executor task launch worker for task 423.0 in stage 
0.0 (TID 423)] sort.UnsafeExternalSorter: inMemSorter numRecords: 44739242
   24/08/03 05:23:40 WARN [Executor task launch worker for task 423.0 in stage 
0.0 (TID 423)] sort.UnsafeExternalSorter: iterator numRecords: 44739242
   24/08/03 05:23:40 WARN [Executor task launch worker for task 423.0 in stage 
0.0 (TID 423)] sort.UnsafeExternalSorter: inMemSorter SortedIterator length 
44739242
   24/08/03 05:23:40 WARN [Executor task launch worker for task 423.0 in stage 
0.0 (TID 423)] sort.UnsafeExternalSorter: inMemIterator size: 44739242
   24/08/03 05:23:40 WARN [Executor task launch worker for task 420.0 in stage 
0.0 (TID 420)] sort.UnsafeExternalSorter: inMemSorter numRecords: 44739242
   24/08/03 05:23:40 WARN [Executor task launch worker for task 420.0 in stage 
0.0 (TID 420)] sort.UnsafeExternalSorter: iterator numRecords: 44739242
   24/08/03 05:23:40 WARN [Executor task launch worker for task 420.0 in stage 
0.0 (TID 420)] sort.UnsafeExternalSorter: inMemSorter SortedIterator length 
44739242
   24/08/03 05:23:40 WARN [Executor task launch worker for task 420.0 in stage 
0.0 (TID 420)] sort.UnsafeExternalSorter: inMemIterator size: 44739242
   24/08/03 05:23:47 WARN [Thread-14] sort.UnsafeExternalSorter: inMemIterator 
size: 49009940
   24/08/03 05:24:09 ERROR [Thread-14] listener.ManagedReservationListener: 
Error reserving memory from target
   java.lang.IllegalStateException: Number of records written exceeded 
numRecordsToWrite = 48823146
        at 
org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.write(UnsafeSorterSpillWriter.java:118)
        at 
org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spillIterator(UnsafeExternalSorter.java:643)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to